llm-chat/tests/test_cli.py

112 lines
3.2 KiB
Python
Raw Normal View History

from io import StringIO
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock
import pytest
from pytest import MonkeyPatch
from rich.console import Console
from typer.testing import CliRunner
import llm_chat
from llm_chat.chat import ChatProtocol
from llm_chat.cli import app
from llm_chat.models import Message, Role
2023-08-24 17:35:27 +00:00
from llm_chat.settings import OpenAISettings
runner = CliRunner()
class ChatFake:
"""Fake chat class for testing."""
args: dict[str, Any]
received_messages: list[str]
2023-08-24 17:35:27 +00:00
settings: OpenAISettings
2023-08-24 17:35:27 +00:00
def __init__(self, settings: OpenAISettings | None = None) -> None:
if settings is not None:
self.settings = settings
else:
self.settings = OpenAISettings()
self.args = {}
self.received_messages = []
def _set_args(self, **kwargs: Any) -> None:
self.args = kwargs
@property
def cost(self) -> float:
"""Get the cost of the conversation."""
return 0.0
2023-08-24 17:35:27 +00:00
@classmethod
def load(
cls, path: Path, api_key: str | None = None, history_dir: Path | None = None
) -> ChatProtocol:
"""Load a chat from a file."""
return cls()
def save(self) -> None:
"""Dummy save method."""
pass
def send_message(self, message: str) -> str:
"""Echo the received message."""
self.received_messages.append(message)
return message
def test_chat(monkeypatch: MonkeyPatch) -> None:
chat_fake = ChatFake()
output = StringIO()
console = Console(file=output)
def mock_get_chat(**_: Any) -> ChatProtocol:
return chat_fake
def mock_get_console() -> Console:
return console
mock_read_user_input = MagicMock(side_effect=["Hello", "/q"])
monkeypatch.setattr(llm_chat.cli, "get_chat", mock_get_chat)
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
2023-08-24 17:35:27 +00:00
result = runner.invoke(app, ["chat"])
assert result.exit_code == 0
assert chat_fake.received_messages == ["Hello"]
@pytest.mark.parametrize("argument", ["--context", "-c"], ids=["--context", "-c"])
def test_chat_with_context(
argument: str, monkeypatch: MonkeyPatch, tmp_path: Path
) -> None:
context = "Hello, world!"
tmp_file = tmp_path / "context.txt"
tmp_file.write_text(context)
chat_fake = ChatFake()
output = StringIO()
console = Console(file=output)
def mock_get_chat(**kwargs: Any) -> ChatProtocol:
chat_fake._set_args(**kwargs)
return chat_fake
def mock_get_console() -> Console:
return console
mock_read_user_input = MagicMock(side_effect=["Hello", "/q"])
monkeypatch.setattr(llm_chat.cli, "get_chat", mock_get_chat)
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
result = runner.invoke(app, [argument, str(tmp_file)])
assert result.exit_code == 0
assert chat_fake.received_messages == ["Hello"]
assert "context" in chat_fake.args
assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)]