Fix test bug introduced with conversation naming
The ChatFake object used in CLI tests never gets a `conversation` attribute defined. This resulted in multiple tests failing since the `save_conversation` function accessed the `name` attribute via `current_chat.conversation.name`. This was resolved by adding a `name` property to the `Chat` class, which can be easily faked in tests. Additionally, formatting was fixed.
This commit is contained in:
parent
68fc11c450
commit
618423c0e8
|
@ -28,7 +28,9 @@ def save_conversation(
|
||||||
if conversation.prompt_tokens == 0:
|
if conversation.prompt_tokens == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
filename = f"{dt.strftime('%Y%m%d%H%M%S')}{'-' + conversation.name if conversation.name else ''}"
|
dt_str = dt.strftime("%Y%m%d%H%M%S")
|
||||||
|
name_str = f"-{conversation.name}" if conversation.name else ""
|
||||||
|
filename = f"{dt_str}{name_str}"
|
||||||
|
|
||||||
history_dir.mkdir(parents=True, exist_ok=True)
|
history_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
@ -53,6 +55,10 @@ class ChatProtocol(Protocol):
|
||||||
def cost(self) -> float:
|
def cost(self) -> float:
|
||||||
"""Get the cost of the conversation."""
|
"""Get the cost of the conversation."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Get the name of the conversation."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def settings(self) -> OpenAISettings:
|
def settings(self) -> OpenAISettings:
|
||||||
"""Get OpenAI chat settings."""
|
"""Get OpenAI chat settings."""
|
||||||
|
@ -137,6 +143,11 @@ class Chat:
|
||||||
"""Get the cost of the conversation."""
|
"""Get the cost of the conversation."""
|
||||||
return self.conversation.cost
|
return self.conversation.cost
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Get the name of the conversation."""
|
||||||
|
return self.conversation.name
|
||||||
|
|
||||||
def _make_request(self, message: str) -> dict[str, Any]:
|
def _make_request(self, message: str) -> dict[str, Any]:
|
||||||
"""Send a request to the OpenAI API.
|
"""Send a request to the OpenAI API.
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,13 @@ from rich.markdown import Markdown
|
||||||
|
|
||||||
from llm_chat.chat import ChatProtocol, get_chat, get_chat_class
|
from llm_chat.chat import ChatProtocol, get_chat, get_chat_class
|
||||||
from llm_chat.models import Message, Role
|
from llm_chat.models import Message, Role
|
||||||
from llm_chat.settings import DEFAULT_HISTORY_DIR, DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings
|
from llm_chat.settings import (
|
||||||
|
DEFAULT_HISTORY_DIR,
|
||||||
|
DEFAULT_MODEL,
|
||||||
|
DEFAULT_TEMPERATURE,
|
||||||
|
Model,
|
||||||
|
OpenAISettings,
|
||||||
|
)
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
|
@ -71,8 +77,8 @@ def run_conversation(current_chat: ChatProtocol) -> None:
|
||||||
console.print(
|
console.print(
|
||||||
f"[bold green]Temperature:[/bold green] {current_chat.settings.temperature}"
|
f"[bold green]Temperature:[/bold green] {current_chat.settings.temperature}"
|
||||||
)
|
)
|
||||||
if current_chat.conversation.name:
|
if current_chat.name:
|
||||||
console.print(f"[bold green]Name:[/bold green] {current_chat.conversation.name}")
|
console.print(f"[bold green]Name:[/bold green] {current_chat.name}")
|
||||||
|
|
||||||
while not finished:
|
while not finished:
|
||||||
prompt = read_user_input(session)
|
prompt = read_user_input(session)
|
||||||
|
@ -150,9 +156,16 @@ def chat(
|
||||||
"""Start a chat session."""
|
"""Start a chat session."""
|
||||||
# TODO: Add option to provide context string as an argument.
|
# TODO: Add option to provide context string as an argument.
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature, history_dir=history_dir)
|
settings = OpenAISettings(
|
||||||
|
api_key=api_key,
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
history_dir=history_dir,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
settings = OpenAISettings(model=model, temperature=temperature, history_dir=history_dir)
|
settings = OpenAISettings(
|
||||||
|
model=model, temperature=temperature, history_dir=history_dir
|
||||||
|
)
|
||||||
|
|
||||||
context_messages = [load_context(path) for path in context]
|
context_messages = [load_context(path) for path in context]
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,10 @@ from llm_chat.models import Conversation, Message, Role
|
||||||
from llm_chat.settings import Model, OpenAISettings
|
from llm_chat.settings import Model, OpenAISettings
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("name,expected_filename", [("", "20210101120000.json"), ("foo", "20210101120000-foo.json")])
|
@pytest.mark.parametrize(
|
||||||
|
"name,expected_filename",
|
||||||
|
[("", "20210101120000.json"), ("foo", "20210101120000-foo.json")],
|
||||||
|
)
|
||||||
def test_save_conversation(name: str, expected_filename: str, tmp_path: Path) -> None:
|
def test_save_conversation(name: str, expected_filename: str, tmp_path: Path) -> None:
|
||||||
conversation = Conversation(
|
conversation = Conversation(
|
||||||
messages=[
|
messages=[
|
||||||
|
|
|
@ -1,19 +1,16 @@
|
||||||
from datetime import datetime
|
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Type
|
from typing import Any, Type
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
from zoneinfo import ZoneInfo
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import time_machine
|
|
||||||
from pytest import MonkeyPatch
|
from pytest import MonkeyPatch
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from typer.testing import CliRunner
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
import llm_chat
|
import llm_chat
|
||||||
from llm_chat.chat import ChatProtocol, save_conversation
|
from llm_chat.chat import ChatProtocol
|
||||||
from llm_chat.cli import app
|
from llm_chat.cli import app
|
||||||
from llm_chat.models import Conversation, Message, Role
|
from llm_chat.models import Conversation, Message, Role
|
||||||
from llm_chat.settings import Model, OpenAISettings
|
from llm_chat.settings import Model, OpenAISettings
|
||||||
|
@ -45,6 +42,11 @@ class ChatFake:
|
||||||
"""Get the cost of the conversation."""
|
"""Get the cost of the conversation."""
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Get the name of the conversation."""
|
||||||
|
return self.args.get("name", "")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
cls, path: Path, api_key: str | None = None, history_dir: Path | None = None
|
cls, path: Path, api_key: str | None = None, history_dir: Path | None = None
|
||||||
|
@ -116,7 +118,11 @@ def test_chat_with_context(
|
||||||
assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)]
|
assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("argument,name", list(product(("--name", "-n"), ("", "foo"))), ids=[f"{arg} {name}" for arg, name in product(("--name", "-n"), ("", "foo"))])
|
@pytest.mark.parametrize(
|
||||||
|
"argument,name",
|
||||||
|
list(product(("--name", "-n"), ("", "foo"))),
|
||||||
|
ids=[f"{arg} {name}" for arg, name in product(("--name", "-n"), ("", "foo"))],
|
||||||
|
)
|
||||||
def test_chat_with_name(
|
def test_chat_with_name(
|
||||||
argument: str, name: str, monkeypatch: MonkeyPatch, tmp_path: Path
|
argument: str, name: str, monkeypatch: MonkeyPatch, tmp_path: Path
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
Loading…
Reference in New Issue