2023-08-19 14:58:54 +00:00
|
|
|
from io import StringIO
|
2023-08-22 16:08:41 +00:00
|
|
|
from pathlib import Path
|
|
|
|
from typing import Any
|
2023-08-19 14:58:54 +00:00
|
|
|
from unittest.mock import MagicMock
|
|
|
|
|
2023-08-22 16:08:41 +00:00
|
|
|
import pytest
|
2023-08-19 14:58:54 +00:00
|
|
|
from pytest import MonkeyPatch
|
|
|
|
from rich.console import Console
|
|
|
|
from typer.testing import CliRunner
|
|
|
|
|
|
|
|
import llm_chat
|
|
|
|
from llm_chat.chat import ChatProtocol
|
|
|
|
from llm_chat.cli import app
|
2023-08-22 16:08:41 +00:00
|
|
|
from llm_chat.models import Message, Role
|
2023-08-24 17:35:27 +00:00
|
|
|
from llm_chat.settings import OpenAISettings
|
2023-08-19 14:58:54 +00:00
|
|
|
|
|
|
|
runner = CliRunner()
|
|
|
|
|
|
|
|
|
|
|
|
class ChatFake:
|
|
|
|
"""Fake chat class for testing."""
|
|
|
|
|
2023-08-22 16:08:41 +00:00
|
|
|
args: dict[str, Any]
|
|
|
|
received_messages: list[str]
|
2023-08-24 17:35:27 +00:00
|
|
|
settings: OpenAISettings
|
2023-08-22 16:08:41 +00:00
|
|
|
|
2023-08-24 17:35:27 +00:00
|
|
|
def __init__(self, settings: OpenAISettings | None = None) -> None:
|
|
|
|
if settings is not None:
|
|
|
|
self.settings = settings
|
|
|
|
else:
|
|
|
|
self.settings = OpenAISettings()
|
2023-08-22 16:08:41 +00:00
|
|
|
self.args = {}
|
|
|
|
self.received_messages = []
|
|
|
|
|
|
|
|
def _set_args(self, **kwargs: Any) -> None:
|
|
|
|
self.args = kwargs
|
2023-08-19 14:58:54 +00:00
|
|
|
|
2023-08-24 14:23:32 +00:00
|
|
|
@property
|
|
|
|
def cost(self) -> float:
|
|
|
|
"""Get the cost of the conversation."""
|
|
|
|
return 0.0
|
2023-08-24 17:35:27 +00:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def load(
|
|
|
|
cls, path: Path, api_key: str | None = None, history_dir: Path | None = None
|
|
|
|
) -> ChatProtocol:
|
|
|
|
"""Load a chat from a file."""
|
|
|
|
return cls()
|
2023-08-24 14:23:32 +00:00
|
|
|
|
2023-08-24 17:01:34 +00:00
|
|
|
def save(self) -> None:
|
|
|
|
"""Dummy save method."""
|
|
|
|
pass
|
|
|
|
|
2023-08-19 14:58:54 +00:00
|
|
|
def send_message(self, message: str) -> str:
|
|
|
|
"""Echo the received message."""
|
|
|
|
self.received_messages.append(message)
|
|
|
|
return message
|
|
|
|
|
|
|
|
|
|
|
|
def test_chat(monkeypatch: MonkeyPatch) -> None:
|
|
|
|
chat_fake = ChatFake()
|
|
|
|
output = StringIO()
|
|
|
|
console = Console(file=output)
|
|
|
|
|
2023-08-22 16:08:41 +00:00
|
|
|
def mock_get_chat(**_: Any) -> ChatProtocol:
|
2023-08-19 14:58:54 +00:00
|
|
|
return chat_fake
|
|
|
|
|
|
|
|
def mock_get_console() -> Console:
|
|
|
|
return console
|
|
|
|
|
|
|
|
mock_read_user_input = MagicMock(side_effect=["Hello", "/q"])
|
|
|
|
|
|
|
|
monkeypatch.setattr(llm_chat.cli, "get_chat", mock_get_chat)
|
|
|
|
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
|
|
|
|
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
|
|
|
|
|
2023-08-24 17:35:27 +00:00
|
|
|
result = runner.invoke(app, ["chat"])
|
2023-08-19 14:58:54 +00:00
|
|
|
assert result.exit_code == 0
|
|
|
|
assert chat_fake.received_messages == ["Hello"]
|
2023-08-22 16:08:41 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("argument", ["--context", "-c"], ids=["--context", "-c"])
|
|
|
|
def test_chat_with_context(
|
|
|
|
argument: str, monkeypatch: MonkeyPatch, tmp_path: Path
|
|
|
|
) -> None:
|
|
|
|
context = "Hello, world!"
|
|
|
|
tmp_file = tmp_path / "context.txt"
|
|
|
|
tmp_file.write_text(context)
|
|
|
|
|
|
|
|
chat_fake = ChatFake()
|
|
|
|
output = StringIO()
|
|
|
|
console = Console(file=output)
|
|
|
|
|
|
|
|
def mock_get_chat(**kwargs: Any) -> ChatProtocol:
|
|
|
|
chat_fake._set_args(**kwargs)
|
|
|
|
return chat_fake
|
|
|
|
|
|
|
|
def mock_get_console() -> Console:
|
|
|
|
return console
|
|
|
|
|
|
|
|
mock_read_user_input = MagicMock(side_effect=["Hello", "/q"])
|
|
|
|
|
|
|
|
monkeypatch.setattr(llm_chat.cli, "get_chat", mock_get_chat)
|
|
|
|
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
|
|
|
|
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
|
|
|
|
|
|
|
|
result = runner.invoke(app, [argument, str(tmp_file)])
|
|
|
|
assert result.exit_code == 0
|
|
|
|
assert chat_fake.received_messages == ["Hello"]
|
|
|
|
assert "context" in chat_fake.args
|
|
|
|
assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)]
|