Fix test bug introduced with conversation naming

The ChatFake object used in CLI tests never gets a `conversation`
attribute defined. This resulted in multiple tests failing since the
`save_conversation` function accessed the `name` attribute via
`current_chat.conversation.name`. This was resolved by adding a `name`
property to the `Chat` class, which can be easily faked in tests.

Additionally, formatting was fixed.
This commit is contained in:
Paul Harrison 2023-09-14 21:28:29 +01:00
parent 68fc11c450
commit 618423c0e8
4 changed files with 46 additions and 13 deletions

View File

@ -27,8 +27,10 @@ def save_conversation(
"""Store a conversation in the history directory."""
if conversation.prompt_tokens == 0:
return
filename = f"{dt.strftime('%Y%m%d%H%M%S')}{'-' + conversation.name if conversation.name else ''}"
dt_str = dt.strftime("%Y%m%d%H%M%S")
name_str = f"-{conversation.name}" if conversation.name else ""
filename = f"{dt_str}{name_str}"
history_dir.mkdir(parents=True, exist_ok=True)
@ -53,6 +55,10 @@ class ChatProtocol(Protocol):
def cost(self) -> float:
"""Get the cost of the conversation."""
@property
def name(self) -> str:
"""Get the name of the conversation."""
@property
def settings(self) -> OpenAISettings:
"""Get OpenAI chat settings."""
@ -137,6 +143,11 @@ class Chat:
"""Get the cost of the conversation."""
return self.conversation.cost
@property
def name(self) -> str:
"""Get the name of the conversation."""
return self.conversation.name
def _make_request(self, message: str) -> dict[str, Any]:
"""Send a request to the OpenAI API.

View File

@ -8,7 +8,13 @@ from rich.markdown import Markdown
from llm_chat.chat import ChatProtocol, get_chat, get_chat_class
from llm_chat.models import Message, Role
from llm_chat.settings import DEFAULT_HISTORY_DIR, DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings
from llm_chat.settings import (
DEFAULT_HISTORY_DIR,
DEFAULT_MODEL,
DEFAULT_TEMPERATURE,
Model,
OpenAISettings,
)
app = typer.Typer()
@ -71,8 +77,8 @@ def run_conversation(current_chat: ChatProtocol) -> None:
console.print(
f"[bold green]Temperature:[/bold green] {current_chat.settings.temperature}"
)
if current_chat.conversation.name:
console.print(f"[bold green]Name:[/bold green] {current_chat.conversation.name}")
if current_chat.name:
console.print(f"[bold green]Name:[/bold green] {current_chat.name}")
while not finished:
prompt = read_user_input(session)
@ -150,9 +156,16 @@ def chat(
"""Start a chat session."""
# TODO: Add option to provide context string as an argument.
if api_key is not None:
settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature, history_dir=history_dir)
settings = OpenAISettings(
api_key=api_key,
model=model,
temperature=temperature,
history_dir=history_dir,
)
else:
settings = OpenAISettings(model=model, temperature=temperature, history_dir=history_dir)
settings = OpenAISettings(
model=model, temperature=temperature, history_dir=history_dir
)
context_messages = [load_context(path) for path in context]

View File

@ -10,7 +10,10 @@ from llm_chat.models import Conversation, Message, Role
from llm_chat.settings import Model, OpenAISettings
@pytest.mark.parametrize("name,expected_filename", [("", "20210101120000.json"), ("foo", "20210101120000-foo.json")])
@pytest.mark.parametrize(
"name,expected_filename",
[("", "20210101120000.json"), ("foo", "20210101120000-foo.json")],
)
def test_save_conversation(name: str, expected_filename: str, tmp_path: Path) -> None:
conversation = Conversation(
messages=[

View File

@ -1,19 +1,16 @@
from datetime import datetime
from io import StringIO
from itertools import product
from pathlib import Path
from typing import Any, Type
from unittest.mock import MagicMock
from zoneinfo import ZoneInfo
import pytest
import time_machine
from pytest import MonkeyPatch
from rich.console import Console
from typer.testing import CliRunner
import llm_chat
from llm_chat.chat import ChatProtocol, save_conversation
from llm_chat.chat import ChatProtocol
from llm_chat.cli import app
from llm_chat.models import Conversation, Message, Role
from llm_chat.settings import Model, OpenAISettings
@ -45,6 +42,11 @@ class ChatFake:
"""Get the cost of the conversation."""
return 0.0
@property
def name(self) -> str:
"""Get the name of the conversation."""
return self.args.get("name", "")
@classmethod
def load(
cls, path: Path, api_key: str | None = None, history_dir: Path | None = None
@ -116,7 +118,11 @@ def test_chat_with_context(
assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)]
@pytest.mark.parametrize("argument,name", list(product(("--name", "-n"), ("", "foo"))), ids=[f"{arg} {name}" for arg, name in product(("--name", "-n"), ("", "foo"))])
@pytest.mark.parametrize(
"argument,name",
list(product(("--name", "-n"), ("", "foo"))),
ids=[f"{arg} {name}" for arg, name in product(("--name", "-n"), ("", "foo"))],
)
def test_chat_with_name(
argument: str, name: str, monkeypatch: MonkeyPatch, tmp_path: Path
) -> None: