Store conversation history on session end

At the end of each sesssion the conversation is stored to a directory
(defaulting to `.history` in the currrent working directory) as a JSON
object. Note, the session must be ended by sending the quit message (/q)
for the conversation to be saved. Ctrl+C will not work.
This commit is contained in:
Paul Harrison 2023-08-24 18:01:34 +01:00
parent 1670383fee
commit 667b9ebfc3
8 changed files with 84 additions and 3 deletions

1
.gitignore vendored
View File

@ -1,5 +1,6 @@
.idea .idea
.vscode .vscode
.history
# Created by https://www.toptal.com/developers/gitignore/api/python,vim,asdf # Created by https://www.toptal.com/developers/gitignore/api/python,vim,asdf
# Edit at https://www.toptal.com/developers/gitignore?templates=python,vim,asdf # Edit at https://www.toptal.com/developers/gitignore?templates=python,vim,asdf

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "llm-chat" name = "llm-chat"
version = "0.3.0" version = "0.4.0"
description = "A general CLI interface for large language models." description = "A general CLI interface for large language models."
authors = ["Paul Harrison <paul@harrison.sh>"] authors = ["Paul Harrison <paul@harrison.sh>"]
readme = "README.md" readme = "README.md"

View File

@ -1,5 +1,8 @@
from datetime import datetime
from enum import StrEnum, auto from enum import StrEnum, auto
from pathlib import Path
from typing import Any, Protocol from typing import Any, Protocol
from zoneinfo import ZoneInfo
from openai import ChatCompletion from openai import ChatCompletion
from openai.openai_object import OpenAIObject from openai.openai_object import OpenAIObject
@ -18,6 +21,18 @@ INITIAL_SYSTEM_MESSAGES = [
] ]
def save_conversation(conversation: Conversation, history_dir: Path, dt: datetime) -> None:
"""Store a conversation in the history directory."""
if conversation.prompt_tokens == 0:
return
history_dir.mkdir(parents=True, exist_ok=True)
path = history_dir / f"{dt.strftime('%Y%m%d%H%M%S')}.json"
with path.open(mode="w") as f:
f.write(conversation.model_dump_json(indent=2))
class Token(StrEnum): class Token(StrEnum):
"""Token type for the OpenAI chat.""" """Token type for the OpenAI chat."""
@ -32,6 +47,9 @@ class ChatProtocol(Protocol):
def cost(self) -> float: def cost(self) -> float:
"""Get the cost of the conversation.""" """Get the cost of the conversation."""
def save(self) -> None:
"""Save the conversation to the history directory."""
def send_message(self, message: str) -> str: def send_message(self, message: str) -> str:
"""Send a message to the assistant.""" """Send a message to the assistant."""
@ -57,12 +75,17 @@ class Chat:
} }
def __init__( def __init__(
self, settings: OpenAISettings | None = None, context: list[Message] = [] self,
settings: OpenAISettings | None = None,
context: list[Message] = [],
store_conversation: bool = True,
) -> None: ) -> None:
self._settings = settings self._settings = settings
self.conversation = Conversation( self.conversation = Conversation(
messages=INITIAL_SYSTEM_MESSAGES + context, messages=INITIAL_SYSTEM_MESSAGES + context,
model=self.settings.model,
) )
self._start_time = datetime.now(tz=ZoneInfo("UTC"))
@property @property
def settings(self) -> OpenAISettings: def settings(self) -> OpenAISettings:
@ -106,6 +129,14 @@ class Chat:
6, 6,
) )
def save(self) -> None:
"""Save the conversation to the history directory."""
save_conversation(
conversation=self.conversation,
history_dir=self.settings.history_dir,
dt=self._start_time,
)
def send_message(self, prompt: str) -> str: def send_message(self, prompt: str) -> str:
"""Send a message to the assistant. """Send a message to the assistant.

View File

@ -124,6 +124,7 @@ def chat(
prompt = read_user_input(session) prompt = read_user_input(session)
if prompt.strip() == "/q": if prompt.strip() == "/q":
finished = True finished = True
current_chat.save()
else: else:
response = current_chat.send_message(prompt.strip()) response = current_chat.send_message(prompt.strip())
console.print(Markdown(response)) console.print(Markdown(response))

View File

@ -2,6 +2,8 @@ from enum import StrEnum, auto
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from llm_chat.settings import Model
class Role(StrEnum): class Role(StrEnum):
"""Role of a user in the chat.""" """Role of a user in the chat."""
@ -27,6 +29,7 @@ class Conversation(BaseModel):
"""Conversation in the chat.""" """Conversation in the chat."""
messages: list[Message] messages: list[Message]
model: Model
completion_tokens: int = 0 completion_tokens: int = 0
prompt_tokens: int = 0 prompt_tokens: int = 0
cost: float = 0.0 cost: float = 0.0

View File

@ -1,4 +1,5 @@
from enum import StrEnum from enum import StrEnum
from pathlib import Path
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
@ -20,6 +21,7 @@ class OpenAISettings(BaseSettings):
api_key: str = "" api_key: str = ""
model: Model = DEFAULT_MODEL model: Model = DEFAULT_MODEL
temperature: float = DEFAULT_TEMPERATURE temperature: float = DEFAULT_TEMPERATURE
history_dir: Path = Path().absolute() / ".history"
model_config: SettingsConfigDict = SettingsConfigDict( # type: ignore[misc] model_config: SettingsConfigDict = SettingsConfigDict( # type: ignore[misc]
env_file=".env", env_file=".env",

View File

@ -1,11 +1,50 @@
from datetime import datetime
from pathlib import Path
from unittest.mock import patch from unittest.mock import patch
from zoneinfo import ZoneInfo
import pytest import pytest
from llm_chat.chat import Chat from llm_chat.chat import Chat, save_conversation
from llm_chat.models import Conversation, Message, Role
from llm_chat.settings import Model, OpenAISettings from llm_chat.settings import Model, OpenAISettings
def test_save_conversation(tmp_path: Path) -> None:
conversation = Conversation(
messages=[
Message(role=Role.SYSTEM, content="Hello!"),
Message(role=Role.USER, content="Hi!"),
Message(role=Role.ASSISTANT, content="How are you?"),
],
model=Model.GPT3,
completion_tokens=10,
prompt_tokens=15,
cost=0.000043,
)
path = tmp_path / ".history"
expected_file_path = path / "20210101120000.json"
dt = datetime(2021, 1, 1, 12, 0, 0, tzinfo=ZoneInfo("UTC"))
assert not path.exists()
save_conversation(
conversation=conversation,
history_dir=path,
dt=dt,
)
assert path.exists()
assert path.is_dir()
assert expected_file_path in path.iterdir()
with expected_file_path.open() as f:
conversation_from_file = Conversation.model_validate_json(f.read())
assert conversation == conversation_from_file
def test_send_message() -> None: def test_send_message() -> None:
with patch("llm_chat.chat.Chat._make_request") as mock_make_request: with patch("llm_chat.chat.Chat._make_request") as mock_make_request:
mock_make_request.return_value = { mock_make_request.return_value = {

View File

@ -34,6 +34,10 @@ class ChatFake:
"""Get the cost of the conversation.""" """Get the cost of the conversation."""
return 0.0 return 0.0
def save(self) -> None:
"""Dummy save method."""
pass
def send_message(self, message: str) -> str: def send_message(self, message: str) -> str:
"""Echo the received message.""" """Echo the received message."""
self.received_messages.append(message) self.received_messages.append(message)