From 2cd623e9600e4f8957731a5e54dff343a2780d00 Mon Sep 17 00:00:00 2001 From: Paul Harrison Date: Tue, 22 Aug 2023 17:08:41 +0100 Subject: [PATCH] Optionally provide context files to chat session This commit enables the user to provide one or more text files as context for their chat session. These will be provided as system messages to OpenAI's API, one message per file. --- pyproject.toml | 2 +- src/llm_chat/chat.py | 6 ++++-- src/llm_chat/cli.py | 37 ++++++++++++++++++++++++++++++++- tests/test_cli.py | 49 +++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 87 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2849184..92ad0c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "llm-chat" -version = "0.1.0" +version = "0.2.0" description = "A general CLI interface for large language models." authors = ["Paul Harrison "] readme = "README.md" diff --git a/src/llm_chat/chat.py b/src/llm_chat/chat.py index 35cf640..7ab0adf 100644 --- a/src/llm_chat/chat.py +++ b/src/llm_chat/chat.py @@ -77,6 +77,8 @@ class Chat: return message -def get_chat(settings: OpenAISettings | None = None) -> ChatProtocol: +def get_chat( + settings: OpenAISettings | None = None, context: list[Message] = [] +) -> ChatProtocol: """Get a chat object.""" - return Chat(settings=settings) + return Chat(settings=settings, context=context) diff --git a/src/llm_chat/cli.py b/src/llm_chat/cli.py index 7d8cff1..a6cc02f 100644 --- a/src/llm_chat/cli.py +++ b/src/llm_chat/cli.py @@ -1,3 +1,4 @@ +from pathlib import Path from typing import Annotated, Any, Optional import typer @@ -6,6 +7,7 @@ from rich.console import Console from rich.markdown import Markdown from llm_chat.chat import get_chat +from llm_chat.models import Message, Role from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings app = typer.Typer() @@ -39,6 +41,20 @@ def read_user_input(session: PromptSession[Any]) -> str: return prompt +def load_context(path: Path) -> Message: + """Load context text from file.""" + if not path.exists(): + raise typer.BadParameter(f"File {path} does not exist.") + + if not path.is_file(): + raise typer.BadParameter(f"Path {path} is not a file.") + + with path.open() as f: + content = f.read() + + return Message(role=Role.SYSTEM, content=content) + + @app.command() def chat( api_key: Annotated[ @@ -63,6 +79,22 @@ def chat( ..., "--temperature", "-t", help="Model temperature (i.e. creativeness)." ), ] = DEFAULT_TEMPERATURE, + context: Annotated[ + list[Path], + typer.Option( + ..., + "--context", + "-c", + help=( + "Path to a file containing context text. " + "Can provide multiple time for multiple files." + ), + exists=True, + file_okay=True, + dir_okay=False, + readable=True, + ), + ] = [], ) -> None: """Start a chat session.""" # TODO: Add option to load context from file. @@ -72,7 +104,10 @@ def chat( else: settings = OpenAISettings(model=model, temperature=temperature) - current_chat = get_chat(settings) + context_messages = [load_context(path) for path in context] + + current_chat = get_chat(settings=settings, context=context_messages) + # current_chat = get_chat(settings=settings) console = get_console() session = get_session() diff --git a/tests/test_cli.py b/tests/test_cli.py index a64b4fc..5861cf6 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,6 +1,9 @@ from io import StringIO +from pathlib import Path +from typing import Any from unittest.mock import MagicMock +import pytest from pytest import MonkeyPatch from rich.console import Console from typer.testing import CliRunner @@ -8,7 +11,7 @@ from typer.testing import CliRunner import llm_chat from llm_chat.chat import ChatProtocol from llm_chat.cli import app -from llm_chat.settings import OpenAISettings +from llm_chat.models import Message, Role runner = CliRunner() @@ -16,7 +19,15 @@ runner = CliRunner() class ChatFake: """Fake chat class for testing.""" - received_messages: list[str] = [] + args: dict[str, Any] + received_messages: list[str] + + def __init__(self) -> None: + self.args = {} + self.received_messages = [] + + def _set_args(self, **kwargs: Any) -> None: + self.args = kwargs def send_message(self, message: str) -> str: """Echo the received message.""" @@ -29,7 +40,7 @@ def test_chat(monkeypatch: MonkeyPatch) -> None: output = StringIO() console = Console(file=output) - def mock_get_chat(_: OpenAISettings) -> ChatProtocol: + def mock_get_chat(**_: Any) -> ChatProtocol: return chat_fake def mock_get_console() -> Console: @@ -44,3 +55,35 @@ def test_chat(monkeypatch: MonkeyPatch) -> None: result = runner.invoke(app) assert result.exit_code == 0 assert chat_fake.received_messages == ["Hello"] + + +@pytest.mark.parametrize("argument", ["--context", "-c"], ids=["--context", "-c"]) +def test_chat_with_context( + argument: str, monkeypatch: MonkeyPatch, tmp_path: Path +) -> None: + context = "Hello, world!" + tmp_file = tmp_path / "context.txt" + tmp_file.write_text(context) + + chat_fake = ChatFake() + output = StringIO() + console = Console(file=output) + + def mock_get_chat(**kwargs: Any) -> ChatProtocol: + chat_fake._set_args(**kwargs) + return chat_fake + + def mock_get_console() -> Console: + return console + + mock_read_user_input = MagicMock(side_effect=["Hello", "/q"]) + + monkeypatch.setattr(llm_chat.cli, "get_chat", mock_get_chat) + monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console) + monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input) + + result = runner.invoke(app, [argument, str(tmp_file)]) + assert result.exit_code == 0 + assert chat_fake.received_messages == ["Hello"] + assert "context" in chat_fake.args + assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)]