Optionally provide context files to chat session

This commit enables the user to provide one or more text files as
context for their chat session. These will be provided as system
messages to OpenAI's API, one message per file.
This commit is contained in:
Paul Harrison 2023-08-22 17:08:41 +01:00
parent 7cbde55aac
commit 2cd623e960
4 changed files with 87 additions and 7 deletions

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "llm-chat"
version = "0.1.0"
version = "0.2.0"
description = "A general CLI interface for large language models."
authors = ["Paul Harrison <paul@harrison.sh>"]
readme = "README.md"

View File

@ -77,6 +77,8 @@ class Chat:
return message
def get_chat(settings: OpenAISettings | None = None) -> ChatProtocol:
def get_chat(
settings: OpenAISettings | None = None, context: list[Message] = []
) -> ChatProtocol:
"""Get a chat object."""
return Chat(settings=settings)
return Chat(settings=settings, context=context)

View File

@ -1,3 +1,4 @@
from pathlib import Path
from typing import Annotated, Any, Optional
import typer
@ -6,6 +7,7 @@ from rich.console import Console
from rich.markdown import Markdown
from llm_chat.chat import get_chat
from llm_chat.models import Message, Role
from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings
app = typer.Typer()
@ -39,6 +41,20 @@ def read_user_input(session: PromptSession[Any]) -> str:
return prompt
def load_context(path: Path) -> Message:
"""Load context text from file."""
if not path.exists():
raise typer.BadParameter(f"File {path} does not exist.")
if not path.is_file():
raise typer.BadParameter(f"Path {path} is not a file.")
with path.open() as f:
content = f.read()
return Message(role=Role.SYSTEM, content=content)
@app.command()
def chat(
api_key: Annotated[
@ -63,6 +79,22 @@ def chat(
..., "--temperature", "-t", help="Model temperature (i.e. creativeness)."
),
] = DEFAULT_TEMPERATURE,
context: Annotated[
list[Path],
typer.Option(
...,
"--context",
"-c",
help=(
"Path to a file containing context text. "
"Can provide multiple time for multiple files."
),
exists=True,
file_okay=True,
dir_okay=False,
readable=True,
),
] = [],
) -> None:
"""Start a chat session."""
# TODO: Add option to load context from file.
@ -72,7 +104,10 @@ def chat(
else:
settings = OpenAISettings(model=model, temperature=temperature)
current_chat = get_chat(settings)
context_messages = [load_context(path) for path in context]
current_chat = get_chat(settings=settings, context=context_messages)
# current_chat = get_chat(settings=settings)
console = get_console()
session = get_session()

View File

@ -1,6 +1,9 @@
from io import StringIO
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock
import pytest
from pytest import MonkeyPatch
from rich.console import Console
from typer.testing import CliRunner
@ -8,7 +11,7 @@ from typer.testing import CliRunner
import llm_chat
from llm_chat.chat import ChatProtocol
from llm_chat.cli import app
from llm_chat.settings import OpenAISettings
from llm_chat.models import Message, Role
runner = CliRunner()
@ -16,7 +19,15 @@ runner = CliRunner()
class ChatFake:
"""Fake chat class for testing."""
received_messages: list[str] = []
args: dict[str, Any]
received_messages: list[str]
def __init__(self) -> None:
self.args = {}
self.received_messages = []
def _set_args(self, **kwargs: Any) -> None:
self.args = kwargs
def send_message(self, message: str) -> str:
"""Echo the received message."""
@ -29,7 +40,7 @@ def test_chat(monkeypatch: MonkeyPatch) -> None:
output = StringIO()
console = Console(file=output)
def mock_get_chat(_: OpenAISettings) -> ChatProtocol:
def mock_get_chat(**_: Any) -> ChatProtocol:
return chat_fake
def mock_get_console() -> Console:
@ -44,3 +55,35 @@ def test_chat(monkeypatch: MonkeyPatch) -> None:
result = runner.invoke(app)
assert result.exit_code == 0
assert chat_fake.received_messages == ["Hello"]
@pytest.mark.parametrize("argument", ["--context", "-c"], ids=["--context", "-c"])
def test_chat_with_context(
argument: str, monkeypatch: MonkeyPatch, tmp_path: Path
) -> None:
context = "Hello, world!"
tmp_file = tmp_path / "context.txt"
tmp_file.write_text(context)
chat_fake = ChatFake()
output = StringIO()
console = Console(file=output)
def mock_get_chat(**kwargs: Any) -> ChatProtocol:
chat_fake._set_args(**kwargs)
return chat_fake
def mock_get_console() -> Console:
return console
mock_read_user_input = MagicMock(side_effect=["Hello", "/q"])
monkeypatch.setattr(llm_chat.cli, "get_chat", mock_get_chat)
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
result = runner.invoke(app, [argument, str(tmp_file)])
assert result.exit_code == 0
assert chat_fake.received_messages == ["Hello"]
assert "context" in chat_fake.args
assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)]