diff --git a/src/llm_chat/chat.py b/src/llm_chat/chat.py index 0d0d989..f63a004 100644 --- a/src/llm_chat/chat.py +++ b/src/llm_chat/chat.py @@ -80,15 +80,41 @@ class Chat: self, settings: OpenAISettings | None = None, context: list[Message] = [], + initial_system_messages: bool = True, ) -> None: self._settings = settings self.conversation = Conversation( - messages=INITIAL_SYSTEM_MESSAGES + context, + messages=INITIAL_SYSTEM_MESSAGES + context + if initial_system_messages + else context, model=self.settings.model, temperature=self.settings.temperature, ) self._start_time = datetime.now(tz=ZoneInfo("UTC")) + @classmethod + def load( + cls, path: Path, api_key: str | None = None, history_dir: Path | None = None + ) -> ChatProtocol: + """Load a chat from a file.""" + with path.open() as f: + conversation = Conversation.model_validate_json(f.read()) + args = { + "model": conversation.model, + "temperature": conversation.temperature, + } + if api_key is not None: + args["api_key"] = api_key + if history_dir is not None: + args["history_dir"] = history_dir + + settings = OpenAISettings(**args) + return cls( + settings=settings, + context=conversation.messages, + initial_system_messages=False, + ) + @property def settings(self) -> OpenAISettings: """Get OpenAI chat settings.""" diff --git a/src/llm_chat/cli.py b/src/llm_chat/cli.py index ae3bf5c..c42ff7f 100644 --- a/src/llm_chat/cli.py +++ b/src/llm_chat/cli.py @@ -6,7 +6,7 @@ from prompt_toolkit import PromptSession from rich.console import Console from rich.markdown import Markdown -from llm_chat.chat import ChatProtocol, get_chat +from llm_chat.chat import Chat, ChatProtocol, get_chat from llm_chat.models import Message, Role from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings @@ -60,7 +60,28 @@ def display_cost(console: Console, chat: ChatProtocol) -> None: console.print(f"\n[bold green]Cost:[/bold green] ${chat.cost}\n") -@app.command() +def run_conversation(current_chat: ChatProtocol) -> None: + """Run a conversation.""" + console = get_console() + session = get_session() + + finished = False + + console.print(f"[bold green]Model:[/bold green] {current_chat.settings.model}") + console.print(f"[bold green]Temperature:[/bold green] {current_chat.settings.temperature}") + + while not finished: + prompt = read_user_input(session) + if prompt.strip() == "/q": + finished = True + current_chat.save() + else: + response = current_chat.send_message(prompt.strip()) + console.print(Markdown(response)) + display_cost(console, current_chat) + + +@app.command("chat") def chat( api_key: Annotated[ Optional[str], @@ -102,7 +123,6 @@ def chat( ] = [], ) -> None: """Start a chat session.""" - # TODO: Add option to load context from file. # TODO: Add option to provide context string as an argument. if api_key is not None: settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature) @@ -112,20 +132,43 @@ def chat( context_messages = [load_context(path) for path in context] current_chat = get_chat(settings=settings, context=context_messages) - console = get_console() - session = get_session() - finished = False + run_conversation(current_chat) - console.print(f"[bold green]Model:[/bold green] {settings.model}") - console.print(f"[bold green]Temperature:[/bold green] {settings.temperature}") - while not finished: - prompt = read_user_input(session) - if prompt.strip() == "/q": - finished = True - current_chat.save() - else: - response = current_chat.send_message(prompt.strip()) - console.print(Markdown(response)) - display_cost(console, current_chat) +@app.command("load") +def load( + path: Annotated[ + Path, + typer.Argument( + help="Path to a conversation file.", + exists=True, + file_okay=True, + dir_okay=False, + readable=True, + ), + ], + api_key: Annotated[ + Optional[str], + typer.Option( + ..., + "--api-key", + "-k", + help=( + "API key. Will read from the environment variable OPENAI_API_KEY " + "if not provided." + ), + ), + ] = None, +) -> None: + """Load a conversation from a file.""" + if api_key is not None: + current_chat = Chat.load(path, api_key=api_key) + else: + current_chat = Chat.load(path) + + run_conversation(current_chat) + + +if __name__ == "__main__": + app() diff --git a/tests/test_chat.py b/tests/test_chat.py index ed07726..0c06f61 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -46,6 +46,42 @@ def test_save_conversation(tmp_path: Path) -> None: assert conversation == conversation_from_file +def test_load(tmp_path: Path) -> None: + # Create a conversation object to save + conversation = Conversation( + messages=[ + Message(role=Role.SYSTEM, content="Hello!"), + Message(role=Role.USER, content="Hi!"), + Message(role=Role.ASSISTANT, content="How are you?"), + ], + model=Model.GPT3, + temperature=0.5, + completion_tokens=10, + prompt_tokens=15, + cost=0.000043, + ) + + # Save the conversation to a file + file_path = tmp_path / "conversation.json" + with file_path.open("w") as f: + f.write(conversation.model_dump_json()) + + # Load the conversation from the file + loaded_chat = Chat.load(file_path, api_key="foo", history_dir=tmp_path) + + # Check that the loaded conversation matches the original conversation + assert loaded_chat.settings.model == conversation.model + assert loaded_chat.settings.temperature == conversation.temperature + assert loaded_chat.conversation.messages == conversation.messages + assert loaded_chat.settings.api_key == "foo" + assert loaded_chat.settings.history_dir == tmp_path + + # We don't want to load the tokens or cost from the previous session + assert loaded_chat.conversation.completion_tokens == 0 + assert loaded_chat.conversation.prompt_tokens == 0 + assert loaded_chat.cost == 0 + + def test_send_message() -> None: with patch("llm_chat.chat.Chat._make_request") as mock_make_request: mock_make_request.return_value = { diff --git a/tests/test_cli.py b/tests/test_cli.py index d57098c..11644da 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -12,6 +12,7 @@ import llm_chat from llm_chat.chat import ChatProtocol from llm_chat.cli import app from llm_chat.models import Message, Role +from llm_chat.settings import OpenAISettings runner = CliRunner() @@ -21,8 +22,13 @@ class ChatFake: args: dict[str, Any] received_messages: list[str] + settings: OpenAISettings - def __init__(self) -> None: + def __init__(self, settings: OpenAISettings | None = None) -> None: + if settings is not None: + self.settings = settings + else: + self.settings = OpenAISettings() self.args = {} self.received_messages = [] @@ -33,6 +39,13 @@ class ChatFake: def cost(self) -> float: """Get the cost of the conversation.""" return 0.0 + + @classmethod + def load( + cls, path: Path, api_key: str | None = None, history_dir: Path | None = None + ) -> ChatProtocol: + """Load a chat from a file.""" + return cls() def save(self) -> None: """Dummy save method.""" @@ -61,7 +74,7 @@ def test_chat(monkeypatch: MonkeyPatch) -> None: monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console) monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input) - result = runner.invoke(app) + result = runner.invoke(app, ["chat"]) assert result.exit_code == 0 assert chat_fake.received_messages == ["Hello"]