Load previous conversation
This commit adds the additional CLI command `load`, which accepts the path to a previously saved conversation from which to load previous messages. The previous cost and token counts are not loaded as theses are deemed functions of the chat session rathr than the conversation and are stored with conversation purely for future reference.
This commit is contained in:
parent
8c17d4165a
commit
f5c7c40cc8
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "llm-chat"
|
name = "llm-chat"
|
||||||
version = "0.4.0"
|
version = "0.5.0"
|
||||||
description = "A general CLI interface for large language models."
|
description = "A general CLI interface for large language models."
|
||||||
authors = ["Paul Harrison <paul@harrison.sh>"]
|
authors = ["Paul Harrison <paul@harrison.sh>"]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
@ -23,7 +23,7 @@ mypy = "^1.5.0"
|
||||||
pydocstyle = "^6.3.0"
|
pydocstyle = "^6.3.0"
|
||||||
|
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
chat = "llm_chat.cli:app"
|
llm = "llm_chat.cli:app"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Protocol
|
from typing import Any, Protocol, Type
|
||||||
from zoneinfo import ZoneInfo
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
from openai import ChatCompletion
|
from openai import ChatCompletion
|
||||||
|
@ -45,10 +45,16 @@ class Token(StrEnum):
|
||||||
class ChatProtocol(Protocol):
|
class ChatProtocol(Protocol):
|
||||||
"""Protocol for chat classes."""
|
"""Protocol for chat classes."""
|
||||||
|
|
||||||
|
conversation: Conversation
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cost(self) -> float:
|
def cost(self) -> float:
|
||||||
"""Get the cost of the conversation."""
|
"""Get the cost of the conversation."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def settings(self) -> OpenAISettings:
|
||||||
|
"""Get OpenAI chat settings."""
|
||||||
|
|
||||||
def save(self) -> None:
|
def save(self) -> None:
|
||||||
"""Save the conversation to the history directory."""
|
"""Save the conversation to the history directory."""
|
||||||
|
|
||||||
|
@ -80,15 +86,41 @@ class Chat:
|
||||||
self,
|
self,
|
||||||
settings: OpenAISettings | None = None,
|
settings: OpenAISettings | None = None,
|
||||||
context: list[Message] = [],
|
context: list[Message] = [],
|
||||||
|
initial_system_messages: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._settings = settings
|
self._settings = settings
|
||||||
self.conversation = Conversation(
|
self.conversation = Conversation(
|
||||||
messages=INITIAL_SYSTEM_MESSAGES + context,
|
messages=INITIAL_SYSTEM_MESSAGES + context
|
||||||
|
if initial_system_messages
|
||||||
|
else context,
|
||||||
model=self.settings.model,
|
model=self.settings.model,
|
||||||
temperature=self.settings.temperature,
|
temperature=self.settings.temperature,
|
||||||
)
|
)
|
||||||
self._start_time = datetime.now(tz=ZoneInfo("UTC"))
|
self._start_time = datetime.now(tz=ZoneInfo("UTC"))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls, path: Path, api_key: str | None = None, history_dir: Path | None = None
|
||||||
|
) -> ChatProtocol:
|
||||||
|
"""Load a chat from a file."""
|
||||||
|
with path.open() as f:
|
||||||
|
conversation = Conversation.model_validate_json(f.read())
|
||||||
|
args: dict[str, Any] = {
|
||||||
|
"model": conversation.model,
|
||||||
|
"temperature": conversation.temperature,
|
||||||
|
}
|
||||||
|
if api_key is not None:
|
||||||
|
args["api_key"] = api_key
|
||||||
|
if history_dir is not None:
|
||||||
|
args["history_dir"] = history_dir
|
||||||
|
|
||||||
|
settings = OpenAISettings(**args)
|
||||||
|
return cls(
|
||||||
|
settings=settings,
|
||||||
|
context=conversation.messages,
|
||||||
|
initial_system_messages=False,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def settings(self) -> OpenAISettings:
|
def settings(self) -> OpenAISettings:
|
||||||
"""Get OpenAI chat settings."""
|
"""Get OpenAI chat settings."""
|
||||||
|
@ -158,3 +190,8 @@ def get_chat(
|
||||||
) -> ChatProtocol:
|
) -> ChatProtocol:
|
||||||
"""Get a chat object."""
|
"""Get a chat object."""
|
||||||
return Chat(settings=settings, context=context)
|
return Chat(settings=settings, context=context)
|
||||||
|
|
||||||
|
|
||||||
|
def get_chat_class() -> Type[Chat]:
|
||||||
|
"""Get the chat class."""
|
||||||
|
return Chat
|
||||||
|
|
|
@ -6,7 +6,7 @@ from prompt_toolkit import PromptSession
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
|
|
||||||
from llm_chat.chat import ChatProtocol, get_chat
|
from llm_chat.chat import ChatProtocol, get_chat, get_chat_class
|
||||||
from llm_chat.models import Message, Role
|
from llm_chat.models import Message, Role
|
||||||
from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings
|
from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings
|
||||||
|
|
||||||
|
@ -60,7 +60,30 @@ def display_cost(console: Console, chat: ChatProtocol) -> None:
|
||||||
console.print(f"\n[bold green]Cost:[/bold green] ${chat.cost}\n")
|
console.print(f"\n[bold green]Cost:[/bold green] ${chat.cost}\n")
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
def run_conversation(current_chat: ChatProtocol) -> None:
|
||||||
|
"""Run a conversation."""
|
||||||
|
console = get_console()
|
||||||
|
session = get_session()
|
||||||
|
|
||||||
|
finished = False
|
||||||
|
|
||||||
|
console.print(f"[bold green]Model:[/bold green] {current_chat.settings.model}")
|
||||||
|
console.print(
|
||||||
|
f"[bold green]Temperature:[/bold green] {current_chat.settings.temperature}"
|
||||||
|
)
|
||||||
|
|
||||||
|
while not finished:
|
||||||
|
prompt = read_user_input(session)
|
||||||
|
if prompt.strip() == "/q":
|
||||||
|
finished = True
|
||||||
|
current_chat.save()
|
||||||
|
else:
|
||||||
|
response = current_chat.send_message(prompt.strip())
|
||||||
|
console.print(Markdown(response))
|
||||||
|
display_cost(console, current_chat)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command("chat")
|
||||||
def chat(
|
def chat(
|
||||||
api_key: Annotated[
|
api_key: Annotated[
|
||||||
Optional[str],
|
Optional[str],
|
||||||
|
@ -102,7 +125,6 @@ def chat(
|
||||||
] = [],
|
] = [],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start a chat session."""
|
"""Start a chat session."""
|
||||||
# TODO: Add option to load context from file.
|
|
||||||
# TODO: Add option to provide context string as an argument.
|
# TODO: Add option to provide context string as an argument.
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature)
|
settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature)
|
||||||
|
@ -112,20 +134,44 @@ def chat(
|
||||||
context_messages = [load_context(path) for path in context]
|
context_messages = [load_context(path) for path in context]
|
||||||
|
|
||||||
current_chat = get_chat(settings=settings, context=context_messages)
|
current_chat = get_chat(settings=settings, context=context_messages)
|
||||||
console = get_console()
|
|
||||||
session = get_session()
|
|
||||||
|
|
||||||
finished = False
|
run_conversation(current_chat)
|
||||||
|
|
||||||
console.print(f"[bold green]Model:[/bold green] {settings.model}")
|
|
||||||
console.print(f"[bold green]Temperature:[/bold green] {settings.temperature}")
|
|
||||||
|
|
||||||
while not finished:
|
@app.command("load")
|
||||||
prompt = read_user_input(session)
|
def load(
|
||||||
if prompt.strip() == "/q":
|
path: Annotated[
|
||||||
finished = True
|
Path,
|
||||||
current_chat.save()
|
typer.Argument(
|
||||||
else:
|
help="Path to a conversation file.",
|
||||||
response = current_chat.send_message(prompt.strip())
|
exists=True,
|
||||||
console.print(Markdown(response))
|
file_okay=True,
|
||||||
display_cost(console, current_chat)
|
dir_okay=False,
|
||||||
|
readable=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
api_key: Annotated[
|
||||||
|
Optional[str],
|
||||||
|
typer.Option(
|
||||||
|
...,
|
||||||
|
"--api-key",
|
||||||
|
"-k",
|
||||||
|
help=(
|
||||||
|
"API key. Will read from the environment variable OPENAI_API_KEY "
|
||||||
|
"if not provided."
|
||||||
|
),
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Load a conversation from a file."""
|
||||||
|
Chat = get_chat_class()
|
||||||
|
if api_key is not None:
|
||||||
|
current_chat = Chat.load(path, api_key=api_key)
|
||||||
|
else:
|
||||||
|
current_chat = Chat.load(path)
|
||||||
|
|
||||||
|
run_conversation(current_chat)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
||||||
|
|
|
@ -46,6 +46,42 @@ def test_save_conversation(tmp_path: Path) -> None:
|
||||||
assert conversation == conversation_from_file
|
assert conversation == conversation_from_file
|
||||||
|
|
||||||
|
|
||||||
|
def test_load(tmp_path: Path) -> None:
|
||||||
|
# Create a conversation object to save
|
||||||
|
conversation = Conversation(
|
||||||
|
messages=[
|
||||||
|
Message(role=Role.SYSTEM, content="Hello!"),
|
||||||
|
Message(role=Role.USER, content="Hi!"),
|
||||||
|
Message(role=Role.ASSISTANT, content="How are you?"),
|
||||||
|
],
|
||||||
|
model=Model.GPT3,
|
||||||
|
temperature=0.5,
|
||||||
|
completion_tokens=10,
|
||||||
|
prompt_tokens=15,
|
||||||
|
cost=0.000043,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the conversation to a file
|
||||||
|
file_path = tmp_path / "conversation.json"
|
||||||
|
with file_path.open("w") as f:
|
||||||
|
f.write(conversation.model_dump_json())
|
||||||
|
|
||||||
|
# Load the conversation from the file
|
||||||
|
loaded_chat = Chat.load(file_path, api_key="foo", history_dir=tmp_path)
|
||||||
|
|
||||||
|
# Check that the loaded conversation matches the original conversation
|
||||||
|
assert loaded_chat.settings.model == conversation.model
|
||||||
|
assert loaded_chat.settings.temperature == conversation.temperature
|
||||||
|
assert loaded_chat.conversation.messages == conversation.messages
|
||||||
|
assert loaded_chat.settings.api_key == "foo"
|
||||||
|
assert loaded_chat.settings.history_dir == tmp_path
|
||||||
|
|
||||||
|
# We don't want to load the tokens or cost from the previous session
|
||||||
|
assert loaded_chat.conversation.completion_tokens == 0
|
||||||
|
assert loaded_chat.conversation.prompt_tokens == 0
|
||||||
|
assert loaded_chat.cost == 0
|
||||||
|
|
||||||
|
|
||||||
def test_send_message() -> None:
|
def test_send_message() -> None:
|
||||||
with patch("llm_chat.chat.Chat._make_request") as mock_make_request:
|
with patch("llm_chat.chat.Chat._make_request") as mock_make_request:
|
||||||
mock_make_request.return_value = {
|
mock_make_request.return_value = {
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any, Type
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -11,7 +11,8 @@ from typer.testing import CliRunner
|
||||||
import llm_chat
|
import llm_chat
|
||||||
from llm_chat.chat import ChatProtocol
|
from llm_chat.chat import ChatProtocol
|
||||||
from llm_chat.cli import app
|
from llm_chat.cli import app
|
||||||
from llm_chat.models import Message, Role
|
from llm_chat.models import Conversation, Message, Role
|
||||||
|
from llm_chat.settings import Model, OpenAISettings
|
||||||
|
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
|
|
||||||
|
@ -20,9 +21,15 @@ class ChatFake:
|
||||||
"""Fake chat class for testing."""
|
"""Fake chat class for testing."""
|
||||||
|
|
||||||
args: dict[str, Any]
|
args: dict[str, Any]
|
||||||
|
conversation: Conversation
|
||||||
received_messages: list[str]
|
received_messages: list[str]
|
||||||
|
settings: OpenAISettings
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, settings: OpenAISettings | None = None) -> None:
|
||||||
|
if settings is not None:
|
||||||
|
self.settings = settings
|
||||||
|
else:
|
||||||
|
self.settings = OpenAISettings()
|
||||||
self.args = {}
|
self.args = {}
|
||||||
self.received_messages = []
|
self.received_messages = []
|
||||||
|
|
||||||
|
@ -34,6 +41,13 @@ class ChatFake:
|
||||||
"""Get the cost of the conversation."""
|
"""Get the cost of the conversation."""
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls, path: Path, api_key: str | None = None, history_dir: Path | None = None
|
||||||
|
) -> ChatProtocol:
|
||||||
|
"""Load a chat from a file."""
|
||||||
|
return cls()
|
||||||
|
|
||||||
def save(self) -> None:
|
def save(self) -> None:
|
||||||
"""Dummy save method."""
|
"""Dummy save method."""
|
||||||
pass
|
pass
|
||||||
|
@ -61,7 +75,7 @@ def test_chat(monkeypatch: MonkeyPatch) -> None:
|
||||||
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
|
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
|
||||||
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
|
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
|
||||||
|
|
||||||
result = runner.invoke(app)
|
result = runner.invoke(app, ["chat"])
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert chat_fake.received_messages == ["Hello"]
|
assert chat_fake.received_messages == ["Hello"]
|
||||||
|
|
||||||
|
@ -91,8 +105,49 @@ def test_chat_with_context(
|
||||||
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
|
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
|
||||||
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
|
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
|
||||||
|
|
||||||
result = runner.invoke(app, [argument, str(tmp_file)])
|
result = runner.invoke(app, ["chat", argument, str(tmp_file)])
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert chat_fake.received_messages == ["Hello"]
|
assert chat_fake.received_messages == ["Hello"]
|
||||||
assert "context" in chat_fake.args
|
assert "context" in chat_fake.args
|
||||||
assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)]
|
assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_load(monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
|
||||||
|
# Create a conversation object to save
|
||||||
|
conversation = Conversation(
|
||||||
|
messages=[
|
||||||
|
Message(role=Role.SYSTEM, content="Hello!"),
|
||||||
|
Message(role=Role.USER, content="Hi!"),
|
||||||
|
Message(role=Role.ASSISTANT, content="How are you?"),
|
||||||
|
],
|
||||||
|
model=Model.GPT3,
|
||||||
|
temperature=0.5,
|
||||||
|
completion_tokens=10,
|
||||||
|
prompt_tokens=15,
|
||||||
|
cost=0.000043,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the conversation to a file
|
||||||
|
file_path = tmp_path / "conversation.json"
|
||||||
|
with file_path.open("w") as f:
|
||||||
|
f.write(conversation.model_dump_json())
|
||||||
|
|
||||||
|
output = StringIO()
|
||||||
|
console = Console(file=output)
|
||||||
|
|
||||||
|
def mock_get_chat() -> Type[ChatFake]:
|
||||||
|
return ChatFake
|
||||||
|
|
||||||
|
def mock_get_console() -> Console:
|
||||||
|
return console
|
||||||
|
|
||||||
|
mock_read_user_input = MagicMock(side_effect=["Hello", "/q"])
|
||||||
|
|
||||||
|
monkeypatch.setattr(llm_chat.cli, "get_chat_class", mock_get_chat)
|
||||||
|
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
|
||||||
|
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
|
||||||
|
|
||||||
|
# Load the conversation from the file
|
||||||
|
result = runner.invoke(app, ["load", str(file_path)])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
Loading…
Reference in New Issue