llm-chat/tests/test_cli.py

95 lines
2.6 KiB
Python
Raw Normal View History

from io import StringIO
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock
import pytest
from pytest import MonkeyPatch
from rich.console import Console
from typer.testing import CliRunner
import llm_chat
from llm_chat.chat import ChatProtocol
from llm_chat.cli import app
from llm_chat.models import Message, Role
runner = CliRunner()
class ChatFake:
"""Fake chat class for testing."""
args: dict[str, Any]
received_messages: list[str]
def __init__(self) -> None:
self.args = {}
self.received_messages = []
def _set_args(self, **kwargs: Any) -> None:
self.args = kwargs
@property
def cost(self) -> float:
"""Get the cost of the conversation."""
return 0.0
def send_message(self, message: str) -> str:
"""Echo the received message."""
self.received_messages.append(message)
return message
def test_chat(monkeypatch: MonkeyPatch) -> None:
chat_fake = ChatFake()
output = StringIO()
console = Console(file=output)
def mock_get_chat(**_: Any) -> ChatProtocol:
return chat_fake
def mock_get_console() -> Console:
return console
mock_read_user_input = MagicMock(side_effect=["Hello", "/q"])
monkeypatch.setattr(llm_chat.cli, "get_chat", mock_get_chat)
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
result = runner.invoke(app)
assert result.exit_code == 0
assert chat_fake.received_messages == ["Hello"]
@pytest.mark.parametrize("argument", ["--context", "-c"], ids=["--context", "-c"])
def test_chat_with_context(
argument: str, monkeypatch: MonkeyPatch, tmp_path: Path
) -> None:
context = "Hello, world!"
tmp_file = tmp_path / "context.txt"
tmp_file.write_text(context)
chat_fake = ChatFake()
output = StringIO()
console = Console(file=output)
def mock_get_chat(**kwargs: Any) -> ChatProtocol:
chat_fake._set_args(**kwargs)
return chat_fake
def mock_get_console() -> Console:
return console
mock_read_user_input = MagicMock(side_effect=["Hello", "/q"])
monkeypatch.setattr(llm_chat.cli, "get_chat", mock_get_chat)
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
result = runner.invoke(app, [argument, str(tmp_file)])
assert result.exit_code == 0
assert chat_fake.received_messages == ["Hello"]
assert "context" in chat_fake.args
assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)]