Load previous conversation #2

Merged
paul merged 1 commits from feat/load-previous-conversation into main 2023-09-14 06:23:06 +00:00
5 changed files with 200 additions and 26 deletions

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "llm-chat" name = "llm-chat"
version = "0.4.0" version = "0.5.0"
description = "A general CLI interface for large language models." description = "A general CLI interface for large language models."
authors = ["Paul Harrison <paul@harrison.sh>"] authors = ["Paul Harrison <paul@harrison.sh>"]
readme = "README.md" readme = "README.md"
@ -23,7 +23,7 @@ mypy = "^1.5.0"
pydocstyle = "^6.3.0" pydocstyle = "^6.3.0"
[tool.poetry.scripts] [tool.poetry.scripts]
chat = "llm_chat.cli:app" llm = "llm_chat.cli:app"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]

View File

@ -1,7 +1,7 @@
from datetime import datetime from datetime import datetime
from enum import StrEnum, auto from enum import StrEnum, auto
from pathlib import Path from pathlib import Path
from typing import Any, Protocol from typing import Any, Protocol, Type
from zoneinfo import ZoneInfo from zoneinfo import ZoneInfo
from openai import ChatCompletion from openai import ChatCompletion
@ -45,10 +45,16 @@ class Token(StrEnum):
class ChatProtocol(Protocol): class ChatProtocol(Protocol):
"""Protocol for chat classes.""" """Protocol for chat classes."""
conversation: Conversation
@property @property
def cost(self) -> float: def cost(self) -> float:
"""Get the cost of the conversation.""" """Get the cost of the conversation."""
@property
def settings(self) -> OpenAISettings:
"""Get OpenAI chat settings."""
def save(self) -> None: def save(self) -> None:
"""Save the conversation to the history directory.""" """Save the conversation to the history directory."""
@ -80,15 +86,41 @@ class Chat:
self, self,
settings: OpenAISettings | None = None, settings: OpenAISettings | None = None,
context: list[Message] = [], context: list[Message] = [],
initial_system_messages: bool = True,
) -> None: ) -> None:
self._settings = settings self._settings = settings
self.conversation = Conversation( self.conversation = Conversation(
messages=INITIAL_SYSTEM_MESSAGES + context, messages=INITIAL_SYSTEM_MESSAGES + context
if initial_system_messages
else context,
model=self.settings.model, model=self.settings.model,
temperature=self.settings.temperature, temperature=self.settings.temperature,
) )
self._start_time = datetime.now(tz=ZoneInfo("UTC")) self._start_time = datetime.now(tz=ZoneInfo("UTC"))
@classmethod
def load(
cls, path: Path, api_key: str | None = None, history_dir: Path | None = None
) -> ChatProtocol:
"""Load a chat from a file."""
with path.open() as f:
conversation = Conversation.model_validate_json(f.read())
args: dict[str, Any] = {
"model": conversation.model,
"temperature": conversation.temperature,
}
if api_key is not None:
args["api_key"] = api_key
if history_dir is not None:
args["history_dir"] = history_dir
settings = OpenAISettings(**args)
return cls(
settings=settings,
context=conversation.messages,
initial_system_messages=False,
)
@property @property
def settings(self) -> OpenAISettings: def settings(self) -> OpenAISettings:
"""Get OpenAI chat settings.""" """Get OpenAI chat settings."""
@ -158,3 +190,8 @@ def get_chat(
) -> ChatProtocol: ) -> ChatProtocol:
"""Get a chat object.""" """Get a chat object."""
return Chat(settings=settings, context=context) return Chat(settings=settings, context=context)
def get_chat_class() -> Type[Chat]:
"""Get the chat class."""
return Chat

View File

@ -6,7 +6,7 @@ from prompt_toolkit import PromptSession
from rich.console import Console from rich.console import Console
from rich.markdown import Markdown from rich.markdown import Markdown
from llm_chat.chat import ChatProtocol, get_chat from llm_chat.chat import ChatProtocol, get_chat, get_chat_class
from llm_chat.models import Message, Role from llm_chat.models import Message, Role
from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings
@ -60,7 +60,30 @@ def display_cost(console: Console, chat: ChatProtocol) -> None:
console.print(f"\n[bold green]Cost:[/bold green] ${chat.cost}\n") console.print(f"\n[bold green]Cost:[/bold green] ${chat.cost}\n")
@app.command() def run_conversation(current_chat: ChatProtocol) -> None:
"""Run a conversation."""
console = get_console()
session = get_session()
finished = False
console.print(f"[bold green]Model:[/bold green] {current_chat.settings.model}")
console.print(
f"[bold green]Temperature:[/bold green] {current_chat.settings.temperature}"
)
while not finished:
prompt = read_user_input(session)
if prompt.strip() == "/q":
finished = True
current_chat.save()
else:
response = current_chat.send_message(prompt.strip())
console.print(Markdown(response))
display_cost(console, current_chat)
@app.command("chat")
def chat( def chat(
api_key: Annotated[ api_key: Annotated[
Optional[str], Optional[str],
@ -102,7 +125,6 @@ def chat(
] = [], ] = [],
) -> None: ) -> None:
"""Start a chat session.""" """Start a chat session."""
# TODO: Add option to load context from file.
# TODO: Add option to provide context string as an argument. # TODO: Add option to provide context string as an argument.
if api_key is not None: if api_key is not None:
settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature) settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature)
@ -112,20 +134,44 @@ def chat(
context_messages = [load_context(path) for path in context] context_messages = [load_context(path) for path in context]
current_chat = get_chat(settings=settings, context=context_messages) current_chat = get_chat(settings=settings, context=context_messages)
console = get_console()
session = get_session()
finished = False run_conversation(current_chat)
console.print(f"[bold green]Model:[/bold green] {settings.model}")
console.print(f"[bold green]Temperature:[/bold green] {settings.temperature}")
while not finished: @app.command("load")
prompt = read_user_input(session) def load(
if prompt.strip() == "/q": path: Annotated[
finished = True Path,
current_chat.save() typer.Argument(
else: help="Path to a conversation file.",
response = current_chat.send_message(prompt.strip()) exists=True,
console.print(Markdown(response)) file_okay=True,
display_cost(console, current_chat) dir_okay=False,
readable=True,
),
],
api_key: Annotated[
Optional[str],
typer.Option(
...,
"--api-key",
"-k",
help=(
"API key. Will read from the environment variable OPENAI_API_KEY "
"if not provided."
),
),
] = None,
) -> None:
"""Load a conversation from a file."""
Chat = get_chat_class()
if api_key is not None:
current_chat = Chat.load(path, api_key=api_key)
else:
current_chat = Chat.load(path)
run_conversation(current_chat)
if __name__ == "__main__":
app()

View File

@ -46,6 +46,42 @@ def test_save_conversation(tmp_path: Path) -> None:
assert conversation == conversation_from_file assert conversation == conversation_from_file
def test_load(tmp_path: Path) -> None:
# Create a conversation object to save
conversation = Conversation(
messages=[
Message(role=Role.SYSTEM, content="Hello!"),
Message(role=Role.USER, content="Hi!"),
Message(role=Role.ASSISTANT, content="How are you?"),
],
model=Model.GPT3,
temperature=0.5,
completion_tokens=10,
prompt_tokens=15,
cost=0.000043,
)
# Save the conversation to a file
file_path = tmp_path / "conversation.json"
with file_path.open("w") as f:
f.write(conversation.model_dump_json())
# Load the conversation from the file
loaded_chat = Chat.load(file_path, api_key="foo", history_dir=tmp_path)
# Check that the loaded conversation matches the original conversation
assert loaded_chat.settings.model == conversation.model
assert loaded_chat.settings.temperature == conversation.temperature
assert loaded_chat.conversation.messages == conversation.messages
assert loaded_chat.settings.api_key == "foo"
assert loaded_chat.settings.history_dir == tmp_path
# We don't want to load the tokens or cost from the previous session
assert loaded_chat.conversation.completion_tokens == 0
assert loaded_chat.conversation.prompt_tokens == 0
assert loaded_chat.cost == 0
def test_send_message() -> None: def test_send_message() -> None:
with patch("llm_chat.chat.Chat._make_request") as mock_make_request: with patch("llm_chat.chat.Chat._make_request") as mock_make_request:
mock_make_request.return_value = { mock_make_request.return_value = {

View File

@ -1,6 +1,6 @@
from io import StringIO from io import StringIO
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, Type
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
@ -11,7 +11,8 @@ from typer.testing import CliRunner
import llm_chat import llm_chat
from llm_chat.chat import ChatProtocol from llm_chat.chat import ChatProtocol
from llm_chat.cli import app from llm_chat.cli import app
from llm_chat.models import Message, Role from llm_chat.models import Conversation, Message, Role
from llm_chat.settings import Model, OpenAISettings
runner = CliRunner() runner = CliRunner()
@ -20,9 +21,15 @@ class ChatFake:
"""Fake chat class for testing.""" """Fake chat class for testing."""
args: dict[str, Any] args: dict[str, Any]
conversation: Conversation
received_messages: list[str] received_messages: list[str]
settings: OpenAISettings
def __init__(self) -> None: def __init__(self, settings: OpenAISettings | None = None) -> None:
if settings is not None:
self.settings = settings
else:
self.settings = OpenAISettings()
self.args = {} self.args = {}
self.received_messages = [] self.received_messages = []
@ -34,6 +41,13 @@ class ChatFake:
"""Get the cost of the conversation.""" """Get the cost of the conversation."""
return 0.0 return 0.0
@classmethod
def load(
cls, path: Path, api_key: str | None = None, history_dir: Path | None = None
) -> ChatProtocol:
"""Load a chat from a file."""
return cls()
def save(self) -> None: def save(self) -> None:
"""Dummy save method.""" """Dummy save method."""
pass pass
@ -61,7 +75,7 @@ def test_chat(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console) monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input) monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
result = runner.invoke(app) result = runner.invoke(app, ["chat"])
assert result.exit_code == 0 assert result.exit_code == 0
assert chat_fake.received_messages == ["Hello"] assert chat_fake.received_messages == ["Hello"]
@ -91,8 +105,49 @@ def test_chat_with_context(
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console) monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input) monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
result = runner.invoke(app, [argument, str(tmp_file)]) result = runner.invoke(app, ["chat", argument, str(tmp_file)])
assert result.exit_code == 0 assert result.exit_code == 0
assert chat_fake.received_messages == ["Hello"] assert chat_fake.received_messages == ["Hello"]
assert "context" in chat_fake.args assert "context" in chat_fake.args
assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)] assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)]
def test_load(monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
# Create a conversation object to save
conversation = Conversation(
messages=[
Message(role=Role.SYSTEM, content="Hello!"),
Message(role=Role.USER, content="Hi!"),
Message(role=Role.ASSISTANT, content="How are you?"),
],
model=Model.GPT3,
temperature=0.5,
completion_tokens=10,
prompt_tokens=15,
cost=0.000043,
)
# Save the conversation to a file
file_path = tmp_path / "conversation.json"
with file_path.open("w") as f:
f.write(conversation.model_dump_json())
output = StringIO()
console = Console(file=output)
def mock_get_chat() -> Type[ChatFake]:
return ChatFake
def mock_get_console() -> Console:
return console
mock_read_user_input = MagicMock(side_effect=["Hello", "/q"])
monkeypatch.setattr(llm_chat.cli, "get_chat_class", mock_get_chat)
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
# Load the conversation from the file
result = runner.invoke(app, ["load", str(file_path)])
assert result.exit_code == 0