from io import StringIO from pathlib import Path from typing import Any from unittest.mock import MagicMock import pytest from pytest import MonkeyPatch from rich.console import Console from typer.testing import CliRunner import llm_chat from llm_chat.chat import ChatProtocol from llm_chat.cli import app from llm_chat.models import Message, Role from llm_chat.settings import OpenAISettings runner = CliRunner() class ChatFake: """Fake chat class for testing.""" args: dict[str, Any] received_messages: list[str] settings: OpenAISettings def __init__(self, settings: OpenAISettings | None = None) -> None: if settings is not None: self.settings = settings else: self.settings = OpenAISettings() self.args = {} self.received_messages = [] def _set_args(self, **kwargs: Any) -> None: self.args = kwargs @property def cost(self) -> float: """Get the cost of the conversation.""" return 0.0 @classmethod def load( cls, path: Path, api_key: str | None = None, history_dir: Path | None = None ) -> ChatProtocol: """Load a chat from a file.""" return cls() def save(self) -> None: """Dummy save method.""" pass def send_message(self, message: str) -> str: """Echo the received message.""" self.received_messages.append(message) return message def test_chat(monkeypatch: MonkeyPatch) -> None: chat_fake = ChatFake() output = StringIO() console = Console(file=output) def mock_get_chat(**_: Any) -> ChatProtocol: return chat_fake def mock_get_console() -> Console: return console mock_read_user_input = MagicMock(side_effect=["Hello", "/q"]) monkeypatch.setattr(llm_chat.cli, "get_chat", mock_get_chat) monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console) monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input) result = runner.invoke(app, ["chat"]) assert result.exit_code == 0 assert chat_fake.received_messages == ["Hello"] @pytest.mark.parametrize("argument", ["--context", "-c"], ids=["--context", "-c"]) def test_chat_with_context( argument: str, monkeypatch: MonkeyPatch, tmp_path: Path ) -> None: context = "Hello, world!" tmp_file = tmp_path / "context.txt" tmp_file.write_text(context) chat_fake = ChatFake() output = StringIO() console = Console(file=output) def mock_get_chat(**kwargs: Any) -> ChatProtocol: chat_fake._set_args(**kwargs) return chat_fake def mock_get_console() -> Console: return console mock_read_user_input = MagicMock(side_effect=["Hello", "/q"]) monkeypatch.setattr(llm_chat.cli, "get_chat", mock_get_chat) monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console) monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input) result = runner.invoke(app, [argument, str(tmp_file)]) assert result.exit_code == 0 assert chat_fake.received_messages == ["Hello"] assert "context" in chat_fake.args assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)]