From 618423c0e8f93a93690488e4e8a37205e24378df Mon Sep 17 00:00:00 2001 From: Paul Harrison Date: Thu, 14 Sep 2023 21:28:29 +0100 Subject: [PATCH] Fix test bug introduced with conversation naming The ChatFake object used in CLI tests never gets a `conversation` attribute defined. This resulted in multiple tests failing since the `save_conversation` function accessed the `name` attribute via `current_chat.conversation.name`. This was resolved by adding a `name` property to the `Chat` class, which can be easily faked in tests. Additionally, formatting was fixed. --- src/llm_chat/chat.py | 15 +++++++++++++-- src/llm_chat/cli.py | 23 ++++++++++++++++++----- tests/test_chat.py | 5 ++++- tests/test_cli.py | 16 +++++++++++----- 4 files changed, 46 insertions(+), 13 deletions(-) diff --git a/src/llm_chat/chat.py b/src/llm_chat/chat.py index 511536f..60c7c18 100644 --- a/src/llm_chat/chat.py +++ b/src/llm_chat/chat.py @@ -27,8 +27,10 @@ def save_conversation( """Store a conversation in the history directory.""" if conversation.prompt_tokens == 0: return - - filename = f"{dt.strftime('%Y%m%d%H%M%S')}{'-' + conversation.name if conversation.name else ''}" + + dt_str = dt.strftime("%Y%m%d%H%M%S") + name_str = f"-{conversation.name}" if conversation.name else "" + filename = f"{dt_str}{name_str}" history_dir.mkdir(parents=True, exist_ok=True) @@ -53,6 +55,10 @@ class ChatProtocol(Protocol): def cost(self) -> float: """Get the cost of the conversation.""" + @property + def name(self) -> str: + """Get the name of the conversation.""" + @property def settings(self) -> OpenAISettings: """Get OpenAI chat settings.""" @@ -137,6 +143,11 @@ class Chat: """Get the cost of the conversation.""" return self.conversation.cost + @property + def name(self) -> str: + """Get the name of the conversation.""" + return self.conversation.name + def _make_request(self, message: str) -> dict[str, Any]: """Send a request to the OpenAI API. diff --git a/src/llm_chat/cli.py b/src/llm_chat/cli.py index b392c66..3f9662c 100644 --- a/src/llm_chat/cli.py +++ b/src/llm_chat/cli.py @@ -8,7 +8,13 @@ from rich.markdown import Markdown from llm_chat.chat import ChatProtocol, get_chat, get_chat_class from llm_chat.models import Message, Role -from llm_chat.settings import DEFAULT_HISTORY_DIR, DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings +from llm_chat.settings import ( + DEFAULT_HISTORY_DIR, + DEFAULT_MODEL, + DEFAULT_TEMPERATURE, + Model, + OpenAISettings, +) app = typer.Typer() @@ -71,8 +77,8 @@ def run_conversation(current_chat: ChatProtocol) -> None: console.print( f"[bold green]Temperature:[/bold green] {current_chat.settings.temperature}" ) - if current_chat.conversation.name: - console.print(f"[bold green]Name:[/bold green] {current_chat.conversation.name}") + if current_chat.name: + console.print(f"[bold green]Name:[/bold green] {current_chat.name}") while not finished: prompt = read_user_input(session) @@ -150,9 +156,16 @@ def chat( """Start a chat session.""" # TODO: Add option to provide context string as an argument. if api_key is not None: - settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature, history_dir=history_dir) + settings = OpenAISettings( + api_key=api_key, + model=model, + temperature=temperature, + history_dir=history_dir, + ) else: - settings = OpenAISettings(model=model, temperature=temperature, history_dir=history_dir) + settings = OpenAISettings( + model=model, temperature=temperature, history_dir=history_dir + ) context_messages = [load_context(path) for path in context] diff --git a/tests/test_chat.py b/tests/test_chat.py index 6946ba9..3ba5603 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -10,7 +10,10 @@ from llm_chat.models import Conversation, Message, Role from llm_chat.settings import Model, OpenAISettings -@pytest.mark.parametrize("name,expected_filename", [("", "20210101120000.json"), ("foo", "20210101120000-foo.json")]) +@pytest.mark.parametrize( + "name,expected_filename", + [("", "20210101120000.json"), ("foo", "20210101120000-foo.json")], +) def test_save_conversation(name: str, expected_filename: str, tmp_path: Path) -> None: conversation = Conversation( messages=[ diff --git a/tests/test_cli.py b/tests/test_cli.py index 7b22a05..c039be5 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,19 +1,16 @@ -from datetime import datetime from io import StringIO from itertools import product from pathlib import Path from typing import Any, Type from unittest.mock import MagicMock -from zoneinfo import ZoneInfo import pytest -import time_machine from pytest import MonkeyPatch from rich.console import Console from typer.testing import CliRunner import llm_chat -from llm_chat.chat import ChatProtocol, save_conversation +from llm_chat.chat import ChatProtocol from llm_chat.cli import app from llm_chat.models import Conversation, Message, Role from llm_chat.settings import Model, OpenAISettings @@ -45,6 +42,11 @@ class ChatFake: """Get the cost of the conversation.""" return 0.0 + @property + def name(self) -> str: + """Get the name of the conversation.""" + return self.args.get("name", "") + @classmethod def load( cls, path: Path, api_key: str | None = None, history_dir: Path | None = None @@ -116,7 +118,11 @@ def test_chat_with_context( assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)] -@pytest.mark.parametrize("argument,name", list(product(("--name", "-n"), ("", "foo"))), ids=[f"{arg} {name}" for arg, name in product(("--name", "-n"), ("", "foo"))]) +@pytest.mark.parametrize( + "argument,name", + list(product(("--name", "-n"), ("", "foo"))), + ids=[f"{arg} {name}" for arg, name in product(("--name", "-n"), ("", "foo"))], +) def test_chat_with_name( argument: str, name: str, monkeypatch: MonkeyPatch, tmp_path: Path ) -> None: