Include one or more context files in chat session

This commit is contained in:
Paul Harrison 2023-08-22 17:08:41 +01:00
parent 7cbde55aac
commit d1a7cd68aa
4 changed files with 87 additions and 7 deletions

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "llm-chat" name = "llm-chat"
version = "0.1.0" version = "0.2.0"
description = "A general CLI interface for large language models." description = "A general CLI interface for large language models."
authors = ["Paul Harrison <paul@harrison.sh>"] authors = ["Paul Harrison <paul@harrison.sh>"]
readme = "README.md" readme = "README.md"

View File

@ -77,6 +77,8 @@ class Chat:
return message return message
def get_chat(settings: OpenAISettings | None = None) -> ChatProtocol: def get_chat(
settings: OpenAISettings | None = None, context: list[Message] = []
) -> ChatProtocol:
"""Get a chat object.""" """Get a chat object."""
return Chat(settings=settings) return Chat(settings=settings, context=context)

View File

@ -1,3 +1,4 @@
from pathlib import Path
from typing import Annotated, Any, Optional from typing import Annotated, Any, Optional
import typer import typer
@ -6,6 +7,7 @@ from rich.console import Console
from rich.markdown import Markdown from rich.markdown import Markdown
from llm_chat.chat import get_chat from llm_chat.chat import get_chat
from llm_chat.models import Message, Role
from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings
app = typer.Typer() app = typer.Typer()
@ -39,6 +41,20 @@ def read_user_input(session: PromptSession[Any]) -> str:
return prompt return prompt
def load_context(path: Path) -> Message:
"""Load context text from file."""
if not path.exists():
raise typer.BadParameter(f"File {path} does not exist.")
if not path.is_file():
raise typer.BadParameter(f"Path {path} is not a file.")
with path.open() as f:
content = f.read()
return Message(role=Role.SYSTEM, content=content)
@app.command() @app.command()
def chat( def chat(
api_key: Annotated[ api_key: Annotated[
@ -63,6 +79,22 @@ def chat(
..., "--temperature", "-t", help="Model temperature (i.e. creativeness)." ..., "--temperature", "-t", help="Model temperature (i.e. creativeness)."
), ),
] = DEFAULT_TEMPERATURE, ] = DEFAULT_TEMPERATURE,
context: Annotated[
list[Path],
typer.Option(
...,
"--context",
"-c",
help=(
"Path to a file containing context text. "
"Can provide multiple time for multiple files."
),
exists=True,
file_okay=True,
dir_okay=False,
readable=True,
),
] = [],
) -> None: ) -> None:
"""Start a chat session.""" """Start a chat session."""
# TODO: Add option to load context from file. # TODO: Add option to load context from file.
@ -72,7 +104,10 @@ def chat(
else: else:
settings = OpenAISettings(model=model, temperature=temperature) settings = OpenAISettings(model=model, temperature=temperature)
current_chat = get_chat(settings) context_messages = [load_context(path) for path in context]
current_chat = get_chat(settings=settings, context=context_messages)
# current_chat = get_chat(settings=settings)
console = get_console() console = get_console()
session = get_session() session = get_session()

View File

@ -1,6 +1,9 @@
from io import StringIO from io import StringIO
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest
from pytest import MonkeyPatch from pytest import MonkeyPatch
from rich.console import Console from rich.console import Console
from typer.testing import CliRunner from typer.testing import CliRunner
@ -8,7 +11,7 @@ from typer.testing import CliRunner
import llm_chat import llm_chat
from llm_chat.chat import ChatProtocol from llm_chat.chat import ChatProtocol
from llm_chat.cli import app from llm_chat.cli import app
from llm_chat.settings import OpenAISettings from llm_chat.models import Message, Role
runner = CliRunner() runner = CliRunner()
@ -16,7 +19,15 @@ runner = CliRunner()
class ChatFake: class ChatFake:
"""Fake chat class for testing.""" """Fake chat class for testing."""
received_messages: list[str] = [] args: dict[str, Any]
received_messages: list[str]
def __init__(self) -> None:
self.args = {}
self.received_messages = []
def _set_args(self, **kwargs: Any) -> None:
self.args = kwargs
def send_message(self, message: str) -> str: def send_message(self, message: str) -> str:
"""Echo the received message.""" """Echo the received message."""
@ -29,7 +40,7 @@ def test_chat(monkeypatch: MonkeyPatch) -> None:
output = StringIO() output = StringIO()
console = Console(file=output) console = Console(file=output)
def mock_get_chat(_: OpenAISettings) -> ChatProtocol: def mock_get_chat(**_: Any) -> ChatProtocol:
return chat_fake return chat_fake
def mock_get_console() -> Console: def mock_get_console() -> Console:
@ -44,3 +55,35 @@ def test_chat(monkeypatch: MonkeyPatch) -> None:
result = runner.invoke(app) result = runner.invoke(app)
assert result.exit_code == 0 assert result.exit_code == 0
assert chat_fake.received_messages == ["Hello"] assert chat_fake.received_messages == ["Hello"]
@pytest.mark.parametrize("argument", ["--context", "-c"], ids=["--context", "-c"])
def test_chat_with_context(
argument: str, monkeypatch: MonkeyPatch, tmp_path: Path
) -> None:
context = "Hello, world!"
tmp_file = tmp_path / "context.txt"
tmp_file.write_text(context)
chat_fake = ChatFake()
output = StringIO()
console = Console(file=output)
def mock_get_chat(**kwargs: Any) -> ChatProtocol:
chat_fake._set_args(**kwargs)
return chat_fake
def mock_get_console() -> Console:
return console
mock_read_user_input = MagicMock(side_effect=["Hello", "/q"])
monkeypatch.setattr(llm_chat.cli, "get_chat", mock_get_chat)
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
result = runner.invoke(app, [argument, str(tmp_file)])
assert result.exit_code == 0
assert chat_fake.received_messages == ["Hello"]
assert "context" in chat_fake.args
assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)]