From f5c7c40cc8b0f754b2cfad760b1ed5f21fc3bcc8 Mon Sep 17 00:00:00 2001 From: Paul Harrison Date: Thu, 24 Aug 2023 18:35:27 +0100 Subject: [PATCH] Load previous conversation This commit adds the additional CLI command `load`, which accepts the path to a previously saved conversation from which to load previous messages. The previous cost and token counts are not loaded as theses are deemed functions of the chat session rathr than the conversation and are stored with conversation purely for future reference. --- pyproject.toml | 4 +-- src/llm_chat/chat.py | 41 +++++++++++++++++++++-- src/llm_chat/cli.py | 80 ++++++++++++++++++++++++++++++++++---------- tests/test_chat.py | 36 ++++++++++++++++++++ tests/test_cli.py | 65 ++++++++++++++++++++++++++++++++--- 5 files changed, 200 insertions(+), 26 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 61598ed..f156cc3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "llm-chat" -version = "0.4.0" +version = "0.5.0" description = "A general CLI interface for large language models." authors = ["Paul Harrison "] readme = "README.md" @@ -23,7 +23,7 @@ mypy = "^1.5.0" pydocstyle = "^6.3.0" [tool.poetry.scripts] -chat = "llm_chat.cli:app" +llm = "llm_chat.cli:app" [build-system] requires = ["poetry-core"] diff --git a/src/llm_chat/chat.py b/src/llm_chat/chat.py index 0d0d989..0f9313b 100644 --- a/src/llm_chat/chat.py +++ b/src/llm_chat/chat.py @@ -1,7 +1,7 @@ from datetime import datetime from enum import StrEnum, auto from pathlib import Path -from typing import Any, Protocol +from typing import Any, Protocol, Type from zoneinfo import ZoneInfo from openai import ChatCompletion @@ -45,10 +45,16 @@ class Token(StrEnum): class ChatProtocol(Protocol): """Protocol for chat classes.""" + conversation: Conversation + @property def cost(self) -> float: """Get the cost of the conversation.""" + @property + def settings(self) -> OpenAISettings: + """Get OpenAI chat settings.""" + def save(self) -> None: """Save the conversation to the history directory.""" @@ -80,15 +86,41 @@ class Chat: self, settings: OpenAISettings | None = None, context: list[Message] = [], + initial_system_messages: bool = True, ) -> None: self._settings = settings self.conversation = Conversation( - messages=INITIAL_SYSTEM_MESSAGES + context, + messages=INITIAL_SYSTEM_MESSAGES + context + if initial_system_messages + else context, model=self.settings.model, temperature=self.settings.temperature, ) self._start_time = datetime.now(tz=ZoneInfo("UTC")) + @classmethod + def load( + cls, path: Path, api_key: str | None = None, history_dir: Path | None = None + ) -> ChatProtocol: + """Load a chat from a file.""" + with path.open() as f: + conversation = Conversation.model_validate_json(f.read()) + args: dict[str, Any] = { + "model": conversation.model, + "temperature": conversation.temperature, + } + if api_key is not None: + args["api_key"] = api_key + if history_dir is not None: + args["history_dir"] = history_dir + + settings = OpenAISettings(**args) + return cls( + settings=settings, + context=conversation.messages, + initial_system_messages=False, + ) + @property def settings(self) -> OpenAISettings: """Get OpenAI chat settings.""" @@ -158,3 +190,8 @@ def get_chat( ) -> ChatProtocol: """Get a chat object.""" return Chat(settings=settings, context=context) + + +def get_chat_class() -> Type[Chat]: + """Get the chat class.""" + return Chat diff --git a/src/llm_chat/cli.py b/src/llm_chat/cli.py index ae3bf5c..4a8281b 100644 --- a/src/llm_chat/cli.py +++ b/src/llm_chat/cli.py @@ -6,7 +6,7 @@ from prompt_toolkit import PromptSession from rich.console import Console from rich.markdown import Markdown -from llm_chat.chat import ChatProtocol, get_chat +from llm_chat.chat import ChatProtocol, get_chat, get_chat_class from llm_chat.models import Message, Role from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings @@ -60,7 +60,30 @@ def display_cost(console: Console, chat: ChatProtocol) -> None: console.print(f"\n[bold green]Cost:[/bold green] ${chat.cost}\n") -@app.command() +def run_conversation(current_chat: ChatProtocol) -> None: + """Run a conversation.""" + console = get_console() + session = get_session() + + finished = False + + console.print(f"[bold green]Model:[/bold green] {current_chat.settings.model}") + console.print( + f"[bold green]Temperature:[/bold green] {current_chat.settings.temperature}" + ) + + while not finished: + prompt = read_user_input(session) + if prompt.strip() == "/q": + finished = True + current_chat.save() + else: + response = current_chat.send_message(prompt.strip()) + console.print(Markdown(response)) + display_cost(console, current_chat) + + +@app.command("chat") def chat( api_key: Annotated[ Optional[str], @@ -102,7 +125,6 @@ def chat( ] = [], ) -> None: """Start a chat session.""" - # TODO: Add option to load context from file. # TODO: Add option to provide context string as an argument. if api_key is not None: settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature) @@ -112,20 +134,44 @@ def chat( context_messages = [load_context(path) for path in context] current_chat = get_chat(settings=settings, context=context_messages) - console = get_console() - session = get_session() - finished = False + run_conversation(current_chat) - console.print(f"[bold green]Model:[/bold green] {settings.model}") - console.print(f"[bold green]Temperature:[/bold green] {settings.temperature}") - while not finished: - prompt = read_user_input(session) - if prompt.strip() == "/q": - finished = True - current_chat.save() - else: - response = current_chat.send_message(prompt.strip()) - console.print(Markdown(response)) - display_cost(console, current_chat) +@app.command("load") +def load( + path: Annotated[ + Path, + typer.Argument( + help="Path to a conversation file.", + exists=True, + file_okay=True, + dir_okay=False, + readable=True, + ), + ], + api_key: Annotated[ + Optional[str], + typer.Option( + ..., + "--api-key", + "-k", + help=( + "API key. Will read from the environment variable OPENAI_API_KEY " + "if not provided." + ), + ), + ] = None, +) -> None: + """Load a conversation from a file.""" + Chat = get_chat_class() + if api_key is not None: + current_chat = Chat.load(path, api_key=api_key) + else: + current_chat = Chat.load(path) + + run_conversation(current_chat) + + +if __name__ == "__main__": + app() diff --git a/tests/test_chat.py b/tests/test_chat.py index ed07726..0c06f61 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -46,6 +46,42 @@ def test_save_conversation(tmp_path: Path) -> None: assert conversation == conversation_from_file +def test_load(tmp_path: Path) -> None: + # Create a conversation object to save + conversation = Conversation( + messages=[ + Message(role=Role.SYSTEM, content="Hello!"), + Message(role=Role.USER, content="Hi!"), + Message(role=Role.ASSISTANT, content="How are you?"), + ], + model=Model.GPT3, + temperature=0.5, + completion_tokens=10, + prompt_tokens=15, + cost=0.000043, + ) + + # Save the conversation to a file + file_path = tmp_path / "conversation.json" + with file_path.open("w") as f: + f.write(conversation.model_dump_json()) + + # Load the conversation from the file + loaded_chat = Chat.load(file_path, api_key="foo", history_dir=tmp_path) + + # Check that the loaded conversation matches the original conversation + assert loaded_chat.settings.model == conversation.model + assert loaded_chat.settings.temperature == conversation.temperature + assert loaded_chat.conversation.messages == conversation.messages + assert loaded_chat.settings.api_key == "foo" + assert loaded_chat.settings.history_dir == tmp_path + + # We don't want to load the tokens or cost from the previous session + assert loaded_chat.conversation.completion_tokens == 0 + assert loaded_chat.conversation.prompt_tokens == 0 + assert loaded_chat.cost == 0 + + def test_send_message() -> None: with patch("llm_chat.chat.Chat._make_request") as mock_make_request: mock_make_request.return_value = { diff --git a/tests/test_cli.py b/tests/test_cli.py index d57098c..6c0521c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,6 +1,6 @@ from io import StringIO from pathlib import Path -from typing import Any +from typing import Any, Type from unittest.mock import MagicMock import pytest @@ -11,7 +11,8 @@ from typer.testing import CliRunner import llm_chat from llm_chat.chat import ChatProtocol from llm_chat.cli import app -from llm_chat.models import Message, Role +from llm_chat.models import Conversation, Message, Role +from llm_chat.settings import Model, OpenAISettings runner = CliRunner() @@ -20,9 +21,15 @@ class ChatFake: """Fake chat class for testing.""" args: dict[str, Any] + conversation: Conversation received_messages: list[str] + settings: OpenAISettings - def __init__(self) -> None: + def __init__(self, settings: OpenAISettings | None = None) -> None: + if settings is not None: + self.settings = settings + else: + self.settings = OpenAISettings() self.args = {} self.received_messages = [] @@ -34,6 +41,13 @@ class ChatFake: """Get the cost of the conversation.""" return 0.0 + @classmethod + def load( + cls, path: Path, api_key: str | None = None, history_dir: Path | None = None + ) -> ChatProtocol: + """Load a chat from a file.""" + return cls() + def save(self) -> None: """Dummy save method.""" pass @@ -61,7 +75,7 @@ def test_chat(monkeypatch: MonkeyPatch) -> None: monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console) monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input) - result = runner.invoke(app) + result = runner.invoke(app, ["chat"]) assert result.exit_code == 0 assert chat_fake.received_messages == ["Hello"] @@ -91,8 +105,49 @@ def test_chat_with_context( monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console) monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input) - result = runner.invoke(app, [argument, str(tmp_file)]) + result = runner.invoke(app, ["chat", argument, str(tmp_file)]) assert result.exit_code == 0 assert chat_fake.received_messages == ["Hello"] assert "context" in chat_fake.args assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)] + + +def test_load(monkeypatch: MonkeyPatch, tmp_path: Path) -> None: + # Create a conversation object to save + conversation = Conversation( + messages=[ + Message(role=Role.SYSTEM, content="Hello!"), + Message(role=Role.USER, content="Hi!"), + Message(role=Role.ASSISTANT, content="How are you?"), + ], + model=Model.GPT3, + temperature=0.5, + completion_tokens=10, + prompt_tokens=15, + cost=0.000043, + ) + + # Save the conversation to a file + file_path = tmp_path / "conversation.json" + with file_path.open("w") as f: + f.write(conversation.model_dump_json()) + + output = StringIO() + console = Console(file=output) + + def mock_get_chat() -> Type[ChatFake]: + return ChatFake + + def mock_get_console() -> Console: + return console + + mock_read_user_input = MagicMock(side_effect=["Hello", "/q"]) + + monkeypatch.setattr(llm_chat.cli, "get_chat_class", mock_get_chat) + monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console) + monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input) + + # Load the conversation from the file + result = runner.invoke(app, ["load", str(file_path)]) + + assert result.exit_code == 0