Name conversation

This commit is contained in:
Paul Harrison 2023-09-14 17:58:40 +01:00
parent f5c7c40cc8
commit 316953d6fd
8 changed files with 75 additions and 57 deletions

47
poetry.lock generated
View File

@ -1,10 +1,9 @@
# This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
[[package]]
name = "aiohttp"
version = "3.8.5"
description = "Async http client/server framework (asyncio)"
category = "main"
optional = false
python-versions = ">=3.6"
files = [
@ -113,7 +112,6 @@ speedups = ["Brotli", "aiodns", "cchardet"]
name = "aiosignal"
version = "1.3.1"
description = "aiosignal: a list of registered asynchronous callbacks"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@ -128,7 +126,6 @@ frozenlist = ">=1.1.0"
name = "annotated-types"
version = "0.5.0"
description = "Reusable constraint types to use with typing.Annotated"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@ -140,7 +137,6 @@ files = [
name = "async-timeout"
version = "4.0.3"
description = "Timeout context manager for asyncio programs"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@ -152,7 +148,6 @@ files = [
name = "attrs"
version = "23.1.0"
description = "Classes Without Boilerplate"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@ -171,7 +166,6 @@ tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pyte
name = "black"
version = "23.7.0"
description = "The uncompromising code formatter."
category = "dev"
optional = false
python-versions = ">=3.8"
files = [
@ -216,7 +210,6 @@ uvloop = ["uvloop (>=0.15.2)"]
name = "certifi"
version = "2023.7.22"
description = "Python package for providing Mozilla's CA Bundle."
category = "main"
optional = false
python-versions = ">=3.6"
files = [
@ -228,7 +221,6 @@ files = [
name = "charset-normalizer"
version = "3.2.0"
description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
category = "main"
optional = false
python-versions = ">=3.7.0"
files = [
@ -313,7 +305,6 @@ files = [
name = "click"
version = "8.1.6"
description = "Composable command line interface toolkit"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@ -328,7 +319,6 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""}
name = "colorama"
version = "0.4.6"
description = "Cross-platform colored terminal text."
category = "main"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
files = [
@ -340,7 +330,6 @@ files = [
name = "frozenlist"
version = "1.4.0"
description = "A list-like structure which implements collections.abc.MutableSequence"
category = "main"
optional = false
python-versions = ">=3.8"
files = [
@ -411,7 +400,6 @@ files = [
name = "idna"
version = "3.4"
description = "Internationalized Domain Names in Applications (IDNA)"
category = "main"
optional = false
python-versions = ">=3.5"
files = [
@ -423,7 +411,6 @@ files = [
name = "iniconfig"
version = "2.0.0"
description = "brain-dead simple config-ini parsing"
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
@ -435,7 +422,6 @@ files = [
name = "isort"
version = "5.12.0"
description = "A Python utility / library to sort Python imports."
category = "dev"
optional = false
python-versions = ">=3.8.0"
files = [
@ -453,7 +439,6 @@ requirements-deprecated-finder = ["pip-api", "pipreqs"]
name = "markdown-it-py"
version = "3.0.0"
description = "Python port of markdown-it. Markdown parsing, done right!"
category = "main"
optional = false
python-versions = ">=3.8"
files = [
@ -478,7 +463,6 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
name = "mdurl"
version = "0.1.2"
description = "Markdown URL utilities"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@ -490,7 +474,6 @@ files = [
name = "multidict"
version = "6.0.4"
description = "multidict implementation"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@ -574,7 +557,6 @@ files = [
name = "mypy"
version = "1.5.0"
description = "Optional static typing for Python"
category = "dev"
optional = false
python-versions = ">=3.8"
files = [
@ -615,7 +597,6 @@ reports = ["lxml"]
name = "mypy-extensions"
version = "1.0.0"
description = "Type system extensions for programs checked with the mypy type checker."
category = "dev"
optional = false
python-versions = ">=3.5"
files = [
@ -627,7 +608,6 @@ files = [
name = "openai"
version = "0.27.8"
description = "Python client library for the OpenAI API"
category = "main"
optional = false
python-versions = ">=3.7.1"
files = [
@ -642,7 +622,7 @@ tqdm = "*"
[package.extras]
datalib = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"]
dev = ["black (>=21.6b0,<22.0)", "pytest (>=6.0.0,<7.0.0)", "pytest-asyncio", "pytest-mock"]
dev = ["black (>=21.6b0,<22.0)", "pytest (==6.*)", "pytest-asyncio", "pytest-mock"]
embeddings = ["matplotlib", "numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "plotly", "scikit-learn (>=1.0.2)", "scipy", "tenacity (>=8.0.1)"]
wandb = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "wandb"]
@ -650,7 +630,6 @@ wandb = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1
name = "packaging"
version = "23.1"
description = "Core utilities for Python packages"
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
@ -662,7 +641,6 @@ files = [
name = "pathspec"
version = "0.11.2"
description = "Utility library for gitignore style pattern matching of file paths."
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
@ -674,7 +652,6 @@ files = [
name = "platformdirs"
version = "3.10.0"
description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"."
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
@ -690,7 +667,6 @@ test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-co
name = "pluggy"
version = "1.2.0"
description = "plugin and hook calling mechanisms for python"
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
@ -706,7 +682,6 @@ testing = ["pytest", "pytest-benchmark"]
name = "prompt-toolkit"
version = "3.0.39"
description = "Library for building powerful interactive command lines in Python"
category = "main"
optional = false
python-versions = ">=3.7.0"
files = [
@ -721,7 +696,6 @@ wcwidth = "*"
name = "pydantic"
version = "2.1.1"
description = "Data validation using Python type hints"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@ -741,7 +715,6 @@ email = ["email-validator (>=2.0.0)"]
name = "pydantic-core"
version = "2.4.0"
description = ""
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@ -855,7 +828,6 @@ typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
name = "pydantic-settings"
version = "2.0.2"
description = "Settings management using Pydantic"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@ -871,7 +843,6 @@ python-dotenv = ">=0.21.0"
name = "pydocstyle"
version = "6.3.0"
description = "Python docstring style checker"
category = "dev"
optional = false
python-versions = ">=3.6"
files = [
@ -889,7 +860,6 @@ toml = ["tomli (>=1.2.3)"]
name = "pygments"
version = "2.16.1"
description = "Pygments is a syntax highlighting package written in Python."
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@ -904,7 +874,6 @@ plugins = ["importlib-metadata"]
name = "pytest"
version = "7.4.0"
description = "pytest: simple powerful testing with Python"
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
@ -925,7 +894,6 @@ testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "no
name = "python-dotenv"
version = "1.0.0"
description = "Read key-value pairs from a .env file and set them as environment variables"
category = "main"
optional = false
python-versions = ">=3.8"
files = [
@ -940,7 +908,6 @@ cli = ["click (>=5.0)"]
name = "requests"
version = "2.31.0"
description = "Python HTTP for Humans."
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@ -962,7 +929,6 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
name = "rich"
version = "13.5.2"
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
category = "main"
optional = false
python-versions = ">=3.7.0"
files = [
@ -981,7 +947,6 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"]
name = "ruff"
version = "0.0.284"
description = "An extremely fast Python linter, written in Rust."
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
@ -1008,7 +973,6 @@ files = [
name = "shellingham"
version = "1.5.3"
description = "Tool to Detect Surrounding Shell"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@ -1020,7 +984,6 @@ files = [
name = "snowballstemmer"
version = "2.2.0"
description = "This package provides 29 stemmers for 28 languages generated from Snowball algorithms."
category = "dev"
optional = false
python-versions = "*"
files = [
@ -1032,7 +995,6 @@ files = [
name = "tqdm"
version = "4.66.1"
description = "Fast, Extensible Progress Meter"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@ -1053,7 +1015,6 @@ telegram = ["requests"]
name = "typer"
version = "0.9.0"
description = "Typer, build great CLIs. Easy to code. Based on Python type hints."
category = "main"
optional = false
python-versions = ">=3.6"
files = [
@ -1078,7 +1039,6 @@ test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.
name = "typing-extensions"
version = "4.7.1"
description = "Backported and Experimental Type Hints for Python 3.7+"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@ -1090,7 +1050,6 @@ files = [
name = "urllib3"
version = "2.0.4"
description = "HTTP library with thread-safe connection pooling, file post, and more."
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@ -1108,7 +1067,6 @@ zstd = ["zstandard (>=0.18.0)"]
name = "wcwidth"
version = "0.2.6"
description = "Measures the displayed width of unicode strings in a terminal"
category = "main"
optional = false
python-versions = "*"
files = [
@ -1120,7 +1078,6 @@ files = [
name = "yarl"
version = "1.9.2"
description = "Yet another URL library"
category = "main"
optional = false
python-versions = ">=3.7"
files = [

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "llm-chat"
version = "0.5.0"
version = "0.6.0"
description = "A general CLI interface for large language models."
authors = ["Paul Harrison <paul@harrison.sh>"]
readme = "README.md"

View File

@ -28,9 +28,11 @@ def save_conversation(
if conversation.prompt_tokens == 0:
return
filename = f"{dt.strftime('%Y%m%d%H%M%S')}{'-' + conversation.name if conversation.name else ''}"
history_dir.mkdir(parents=True, exist_ok=True)
path = history_dir / f"{dt.strftime('%Y%m%d%H%M%S')}.json"
path = history_dir / f"{filename}.json"
with path.open(mode="w") as f:
f.write(conversation.model_dump_json(indent=2))
@ -86,6 +88,7 @@ class Chat:
self,
settings: OpenAISettings | None = None,
context: list[Message] = [],
name: str = "",
initial_system_messages: bool = True,
) -> None:
self._settings = settings
@ -95,6 +98,7 @@ class Chat:
else context,
model=self.settings.model,
temperature=self.settings.temperature,
name=name,
)
self._start_time = datetime.now(tz=ZoneInfo("UTC"))
@ -186,10 +190,10 @@ class Chat:
def get_chat(
settings: OpenAISettings | None = None, context: list[Message] = []
settings: OpenAISettings | None = None, context: list[Message] = [], name: str = ""
) -> ChatProtocol:
"""Get a chat object."""
return Chat(settings=settings, context=context)
return Chat(settings=settings, context=context, name=name)
def get_chat_class() -> Type[Chat]:

View File

@ -8,7 +8,7 @@ from rich.markdown import Markdown
from llm_chat.chat import ChatProtocol, get_chat, get_chat_class
from llm_chat.models import Message, Role
from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings
from llm_chat.settings import DEFAULT_HISTORY_DIR, DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings
app = typer.Typer()
@ -71,6 +71,8 @@ def run_conversation(current_chat: ChatProtocol) -> None:
console.print(
f"[bold green]Temperature:[/bold green] {current_chat.settings.temperature}"
)
if current_chat.conversation.name:
console.print(f"[bold green]Name:[/bold green] {current_chat.conversation.name}")
while not finished:
prompt = read_user_input(session)
@ -123,17 +125,38 @@ def chat(
readable=True,
),
] = [],
history_dir: Annotated[
Path,
typer.Option(
...,
"--history-dir",
"-d",
help="Path to the directory where conversation history will be saved.",
exists=True,
dir_okay=True,
file_okay=False,
),
] = DEFAULT_HISTORY_DIR,
name: Annotated[
str,
typer.Option(
...,
"--name",
"-n",
help="Name of the chat.",
),
] = "",
) -> None:
"""Start a chat session."""
# TODO: Add option to provide context string as an argument.
if api_key is not None:
settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature)
settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature, history_dir=history_dir)
else:
settings = OpenAISettings(model=model, temperature=temperature)
settings = OpenAISettings(model=model, temperature=temperature, history_dir=history_dir)
context_messages = [load_context(path) for path in context]
current_chat = get_chat(settings=settings, context=context_messages)
current_chat = get_chat(settings=settings, context=context_messages, name=name)
run_conversation(current_chat)

View File

@ -31,6 +31,7 @@ class Conversation(BaseModel):
messages: list[Message]
model: Model
temperature: float = DEFAULT_TEMPERATURE
name: str = ""
completion_tokens: int = 0
prompt_tokens: int = 0
cost: float = 0.0

View File

@ -13,6 +13,7 @@ class Model(StrEnum):
DEFAULT_MODEL = Model.GPT3
DEFAULT_TEMPERATURE = 0.7
DEFAULT_HISTORY_DIR = Path().absolute() / ".history"
class OpenAISettings(BaseSettings):
@ -21,7 +22,7 @@ class OpenAISettings(BaseSettings):
api_key: str = ""
model: Model = DEFAULT_MODEL
temperature: float = DEFAULT_TEMPERATURE
history_dir: Path = Path().absolute() / ".history"
history_dir: Path = DEFAULT_HISTORY_DIR
model_config: SettingsConfigDict = SettingsConfigDict( # type: ignore[misc]
env_file=".env",

View File

@ -10,7 +10,8 @@ from llm_chat.models import Conversation, Message, Role
from llm_chat.settings import Model, OpenAISettings
def test_save_conversation(tmp_path: Path) -> None:
@pytest.mark.parametrize("name,expected_filename", [("", "20210101120000.json"), ("foo", "20210101120000-foo.json")])
def test_save_conversation(name: str, expected_filename: str, tmp_path: Path) -> None:
conversation = Conversation(
messages=[
Message(role=Role.SYSTEM, content="Hello!"),
@ -22,10 +23,11 @@ def test_save_conversation(tmp_path: Path) -> None:
completion_tokens=10,
prompt_tokens=15,
cost=0.000043,
name=name,
)
path = tmp_path / ".history"
expected_file_path = path / "20210101120000.json"
expected_file_path = path / expected_filename
dt = datetime(2021, 1, 1, 12, 0, 0, tzinfo=ZoneInfo("UTC"))
assert not path.exists()

View File

@ -1,15 +1,19 @@
from datetime import datetime
from io import StringIO
from itertools import product
from pathlib import Path
from typing import Any, Type
from unittest.mock import MagicMock
from zoneinfo import ZoneInfo
import pytest
import time_machine
from pytest import MonkeyPatch
from rich.console import Console
from typer.testing import CliRunner
import llm_chat
from llm_chat.chat import ChatProtocol
from llm_chat.chat import ChatProtocol, save_conversation
from llm_chat.cli import app
from llm_chat.models import Conversation, Message, Role
from llm_chat.settings import Model, OpenAISettings
@ -112,6 +116,32 @@ def test_chat_with_context(
assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)]
@pytest.mark.parametrize("argument,name", list(product(("--name", "-n"), ("", "foo"))), ids=[f"{arg} {name}" for arg, name in product(("--name", "-n"), ("", "foo"))])
def test_chat_with_name(
argument: str, name: str, monkeypatch: MonkeyPatch, tmp_path: Path
) -> None:
chat_fake = ChatFake()
output = StringIO()
console = Console(file=output)
def mock_get_chat(**kwargs: Any) -> ChatProtocol:
chat_fake._set_args(**kwargs)
return chat_fake
def mock_get_console() -> Console:
return console
mock_read_user_input = MagicMock(side_effect=["Hello", "/q"])
monkeypatch.setattr(llm_chat.cli, "get_chat", mock_get_chat)
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
result = runner.invoke(app, ["chat", argument, name])
assert result.exit_code == 0
assert chat_fake.args["name"] == name
def test_load(monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
# Create a conversation object to save
conversation = Conversation(