diff --git a/pyproject.toml b/pyproject.toml index 2849184..92ad0c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "llm-chat" -version = "0.1.0" +version = "0.2.0" description = "A general CLI interface for large language models." authors = ["Paul Harrison "] readme = "README.md" diff --git a/src/llm_chat/chat.py b/src/llm_chat/chat.py index 35cf640..1c65b10 100644 --- a/src/llm_chat/chat.py +++ b/src/llm_chat/chat.py @@ -77,6 +77,6 @@ class Chat: return message -def get_chat(settings: OpenAISettings | None = None) -> ChatProtocol: +def get_chat(settings: OpenAISettings | None = None, context: list[Message] = []) -> ChatProtocol: """Get a chat object.""" - return Chat(settings=settings) + return Chat(settings=settings, context=context) diff --git a/src/llm_chat/cli.py b/src/llm_chat/cli.py index 7d8cff1..9487113 100644 --- a/src/llm_chat/cli.py +++ b/src/llm_chat/cli.py @@ -1,3 +1,4 @@ +from pathlib import Path from typing import Annotated, Any, Optional import typer @@ -6,6 +7,7 @@ from rich.console import Console from rich.markdown import Markdown from llm_chat.chat import get_chat +from llm_chat.models import Message, Role from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings app = typer.Typer() @@ -39,6 +41,20 @@ def read_user_input(session: PromptSession[Any]) -> str: return prompt +def load_context(path: Path) -> Message: + """Load context text from file.""" + if not path.exists(): + raise typer.BadParameter(f"File {path} does not exist.") + + if not path.is_file(): + raise typer.BadParameter(f"Path {path} is not a file.") + + with path.open() as f: + content = f.read() + + return Message(role=Role.SYSTEM, content=content) + + @app.command() def chat( api_key: Annotated[ @@ -63,6 +79,19 @@ def chat( ..., "--temperature", "-t", help="Model temperature (i.e. creativeness)." ), ] = DEFAULT_TEMPERATURE, + context: Annotated[ + Optional[list[Path]], + typer.Option( + ..., + "--context", + "-c", + help="Path to a file containing context text.", + exists=True, + file_okay=True, + dir_okay=False, + readable=True, + ), + ] = [], ) -> None: """Start a chat session.""" # TODO: Add option to load context from file. @@ -71,8 +100,11 @@ def chat( settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature) else: settings = OpenAISettings(model=model, temperature=temperature) + + context_messages = [load_context(path) for path in context] - current_chat = get_chat(settings) + current_chat = get_chat(settings=settings, context=context_messages) + # current_chat = get_chat(settings=settings) console = get_console() session = get_session() diff --git a/tests/test_cli.py b/tests/test_cli.py index a64b4fc..c0c891d 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,6 +1,9 @@ from io import StringIO +from pathlib import Path +from typing import Any from unittest.mock import MagicMock +import pytest from pytest import MonkeyPatch from rich.console import Console from typer.testing import CliRunner @@ -8,6 +11,7 @@ from typer.testing import CliRunner import llm_chat from llm_chat.chat import ChatProtocol from llm_chat.cli import app +from llm_chat.models import Message, Role from llm_chat.settings import OpenAISettings runner = CliRunner() @@ -16,7 +20,15 @@ runner = CliRunner() class ChatFake: """Fake chat class for testing.""" - received_messages: list[str] = [] + args: dict[str, Any] + received_messages: list[str] + + def __init__(self) -> None: + self.args = {} + self.received_messages = [] + + def _set_args(self, **kwargs: Any) -> None: + self.args = kwargs def send_message(self, message: str) -> str: """Echo the received message.""" @@ -29,7 +41,7 @@ def test_chat(monkeypatch: MonkeyPatch) -> None: output = StringIO() console = Console(file=output) - def mock_get_chat(_: OpenAISettings) -> ChatProtocol: + def mock_get_chat(**_: Any) -> ChatProtocol: return chat_fake def mock_get_console() -> Console: @@ -44,3 +56,35 @@ def test_chat(monkeypatch: MonkeyPatch) -> None: result = runner.invoke(app) assert result.exit_code == 0 assert chat_fake.received_messages == ["Hello"] + + +@pytest.mark.parametrize( + "argument", ["--context", "-c"], ids=["--context", "-c"] +) +def test_chat_with_context(argument: str, monkeypatch: MonkeyPatch, tmp_path: Path) -> None: + context = "Hello, world!" + tmp_file = tmp_path / "context.txt" + tmp_file.write_text(context) + + chat_fake = ChatFake() + output = StringIO() + console = Console(file=output) + + def mock_get_chat(**kwargs: Any) -> ChatProtocol: + chat_fake._set_args(**kwargs) + return chat_fake + + def mock_get_console() -> Console: + return console + + mock_read_user_input = MagicMock(side_effect=["Hello", "/q"]) + + monkeypatch.setattr(llm_chat.cli, "get_chat", mock_get_chat) + monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console) + monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input) + + result = runner.invoke(app, [argument, str(tmp_file)]) + assert result.exit_code == 0 + assert chat_fake.received_messages == ["Hello"] + assert "context" in chat_fake.args + assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)]