Name conversation

Closes #4
This commit is contained in:
Paul Harrison 2023-09-14 17:58:40 +01:00
parent f5c7c40cc8
commit 68fc11c450
8 changed files with 75 additions and 57 deletions

47
poetry.lock generated
View File

@ -1,10 +1,9 @@
# This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand. # This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
[[package]] [[package]]
name = "aiohttp" name = "aiohttp"
version = "3.8.5" version = "3.8.5"
description = "Async http client/server framework (asyncio)" description = "Async http client/server framework (asyncio)"
category = "main"
optional = false optional = false
python-versions = ">=3.6" python-versions = ">=3.6"
files = [ files = [
@ -113,7 +112,6 @@ speedups = ["Brotli", "aiodns", "cchardet"]
name = "aiosignal" name = "aiosignal"
version = "1.3.1" version = "1.3.1"
description = "aiosignal: a list of registered asynchronous callbacks" description = "aiosignal: a list of registered asynchronous callbacks"
category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
@ -128,7 +126,6 @@ frozenlist = ">=1.1.0"
name = "annotated-types" name = "annotated-types"
version = "0.5.0" version = "0.5.0"
description = "Reusable constraint types to use with typing.Annotated" description = "Reusable constraint types to use with typing.Annotated"
category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
@ -140,7 +137,6 @@ files = [
name = "async-timeout" name = "async-timeout"
version = "4.0.3" version = "4.0.3"
description = "Timeout context manager for asyncio programs" description = "Timeout context manager for asyncio programs"
category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
@ -152,7 +148,6 @@ files = [
name = "attrs" name = "attrs"
version = "23.1.0" version = "23.1.0"
description = "Classes Without Boilerplate" description = "Classes Without Boilerplate"
category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
@ -171,7 +166,6 @@ tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pyte
name = "black" name = "black"
version = "23.7.0" version = "23.7.0"
description = "The uncompromising code formatter." description = "The uncompromising code formatter."
category = "dev"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
@ -216,7 +210,6 @@ uvloop = ["uvloop (>=0.15.2)"]
name = "certifi" name = "certifi"
version = "2023.7.22" version = "2023.7.22"
description = "Python package for providing Mozilla's CA Bundle." description = "Python package for providing Mozilla's CA Bundle."
category = "main"
optional = false optional = false
python-versions = ">=3.6" python-versions = ">=3.6"
files = [ files = [
@ -228,7 +221,6 @@ files = [
name = "charset-normalizer" name = "charset-normalizer"
version = "3.2.0" version = "3.2.0"
description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
category = "main"
optional = false optional = false
python-versions = ">=3.7.0" python-versions = ">=3.7.0"
files = [ files = [
@ -313,7 +305,6 @@ files = [
name = "click" name = "click"
version = "8.1.6" version = "8.1.6"
description = "Composable command line interface toolkit" description = "Composable command line interface toolkit"
category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
@ -328,7 +319,6 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""}
name = "colorama" name = "colorama"
version = "0.4.6" version = "0.4.6"
description = "Cross-platform colored terminal text." description = "Cross-platform colored terminal text."
category = "main"
optional = false optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
files = [ files = [
@ -340,7 +330,6 @@ files = [
name = "frozenlist" name = "frozenlist"
version = "1.4.0" version = "1.4.0"
description = "A list-like structure which implements collections.abc.MutableSequence" description = "A list-like structure which implements collections.abc.MutableSequence"
category = "main"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
@ -411,7 +400,6 @@ files = [
name = "idna" name = "idna"
version = "3.4" version = "3.4"
description = "Internationalized Domain Names in Applications (IDNA)" description = "Internationalized Domain Names in Applications (IDNA)"
category = "main"
optional = false optional = false
python-versions = ">=3.5" python-versions = ">=3.5"
files = [ files = [
@ -423,7 +411,6 @@ files = [
name = "iniconfig" name = "iniconfig"
version = "2.0.0" version = "2.0.0"
description = "brain-dead simple config-ini parsing" description = "brain-dead simple config-ini parsing"
category = "dev"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
@ -435,7 +422,6 @@ files = [
name = "isort" name = "isort"
version = "5.12.0" version = "5.12.0"
description = "A Python utility / library to sort Python imports." description = "A Python utility / library to sort Python imports."
category = "dev"
optional = false optional = false
python-versions = ">=3.8.0" python-versions = ">=3.8.0"
files = [ files = [
@ -453,7 +439,6 @@ requirements-deprecated-finder = ["pip-api", "pipreqs"]
name = "markdown-it-py" name = "markdown-it-py"
version = "3.0.0" version = "3.0.0"
description = "Python port of markdown-it. Markdown parsing, done right!" description = "Python port of markdown-it. Markdown parsing, done right!"
category = "main"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
@ -478,7 +463,6 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
name = "mdurl" name = "mdurl"
version = "0.1.2" version = "0.1.2"
description = "Markdown URL utilities" description = "Markdown URL utilities"
category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
@ -490,7 +474,6 @@ files = [
name = "multidict" name = "multidict"
version = "6.0.4" version = "6.0.4"
description = "multidict implementation" description = "multidict implementation"
category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
@ -574,7 +557,6 @@ files = [
name = "mypy" name = "mypy"
version = "1.5.0" version = "1.5.0"
description = "Optional static typing for Python" description = "Optional static typing for Python"
category = "dev"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
@ -615,7 +597,6 @@ reports = ["lxml"]
name = "mypy-extensions" name = "mypy-extensions"
version = "1.0.0" version = "1.0.0"
description = "Type system extensions for programs checked with the mypy type checker." description = "Type system extensions for programs checked with the mypy type checker."
category = "dev"
optional = false optional = false
python-versions = ">=3.5" python-versions = ">=3.5"
files = [ files = [
@ -627,7 +608,6 @@ files = [
name = "openai" name = "openai"
version = "0.27.8" version = "0.27.8"
description = "Python client library for the OpenAI API" description = "Python client library for the OpenAI API"
category = "main"
optional = false optional = false
python-versions = ">=3.7.1" python-versions = ">=3.7.1"
files = [ files = [
@ -642,7 +622,7 @@ tqdm = "*"
[package.extras] [package.extras]
datalib = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] datalib = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"]
dev = ["black (>=21.6b0,<22.0)", "pytest (>=6.0.0,<7.0.0)", "pytest-asyncio", "pytest-mock"] dev = ["black (>=21.6b0,<22.0)", "pytest (==6.*)", "pytest-asyncio", "pytest-mock"]
embeddings = ["matplotlib", "numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "plotly", "scikit-learn (>=1.0.2)", "scipy", "tenacity (>=8.0.1)"] embeddings = ["matplotlib", "numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "plotly", "scikit-learn (>=1.0.2)", "scipy", "tenacity (>=8.0.1)"]
wandb = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "wandb"] wandb = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "wandb"]
@ -650,7 +630,6 @@ wandb = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1
name = "packaging" name = "packaging"
version = "23.1" version = "23.1"
description = "Core utilities for Python packages" description = "Core utilities for Python packages"
category = "dev"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
@ -662,7 +641,6 @@ files = [
name = "pathspec" name = "pathspec"
version = "0.11.2" version = "0.11.2"
description = "Utility library for gitignore style pattern matching of file paths." description = "Utility library for gitignore style pattern matching of file paths."
category = "dev"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
@ -674,7 +652,6 @@ files = [
name = "platformdirs" name = "platformdirs"
version = "3.10.0" version = "3.10.0"
description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"."
category = "dev"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
@ -690,7 +667,6 @@ test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-co
name = "pluggy" name = "pluggy"
version = "1.2.0" version = "1.2.0"
description = "plugin and hook calling mechanisms for python" description = "plugin and hook calling mechanisms for python"
category = "dev"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
@ -706,7 +682,6 @@ testing = ["pytest", "pytest-benchmark"]
name = "prompt-toolkit" name = "prompt-toolkit"
version = "3.0.39" version = "3.0.39"
description = "Library for building powerful interactive command lines in Python" description = "Library for building powerful interactive command lines in Python"
category = "main"
optional = false optional = false
python-versions = ">=3.7.0" python-versions = ">=3.7.0"
files = [ files = [
@ -721,7 +696,6 @@ wcwidth = "*"
name = "pydantic" name = "pydantic"
version = "2.1.1" version = "2.1.1"
description = "Data validation using Python type hints" description = "Data validation using Python type hints"
category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
@ -741,7 +715,6 @@ email = ["email-validator (>=2.0.0)"]
name = "pydantic-core" name = "pydantic-core"
version = "2.4.0" version = "2.4.0"
description = "" description = ""
category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
@ -855,7 +828,6 @@ typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
name = "pydantic-settings" name = "pydantic-settings"
version = "2.0.2" version = "2.0.2"
description = "Settings management using Pydantic" description = "Settings management using Pydantic"
category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
@ -871,7 +843,6 @@ python-dotenv = ">=0.21.0"
name = "pydocstyle" name = "pydocstyle"
version = "6.3.0" version = "6.3.0"
description = "Python docstring style checker" description = "Python docstring style checker"
category = "dev"
optional = false optional = false
python-versions = ">=3.6" python-versions = ">=3.6"
files = [ files = [
@ -889,7 +860,6 @@ toml = ["tomli (>=1.2.3)"]
name = "pygments" name = "pygments"
version = "2.16.1" version = "2.16.1"
description = "Pygments is a syntax highlighting package written in Python." description = "Pygments is a syntax highlighting package written in Python."
category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
@ -904,7 +874,6 @@ plugins = ["importlib-metadata"]
name = "pytest" name = "pytest"
version = "7.4.0" version = "7.4.0"
description = "pytest: simple powerful testing with Python" description = "pytest: simple powerful testing with Python"
category = "dev"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
@ -925,7 +894,6 @@ testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "no
name = "python-dotenv" name = "python-dotenv"
version = "1.0.0" version = "1.0.0"
description = "Read key-value pairs from a .env file and set them as environment variables" description = "Read key-value pairs from a .env file and set them as environment variables"
category = "main"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
@ -940,7 +908,6 @@ cli = ["click (>=5.0)"]
name = "requests" name = "requests"
version = "2.31.0" version = "2.31.0"
description = "Python HTTP for Humans." description = "Python HTTP for Humans."
category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
@ -962,7 +929,6 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
name = "rich" name = "rich"
version = "13.5.2" version = "13.5.2"
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
category = "main"
optional = false optional = false
python-versions = ">=3.7.0" python-versions = ">=3.7.0"
files = [ files = [
@ -981,7 +947,6 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"]
name = "ruff" name = "ruff"
version = "0.0.284" version = "0.0.284"
description = "An extremely fast Python linter, written in Rust." description = "An extremely fast Python linter, written in Rust."
category = "dev"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
@ -1008,7 +973,6 @@ files = [
name = "shellingham" name = "shellingham"
version = "1.5.3" version = "1.5.3"
description = "Tool to Detect Surrounding Shell" description = "Tool to Detect Surrounding Shell"
category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
@ -1020,7 +984,6 @@ files = [
name = "snowballstemmer" name = "snowballstemmer"
version = "2.2.0" version = "2.2.0"
description = "This package provides 29 stemmers for 28 languages generated from Snowball algorithms." description = "This package provides 29 stemmers for 28 languages generated from Snowball algorithms."
category = "dev"
optional = false optional = false
python-versions = "*" python-versions = "*"
files = [ files = [
@ -1032,7 +995,6 @@ files = [
name = "tqdm" name = "tqdm"
version = "4.66.1" version = "4.66.1"
description = "Fast, Extensible Progress Meter" description = "Fast, Extensible Progress Meter"
category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
@ -1053,7 +1015,6 @@ telegram = ["requests"]
name = "typer" name = "typer"
version = "0.9.0" version = "0.9.0"
description = "Typer, build great CLIs. Easy to code. Based on Python type hints." description = "Typer, build great CLIs. Easy to code. Based on Python type hints."
category = "main"
optional = false optional = false
python-versions = ">=3.6" python-versions = ">=3.6"
files = [ files = [
@ -1078,7 +1039,6 @@ test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.
name = "typing-extensions" name = "typing-extensions"
version = "4.7.1" version = "4.7.1"
description = "Backported and Experimental Type Hints for Python 3.7+" description = "Backported and Experimental Type Hints for Python 3.7+"
category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
@ -1090,7 +1050,6 @@ files = [
name = "urllib3" name = "urllib3"
version = "2.0.4" version = "2.0.4"
description = "HTTP library with thread-safe connection pooling, file post, and more." description = "HTTP library with thread-safe connection pooling, file post, and more."
category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
@ -1108,7 +1067,6 @@ zstd = ["zstandard (>=0.18.0)"]
name = "wcwidth" name = "wcwidth"
version = "0.2.6" version = "0.2.6"
description = "Measures the displayed width of unicode strings in a terminal" description = "Measures the displayed width of unicode strings in a terminal"
category = "main"
optional = false optional = false
python-versions = "*" python-versions = "*"
files = [ files = [
@ -1120,7 +1078,6 @@ files = [
name = "yarl" name = "yarl"
version = "1.9.2" version = "1.9.2"
description = "Yet another URL library" description = "Yet another URL library"
category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "llm-chat" name = "llm-chat"
version = "0.5.0" version = "0.6.0"
description = "A general CLI interface for large language models." description = "A general CLI interface for large language models."
authors = ["Paul Harrison <paul@harrison.sh>"] authors = ["Paul Harrison <paul@harrison.sh>"]
readme = "README.md" readme = "README.md"

View File

@ -28,9 +28,11 @@ def save_conversation(
if conversation.prompt_tokens == 0: if conversation.prompt_tokens == 0:
return return
filename = f"{dt.strftime('%Y%m%d%H%M%S')}{'-' + conversation.name if conversation.name else ''}"
history_dir.mkdir(parents=True, exist_ok=True) history_dir.mkdir(parents=True, exist_ok=True)
path = history_dir / f"{dt.strftime('%Y%m%d%H%M%S')}.json" path = history_dir / f"{filename}.json"
with path.open(mode="w") as f: with path.open(mode="w") as f:
f.write(conversation.model_dump_json(indent=2)) f.write(conversation.model_dump_json(indent=2))
@ -86,6 +88,7 @@ class Chat:
self, self,
settings: OpenAISettings | None = None, settings: OpenAISettings | None = None,
context: list[Message] = [], context: list[Message] = [],
name: str = "",
initial_system_messages: bool = True, initial_system_messages: bool = True,
) -> None: ) -> None:
self._settings = settings self._settings = settings
@ -95,6 +98,7 @@ class Chat:
else context, else context,
model=self.settings.model, model=self.settings.model,
temperature=self.settings.temperature, temperature=self.settings.temperature,
name=name,
) )
self._start_time = datetime.now(tz=ZoneInfo("UTC")) self._start_time = datetime.now(tz=ZoneInfo("UTC"))
@ -186,10 +190,10 @@ class Chat:
def get_chat( def get_chat(
settings: OpenAISettings | None = None, context: list[Message] = [] settings: OpenAISettings | None = None, context: list[Message] = [], name: str = ""
) -> ChatProtocol: ) -> ChatProtocol:
"""Get a chat object.""" """Get a chat object."""
return Chat(settings=settings, context=context) return Chat(settings=settings, context=context, name=name)
def get_chat_class() -> Type[Chat]: def get_chat_class() -> Type[Chat]:

View File

@ -8,7 +8,7 @@ from rich.markdown import Markdown
from llm_chat.chat import ChatProtocol, get_chat, get_chat_class from llm_chat.chat import ChatProtocol, get_chat, get_chat_class
from llm_chat.models import Message, Role from llm_chat.models import Message, Role
from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings from llm_chat.settings import DEFAULT_HISTORY_DIR, DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings
app = typer.Typer() app = typer.Typer()
@ -71,6 +71,8 @@ def run_conversation(current_chat: ChatProtocol) -> None:
console.print( console.print(
f"[bold green]Temperature:[/bold green] {current_chat.settings.temperature}" f"[bold green]Temperature:[/bold green] {current_chat.settings.temperature}"
) )
if current_chat.conversation.name:
console.print(f"[bold green]Name:[/bold green] {current_chat.conversation.name}")
while not finished: while not finished:
prompt = read_user_input(session) prompt = read_user_input(session)
@ -123,17 +125,38 @@ def chat(
readable=True, readable=True,
), ),
] = [], ] = [],
history_dir: Annotated[
Path,
typer.Option(
...,
"--history-dir",
"-d",
help="Path to the directory where conversation history will be saved.",
exists=True,
dir_okay=True,
file_okay=False,
),
] = DEFAULT_HISTORY_DIR,
name: Annotated[
str,
typer.Option(
...,
"--name",
"-n",
help="Name of the chat.",
),
] = "",
) -> None: ) -> None:
"""Start a chat session.""" """Start a chat session."""
# TODO: Add option to provide context string as an argument. # TODO: Add option to provide context string as an argument.
if api_key is not None: if api_key is not None:
settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature) settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature, history_dir=history_dir)
else: else:
settings = OpenAISettings(model=model, temperature=temperature) settings = OpenAISettings(model=model, temperature=temperature, history_dir=history_dir)
context_messages = [load_context(path) for path in context] context_messages = [load_context(path) for path in context]
current_chat = get_chat(settings=settings, context=context_messages) current_chat = get_chat(settings=settings, context=context_messages, name=name)
run_conversation(current_chat) run_conversation(current_chat)

View File

@ -31,6 +31,7 @@ class Conversation(BaseModel):
messages: list[Message] messages: list[Message]
model: Model model: Model
temperature: float = DEFAULT_TEMPERATURE temperature: float = DEFAULT_TEMPERATURE
name: str = ""
completion_tokens: int = 0 completion_tokens: int = 0
prompt_tokens: int = 0 prompt_tokens: int = 0
cost: float = 0.0 cost: float = 0.0

View File

@ -13,6 +13,7 @@ class Model(StrEnum):
DEFAULT_MODEL = Model.GPT3 DEFAULT_MODEL = Model.GPT3
DEFAULT_TEMPERATURE = 0.7 DEFAULT_TEMPERATURE = 0.7
DEFAULT_HISTORY_DIR = Path().absolute() / ".history"
class OpenAISettings(BaseSettings): class OpenAISettings(BaseSettings):
@ -21,7 +22,7 @@ class OpenAISettings(BaseSettings):
api_key: str = "" api_key: str = ""
model: Model = DEFAULT_MODEL model: Model = DEFAULT_MODEL
temperature: float = DEFAULT_TEMPERATURE temperature: float = DEFAULT_TEMPERATURE
history_dir: Path = Path().absolute() / ".history" history_dir: Path = DEFAULT_HISTORY_DIR
model_config: SettingsConfigDict = SettingsConfigDict( # type: ignore[misc] model_config: SettingsConfigDict = SettingsConfigDict( # type: ignore[misc]
env_file=".env", env_file=".env",

View File

@ -10,7 +10,8 @@ from llm_chat.models import Conversation, Message, Role
from llm_chat.settings import Model, OpenAISettings from llm_chat.settings import Model, OpenAISettings
def test_save_conversation(tmp_path: Path) -> None: @pytest.mark.parametrize("name,expected_filename", [("", "20210101120000.json"), ("foo", "20210101120000-foo.json")])
def test_save_conversation(name: str, expected_filename: str, tmp_path: Path) -> None:
conversation = Conversation( conversation = Conversation(
messages=[ messages=[
Message(role=Role.SYSTEM, content="Hello!"), Message(role=Role.SYSTEM, content="Hello!"),
@ -22,10 +23,11 @@ def test_save_conversation(tmp_path: Path) -> None:
completion_tokens=10, completion_tokens=10,
prompt_tokens=15, prompt_tokens=15,
cost=0.000043, cost=0.000043,
name=name,
) )
path = tmp_path / ".history" path = tmp_path / ".history"
expected_file_path = path / "20210101120000.json" expected_file_path = path / expected_filename
dt = datetime(2021, 1, 1, 12, 0, 0, tzinfo=ZoneInfo("UTC")) dt = datetime(2021, 1, 1, 12, 0, 0, tzinfo=ZoneInfo("UTC"))
assert not path.exists() assert not path.exists()

View File

@ -1,15 +1,19 @@
from datetime import datetime
from io import StringIO from io import StringIO
from itertools import product
from pathlib import Path from pathlib import Path
from typing import Any, Type from typing import Any, Type
from unittest.mock import MagicMock from unittest.mock import MagicMock
from zoneinfo import ZoneInfo
import pytest import pytest
import time_machine
from pytest import MonkeyPatch from pytest import MonkeyPatch
from rich.console import Console from rich.console import Console
from typer.testing import CliRunner from typer.testing import CliRunner
import llm_chat import llm_chat
from llm_chat.chat import ChatProtocol from llm_chat.chat import ChatProtocol, save_conversation
from llm_chat.cli import app from llm_chat.cli import app
from llm_chat.models import Conversation, Message, Role from llm_chat.models import Conversation, Message, Role
from llm_chat.settings import Model, OpenAISettings from llm_chat.settings import Model, OpenAISettings
@ -112,6 +116,32 @@ def test_chat_with_context(
assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)] assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)]
@pytest.mark.parametrize("argument,name", list(product(("--name", "-n"), ("", "foo"))), ids=[f"{arg} {name}" for arg, name in product(("--name", "-n"), ("", "foo"))])
def test_chat_with_name(
argument: str, name: str, monkeypatch: MonkeyPatch, tmp_path: Path
) -> None:
chat_fake = ChatFake()
output = StringIO()
console = Console(file=output)
def mock_get_chat(**kwargs: Any) -> ChatProtocol:
chat_fake._set_args(**kwargs)
return chat_fake
def mock_get_console() -> Console:
return console
mock_read_user_input = MagicMock(side_effect=["Hello", "/q"])
monkeypatch.setattr(llm_chat.cli, "get_chat", mock_get_chat)
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
result = runner.invoke(app, ["chat", argument, name])
assert result.exit_code == 0
assert chat_fake.args["name"] == name
def test_load(monkeypatch: MonkeyPatch, tmp_path: Path) -> None: def test_load(monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
# Create a conversation object to save # Create a conversation object to save
conversation = Conversation( conversation = Conversation(