Optionally provide context files to chat session
This commit enables the user to provide one or more text files as context for their chat session. These will be provided as system messages to OpenAI's API, one message per file.
This commit is contained in:
parent
7cbde55aac
commit
2cd623e960
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "llm-chat"
|
name = "llm-chat"
|
||||||
version = "0.1.0"
|
version = "0.2.0"
|
||||||
description = "A general CLI interface for large language models."
|
description = "A general CLI interface for large language models."
|
||||||
authors = ["Paul Harrison <paul@harrison.sh>"]
|
authors = ["Paul Harrison <paul@harrison.sh>"]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
|
@ -77,6 +77,8 @@ class Chat:
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
||||||
def get_chat(settings: OpenAISettings | None = None) -> ChatProtocol:
|
def get_chat(
|
||||||
|
settings: OpenAISettings | None = None, context: list[Message] = []
|
||||||
|
) -> ChatProtocol:
|
||||||
"""Get a chat object."""
|
"""Get a chat object."""
|
||||||
return Chat(settings=settings)
|
return Chat(settings=settings, context=context)
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from pathlib import Path
|
||||||
from typing import Annotated, Any, Optional
|
from typing import Annotated, Any, Optional
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
@ -6,6 +7,7 @@ from rich.console import Console
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
|
|
||||||
from llm_chat.chat import get_chat
|
from llm_chat.chat import get_chat
|
||||||
|
from llm_chat.models import Message, Role
|
||||||
from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings
|
from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
@ -39,6 +41,20 @@ def read_user_input(session: PromptSession[Any]) -> str:
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def load_context(path: Path) -> Message:
|
||||||
|
"""Load context text from file."""
|
||||||
|
if not path.exists():
|
||||||
|
raise typer.BadParameter(f"File {path} does not exist.")
|
||||||
|
|
||||||
|
if not path.is_file():
|
||||||
|
raise typer.BadParameter(f"Path {path} is not a file.")
|
||||||
|
|
||||||
|
with path.open() as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
return Message(role=Role.SYSTEM, content=content)
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def chat(
|
def chat(
|
||||||
api_key: Annotated[
|
api_key: Annotated[
|
||||||
|
@ -63,6 +79,22 @@ def chat(
|
||||||
..., "--temperature", "-t", help="Model temperature (i.e. creativeness)."
|
..., "--temperature", "-t", help="Model temperature (i.e. creativeness)."
|
||||||
),
|
),
|
||||||
] = DEFAULT_TEMPERATURE,
|
] = DEFAULT_TEMPERATURE,
|
||||||
|
context: Annotated[
|
||||||
|
list[Path],
|
||||||
|
typer.Option(
|
||||||
|
...,
|
||||||
|
"--context",
|
||||||
|
"-c",
|
||||||
|
help=(
|
||||||
|
"Path to a file containing context text. "
|
||||||
|
"Can provide multiple time for multiple files."
|
||||||
|
),
|
||||||
|
exists=True,
|
||||||
|
file_okay=True,
|
||||||
|
dir_okay=False,
|
||||||
|
readable=True,
|
||||||
|
),
|
||||||
|
] = [],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start a chat session."""
|
"""Start a chat session."""
|
||||||
# TODO: Add option to load context from file.
|
# TODO: Add option to load context from file.
|
||||||
|
@ -72,7 +104,10 @@ def chat(
|
||||||
else:
|
else:
|
||||||
settings = OpenAISettings(model=model, temperature=temperature)
|
settings = OpenAISettings(model=model, temperature=temperature)
|
||||||
|
|
||||||
current_chat = get_chat(settings)
|
context_messages = [load_context(path) for path in context]
|
||||||
|
|
||||||
|
current_chat = get_chat(settings=settings, context=context_messages)
|
||||||
|
# current_chat = get_chat(settings=settings)
|
||||||
console = get_console()
|
console = get_console()
|
||||||
session = get_session()
|
session = get_session()
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
from pytest import MonkeyPatch
|
from pytest import MonkeyPatch
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from typer.testing import CliRunner
|
from typer.testing import CliRunner
|
||||||
|
@ -8,7 +11,7 @@ from typer.testing import CliRunner
|
||||||
import llm_chat
|
import llm_chat
|
||||||
from llm_chat.chat import ChatProtocol
|
from llm_chat.chat import ChatProtocol
|
||||||
from llm_chat.cli import app
|
from llm_chat.cli import app
|
||||||
from llm_chat.settings import OpenAISettings
|
from llm_chat.models import Message, Role
|
||||||
|
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
|
|
||||||
|
@ -16,7 +19,15 @@ runner = CliRunner()
|
||||||
class ChatFake:
|
class ChatFake:
|
||||||
"""Fake chat class for testing."""
|
"""Fake chat class for testing."""
|
||||||
|
|
||||||
received_messages: list[str] = []
|
args: dict[str, Any]
|
||||||
|
received_messages: list[str]
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.args = {}
|
||||||
|
self.received_messages = []
|
||||||
|
|
||||||
|
def _set_args(self, **kwargs: Any) -> None:
|
||||||
|
self.args = kwargs
|
||||||
|
|
||||||
def send_message(self, message: str) -> str:
|
def send_message(self, message: str) -> str:
|
||||||
"""Echo the received message."""
|
"""Echo the received message."""
|
||||||
|
@ -29,7 +40,7 @@ def test_chat(monkeypatch: MonkeyPatch) -> None:
|
||||||
output = StringIO()
|
output = StringIO()
|
||||||
console = Console(file=output)
|
console = Console(file=output)
|
||||||
|
|
||||||
def mock_get_chat(_: OpenAISettings) -> ChatProtocol:
|
def mock_get_chat(**_: Any) -> ChatProtocol:
|
||||||
return chat_fake
|
return chat_fake
|
||||||
|
|
||||||
def mock_get_console() -> Console:
|
def mock_get_console() -> Console:
|
||||||
|
@ -44,3 +55,35 @@ def test_chat(monkeypatch: MonkeyPatch) -> None:
|
||||||
result = runner.invoke(app)
|
result = runner.invoke(app)
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert chat_fake.received_messages == ["Hello"]
|
assert chat_fake.received_messages == ["Hello"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("argument", ["--context", "-c"], ids=["--context", "-c"])
|
||||||
|
def test_chat_with_context(
|
||||||
|
argument: str, monkeypatch: MonkeyPatch, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
context = "Hello, world!"
|
||||||
|
tmp_file = tmp_path / "context.txt"
|
||||||
|
tmp_file.write_text(context)
|
||||||
|
|
||||||
|
chat_fake = ChatFake()
|
||||||
|
output = StringIO()
|
||||||
|
console = Console(file=output)
|
||||||
|
|
||||||
|
def mock_get_chat(**kwargs: Any) -> ChatProtocol:
|
||||||
|
chat_fake._set_args(**kwargs)
|
||||||
|
return chat_fake
|
||||||
|
|
||||||
|
def mock_get_console() -> Console:
|
||||||
|
return console
|
||||||
|
|
||||||
|
mock_read_user_input = MagicMock(side_effect=["Hello", "/q"])
|
||||||
|
|
||||||
|
monkeypatch.setattr(llm_chat.cli, "get_chat", mock_get_chat)
|
||||||
|
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
|
||||||
|
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
|
||||||
|
|
||||||
|
result = runner.invoke(app, [argument, str(tmp_file)])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert chat_fake.received_messages == ["Hello"]
|
||||||
|
assert "context" in chat_fake.args
|
||||||
|
assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)]
|
||||||
|
|
Loading…
Reference in New Issue