Store conversation history on session end
At the end of each sesssion the conversation is stored to a directory (defaulting to `.history` in the currrent working directory) as a JSON object. Note, the session must be ended by sending the quit message (/q) for the conversation to be saved. Ctrl+C will not work.
This commit is contained in:
parent
1670383fee
commit
667b9ebfc3
|
@ -1,5 +1,6 @@
|
|||
.idea
|
||||
.vscode
|
||||
.history
|
||||
|
||||
# Created by https://www.toptal.com/developers/gitignore/api/python,vim,asdf
|
||||
# Edit at https://www.toptal.com/developers/gitignore?templates=python,vim,asdf
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "llm-chat"
|
||||
version = "0.3.0"
|
||||
version = "0.4.0"
|
||||
description = "A general CLI interface for large language models."
|
||||
authors = ["Paul Harrison <paul@harrison.sh>"]
|
||||
readme = "README.md"
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
from datetime import datetime
|
||||
from enum import StrEnum, auto
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from openai import ChatCompletion
|
||||
from openai.openai_object import OpenAIObject
|
||||
|
@ -18,6 +21,18 @@ INITIAL_SYSTEM_MESSAGES = [
|
|||
]
|
||||
|
||||
|
||||
def save_conversation(conversation: Conversation, history_dir: Path, dt: datetime) -> None:
|
||||
"""Store a conversation in the history directory."""
|
||||
if conversation.prompt_tokens == 0:
|
||||
return
|
||||
|
||||
history_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
path = history_dir / f"{dt.strftime('%Y%m%d%H%M%S')}.json"
|
||||
with path.open(mode="w") as f:
|
||||
f.write(conversation.model_dump_json(indent=2))
|
||||
|
||||
|
||||
class Token(StrEnum):
|
||||
"""Token type for the OpenAI chat."""
|
||||
|
||||
|
@ -32,6 +47,9 @@ class ChatProtocol(Protocol):
|
|||
def cost(self) -> float:
|
||||
"""Get the cost of the conversation."""
|
||||
|
||||
def save(self) -> None:
|
||||
"""Save the conversation to the history directory."""
|
||||
|
||||
def send_message(self, message: str) -> str:
|
||||
"""Send a message to the assistant."""
|
||||
|
||||
|
@ -57,12 +75,17 @@ class Chat:
|
|||
}
|
||||
|
||||
def __init__(
|
||||
self, settings: OpenAISettings | None = None, context: list[Message] = []
|
||||
self,
|
||||
settings: OpenAISettings | None = None,
|
||||
context: list[Message] = [],
|
||||
store_conversation: bool = True,
|
||||
) -> None:
|
||||
self._settings = settings
|
||||
self.conversation = Conversation(
|
||||
messages=INITIAL_SYSTEM_MESSAGES + context,
|
||||
model=self.settings.model,
|
||||
)
|
||||
self._start_time = datetime.now(tz=ZoneInfo("UTC"))
|
||||
|
||||
@property
|
||||
def settings(self) -> OpenAISettings:
|
||||
|
@ -106,6 +129,14 @@ class Chat:
|
|||
6,
|
||||
)
|
||||
|
||||
def save(self) -> None:
|
||||
"""Save the conversation to the history directory."""
|
||||
save_conversation(
|
||||
conversation=self.conversation,
|
||||
history_dir=self.settings.history_dir,
|
||||
dt=self._start_time,
|
||||
)
|
||||
|
||||
def send_message(self, prompt: str) -> str:
|
||||
"""Send a message to the assistant.
|
||||
|
||||
|
|
|
@ -124,6 +124,7 @@ def chat(
|
|||
prompt = read_user_input(session)
|
||||
if prompt.strip() == "/q":
|
||||
finished = True
|
||||
current_chat.save()
|
||||
else:
|
||||
response = current_chat.send_message(prompt.strip())
|
||||
console.print(Markdown(response))
|
||||
|
|
|
@ -2,6 +2,8 @@ from enum import StrEnum, auto
|
|||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from llm_chat.settings import Model
|
||||
|
||||
|
||||
class Role(StrEnum):
|
||||
"""Role of a user in the chat."""
|
||||
|
@ -27,6 +29,7 @@ class Conversation(BaseModel):
|
|||
"""Conversation in the chat."""
|
||||
|
||||
messages: list[Message]
|
||||
model: Model
|
||||
completion_tokens: int = 0
|
||||
prompt_tokens: int = 0
|
||||
cost: float = 0.0
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
@ -20,6 +21,7 @@ class OpenAISettings(BaseSettings):
|
|||
api_key: str = ""
|
||||
model: Model = DEFAULT_MODEL
|
||||
temperature: float = DEFAULT_TEMPERATURE
|
||||
history_dir: Path = Path().absolute() / ".history"
|
||||
|
||||
model_config: SettingsConfigDict = SettingsConfigDict( # type: ignore[misc]
|
||||
env_file=".env",
|
||||
|
|
|
@ -1,11 +1,50 @@
|
|||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import pytest
|
||||
|
||||
from llm_chat.chat import Chat
|
||||
from llm_chat.chat import Chat, save_conversation
|
||||
from llm_chat.models import Conversation, Message, Role
|
||||
from llm_chat.settings import Model, OpenAISettings
|
||||
|
||||
|
||||
def test_save_conversation(tmp_path: Path) -> None:
|
||||
conversation = Conversation(
|
||||
messages=[
|
||||
Message(role=Role.SYSTEM, content="Hello!"),
|
||||
Message(role=Role.USER, content="Hi!"),
|
||||
Message(role=Role.ASSISTANT, content="How are you?"),
|
||||
],
|
||||
model=Model.GPT3,
|
||||
completion_tokens=10,
|
||||
prompt_tokens=15,
|
||||
cost=0.000043,
|
||||
)
|
||||
|
||||
path = tmp_path / ".history"
|
||||
expected_file_path = path / "20210101120000.json"
|
||||
dt = datetime(2021, 1, 1, 12, 0, 0, tzinfo=ZoneInfo("UTC"))
|
||||
|
||||
assert not path.exists()
|
||||
|
||||
save_conversation(
|
||||
conversation=conversation,
|
||||
history_dir=path,
|
||||
dt=dt,
|
||||
)
|
||||
|
||||
assert path.exists()
|
||||
assert path.is_dir()
|
||||
assert expected_file_path in path.iterdir()
|
||||
|
||||
with expected_file_path.open() as f:
|
||||
conversation_from_file = Conversation.model_validate_json(f.read())
|
||||
|
||||
assert conversation == conversation_from_file
|
||||
|
||||
|
||||
def test_send_message() -> None:
|
||||
with patch("llm_chat.chat.Chat._make_request") as mock_make_request:
|
||||
mock_make_request.return_value = {
|
||||
|
|
|
@ -34,6 +34,10 @@ class ChatFake:
|
|||
"""Get the cost of the conversation."""
|
||||
return 0.0
|
||||
|
||||
def save(self) -> None:
|
||||
"""Dummy save method."""
|
||||
pass
|
||||
|
||||
def send_message(self, message: str) -> str:
|
||||
"""Echo the received message."""
|
||||
self.received_messages.append(message)
|
||||
|
|
Loading…
Reference in New Issue