diff --git a/poetry.lock b/poetry.lock index ad0d2cf..362f1d4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,10 +1,9 @@ -# This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "aiohttp" version = "3.8.5" description = "Async http client/server framework (asyncio)" -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -113,7 +112,6 @@ speedups = ["Brotli", "aiodns", "cchardet"] name = "aiosignal" version = "1.3.1" description = "aiosignal: a list of registered asynchronous callbacks" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -128,7 +126,6 @@ frozenlist = ">=1.1.0" name = "annotated-types" version = "0.5.0" description = "Reusable constraint types to use with typing.Annotated" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -140,7 +137,6 @@ files = [ name = "async-timeout" version = "4.0.3" description = "Timeout context manager for asyncio programs" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -152,7 +148,6 @@ files = [ name = "attrs" version = "23.1.0" description = "Classes Without Boilerplate" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -171,7 +166,6 @@ tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pyte name = "black" version = "23.7.0" description = "The uncompromising code formatter." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -216,7 +210,6 @@ uvloop = ["uvloop (>=0.15.2)"] name = "certifi" version = "2023.7.22" description = "Python package for providing Mozilla's CA Bundle." -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -228,7 +221,6 @@ files = [ name = "charset-normalizer" version = "3.2.0" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." -category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -313,7 +305,6 @@ files = [ name = "click" version = "8.1.6" description = "Composable command line interface toolkit" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -328,7 +319,6 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} name = "colorama" version = "0.4.6" description = "Cross-platform colored terminal text." -category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" files = [ @@ -340,7 +330,6 @@ files = [ name = "frozenlist" version = "1.4.0" description = "A list-like structure which implements collections.abc.MutableSequence" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -411,7 +400,6 @@ files = [ name = "idna" version = "3.4" description = "Internationalized Domain Names in Applications (IDNA)" -category = "main" optional = false python-versions = ">=3.5" files = [ @@ -423,7 +411,6 @@ files = [ name = "iniconfig" version = "2.0.0" description = "brain-dead simple config-ini parsing" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -435,7 +422,6 @@ files = [ name = "isort" version = "5.12.0" description = "A Python utility / library to sort Python imports." -category = "dev" optional = false python-versions = ">=3.8.0" files = [ @@ -453,7 +439,6 @@ requirements-deprecated-finder = ["pip-api", "pipreqs"] name = "markdown-it-py" version = "3.0.0" description = "Python port of markdown-it. Markdown parsing, done right!" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -478,7 +463,6 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] name = "mdurl" version = "0.1.2" description = "Markdown URL utilities" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -490,7 +474,6 @@ files = [ name = "multidict" version = "6.0.4" description = "multidict implementation" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -574,7 +557,6 @@ files = [ name = "mypy" version = "1.5.0" description = "Optional static typing for Python" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -615,7 +597,6 @@ reports = ["lxml"] name = "mypy-extensions" version = "1.0.0" description = "Type system extensions for programs checked with the mypy type checker." -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -627,7 +608,6 @@ files = [ name = "openai" version = "0.27.8" description = "Python client library for the OpenAI API" -category = "main" optional = false python-versions = ">=3.7.1" files = [ @@ -642,7 +622,7 @@ tqdm = "*" [package.extras] datalib = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] -dev = ["black (>=21.6b0,<22.0)", "pytest (>=6.0.0,<7.0.0)", "pytest-asyncio", "pytest-mock"] +dev = ["black (>=21.6b0,<22.0)", "pytest (==6.*)", "pytest-asyncio", "pytest-mock"] embeddings = ["matplotlib", "numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "plotly", "scikit-learn (>=1.0.2)", "scipy", "tenacity (>=8.0.1)"] wandb = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "wandb"] @@ -650,7 +630,6 @@ wandb = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1 name = "packaging" version = "23.1" description = "Core utilities for Python packages" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -662,7 +641,6 @@ files = [ name = "pathspec" version = "0.11.2" description = "Utility library for gitignore style pattern matching of file paths." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -674,7 +652,6 @@ files = [ name = "platformdirs" version = "3.10.0" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -690,7 +667,6 @@ test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-co name = "pluggy" version = "1.2.0" description = "plugin and hook calling mechanisms for python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -706,7 +682,6 @@ testing = ["pytest", "pytest-benchmark"] name = "prompt-toolkit" version = "3.0.39" description = "Library for building powerful interactive command lines in Python" -category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -721,7 +696,6 @@ wcwidth = "*" name = "pydantic" version = "2.1.1" description = "Data validation using Python type hints" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -741,7 +715,6 @@ email = ["email-validator (>=2.0.0)"] name = "pydantic-core" version = "2.4.0" description = "" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -855,7 +828,6 @@ typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" name = "pydantic-settings" version = "2.0.2" description = "Settings management using Pydantic" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -871,7 +843,6 @@ python-dotenv = ">=0.21.0" name = "pydocstyle" version = "6.3.0" description = "Python docstring style checker" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -889,7 +860,6 @@ toml = ["tomli (>=1.2.3)"] name = "pygments" version = "2.16.1" description = "Pygments is a syntax highlighting package written in Python." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -904,7 +874,6 @@ plugins = ["importlib-metadata"] name = "pytest" version = "7.4.0" description = "pytest: simple powerful testing with Python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -925,7 +894,6 @@ testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "no name = "python-dotenv" version = "1.0.0" description = "Read key-value pairs from a .env file and set them as environment variables" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -940,7 +908,6 @@ cli = ["click (>=5.0)"] name = "requests" version = "2.31.0" description = "Python HTTP for Humans." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -962,7 +929,6 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] name = "rich" version = "13.5.2" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" -category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -981,7 +947,6 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"] name = "ruff" version = "0.0.284" description = "An extremely fast Python linter, written in Rust." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1008,7 +973,6 @@ files = [ name = "shellingham" version = "1.5.3" description = "Tool to Detect Surrounding Shell" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1020,7 +984,6 @@ files = [ name = "snowballstemmer" version = "2.2.0" description = "This package provides 29 stemmers for 28 languages generated from Snowball algorithms." -category = "dev" optional = false python-versions = "*" files = [ @@ -1032,7 +995,6 @@ files = [ name = "tqdm" version = "4.66.1" description = "Fast, Extensible Progress Meter" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1053,7 +1015,6 @@ telegram = ["requests"] name = "typer" version = "0.9.0" description = "Typer, build great CLIs. Easy to code. Based on Python type hints." -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -1078,7 +1039,6 @@ test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6. name = "typing-extensions" version = "4.7.1" description = "Backported and Experimental Type Hints for Python 3.7+" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1090,7 +1050,6 @@ files = [ name = "urllib3" version = "2.0.4" description = "HTTP library with thread-safe connection pooling, file post, and more." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1108,7 +1067,6 @@ zstd = ["zstandard (>=0.18.0)"] name = "wcwidth" version = "0.2.6" description = "Measures the displayed width of unicode strings in a terminal" -category = "main" optional = false python-versions = "*" files = [ @@ -1120,7 +1078,6 @@ files = [ name = "yarl" version = "1.9.2" description = "Yet another URL library" -category = "main" optional = false python-versions = ">=3.7" files = [ diff --git a/pyproject.toml b/pyproject.toml index f156cc3..2455a40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "llm-chat" -version = "0.5.0" +version = "0.6.0" description = "A general CLI interface for large language models." authors = ["Paul Harrison "] readme = "README.md" diff --git a/src/llm_chat/chat.py b/src/llm_chat/chat.py index 0f9313b..511536f 100644 --- a/src/llm_chat/chat.py +++ b/src/llm_chat/chat.py @@ -27,10 +27,12 @@ def save_conversation( """Store a conversation in the history directory.""" if conversation.prompt_tokens == 0: return + + filename = f"{dt.strftime('%Y%m%d%H%M%S')}{'-' + conversation.name if conversation.name else ''}" history_dir.mkdir(parents=True, exist_ok=True) - path = history_dir / f"{dt.strftime('%Y%m%d%H%M%S')}.json" + path = history_dir / f"{filename}.json" with path.open(mode="w") as f: f.write(conversation.model_dump_json(indent=2)) @@ -86,6 +88,7 @@ class Chat: self, settings: OpenAISettings | None = None, context: list[Message] = [], + name: str = "", initial_system_messages: bool = True, ) -> None: self._settings = settings @@ -95,6 +98,7 @@ class Chat: else context, model=self.settings.model, temperature=self.settings.temperature, + name=name, ) self._start_time = datetime.now(tz=ZoneInfo("UTC")) @@ -186,10 +190,10 @@ class Chat: def get_chat( - settings: OpenAISettings | None = None, context: list[Message] = [] + settings: OpenAISettings | None = None, context: list[Message] = [], name: str = "" ) -> ChatProtocol: """Get a chat object.""" - return Chat(settings=settings, context=context) + return Chat(settings=settings, context=context, name=name) def get_chat_class() -> Type[Chat]: diff --git a/src/llm_chat/cli.py b/src/llm_chat/cli.py index 4a8281b..b392c66 100644 --- a/src/llm_chat/cli.py +++ b/src/llm_chat/cli.py @@ -8,7 +8,7 @@ from rich.markdown import Markdown from llm_chat.chat import ChatProtocol, get_chat, get_chat_class from llm_chat.models import Message, Role -from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings +from llm_chat.settings import DEFAULT_HISTORY_DIR, DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings app = typer.Typer() @@ -71,6 +71,8 @@ def run_conversation(current_chat: ChatProtocol) -> None: console.print( f"[bold green]Temperature:[/bold green] {current_chat.settings.temperature}" ) + if current_chat.conversation.name: + console.print(f"[bold green]Name:[/bold green] {current_chat.conversation.name}") while not finished: prompt = read_user_input(session) @@ -123,17 +125,38 @@ def chat( readable=True, ), ] = [], + history_dir: Annotated[ + Path, + typer.Option( + ..., + "--history-dir", + "-d", + help="Path to the directory where conversation history will be saved.", + exists=True, + dir_okay=True, + file_okay=False, + ), + ] = DEFAULT_HISTORY_DIR, + name: Annotated[ + str, + typer.Option( + ..., + "--name", + "-n", + help="Name of the chat.", + ), + ] = "", ) -> None: """Start a chat session.""" # TODO: Add option to provide context string as an argument. if api_key is not None: - settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature) + settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature, history_dir=history_dir) else: - settings = OpenAISettings(model=model, temperature=temperature) + settings = OpenAISettings(model=model, temperature=temperature, history_dir=history_dir) context_messages = [load_context(path) for path in context] - current_chat = get_chat(settings=settings, context=context_messages) + current_chat = get_chat(settings=settings, context=context_messages, name=name) run_conversation(current_chat) diff --git a/src/llm_chat/models.py b/src/llm_chat/models.py index 94e6f57..d3262ca 100644 --- a/src/llm_chat/models.py +++ b/src/llm_chat/models.py @@ -31,6 +31,7 @@ class Conversation(BaseModel): messages: list[Message] model: Model temperature: float = DEFAULT_TEMPERATURE + name: str = "" completion_tokens: int = 0 prompt_tokens: int = 0 cost: float = 0.0 diff --git a/src/llm_chat/settings.py b/src/llm_chat/settings.py index 1e024c9..783c850 100644 --- a/src/llm_chat/settings.py +++ b/src/llm_chat/settings.py @@ -13,6 +13,7 @@ class Model(StrEnum): DEFAULT_MODEL = Model.GPT3 DEFAULT_TEMPERATURE = 0.7 +DEFAULT_HISTORY_DIR = Path().absolute() / ".history" class OpenAISettings(BaseSettings): @@ -21,7 +22,7 @@ class OpenAISettings(BaseSettings): api_key: str = "" model: Model = DEFAULT_MODEL temperature: float = DEFAULT_TEMPERATURE - history_dir: Path = Path().absolute() / ".history" + history_dir: Path = DEFAULT_HISTORY_DIR model_config: SettingsConfigDict = SettingsConfigDict( # type: ignore[misc] env_file=".env", diff --git a/tests/test_chat.py b/tests/test_chat.py index 0c06f61..6946ba9 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -10,7 +10,8 @@ from llm_chat.models import Conversation, Message, Role from llm_chat.settings import Model, OpenAISettings -def test_save_conversation(tmp_path: Path) -> None: +@pytest.mark.parametrize("name,expected_filename", [("", "20210101120000.json"), ("foo", "20210101120000-foo.json")]) +def test_save_conversation(name: str, expected_filename: str, tmp_path: Path) -> None: conversation = Conversation( messages=[ Message(role=Role.SYSTEM, content="Hello!"), @@ -22,10 +23,11 @@ def test_save_conversation(tmp_path: Path) -> None: completion_tokens=10, prompt_tokens=15, cost=0.000043, + name=name, ) path = tmp_path / ".history" - expected_file_path = path / "20210101120000.json" + expected_file_path = path / expected_filename dt = datetime(2021, 1, 1, 12, 0, 0, tzinfo=ZoneInfo("UTC")) assert not path.exists() diff --git a/tests/test_cli.py b/tests/test_cli.py index 6c0521c..7b22a05 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,15 +1,19 @@ +from datetime import datetime from io import StringIO +from itertools import product from pathlib import Path from typing import Any, Type from unittest.mock import MagicMock +from zoneinfo import ZoneInfo import pytest +import time_machine from pytest import MonkeyPatch from rich.console import Console from typer.testing import CliRunner import llm_chat -from llm_chat.chat import ChatProtocol +from llm_chat.chat import ChatProtocol, save_conversation from llm_chat.cli import app from llm_chat.models import Conversation, Message, Role from llm_chat.settings import Model, OpenAISettings @@ -112,6 +116,32 @@ def test_chat_with_context( assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)] +@pytest.mark.parametrize("argument,name", list(product(("--name", "-n"), ("", "foo"))), ids=[f"{arg} {name}" for arg, name in product(("--name", "-n"), ("", "foo"))]) +def test_chat_with_name( + argument: str, name: str, monkeypatch: MonkeyPatch, tmp_path: Path +) -> None: + chat_fake = ChatFake() + output = StringIO() + console = Console(file=output) + + def mock_get_chat(**kwargs: Any) -> ChatProtocol: + chat_fake._set_args(**kwargs) + return chat_fake + + def mock_get_console() -> Console: + return console + + mock_read_user_input = MagicMock(side_effect=["Hello", "/q"]) + + monkeypatch.setattr(llm_chat.cli, "get_chat", mock_get_chat) + monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console) + monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input) + + result = runner.invoke(app, ["chat", argument, name]) + assert result.exit_code == 0 + assert chat_fake.args["name"] == name + + def test_load(monkeypatch: MonkeyPatch, tmp_path: Path) -> None: # Create a conversation object to save conversation = Conversation(