Fix test bug introduced with conversation naming

The ChatFake object used in CLI tests never gets a `conversation`
attribute defined. This resulted in multiple tests failing since the
`save_conversation` function accessed the `name` attribute via
`current_chat.conversation.name`. This was resolved by adding a `name`
property to the `Chat` class, which can be easily faked in tests.

Additionally, formatting was fixed.
This commit is contained in:
Paul Harrison 2023-09-14 21:28:29 +01:00
parent 68fc11c450
commit 618423c0e8
4 changed files with 46 additions and 13 deletions

View File

@ -28,7 +28,9 @@ def save_conversation(
if conversation.prompt_tokens == 0: if conversation.prompt_tokens == 0:
return return
filename = f"{dt.strftime('%Y%m%d%H%M%S')}{'-' + conversation.name if conversation.name else ''}" dt_str = dt.strftime("%Y%m%d%H%M%S")
name_str = f"-{conversation.name}" if conversation.name else ""
filename = f"{dt_str}{name_str}"
history_dir.mkdir(parents=True, exist_ok=True) history_dir.mkdir(parents=True, exist_ok=True)
@ -53,6 +55,10 @@ class ChatProtocol(Protocol):
def cost(self) -> float: def cost(self) -> float:
"""Get the cost of the conversation.""" """Get the cost of the conversation."""
@property
def name(self) -> str:
"""Get the name of the conversation."""
@property @property
def settings(self) -> OpenAISettings: def settings(self) -> OpenAISettings:
"""Get OpenAI chat settings.""" """Get OpenAI chat settings."""
@ -137,6 +143,11 @@ class Chat:
"""Get the cost of the conversation.""" """Get the cost of the conversation."""
return self.conversation.cost return self.conversation.cost
@property
def name(self) -> str:
"""Get the name of the conversation."""
return self.conversation.name
def _make_request(self, message: str) -> dict[str, Any]: def _make_request(self, message: str) -> dict[str, Any]:
"""Send a request to the OpenAI API. """Send a request to the OpenAI API.

View File

@ -8,7 +8,13 @@ from rich.markdown import Markdown
from llm_chat.chat import ChatProtocol, get_chat, get_chat_class from llm_chat.chat import ChatProtocol, get_chat, get_chat_class
from llm_chat.models import Message, Role from llm_chat.models import Message, Role
from llm_chat.settings import DEFAULT_HISTORY_DIR, DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings from llm_chat.settings import (
DEFAULT_HISTORY_DIR,
DEFAULT_MODEL,
DEFAULT_TEMPERATURE,
Model,
OpenAISettings,
)
app = typer.Typer() app = typer.Typer()
@ -71,8 +77,8 @@ def run_conversation(current_chat: ChatProtocol) -> None:
console.print( console.print(
f"[bold green]Temperature:[/bold green] {current_chat.settings.temperature}" f"[bold green]Temperature:[/bold green] {current_chat.settings.temperature}"
) )
if current_chat.conversation.name: if current_chat.name:
console.print(f"[bold green]Name:[/bold green] {current_chat.conversation.name}") console.print(f"[bold green]Name:[/bold green] {current_chat.name}")
while not finished: while not finished:
prompt = read_user_input(session) prompt = read_user_input(session)
@ -150,9 +156,16 @@ def chat(
"""Start a chat session.""" """Start a chat session."""
# TODO: Add option to provide context string as an argument. # TODO: Add option to provide context string as an argument.
if api_key is not None: if api_key is not None:
settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature, history_dir=history_dir) settings = OpenAISettings(
api_key=api_key,
model=model,
temperature=temperature,
history_dir=history_dir,
)
else: else:
settings = OpenAISettings(model=model, temperature=temperature, history_dir=history_dir) settings = OpenAISettings(
model=model, temperature=temperature, history_dir=history_dir
)
context_messages = [load_context(path) for path in context] context_messages = [load_context(path) for path in context]

View File

@ -10,7 +10,10 @@ from llm_chat.models import Conversation, Message, Role
from llm_chat.settings import Model, OpenAISettings from llm_chat.settings import Model, OpenAISettings
@pytest.mark.parametrize("name,expected_filename", [("", "20210101120000.json"), ("foo", "20210101120000-foo.json")]) @pytest.mark.parametrize(
"name,expected_filename",
[("", "20210101120000.json"), ("foo", "20210101120000-foo.json")],
)
def test_save_conversation(name: str, expected_filename: str, tmp_path: Path) -> None: def test_save_conversation(name: str, expected_filename: str, tmp_path: Path) -> None:
conversation = Conversation( conversation = Conversation(
messages=[ messages=[

View File

@ -1,19 +1,16 @@
from datetime import datetime
from io import StringIO from io import StringIO
from itertools import product from itertools import product
from pathlib import Path from pathlib import Path
from typing import Any, Type from typing import Any, Type
from unittest.mock import MagicMock from unittest.mock import MagicMock
from zoneinfo import ZoneInfo
import pytest import pytest
import time_machine
from pytest import MonkeyPatch from pytest import MonkeyPatch
from rich.console import Console from rich.console import Console
from typer.testing import CliRunner from typer.testing import CliRunner
import llm_chat import llm_chat
from llm_chat.chat import ChatProtocol, save_conversation from llm_chat.chat import ChatProtocol
from llm_chat.cli import app from llm_chat.cli import app
from llm_chat.models import Conversation, Message, Role from llm_chat.models import Conversation, Message, Role
from llm_chat.settings import Model, OpenAISettings from llm_chat.settings import Model, OpenAISettings
@ -45,6 +42,11 @@ class ChatFake:
"""Get the cost of the conversation.""" """Get the cost of the conversation."""
return 0.0 return 0.0
@property
def name(self) -> str:
"""Get the name of the conversation."""
return self.args.get("name", "")
@classmethod @classmethod
def load( def load(
cls, path: Path, api_key: str | None = None, history_dir: Path | None = None cls, path: Path, api_key: str | None = None, history_dir: Path | None = None
@ -116,7 +118,11 @@ def test_chat_with_context(
assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)] assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)]
@pytest.mark.parametrize("argument,name", list(product(("--name", "-n"), ("", "foo"))), ids=[f"{arg} {name}" for arg, name in product(("--name", "-n"), ("", "foo"))]) @pytest.mark.parametrize(
"argument,name",
list(product(("--name", "-n"), ("", "foo"))),
ids=[f"{arg} {name}" for arg, name in product(("--name", "-n"), ("", "foo"))],
)
def test_chat_with_name( def test_chat_with_name(
argument: str, name: str, monkeypatch: MonkeyPatch, tmp_path: Path argument: str, name: str, monkeypatch: MonkeyPatch, tmp_path: Path
) -> None: ) -> None: