diff --git a/.gitignore b/.gitignore index b9865d2..5b9b8c1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ .idea .vscode +.history # Created by https://www.toptal.com/developers/gitignore/api/python,vim,asdf # Edit at https://www.toptal.com/developers/gitignore?templates=python,vim,asdf diff --git a/pyproject.toml b/pyproject.toml index 35dd17d..61598ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "llm-chat" -version = "0.3.0" +version = "0.4.0" description = "A general CLI interface for large language models." authors = ["Paul Harrison "] readme = "README.md" diff --git a/src/llm_chat/chat.py b/src/llm_chat/chat.py index 328d3b6..0d0d989 100644 --- a/src/llm_chat/chat.py +++ b/src/llm_chat/chat.py @@ -1,5 +1,8 @@ +from datetime import datetime from enum import StrEnum, auto +from pathlib import Path from typing import Any, Protocol +from zoneinfo import ZoneInfo from openai import ChatCompletion from openai.openai_object import OpenAIObject @@ -18,6 +21,20 @@ INITIAL_SYSTEM_MESSAGES = [ ] +def save_conversation( + conversation: Conversation, history_dir: Path, dt: datetime +) -> None: + """Store a conversation in the history directory.""" + if conversation.prompt_tokens == 0: + return + + history_dir.mkdir(parents=True, exist_ok=True) + + path = history_dir / f"{dt.strftime('%Y%m%d%H%M%S')}.json" + with path.open(mode="w") as f: + f.write(conversation.model_dump_json(indent=2)) + + class Token(StrEnum): """Token type for the OpenAI chat.""" @@ -32,6 +49,9 @@ class ChatProtocol(Protocol): def cost(self) -> float: """Get the cost of the conversation.""" + def save(self) -> None: + """Save the conversation to the history directory.""" + def send_message(self, message: str) -> str: """Send a message to the assistant.""" @@ -57,12 +77,17 @@ class Chat: } def __init__( - self, settings: OpenAISettings | None = None, context: list[Message] = [] + self, + settings: OpenAISettings | None = None, + context: list[Message] = [], ) -> None: self._settings = settings self.conversation = Conversation( messages=INITIAL_SYSTEM_MESSAGES + context, + model=self.settings.model, + temperature=self.settings.temperature, ) + self._start_time = datetime.now(tz=ZoneInfo("UTC")) @property def settings(self) -> OpenAISettings: @@ -106,6 +131,14 @@ class Chat: 6, ) + def save(self) -> None: + """Save the conversation to the history directory.""" + save_conversation( + conversation=self.conversation, + history_dir=self.settings.history_dir, + dt=self._start_time, + ) + def send_message(self, prompt: str) -> str: """Send a message to the assistant. diff --git a/src/llm_chat/cli.py b/src/llm_chat/cli.py index 5d5c1b5..ae3bf5c 100644 --- a/src/llm_chat/cli.py +++ b/src/llm_chat/cli.py @@ -124,6 +124,7 @@ def chat( prompt = read_user_input(session) if prompt.strip() == "/q": finished = True + current_chat.save() else: response = current_chat.send_message(prompt.strip()) console.print(Markdown(response)) diff --git a/src/llm_chat/models.py b/src/llm_chat/models.py index 4b5a74c..94e6f57 100644 --- a/src/llm_chat/models.py +++ b/src/llm_chat/models.py @@ -2,6 +2,8 @@ from enum import StrEnum, auto from pydantic import BaseModel, ConfigDict +from llm_chat.settings import DEFAULT_TEMPERATURE, Model + class Role(StrEnum): """Role of a user in the chat.""" @@ -27,6 +29,8 @@ class Conversation(BaseModel): """Conversation in the chat.""" messages: list[Message] + model: Model + temperature: float = DEFAULT_TEMPERATURE completion_tokens: int = 0 prompt_tokens: int = 0 cost: float = 0.0 diff --git a/src/llm_chat/settings.py b/src/llm_chat/settings.py index f266183..1e024c9 100644 --- a/src/llm_chat/settings.py +++ b/src/llm_chat/settings.py @@ -1,4 +1,5 @@ from enum import StrEnum +from pathlib import Path from pydantic_settings import BaseSettings, SettingsConfigDict @@ -20,6 +21,7 @@ class OpenAISettings(BaseSettings): api_key: str = "" model: Model = DEFAULT_MODEL temperature: float = DEFAULT_TEMPERATURE + history_dir: Path = Path().absolute() / ".history" model_config: SettingsConfigDict = SettingsConfigDict( # type: ignore[misc] env_file=".env", diff --git a/tests/test_chat.py b/tests/test_chat.py index a5b6edf..ed07726 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -1,11 +1,51 @@ +from datetime import datetime +from pathlib import Path from unittest.mock import patch +from zoneinfo import ZoneInfo import pytest -from llm_chat.chat import Chat +from llm_chat.chat import Chat, save_conversation +from llm_chat.models import Conversation, Message, Role from llm_chat.settings import Model, OpenAISettings +def test_save_conversation(tmp_path: Path) -> None: + conversation = Conversation( + messages=[ + Message(role=Role.SYSTEM, content="Hello!"), + Message(role=Role.USER, content="Hi!"), + Message(role=Role.ASSISTANT, content="How are you?"), + ], + model=Model.GPT3, + temperature=0.5, + completion_tokens=10, + prompt_tokens=15, + cost=0.000043, + ) + + path = tmp_path / ".history" + expected_file_path = path / "20210101120000.json" + dt = datetime(2021, 1, 1, 12, 0, 0, tzinfo=ZoneInfo("UTC")) + + assert not path.exists() + + save_conversation( + conversation=conversation, + history_dir=path, + dt=dt, + ) + + assert path.exists() + assert path.is_dir() + assert expected_file_path in path.iterdir() + + with expected_file_path.open() as f: + conversation_from_file = Conversation.model_validate_json(f.read()) + + assert conversation == conversation_from_file + + def test_send_message() -> None: with patch("llm_chat.chat.Chat._make_request") as mock_make_request: mock_make_request.return_value = { diff --git a/tests/test_cli.py b/tests/test_cli.py index 675178d..d57098c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -34,6 +34,10 @@ class ChatFake: """Get the cost of the conversation.""" return 0.0 + def save(self) -> None: + """Dummy save method.""" + pass + def send_message(self, message: str) -> str: """Echo the received message.""" self.received_messages.append(message)