Load previous conversation #2
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "llm-chat"
|
name = "llm-chat"
|
||||||
version = "0.4.0"
|
version = "0.5.0"
|
||||||
description = "A general CLI interface for large language models."
|
description = "A general CLI interface for large language models."
|
||||||
authors = ["Paul Harrison <paul@harrison.sh>"]
|
authors = ["Paul Harrison <paul@harrison.sh>"]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
@ -23,7 +23,7 @@ mypy = "^1.5.0"
|
||||||
pydocstyle = "^6.3.0"
|
pydocstyle = "^6.3.0"
|
||||||
|
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
chat = "llm_chat.cli:app"
|
llm = "llm_chat.cli:app"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Protocol
|
from typing import Any, Protocol, Type
|
||||||
from zoneinfo import ZoneInfo
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
from openai import ChatCompletion
|
from openai import ChatCompletion
|
||||||
|
@ -45,10 +45,16 @@ class Token(StrEnum):
|
||||||
class ChatProtocol(Protocol):
|
class ChatProtocol(Protocol):
|
||||||
"""Protocol for chat classes."""
|
"""Protocol for chat classes."""
|
||||||
|
|
||||||
|
conversation: Conversation
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cost(self) -> float:
|
def cost(self) -> float:
|
||||||
"""Get the cost of the conversation."""
|
"""Get the cost of the conversation."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def settings(self) -> OpenAISettings:
|
||||||
|
"""Get OpenAI chat settings."""
|
||||||
|
|
||||||
def save(self) -> None:
|
def save(self) -> None:
|
||||||
"""Save the conversation to the history directory."""
|
"""Save the conversation to the history directory."""
|
||||||
|
|
||||||
|
@ -80,15 +86,41 @@ class Chat:
|
||||||
self,
|
self,
|
||||||
settings: OpenAISettings | None = None,
|
settings: OpenAISettings | None = None,
|
||||||
context: list[Message] = [],
|
context: list[Message] = [],
|
||||||
|
initial_system_messages: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._settings = settings
|
self._settings = settings
|
||||||
self.conversation = Conversation(
|
self.conversation = Conversation(
|
||||||
messages=INITIAL_SYSTEM_MESSAGES + context,
|
messages=INITIAL_SYSTEM_MESSAGES + context
|
||||||
|
if initial_system_messages
|
||||||
|
else context,
|
||||||
model=self.settings.model,
|
model=self.settings.model,
|
||||||
temperature=self.settings.temperature,
|
temperature=self.settings.temperature,
|
||||||
)
|
)
|
||||||
self._start_time = datetime.now(tz=ZoneInfo("UTC"))
|
self._start_time = datetime.now(tz=ZoneInfo("UTC"))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls, path: Path, api_key: str | None = None, history_dir: Path | None = None
|
||||||
|
) -> ChatProtocol:
|
||||||
|
"""Load a chat from a file."""
|
||||||
|
with path.open() as f:
|
||||||
|
conversation = Conversation.model_validate_json(f.read())
|
||||||
|
args: dict[str, Any] = {
|
||||||
|
"model": conversation.model,
|
||||||
|
"temperature": conversation.temperature,
|
||||||
|
}
|
||||||
|
if api_key is not None:
|
||||||
|
args["api_key"] = api_key
|
||||||
|
if history_dir is not None:
|
||||||
|
args["history_dir"] = history_dir
|
||||||
|
|
||||||
|
settings = OpenAISettings(**args)
|
||||||
|
return cls(
|
||||||
|
settings=settings,
|
||||||
|
context=conversation.messages,
|
||||||
|
initial_system_messages=False,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def settings(self) -> OpenAISettings:
|
def settings(self) -> OpenAISettings:
|
||||||
"""Get OpenAI chat settings."""
|
"""Get OpenAI chat settings."""
|
||||||
|
@ -158,3 +190,8 @@ def get_chat(
|
||||||
) -> ChatProtocol:
|
) -> ChatProtocol:
|
||||||
"""Get a chat object."""
|
"""Get a chat object."""
|
||||||
return Chat(settings=settings, context=context)
|
return Chat(settings=settings, context=context)
|
||||||
|
|
||||||
|
|
||||||
|
def get_chat_class() -> Type[Chat]:
|
||||||
|
"""Get the chat class."""
|
||||||
|
return Chat
|
||||||
|
|
|
@ -6,7 +6,7 @@ from prompt_toolkit import PromptSession
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
|
|
||||||
from llm_chat.chat import ChatProtocol, get_chat
|
from llm_chat.chat import ChatProtocol, get_chat, get_chat_class
|
||||||
from llm_chat.models import Message, Role
|
from llm_chat.models import Message, Role
|
||||||
from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings
|
from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings
|
||||||
|
|
||||||
|
@ -60,7 +60,30 @@ def display_cost(console: Console, chat: ChatProtocol) -> None:
|
||||||
console.print(f"\n[bold green]Cost:[/bold green] ${chat.cost}\n")
|
console.print(f"\n[bold green]Cost:[/bold green] ${chat.cost}\n")
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
def run_conversation(current_chat: ChatProtocol) -> None:
|
||||||
|
"""Run a conversation."""
|
||||||
|
console = get_console()
|
||||||
|
session = get_session()
|
||||||
|
|
||||||
|
finished = False
|
||||||
|
|
||||||
|
console.print(f"[bold green]Model:[/bold green] {current_chat.settings.model}")
|
||||||
|
console.print(
|
||||||
|
f"[bold green]Temperature:[/bold green] {current_chat.settings.temperature}"
|
||||||
|
)
|
||||||
|
|
||||||
|
while not finished:
|
||||||
|
prompt = read_user_input(session)
|
||||||
|
if prompt.strip() == "/q":
|
||||||
|
finished = True
|
||||||
|
current_chat.save()
|
||||||
|
else:
|
||||||
|
response = current_chat.send_message(prompt.strip())
|
||||||
|
console.print(Markdown(response))
|
||||||
|
display_cost(console, current_chat)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command("chat")
|
||||||
def chat(
|
def chat(
|
||||||
api_key: Annotated[
|
api_key: Annotated[
|
||||||
Optional[str],
|
Optional[str],
|
||||||
|
@ -102,7 +125,6 @@ def chat(
|
||||||
] = [],
|
] = [],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start a chat session."""
|
"""Start a chat session."""
|
||||||
# TODO: Add option to load context from file.
|
|
||||||
# TODO: Add option to provide context string as an argument.
|
# TODO: Add option to provide context string as an argument.
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature)
|
settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature)
|
||||||
|
@ -112,20 +134,44 @@ def chat(
|
||||||
context_messages = [load_context(path) for path in context]
|
context_messages = [load_context(path) for path in context]
|
||||||
|
|
||||||
current_chat = get_chat(settings=settings, context=context_messages)
|
current_chat = get_chat(settings=settings, context=context_messages)
|
||||||
console = get_console()
|
|
||||||
session = get_session()
|
|
||||||
|
|
||||||
finished = False
|
run_conversation(current_chat)
|
||||||
|
|
||||||
console.print(f"[bold green]Model:[/bold green] {settings.model}")
|
|
||||||
console.print(f"[bold green]Temperature:[/bold green] {settings.temperature}")
|
|
||||||
|
|
||||||
while not finished:
|
@app.command("load")
|
||||||
prompt = read_user_input(session)
|
def load(
|
||||||
if prompt.strip() == "/q":
|
path: Annotated[
|
||||||
finished = True
|
Path,
|
||||||
current_chat.save()
|
typer.Argument(
|
||||||
else:
|
help="Path to a conversation file.",
|
||||||
response = current_chat.send_message(prompt.strip())
|
exists=True,
|
||||||
console.print(Markdown(response))
|
file_okay=True,
|
||||||
display_cost(console, current_chat)
|
dir_okay=False,
|
||||||
|
readable=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
api_key: Annotated[
|
||||||
|
Optional[str],
|
||||||
|
typer.Option(
|
||||||
|
...,
|
||||||
|
"--api-key",
|
||||||
|
"-k",
|
||||||
|
help=(
|
||||||
|
"API key. Will read from the environment variable OPENAI_API_KEY "
|
||||||
|
"if not provided."
|
||||||
|
),
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Load a conversation from a file."""
|
||||||
|
Chat = get_chat_class()
|
||||||
|
if api_key is not None:
|
||||||
|
current_chat = Chat.load(path, api_key=api_key)
|
||||||
|
else:
|
||||||
|
current_chat = Chat.load(path)
|
||||||
|
|
||||||
|
run_conversation(current_chat)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
||||||
|
|
|
@ -46,6 +46,42 @@ def test_save_conversation(tmp_path: Path) -> None:
|
||||||
assert conversation == conversation_from_file
|
assert conversation == conversation_from_file
|
||||||
|
|
||||||
|
|
||||||
|
def test_load(tmp_path: Path) -> None:
|
||||||
|
# Create a conversation object to save
|
||||||
|
conversation = Conversation(
|
||||||
|
messages=[
|
||||||
|
Message(role=Role.SYSTEM, content="Hello!"),
|
||||||
|
Message(role=Role.USER, content="Hi!"),
|
||||||
|
Message(role=Role.ASSISTANT, content="How are you?"),
|
||||||
|
],
|
||||||
|
model=Model.GPT3,
|
||||||
|
temperature=0.5,
|
||||||
|
completion_tokens=10,
|
||||||
|
prompt_tokens=15,
|
||||||
|
cost=0.000043,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the conversation to a file
|
||||||
|
file_path = tmp_path / "conversation.json"
|
||||||
|
with file_path.open("w") as f:
|
||||||
|
f.write(conversation.model_dump_json())
|
||||||
|
|
||||||
|
# Load the conversation from the file
|
||||||
|
loaded_chat = Chat.load(file_path, api_key="foo", history_dir=tmp_path)
|
||||||
|
|
||||||
|
# Check that the loaded conversation matches the original conversation
|
||||||
|
assert loaded_chat.settings.model == conversation.model
|
||||||
|
assert loaded_chat.settings.temperature == conversation.temperature
|
||||||
|
assert loaded_chat.conversation.messages == conversation.messages
|
||||||
|
assert loaded_chat.settings.api_key == "foo"
|
||||||
|
assert loaded_chat.settings.history_dir == tmp_path
|
||||||
|
|
||||||
|
# We don't want to load the tokens or cost from the previous session
|
||||||
|
assert loaded_chat.conversation.completion_tokens == 0
|
||||||
|
assert loaded_chat.conversation.prompt_tokens == 0
|
||||||
|
assert loaded_chat.cost == 0
|
||||||
|
|
||||||
|
|
||||||
def test_send_message() -> None:
|
def test_send_message() -> None:
|
||||||
with patch("llm_chat.chat.Chat._make_request") as mock_make_request:
|
with patch("llm_chat.chat.Chat._make_request") as mock_make_request:
|
||||||
mock_make_request.return_value = {
|
mock_make_request.return_value = {
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any, Type
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -11,7 +11,8 @@ from typer.testing import CliRunner
|
||||||
import llm_chat
|
import llm_chat
|
||||||
from llm_chat.chat import ChatProtocol
|
from llm_chat.chat import ChatProtocol
|
||||||
from llm_chat.cli import app
|
from llm_chat.cli import app
|
||||||
from llm_chat.models import Message, Role
|
from llm_chat.models import Conversation, Message, Role
|
||||||
|
from llm_chat.settings import Model, OpenAISettings
|
||||||
|
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
|
|
||||||
|
@ -20,9 +21,15 @@ class ChatFake:
|
||||||
"""Fake chat class for testing."""
|
"""Fake chat class for testing."""
|
||||||
|
|
||||||
args: dict[str, Any]
|
args: dict[str, Any]
|
||||||
|
conversation: Conversation
|
||||||
received_messages: list[str]
|
received_messages: list[str]
|
||||||
|
settings: OpenAISettings
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, settings: OpenAISettings | None = None) -> None:
|
||||||
|
if settings is not None:
|
||||||
|
self.settings = settings
|
||||||
|
else:
|
||||||
|
self.settings = OpenAISettings()
|
||||||
self.args = {}
|
self.args = {}
|
||||||
self.received_messages = []
|
self.received_messages = []
|
||||||
|
|
||||||
|
@ -34,6 +41,13 @@ class ChatFake:
|
||||||
"""Get the cost of the conversation."""
|
"""Get the cost of the conversation."""
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls, path: Path, api_key: str | None = None, history_dir: Path | None = None
|
||||||
|
) -> ChatProtocol:
|
||||||
|
"""Load a chat from a file."""
|
||||||
|
return cls()
|
||||||
|
|
||||||
def save(self) -> None:
|
def save(self) -> None:
|
||||||
"""Dummy save method."""
|
"""Dummy save method."""
|
||||||
pass
|
pass
|
||||||
|
@ -61,7 +75,7 @@ def test_chat(monkeypatch: MonkeyPatch) -> None:
|
||||||
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
|
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
|
||||||
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
|
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
|
||||||
|
|
||||||
result = runner.invoke(app)
|
result = runner.invoke(app, ["chat"])
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert chat_fake.received_messages == ["Hello"]
|
assert chat_fake.received_messages == ["Hello"]
|
||||||
|
|
||||||
|
@ -91,8 +105,49 @@ def test_chat_with_context(
|
||||||
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
|
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
|
||||||
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
|
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
|
||||||
|
|
||||||
result = runner.invoke(app, [argument, str(tmp_file)])
|
result = runner.invoke(app, ["chat", argument, str(tmp_file)])
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert chat_fake.received_messages == ["Hello"]
|
assert chat_fake.received_messages == ["Hello"]
|
||||||
assert "context" in chat_fake.args
|
assert "context" in chat_fake.args
|
||||||
assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)]
|
assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_load(monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
|
||||||
|
# Create a conversation object to save
|
||||||
|
conversation = Conversation(
|
||||||
|
messages=[
|
||||||
|
Message(role=Role.SYSTEM, content="Hello!"),
|
||||||
|
Message(role=Role.USER, content="Hi!"),
|
||||||
|
Message(role=Role.ASSISTANT, content="How are you?"),
|
||||||
|
],
|
||||||
|
model=Model.GPT3,
|
||||||
|
temperature=0.5,
|
||||||
|
completion_tokens=10,
|
||||||
|
prompt_tokens=15,
|
||||||
|
cost=0.000043,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the conversation to a file
|
||||||
|
file_path = tmp_path / "conversation.json"
|
||||||
|
with file_path.open("w") as f:
|
||||||
|
f.write(conversation.model_dump_json())
|
||||||
|
|
||||||
|
output = StringIO()
|
||||||
|
console = Console(file=output)
|
||||||
|
|
||||||
|
def mock_get_chat() -> Type[ChatFake]:
|
||||||
|
return ChatFake
|
||||||
|
|
||||||
|
def mock_get_console() -> Console:
|
||||||
|
return console
|
||||||
|
|
||||||
|
mock_read_user_input = MagicMock(side_effect=["Hello", "/q"])
|
||||||
|
|
||||||
|
monkeypatch.setattr(llm_chat.cli, "get_chat_class", mock_get_chat)
|
||||||
|
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
|
||||||
|
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
|
||||||
|
|
||||||
|
# Load the conversation from the file
|
||||||
|
result = runner.invoke(app, ["load", str(file_path)])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
Loading…
Reference in New Issue