
195 lines
5.8 KiB
Raw Normal View History

from io import StringIO
2023-09-14 16:58:40 +00:00
from itertools import product
from pathlib import Path
from typing import Any, Type
from unittest.mock import MagicMock
import pytest
from pytest import MonkeyPatch
from rich.console import Console
from typer.testing import CliRunner
import llm_chat.cli
from llm_chat.chat import ChatProtocol
from llm_chat.cli.main import app
from llm_chat.models import Conversation, Message, Role
from llm_chat.settings import Model, OpenAISettings
runner = CliRunner()
class ChatFake:
"""Fake chat class for testing."""
args: dict[str, Any]
conversation: Conversation
received_messages: list[str]
settings: OpenAISettings
def __init__(self, settings: OpenAISettings | None = None) -> None:
if settings is not None:
self.settings = settings
self.settings = OpenAISettings()
self.args = {}
self.received_messages = []
def _set_args(self, **kwargs: Any) -> None:
self.args = kwargs
def bot(self) -> str:
"""Get the name of the bot the conversation is with."""
return self.args.get("bot", "")
def cost(self) -> float:
"""Get the cost of the conversation."""
return 0.0
def name(self) -> str:
"""Get the name of the conversation."""
return self.args.get("name", "")
def load(
cls, path: Path, api_key: str | None = None, history_dir: Path | None = None
) -> ChatProtocol:
"""Load a chat from a file."""
return cls()
def save(self) -> None:
"""Dummy save method."""
def send_message(self, message: str) -> str:
"""Echo the received message."""
return message
def test_chat(monkeypatch: MonkeyPatch) -> None:
chat_fake = ChatFake()
output = StringIO()
console = Console(file=output)
def mock_get_chat(**_: Any) -> ChatProtocol:
return chat_fake
def mock_get_console() -> Console:
return console
mock_read_user_input = MagicMock(side_effect=["Hello", "/q"])
monkeypatch.setattr(llm_chat.cli.main, "get_chat", mock_get_chat)
monkeypatch.setattr(llm_chat.cli.main, "get_console", mock_get_console)
monkeypatch.setattr(llm_chat.cli.main, "read_user_input", mock_read_user_input)
result = runner.invoke(app, ["chat"])
assert result.exit_code == 0
assert chat_fake.received_messages == ["Hello"]
@pytest.mark.parametrize("argument", ["--context", "-c"], ids=["--context", "-c"])
def test_chat_with_context(
argument: str, monkeypatch: MonkeyPatch, tmp_path: Path
) -> None:
context = "Hello, world!"
tmp_file = tmp_path / "context.txt"
chat_fake = ChatFake()
output = StringIO()
console = Console(file=output)
def mock_get_chat(**kwargs: Any) -> ChatProtocol:
return chat_fake
def mock_get_console() -> Console:
return console
mock_read_user_input = MagicMock(side_effect=["Hello", "/q"])
monkeypatch.setattr(llm_chat.cli.main, "get_chat", mock_get_chat)
monkeypatch.setattr(llm_chat.cli.main, "get_console", mock_get_console)
monkeypatch.setattr(llm_chat.cli.main, "read_user_input", mock_read_user_input)
result = runner.invoke(app, ["chat", argument, str(tmp_file)])
assert result.exit_code == 0
assert chat_fake.received_messages == ["Hello"]
assert "context" in chat_fake.args
assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)]
list(product(("--name", "-n"), ("", "foo"))),
ids=[f"{arg} {name}" for arg, name in product(("--name", "-n"), ("", "foo"))],
2023-09-14 16:58:40 +00:00
def test_chat_with_name(
argument: str, name: str, monkeypatch: MonkeyPatch, tmp_path: Path
) -> None:
chat_fake = ChatFake()
output = StringIO()
console = Console(file=output)
def mock_get_chat(**kwargs: Any) -> ChatProtocol:
return chat_fake
def mock_get_console() -> Console:
return console
mock_read_user_input = MagicMock(side_effect=["Hello", "/q"])
monkeypatch.setattr(llm_chat.cli.main, "get_chat", mock_get_chat)
monkeypatch.setattr(llm_chat.cli.main, "get_console", mock_get_console)
monkeypatch.setattr(llm_chat.cli.main, "read_user_input", mock_read_user_input)
2023-09-14 16:58:40 +00:00
result = runner.invoke(app, ["chat", argument, name])
assert result.exit_code == 0
assert chat_fake.args["name"] == name
def test_load(monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
# Create a conversation object to save
conversation = Conversation(
Message(role=Role.SYSTEM, content="Hello!"),
Message(role=Role.USER, content="Hi!"),
Message(role=Role.ASSISTANT, content="How are you?"),
# Save the conversation to a file
file_path = tmp_path / "conversation.json"
with file_path.open("w") as f:
output = StringIO()
console = Console(file=output)
def mock_get_chat() -> Type[ChatFake]:
return ChatFake
def mock_get_console() -> Console:
return console
mock_read_user_input = MagicMock(side_effect=["Hello", "/q"])
monkeypatch.setattr(llm_chat.cli.main, "get_chat_class", mock_get_chat)
monkeypatch.setattr(llm_chat.cli.main, "get_console", mock_get_console)
monkeypatch.setattr(llm_chat.cli.main, "read_user_input", mock_read_user_input)
# Load the conversation from the file
result = runner.invoke(app, ["load", str(file_path)])
assert result.exit_code == 0