diff --git a/pyproject.toml b/pyproject.toml index 61598ed..f156cc3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "llm-chat" -version = "0.4.0" +version = "0.5.0" description = "A general CLI interface for large language models." authors = ["Paul Harrison "] readme = "README.md" @@ -23,7 +23,7 @@ mypy = "^1.5.0" pydocstyle = "^6.3.0" [tool.poetry.scripts] -chat = "llm_chat.cli:app" +llm = "llm_chat.cli:app" [build-system] requires = ["poetry-core"] diff --git a/src/llm_chat/chat.py b/src/llm_chat/chat.py index 0d0d989..0f9313b 100644 --- a/src/llm_chat/chat.py +++ b/src/llm_chat/chat.py @@ -1,7 +1,7 @@ from datetime import datetime from enum import StrEnum, auto from pathlib import Path -from typing import Any, Protocol +from typing import Any, Protocol, Type from zoneinfo import ZoneInfo from openai import ChatCompletion @@ -45,10 +45,16 @@ class Token(StrEnum): class ChatProtocol(Protocol): """Protocol for chat classes.""" + conversation: Conversation + @property def cost(self) -> float: """Get the cost of the conversation.""" + @property + def settings(self) -> OpenAISettings: + """Get OpenAI chat settings.""" + def save(self) -> None: """Save the conversation to the history directory.""" @@ -80,15 +86,41 @@ class Chat: self, settings: OpenAISettings | None = None, context: list[Message] = [], + initial_system_messages: bool = True, ) -> None: self._settings = settings self.conversation = Conversation( - messages=INITIAL_SYSTEM_MESSAGES + context, + messages=INITIAL_SYSTEM_MESSAGES + context + if initial_system_messages + else context, model=self.settings.model, temperature=self.settings.temperature, ) self._start_time = datetime.now(tz=ZoneInfo("UTC")) + @classmethod + def load( + cls, path: Path, api_key: str | None = None, history_dir: Path | None = None + ) -> ChatProtocol: + """Load a chat from a file.""" + with path.open() as f: + conversation = Conversation.model_validate_json(f.read()) + args: dict[str, Any] = { + "model": conversation.model, + "temperature": conversation.temperature, + } + if api_key is not None: + args["api_key"] = api_key + if history_dir is not None: + args["history_dir"] = history_dir + + settings = OpenAISettings(**args) + return cls( + settings=settings, + context=conversation.messages, + initial_system_messages=False, + ) + @property def settings(self) -> OpenAISettings: """Get OpenAI chat settings.""" @@ -158,3 +190,8 @@ def get_chat( ) -> ChatProtocol: """Get a chat object.""" return Chat(settings=settings, context=context) + + +def get_chat_class() -> Type[Chat]: + """Get the chat class.""" + return Chat diff --git a/src/llm_chat/cli.py b/src/llm_chat/cli.py index ae3bf5c..4a8281b 100644 --- a/src/llm_chat/cli.py +++ b/src/llm_chat/cli.py @@ -6,7 +6,7 @@ from prompt_toolkit import PromptSession from rich.console import Console from rich.markdown import Markdown -from llm_chat.chat import ChatProtocol, get_chat +from llm_chat.chat import ChatProtocol, get_chat, get_chat_class from llm_chat.models import Message, Role from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings @@ -60,7 +60,30 @@ def display_cost(console: Console, chat: ChatProtocol) -> None: console.print(f"\n[bold green]Cost:[/bold green] ${chat.cost}\n") -@app.command() +def run_conversation(current_chat: ChatProtocol) -> None: + """Run a conversation.""" + console = get_console() + session = get_session() + + finished = False + + console.print(f"[bold green]Model:[/bold green] {current_chat.settings.model}") + console.print( + f"[bold green]Temperature:[/bold green] {current_chat.settings.temperature}" + ) + + while not finished: + prompt = read_user_input(session) + if prompt.strip() == "/q": + finished = True + current_chat.save() + else: + response = current_chat.send_message(prompt.strip()) + console.print(Markdown(response)) + display_cost(console, current_chat) + + +@app.command("chat") def chat( api_key: Annotated[ Optional[str], @@ -102,7 +125,6 @@ def chat( ] = [], ) -> None: """Start a chat session.""" - # TODO: Add option to load context from file. # TODO: Add option to provide context string as an argument. if api_key is not None: settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature) @@ -112,20 +134,44 @@ def chat( context_messages = [load_context(path) for path in context] current_chat = get_chat(settings=settings, context=context_messages) - console = get_console() - session = get_session() - finished = False + run_conversation(current_chat) - console.print(f"[bold green]Model:[/bold green] {settings.model}") - console.print(f"[bold green]Temperature:[/bold green] {settings.temperature}") - while not finished: - prompt = read_user_input(session) - if prompt.strip() == "/q": - finished = True - current_chat.save() - else: - response = current_chat.send_message(prompt.strip()) - console.print(Markdown(response)) - display_cost(console, current_chat) +@app.command("load") +def load( + path: Annotated[ + Path, + typer.Argument( + help="Path to a conversation file.", + exists=True, + file_okay=True, + dir_okay=False, + readable=True, + ), + ], + api_key: Annotated[ + Optional[str], + typer.Option( + ..., + "--api-key", + "-k", + help=( + "API key. Will read from the environment variable OPENAI_API_KEY " + "if not provided." + ), + ), + ] = None, +) -> None: + """Load a conversation from a file.""" + Chat = get_chat_class() + if api_key is not None: + current_chat = Chat.load(path, api_key=api_key) + else: + current_chat = Chat.load(path) + + run_conversation(current_chat) + + +if __name__ == "__main__": + app() diff --git a/tests/test_chat.py b/tests/test_chat.py index ed07726..0c06f61 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -46,6 +46,42 @@ def test_save_conversation(tmp_path: Path) -> None: assert conversation == conversation_from_file +def test_load(tmp_path: Path) -> None: + # Create a conversation object to save + conversation = Conversation( + messages=[ + Message(role=Role.SYSTEM, content="Hello!"), + Message(role=Role.USER, content="Hi!"), + Message(role=Role.ASSISTANT, content="How are you?"), + ], + model=Model.GPT3, + temperature=0.5, + completion_tokens=10, + prompt_tokens=15, + cost=0.000043, + ) + + # Save the conversation to a file + file_path = tmp_path / "conversation.json" + with file_path.open("w") as f: + f.write(conversation.model_dump_json()) + + # Load the conversation from the file + loaded_chat = Chat.load(file_path, api_key="foo", history_dir=tmp_path) + + # Check that the loaded conversation matches the original conversation + assert loaded_chat.settings.model == conversation.model + assert loaded_chat.settings.temperature == conversation.temperature + assert loaded_chat.conversation.messages == conversation.messages + assert loaded_chat.settings.api_key == "foo" + assert loaded_chat.settings.history_dir == tmp_path + + # We don't want to load the tokens or cost from the previous session + assert loaded_chat.conversation.completion_tokens == 0 + assert loaded_chat.conversation.prompt_tokens == 0 + assert loaded_chat.cost == 0 + + def test_send_message() -> None: with patch("llm_chat.chat.Chat._make_request") as mock_make_request: mock_make_request.return_value = { diff --git a/tests/test_cli.py b/tests/test_cli.py index d57098c..6c0521c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,6 +1,6 @@ from io import StringIO from pathlib import Path -from typing import Any +from typing import Any, Type from unittest.mock import MagicMock import pytest @@ -11,7 +11,8 @@ from typer.testing import CliRunner import llm_chat from llm_chat.chat import ChatProtocol from llm_chat.cli import app -from llm_chat.models import Message, Role +from llm_chat.models import Conversation, Message, Role +from llm_chat.settings import Model, OpenAISettings runner = CliRunner() @@ -20,9 +21,15 @@ class ChatFake: """Fake chat class for testing.""" args: dict[str, Any] + conversation: Conversation received_messages: list[str] + settings: OpenAISettings - def __init__(self) -> None: + def __init__(self, settings: OpenAISettings | None = None) -> None: + if settings is not None: + self.settings = settings + else: + self.settings = OpenAISettings() self.args = {} self.received_messages = [] @@ -34,6 +41,13 @@ class ChatFake: """Get the cost of the conversation.""" return 0.0 + @classmethod + def load( + cls, path: Path, api_key: str | None = None, history_dir: Path | None = None + ) -> ChatProtocol: + """Load a chat from a file.""" + return cls() + def save(self) -> None: """Dummy save method.""" pass @@ -61,7 +75,7 @@ def test_chat(monkeypatch: MonkeyPatch) -> None: monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console) monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input) - result = runner.invoke(app) + result = runner.invoke(app, ["chat"]) assert result.exit_code == 0 assert chat_fake.received_messages == ["Hello"] @@ -91,8 +105,49 @@ def test_chat_with_context( monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console) monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input) - result = runner.invoke(app, [argument, str(tmp_file)]) + result = runner.invoke(app, ["chat", argument, str(tmp_file)]) assert result.exit_code == 0 assert chat_fake.received_messages == ["Hello"] assert "context" in chat_fake.args assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)] + + +def test_load(monkeypatch: MonkeyPatch, tmp_path: Path) -> None: + # Create a conversation object to save + conversation = Conversation( + messages=[ + Message(role=Role.SYSTEM, content="Hello!"), + Message(role=Role.USER, content="Hi!"), + Message(role=Role.ASSISTANT, content="How are you?"), + ], + model=Model.GPT3, + temperature=0.5, + completion_tokens=10, + prompt_tokens=15, + cost=0.000043, + ) + + # Save the conversation to a file + file_path = tmp_path / "conversation.json" + with file_path.open("w") as f: + f.write(conversation.model_dump_json()) + + output = StringIO() + console = Console(file=output) + + def mock_get_chat() -> Type[ChatFake]: + return ChatFake + + def mock_get_console() -> Console: + return console + + mock_read_user_input = MagicMock(side_effect=["Hello", "/q"]) + + monkeypatch.setattr(llm_chat.cli, "get_chat_class", mock_get_chat) + monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console) + monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input) + + # Load the conversation from the file + result = runner.invoke(app, ["load", str(file_path)]) + + assert result.exit_code == 0