llm-chat/tests/test_cli.py

184 lines
5.5 KiB
Python
Raw Normal View History

2023-09-14 16:58:40 +00:00
from datetime import datetime
from io import StringIO
2023-09-14 16:58:40 +00:00
from itertools import product
from pathlib import Path
from typing import Any, Type
from unittest.mock import MagicMock
2023-09-14 16:58:40 +00:00
from zoneinfo import ZoneInfo
import pytest
2023-09-14 16:58:40 +00:00
import time_machine
from pytest import MonkeyPatch
from rich.console import Console
from typer.testing import CliRunner
import llm_chat
2023-09-14 16:58:40 +00:00
from llm_chat.chat import ChatProtocol, save_conversation
from llm_chat.cli import app
from llm_chat.models import Conversation, Message, Role
from llm_chat.settings import Model, OpenAISettings
runner = CliRunner()
class ChatFake:
"""Fake chat class for testing."""
args: dict[str, Any]
conversation: Conversation
received_messages: list[str]
settings: OpenAISettings
def __init__(self, settings: OpenAISettings | None = None) -> None:
if settings is not None:
self.settings = settings
else:
self.settings = OpenAISettings()
self.args = {}
self.received_messages = []
def _set_args(self, **kwargs: Any) -> None:
self.args = kwargs
@property
def cost(self) -> float:
"""Get the cost of the conversation."""
return 0.0
@classmethod
def load(
cls, path: Path, api_key: str | None = None, history_dir: Path | None = None
) -> ChatProtocol:
"""Load a chat from a file."""
return cls()
def save(self) -> None:
"""Dummy save method."""
pass
def send_message(self, message: str) -> str:
"""Echo the received message."""
self.received_messages.append(message)
return message
def test_chat(monkeypatch: MonkeyPatch) -> None:
chat_fake = ChatFake()
output = StringIO()
console = Console(file=output)
def mock_get_chat(**_: Any) -> ChatProtocol:
return chat_fake
def mock_get_console() -> Console:
return console
mock_read_user_input = MagicMock(side_effect=["Hello", "/q"])
monkeypatch.setattr(llm_chat.cli, "get_chat", mock_get_chat)
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
result = runner.invoke(app, ["chat"])
assert result.exit_code == 0
assert chat_fake.received_messages == ["Hello"]
@pytest.mark.parametrize("argument", ["--context", "-c"], ids=["--context", "-c"])
def test_chat_with_context(
argument: str, monkeypatch: MonkeyPatch, tmp_path: Path
) -> None:
context = "Hello, world!"
tmp_file = tmp_path / "context.txt"
tmp_file.write_text(context)
chat_fake = ChatFake()
output = StringIO()
console = Console(file=output)
def mock_get_chat(**kwargs: Any) -> ChatProtocol:
chat_fake._set_args(**kwargs)
return chat_fake
def mock_get_console() -> Console:
return console
mock_read_user_input = MagicMock(side_effect=["Hello", "/q"])
monkeypatch.setattr(llm_chat.cli, "get_chat", mock_get_chat)
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
result = runner.invoke(app, ["chat", argument, str(tmp_file)])
assert result.exit_code == 0
assert chat_fake.received_messages == ["Hello"]
assert "context" in chat_fake.args
assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)]
2023-09-14 16:58:40 +00:00
@pytest.mark.parametrize("argument,name", list(product(("--name", "-n"), ("", "foo"))), ids=[f"{arg} {name}" for arg, name in product(("--name", "-n"), ("", "foo"))])
def test_chat_with_name(
argument: str, name: str, monkeypatch: MonkeyPatch, tmp_path: Path
) -> None:
chat_fake = ChatFake()
output = StringIO()
console = Console(file=output)
def mock_get_chat(**kwargs: Any) -> ChatProtocol:
chat_fake._set_args(**kwargs)
return chat_fake
def mock_get_console() -> Console:
return console
mock_read_user_input = MagicMock(side_effect=["Hello", "/q"])
monkeypatch.setattr(llm_chat.cli, "get_chat", mock_get_chat)
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
result = runner.invoke(app, ["chat", argument, name])
assert result.exit_code == 0
assert chat_fake.args["name"] == name
def test_load(monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
# Create a conversation object to save
conversation = Conversation(
messages=[
Message(role=Role.SYSTEM, content="Hello!"),
Message(role=Role.USER, content="Hi!"),
Message(role=Role.ASSISTANT, content="How are you?"),
],
model=Model.GPT3,
temperature=0.5,
completion_tokens=10,
prompt_tokens=15,
cost=0.000043,
)
# Save the conversation to a file
file_path = tmp_path / "conversation.json"
with file_path.open("w") as f:
f.write(conversation.model_dump_json())
output = StringIO()
console = Console(file=output)
def mock_get_chat() -> Type[ChatFake]:
return ChatFake
def mock_get_console() -> Console:
return console
mock_read_user_input = MagicMock(side_effect=["Hello", "/q"])
monkeypatch.setattr(llm_chat.cli, "get_chat_class", mock_get_chat)
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
# Load the conversation from the file
result = runner.invoke(app, ["load", str(file_path)])
assert result.exit_code == 0