Include one or more context files in chat session
This commit is contained in:
parent
7cbde55aac
commit
d1a7cd68aa
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "llm-chat"
|
name = "llm-chat"
|
||||||
version = "0.1.0"
|
version = "0.2.0"
|
||||||
description = "A general CLI interface for large language models."
|
description = "A general CLI interface for large language models."
|
||||||
authors = ["Paul Harrison <paul@harrison.sh>"]
|
authors = ["Paul Harrison <paul@harrison.sh>"]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
|
@ -77,6 +77,8 @@ class Chat:
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
||||||
def get_chat(settings: OpenAISettings | None = None) -> ChatProtocol:
|
def get_chat(
|
||||||
|
settings: OpenAISettings | None = None, context: list[Message] = []
|
||||||
|
) -> ChatProtocol:
|
||||||
"""Get a chat object."""
|
"""Get a chat object."""
|
||||||
return Chat(settings=settings)
|
return Chat(settings=settings, context=context)
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from pathlib import Path
|
||||||
from typing import Annotated, Any, Optional
|
from typing import Annotated, Any, Optional
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
@ -6,6 +7,7 @@ from rich.console import Console
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
|
|
||||||
from llm_chat.chat import get_chat
|
from llm_chat.chat import get_chat
|
||||||
|
from llm_chat.models import Message, Role
|
||||||
from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings
|
from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
@ -39,6 +41,20 @@ def read_user_input(session: PromptSession[Any]) -> str:
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def load_context(path: Path) -> Message:
|
||||||
|
"""Load context text from file."""
|
||||||
|
if not path.exists():
|
||||||
|
raise typer.BadParameter(f"File {path} does not exist.")
|
||||||
|
|
||||||
|
if not path.is_file():
|
||||||
|
raise typer.BadParameter(f"Path {path} is not a file.")
|
||||||
|
|
||||||
|
with path.open() as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
return Message(role=Role.SYSTEM, content=content)
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def chat(
|
def chat(
|
||||||
api_key: Annotated[
|
api_key: Annotated[
|
||||||
|
@ -63,6 +79,22 @@ def chat(
|
||||||
..., "--temperature", "-t", help="Model temperature (i.e. creativeness)."
|
..., "--temperature", "-t", help="Model temperature (i.e. creativeness)."
|
||||||
),
|
),
|
||||||
] = DEFAULT_TEMPERATURE,
|
] = DEFAULT_TEMPERATURE,
|
||||||
|
context: Annotated[
|
||||||
|
list[Path],
|
||||||
|
typer.Option(
|
||||||
|
...,
|
||||||
|
"--context",
|
||||||
|
"-c",
|
||||||
|
help=(
|
||||||
|
"Path to a file containing context text. "
|
||||||
|
"Can provide multiple time for multiple files."
|
||||||
|
),
|
||||||
|
exists=True,
|
||||||
|
file_okay=True,
|
||||||
|
dir_okay=False,
|
||||||
|
readable=True,
|
||||||
|
),
|
||||||
|
] = [],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start a chat session."""
|
"""Start a chat session."""
|
||||||
# TODO: Add option to load context from file.
|
# TODO: Add option to load context from file.
|
||||||
|
@ -72,7 +104,10 @@ def chat(
|
||||||
else:
|
else:
|
||||||
settings = OpenAISettings(model=model, temperature=temperature)
|
settings = OpenAISettings(model=model, temperature=temperature)
|
||||||
|
|
||||||
current_chat = get_chat(settings)
|
context_messages = [load_context(path) for path in context]
|
||||||
|
|
||||||
|
current_chat = get_chat(settings=settings, context=context_messages)
|
||||||
|
# current_chat = get_chat(settings=settings)
|
||||||
console = get_console()
|
console = get_console()
|
||||||
session = get_session()
|
session = get_session()
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
from pytest import MonkeyPatch
|
from pytest import MonkeyPatch
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from typer.testing import CliRunner
|
from typer.testing import CliRunner
|
||||||
|
@ -8,7 +11,7 @@ from typer.testing import CliRunner
|
||||||
import llm_chat
|
import llm_chat
|
||||||
from llm_chat.chat import ChatProtocol
|
from llm_chat.chat import ChatProtocol
|
||||||
from llm_chat.cli import app
|
from llm_chat.cli import app
|
||||||
from llm_chat.settings import OpenAISettings
|
from llm_chat.models import Message, Role
|
||||||
|
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
|
|
||||||
|
@ -16,7 +19,15 @@ runner = CliRunner()
|
||||||
class ChatFake:
|
class ChatFake:
|
||||||
"""Fake chat class for testing."""
|
"""Fake chat class for testing."""
|
||||||
|
|
||||||
received_messages: list[str] = []
|
args: dict[str, Any]
|
||||||
|
received_messages: list[str]
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.args = {}
|
||||||
|
self.received_messages = []
|
||||||
|
|
||||||
|
def _set_args(self, **kwargs: Any) -> None:
|
||||||
|
self.args = kwargs
|
||||||
|
|
||||||
def send_message(self, message: str) -> str:
|
def send_message(self, message: str) -> str:
|
||||||
"""Echo the received message."""
|
"""Echo the received message."""
|
||||||
|
@ -29,7 +40,7 @@ def test_chat(monkeypatch: MonkeyPatch) -> None:
|
||||||
output = StringIO()
|
output = StringIO()
|
||||||
console = Console(file=output)
|
console = Console(file=output)
|
||||||
|
|
||||||
def mock_get_chat(_: OpenAISettings) -> ChatProtocol:
|
def mock_get_chat(**_: Any) -> ChatProtocol:
|
||||||
return chat_fake
|
return chat_fake
|
||||||
|
|
||||||
def mock_get_console() -> Console:
|
def mock_get_console() -> Console:
|
||||||
|
@ -44,3 +55,35 @@ def test_chat(monkeypatch: MonkeyPatch) -> None:
|
||||||
result = runner.invoke(app)
|
result = runner.invoke(app)
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert chat_fake.received_messages == ["Hello"]
|
assert chat_fake.received_messages == ["Hello"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("argument", ["--context", "-c"], ids=["--context", "-c"])
|
||||||
|
def test_chat_with_context(
|
||||||
|
argument: str, monkeypatch: MonkeyPatch, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
context = "Hello, world!"
|
||||||
|
tmp_file = tmp_path / "context.txt"
|
||||||
|
tmp_file.write_text(context)
|
||||||
|
|
||||||
|
chat_fake = ChatFake()
|
||||||
|
output = StringIO()
|
||||||
|
console = Console(file=output)
|
||||||
|
|
||||||
|
def mock_get_chat(**kwargs: Any) -> ChatProtocol:
|
||||||
|
chat_fake._set_args(**kwargs)
|
||||||
|
return chat_fake
|
||||||
|
|
||||||
|
def mock_get_console() -> Console:
|
||||||
|
return console
|
||||||
|
|
||||||
|
mock_read_user_input = MagicMock(side_effect=["Hello", "/q"])
|
||||||
|
|
||||||
|
monkeypatch.setattr(llm_chat.cli, "get_chat", mock_get_chat)
|
||||||
|
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
|
||||||
|
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
|
||||||
|
|
||||||
|
result = runner.invoke(app, [argument, str(tmp_file)])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert chat_fake.received_messages == ["Hello"]
|
||||||
|
assert "context" in chat_fake.args
|
||||||
|
assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)]
|
||||||
|
|
Loading…
Reference in New Issue