from io import StringIO from itertools import product from pathlib import Path from typing import Any, Type from unittest.mock import MagicMock import pytest from pytest import MonkeyPatch from rich.console import Console from typer.testing import CliRunner import llm_chat.cli from llm_chat.chat import ChatProtocol from llm_chat.cli.main import app from llm_chat.models import Conversation, Message, Role from llm_chat.settings import Model, OpenAISettings runner = CliRunner() class ChatFake: """Fake chat class for testing.""" args: dict[str, Any] conversation: Conversation received_messages: list[str] settings: OpenAISettings def __init__(self, settings: OpenAISettings | None = None) -> None: if settings is not None: self.settings = settings else: self.settings = OpenAISettings() self.args = {} self.received_messages = [] def _set_args(self, **kwargs: Any) -> None: self.args = kwargs @property def bot(self) -> str: """Get the name of the bot the conversation is with.""" return self.args.get("bot", "") @property def cost(self) -> float: """Get the cost of the conversation.""" return 0.0 @property def name(self) -> str: """Get the name of the conversation.""" return self.args.get("name", "") @classmethod def load( cls, path: Path, api_key: str | None = None, history_dir: Path | None = None ) -> ChatProtocol: """Load a chat from a file.""" return cls() def save(self) -> None: """Dummy save method.""" pass def send_message(self, message: str) -> str: """Echo the received message.""" self.received_messages.append(message) return message def test_chat(monkeypatch: MonkeyPatch) -> None: chat_fake = ChatFake() output = StringIO() console = Console(file=output) def mock_get_chat(**_: Any) -> ChatProtocol: return chat_fake def mock_get_console() -> Console: return console mock_read_user_input = MagicMock(side_effect=["Hello", "/q"]) monkeypatch.setattr(llm_chat.cli.main, "get_chat", mock_get_chat) monkeypatch.setattr(llm_chat.cli.main, "get_console", mock_get_console) monkeypatch.setattr(llm_chat.cli.main, "read_user_input", mock_read_user_input) result = runner.invoke(app, ["chat"]) assert result.exit_code == 0 assert chat_fake.received_messages == ["Hello"] @pytest.mark.parametrize("argument", ["--context", "-c"], ids=["--context", "-c"]) def test_chat_with_context( argument: str, monkeypatch: MonkeyPatch, tmp_path: Path ) -> None: context = "Hello, world!" tmp_file = tmp_path / "context.txt" tmp_file.write_text(context) chat_fake = ChatFake() output = StringIO() console = Console(file=output) def mock_get_chat(**kwargs: Any) -> ChatProtocol: chat_fake._set_args(**kwargs) return chat_fake def mock_get_console() -> Console: return console mock_read_user_input = MagicMock(side_effect=["Hello", "/q"]) monkeypatch.setattr(llm_chat.cli.main, "get_chat", mock_get_chat) monkeypatch.setattr(llm_chat.cli.main, "get_console", mock_get_console) monkeypatch.setattr(llm_chat.cli.main, "read_user_input", mock_read_user_input) result = runner.invoke(app, ["chat", argument, str(tmp_file)]) assert result.exit_code == 0 assert chat_fake.received_messages == ["Hello"] assert "context" in chat_fake.args assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)] @pytest.mark.parametrize( "argument,name", list(product(("--name", "-n"), ("", "foo"))), ids=[f"{arg} {name}" for arg, name in product(("--name", "-n"), ("", "foo"))], ) def test_chat_with_name( argument: str, name: str, monkeypatch: MonkeyPatch, tmp_path: Path ) -> None: chat_fake = ChatFake() output = StringIO() console = Console(file=output) def mock_get_chat(**kwargs: Any) -> ChatProtocol: chat_fake._set_args(**kwargs) return chat_fake def mock_get_console() -> Console: return console mock_read_user_input = MagicMock(side_effect=["Hello", "/q"]) monkeypatch.setattr(llm_chat.cli.main, "get_chat", mock_get_chat) monkeypatch.setattr(llm_chat.cli.main, "get_console", mock_get_console) monkeypatch.setattr(llm_chat.cli.main, "read_user_input", mock_read_user_input) result = runner.invoke(app, ["chat", argument, name]) assert result.exit_code == 0 assert chat_fake.args["name"] == name def test_load(monkeypatch: MonkeyPatch, tmp_path: Path) -> None: # Create a conversation object to save conversation = Conversation( messages=[ Message(role=Role.SYSTEM, content="Hello!"), Message(role=Role.USER, content="Hi!"), Message(role=Role.ASSISTANT, content="How are you?"), ], model=Model.GPT3, temperature=0.5, completion_tokens=10, prompt_tokens=15, cost=0.000043, ) # Save the conversation to a file file_path = tmp_path / "conversation.json" with file_path.open("w") as f: f.write(conversation.model_dump_json()) output = StringIO() console = Console(file=output) def mock_get_chat() -> Type[ChatFake]: return ChatFake def mock_get_console() -> Console: return console mock_read_user_input = MagicMock(side_effect=["Hello", "/q"]) monkeypatch.setattr(llm_chat.cli.main, "get_chat_class", mock_get_chat) monkeypatch.setattr(llm_chat.cli.main, "get_console", mock_get_console) monkeypatch.setattr(llm_chat.cli.main, "read_user_input", mock_read_user_input) # Load the conversation from the file result = runner.invoke(app, ["load", str(file_path)]) assert result.exit_code == 0