Name conversation #7
|
@ -1,10 +1,9 @@
|
|||
# This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiohttp"
|
||||
version = "3.8.5"
|
||||
description = "Async http client/server framework (asyncio)"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
|
@ -113,7 +112,6 @@ speedups = ["Brotli", "aiodns", "cchardet"]
|
|||
name = "aiosignal"
|
||||
version = "1.3.1"
|
||||
description = "aiosignal: a list of registered asynchronous callbacks"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
|
@ -128,7 +126,6 @@ frozenlist = ">=1.1.0"
|
|||
name = "annotated-types"
|
||||
version = "0.5.0"
|
||||
description = "Reusable constraint types to use with typing.Annotated"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
|
@ -140,7 +137,6 @@ files = [
|
|||
name = "async-timeout"
|
||||
version = "4.0.3"
|
||||
description = "Timeout context manager for asyncio programs"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
|
@ -152,7 +148,6 @@ files = [
|
|||
name = "attrs"
|
||||
version = "23.1.0"
|
||||
description = "Classes Without Boilerplate"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
|
@ -171,7 +166,6 @@ tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pyte
|
|||
name = "black"
|
||||
version = "23.7.0"
|
||||
description = "The uncompromising code formatter."
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
|
@ -216,7 +210,6 @@ uvloop = ["uvloop (>=0.15.2)"]
|
|||
name = "certifi"
|
||||
version = "2023.7.22"
|
||||
description = "Python package for providing Mozilla's CA Bundle."
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
|
@ -228,7 +221,6 @@ files = [
|
|||
name = "charset-normalizer"
|
||||
version = "3.2.0"
|
||||
description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7.0"
|
||||
files = [
|
||||
|
@ -313,7 +305,6 @@ files = [
|
|||
name = "click"
|
||||
version = "8.1.6"
|
||||
description = "Composable command line interface toolkit"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
|
@ -328,7 +319,6 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""}
|
|||
name = "colorama"
|
||||
version = "0.4.6"
|
||||
description = "Cross-platform colored terminal text."
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
|
||||
files = [
|
||||
|
@ -340,7 +330,6 @@ files = [
|
|||
name = "frozenlist"
|
||||
version = "1.4.0"
|
||||
description = "A list-like structure which implements collections.abc.MutableSequence"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
|
@ -411,7 +400,6 @@ files = [
|
|||
name = "idna"
|
||||
version = "3.4"
|
||||
description = "Internationalized Domain Names in Applications (IDNA)"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.5"
|
||||
files = [
|
||||
|
@ -423,7 +411,6 @@ files = [
|
|||
name = "iniconfig"
|
||||
version = "2.0.0"
|
||||
description = "brain-dead simple config-ini parsing"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
|
@ -435,7 +422,6 @@ files = [
|
|||
name = "isort"
|
||||
version = "5.12.0"
|
||||
description = "A Python utility / library to sort Python imports."
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.8.0"
|
||||
files = [
|
||||
|
@ -453,7 +439,6 @@ requirements-deprecated-finder = ["pip-api", "pipreqs"]
|
|||
name = "markdown-it-py"
|
||||
version = "3.0.0"
|
||||
description = "Python port of markdown-it. Markdown parsing, done right!"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
|
@ -478,7 +463,6 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
|
|||
name = "mdurl"
|
||||
version = "0.1.2"
|
||||
description = "Markdown URL utilities"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
|
@ -490,7 +474,6 @@ files = [
|
|||
name = "multidict"
|
||||
version = "6.0.4"
|
||||
description = "multidict implementation"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
|
@ -574,7 +557,6 @@ files = [
|
|||
name = "mypy"
|
||||
version = "1.5.0"
|
||||
description = "Optional static typing for Python"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
|
@ -615,7 +597,6 @@ reports = ["lxml"]
|
|||
name = "mypy-extensions"
|
||||
version = "1.0.0"
|
||||
description = "Type system extensions for programs checked with the mypy type checker."
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.5"
|
||||
files = [
|
||||
|
@ -627,7 +608,6 @@ files = [
|
|||
name = "openai"
|
||||
version = "0.27.8"
|
||||
description = "Python client library for the OpenAI API"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7.1"
|
||||
files = [
|
||||
|
@ -642,7 +622,7 @@ tqdm = "*"
|
|||
|
||||
[package.extras]
|
||||
datalib = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"]
|
||||
dev = ["black (>=21.6b0,<22.0)", "pytest (>=6.0.0,<7.0.0)", "pytest-asyncio", "pytest-mock"]
|
||||
dev = ["black (>=21.6b0,<22.0)", "pytest (==6.*)", "pytest-asyncio", "pytest-mock"]
|
||||
embeddings = ["matplotlib", "numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "plotly", "scikit-learn (>=1.0.2)", "scipy", "tenacity (>=8.0.1)"]
|
||||
wandb = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "wandb"]
|
||||
|
||||
|
@ -650,7 +630,6 @@ wandb = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1
|
|||
name = "packaging"
|
||||
version = "23.1"
|
||||
description = "Core utilities for Python packages"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
|
@ -662,7 +641,6 @@ files = [
|
|||
name = "pathspec"
|
||||
version = "0.11.2"
|
||||
description = "Utility library for gitignore style pattern matching of file paths."
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
|
@ -674,7 +652,6 @@ files = [
|
|||
name = "platformdirs"
|
||||
version = "3.10.0"
|
||||
description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"."
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
|
@ -690,7 +667,6 @@ test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-co
|
|||
name = "pluggy"
|
||||
version = "1.2.0"
|
||||
description = "plugin and hook calling mechanisms for python"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
|
@ -706,7 +682,6 @@ testing = ["pytest", "pytest-benchmark"]
|
|||
name = "prompt-toolkit"
|
||||
version = "3.0.39"
|
||||
description = "Library for building powerful interactive command lines in Python"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7.0"
|
||||
files = [
|
||||
|
@ -721,7 +696,6 @@ wcwidth = "*"
|
|||
name = "pydantic"
|
||||
version = "2.1.1"
|
||||
description = "Data validation using Python type hints"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
|
@ -741,7 +715,6 @@ email = ["email-validator (>=2.0.0)"]
|
|||
name = "pydantic-core"
|
||||
version = "2.4.0"
|
||||
description = ""
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
|
@ -855,7 +828,6 @@ typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
|
|||
name = "pydantic-settings"
|
||||
version = "2.0.2"
|
||||
description = "Settings management using Pydantic"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
|
@ -871,7 +843,6 @@ python-dotenv = ">=0.21.0"
|
|||
name = "pydocstyle"
|
||||
version = "6.3.0"
|
||||
description = "Python docstring style checker"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
|
@ -889,7 +860,6 @@ toml = ["tomli (>=1.2.3)"]
|
|||
name = "pygments"
|
||||
version = "2.16.1"
|
||||
description = "Pygments is a syntax highlighting package written in Python."
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
|
@ -904,7 +874,6 @@ plugins = ["importlib-metadata"]
|
|||
name = "pytest"
|
||||
version = "7.4.0"
|
||||
description = "pytest: simple powerful testing with Python"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
|
@ -925,7 +894,6 @@ testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "no
|
|||
name = "python-dotenv"
|
||||
version = "1.0.0"
|
||||
description = "Read key-value pairs from a .env file and set them as environment variables"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
|
@ -940,7 +908,6 @@ cli = ["click (>=5.0)"]
|
|||
name = "requests"
|
||||
version = "2.31.0"
|
||||
description = "Python HTTP for Humans."
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
|
@ -962,7 +929,6 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
|||
name = "rich"
|
||||
version = "13.5.2"
|
||||
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7.0"
|
||||
files = [
|
||||
|
@ -981,7 +947,6 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"]
|
|||
name = "ruff"
|
||||
version = "0.0.284"
|
||||
description = "An extremely fast Python linter, written in Rust."
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
|
@ -1008,7 +973,6 @@ files = [
|
|||
name = "shellingham"
|
||||
version = "1.5.3"
|
||||
description = "Tool to Detect Surrounding Shell"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
|
@ -1020,7 +984,6 @@ files = [
|
|||
name = "snowballstemmer"
|
||||
version = "2.2.0"
|
||||
description = "This package provides 29 stemmers for 28 languages generated from Snowball algorithms."
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
|
@ -1032,7 +995,6 @@ files = [
|
|||
name = "tqdm"
|
||||
version = "4.66.1"
|
||||
description = "Fast, Extensible Progress Meter"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
|
@ -1053,7 +1015,6 @@ telegram = ["requests"]
|
|||
name = "typer"
|
||||
version = "0.9.0"
|
||||
description = "Typer, build great CLIs. Easy to code. Based on Python type hints."
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
|
@ -1078,7 +1039,6 @@ test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.
|
|||
name = "typing-extensions"
|
||||
version = "4.7.1"
|
||||
description = "Backported and Experimental Type Hints for Python 3.7+"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
|
@ -1090,7 +1050,6 @@ files = [
|
|||
name = "urllib3"
|
||||
version = "2.0.4"
|
||||
description = "HTTP library with thread-safe connection pooling, file post, and more."
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
|
@ -1108,7 +1067,6 @@ zstd = ["zstandard (>=0.18.0)"]
|
|||
name = "wcwidth"
|
||||
version = "0.2.6"
|
||||
description = "Measures the displayed width of unicode strings in a terminal"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
|
@ -1120,7 +1078,6 @@ files = [
|
|||
name = "yarl"
|
||||
version = "1.9.2"
|
||||
description = "Yet another URL library"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "llm-chat"
|
||||
version = "0.5.0"
|
||||
version = "0.6.0"
|
||||
description = "A general CLI interface for large language models."
|
||||
authors = ["Paul Harrison <paul@harrison.sh>"]
|
||||
readme = "README.md"
|
||||
|
|
|
@ -27,10 +27,12 @@ def save_conversation(
|
|||
"""Store a conversation in the history directory."""
|
||||
if conversation.prompt_tokens == 0:
|
||||
return
|
||||
|
||||
filename = f"{dt.strftime('%Y%m%d%H%M%S')}{'-' + conversation.name if conversation.name else ''}"
|
||||
|
||||
history_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
path = history_dir / f"{dt.strftime('%Y%m%d%H%M%S')}.json"
|
||||
path = history_dir / f"{filename}.json"
|
||||
with path.open(mode="w") as f:
|
||||
f.write(conversation.model_dump_json(indent=2))
|
||||
|
||||
|
@ -86,6 +88,7 @@ class Chat:
|
|||
self,
|
||||
settings: OpenAISettings | None = None,
|
||||
context: list[Message] = [],
|
||||
name: str = "",
|
||||
initial_system_messages: bool = True,
|
||||
) -> None:
|
||||
self._settings = settings
|
||||
|
@ -95,6 +98,7 @@ class Chat:
|
|||
else context,
|
||||
model=self.settings.model,
|
||||
temperature=self.settings.temperature,
|
||||
name=name,
|
||||
)
|
||||
self._start_time = datetime.now(tz=ZoneInfo("UTC"))
|
||||
|
||||
|
@ -186,10 +190,10 @@ class Chat:
|
|||
|
||||
|
||||
def get_chat(
|
||||
settings: OpenAISettings | None = None, context: list[Message] = []
|
||||
settings: OpenAISettings | None = None, context: list[Message] = [], name: str = ""
|
||||
) -> ChatProtocol:
|
||||
"""Get a chat object."""
|
||||
return Chat(settings=settings, context=context)
|
||||
return Chat(settings=settings, context=context, name=name)
|
||||
|
||||
|
||||
def get_chat_class() -> Type[Chat]:
|
||||
|
|
|
@ -8,7 +8,7 @@ from rich.markdown import Markdown
|
|||
|
||||
from llm_chat.chat import ChatProtocol, get_chat, get_chat_class
|
||||
from llm_chat.models import Message, Role
|
||||
from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings
|
||||
from llm_chat.settings import DEFAULT_HISTORY_DIR, DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
@ -71,6 +71,8 @@ def run_conversation(current_chat: ChatProtocol) -> None:
|
|||
console.print(
|
||||
f"[bold green]Temperature:[/bold green] {current_chat.settings.temperature}"
|
||||
)
|
||||
if current_chat.conversation.name:
|
||||
console.print(f"[bold green]Name:[/bold green] {current_chat.conversation.name}")
|
||||
|
||||
while not finished:
|
||||
prompt = read_user_input(session)
|
||||
|
@ -123,17 +125,38 @@ def chat(
|
|||
readable=True,
|
||||
),
|
||||
] = [],
|
||||
history_dir: Annotated[
|
||||
Path,
|
||||
typer.Option(
|
||||
...,
|
||||
"--history-dir",
|
||||
"-d",
|
||||
help="Path to the directory where conversation history will be saved.",
|
||||
exists=True,
|
||||
dir_okay=True,
|
||||
file_okay=False,
|
||||
),
|
||||
] = DEFAULT_HISTORY_DIR,
|
||||
name: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
...,
|
||||
"--name",
|
||||
"-n",
|
||||
help="Name of the chat.",
|
||||
),
|
||||
] = "",
|
||||
) -> None:
|
||||
"""Start a chat session."""
|
||||
# TODO: Add option to provide context string as an argument.
|
||||
if api_key is not None:
|
||||
settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature)
|
||||
settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature, history_dir=history_dir)
|
||||
else:
|
||||
settings = OpenAISettings(model=model, temperature=temperature)
|
||||
settings = OpenAISettings(model=model, temperature=temperature, history_dir=history_dir)
|
||||
|
||||
context_messages = [load_context(path) for path in context]
|
||||
|
||||
current_chat = get_chat(settings=settings, context=context_messages)
|
||||
current_chat = get_chat(settings=settings, context=context_messages, name=name)
|
||||
|
||||
run_conversation(current_chat)
|
||||
|
||||
|
|
|
@ -31,6 +31,7 @@ class Conversation(BaseModel):
|
|||
messages: list[Message]
|
||||
model: Model
|
||||
temperature: float = DEFAULT_TEMPERATURE
|
||||
name: str = ""
|
||||
completion_tokens: int = 0
|
||||
prompt_tokens: int = 0
|
||||
cost: float = 0.0
|
||||
|
|
|
@ -13,6 +13,7 @@ class Model(StrEnum):
|
|||
|
||||
DEFAULT_MODEL = Model.GPT3
|
||||
DEFAULT_TEMPERATURE = 0.7
|
||||
DEFAULT_HISTORY_DIR = Path().absolute() / ".history"
|
||||
|
||||
|
||||
class OpenAISettings(BaseSettings):
|
||||
|
@ -21,7 +22,7 @@ class OpenAISettings(BaseSettings):
|
|||
api_key: str = ""
|
||||
model: Model = DEFAULT_MODEL
|
||||
temperature: float = DEFAULT_TEMPERATURE
|
||||
history_dir: Path = Path().absolute() / ".history"
|
||||
history_dir: Path = DEFAULT_HISTORY_DIR
|
||||
|
||||
model_config: SettingsConfigDict = SettingsConfigDict( # type: ignore[misc]
|
||||
env_file=".env",
|
||||
|
|
|
@ -10,7 +10,8 @@ from llm_chat.models import Conversation, Message, Role
|
|||
from llm_chat.settings import Model, OpenAISettings
|
||||
|
||||
|
||||
def test_save_conversation(tmp_path: Path) -> None:
|
||||
@pytest.mark.parametrize("name,expected_filename", [("", "20210101120000.json"), ("foo", "20210101120000-foo.json")])
|
||||
def test_save_conversation(name: str, expected_filename: str, tmp_path: Path) -> None:
|
||||
conversation = Conversation(
|
||||
messages=[
|
||||
Message(role=Role.SYSTEM, content="Hello!"),
|
||||
|
@ -22,10 +23,11 @@ def test_save_conversation(tmp_path: Path) -> None:
|
|||
completion_tokens=10,
|
||||
prompt_tokens=15,
|
||||
cost=0.000043,
|
||||
name=name,
|
||||
)
|
||||
|
||||
path = tmp_path / ".history"
|
||||
expected_file_path = path / "20210101120000.json"
|
||||
expected_file_path = path / expected_filename
|
||||
dt = datetime(2021, 1, 1, 12, 0, 0, tzinfo=ZoneInfo("UTC"))
|
||||
|
||||
assert not path.exists()
|
||||
|
|
|
@ -1,15 +1,19 @@
|
|||
from datetime import datetime
|
||||
from io import StringIO
|
||||
from itertools import product
|
||||
from pathlib import Path
|
||||
from typing import Any, Type
|
||||
from unittest.mock import MagicMock
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import pytest
|
||||
import time_machine
|
||||
from pytest import MonkeyPatch
|
||||
from rich.console import Console
|
||||
from typer.testing import CliRunner
|
||||
|
||||
import llm_chat
|
||||
from llm_chat.chat import ChatProtocol
|
||||
from llm_chat.chat import ChatProtocol, save_conversation
|
||||
from llm_chat.cli import app
|
||||
from llm_chat.models import Conversation, Message, Role
|
||||
from llm_chat.settings import Model, OpenAISettings
|
||||
|
@ -112,6 +116,32 @@ def test_chat_with_context(
|
|||
assert chat_fake.args["context"] == [Message(role=Role.SYSTEM, content=context)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("argument,name", list(product(("--name", "-n"), ("", "foo"))), ids=[f"{arg} {name}" for arg, name in product(("--name", "-n"), ("", "foo"))])
|
||||
def test_chat_with_name(
|
||||
argument: str, name: str, monkeypatch: MonkeyPatch, tmp_path: Path
|
||||
) -> None:
|
||||
chat_fake = ChatFake()
|
||||
output = StringIO()
|
||||
console = Console(file=output)
|
||||
|
||||
def mock_get_chat(**kwargs: Any) -> ChatProtocol:
|
||||
chat_fake._set_args(**kwargs)
|
||||
return chat_fake
|
||||
|
||||
def mock_get_console() -> Console:
|
||||
return console
|
||||
|
||||
mock_read_user_input = MagicMock(side_effect=["Hello", "/q"])
|
||||
|
||||
monkeypatch.setattr(llm_chat.cli, "get_chat", mock_get_chat)
|
||||
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
|
||||
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
|
||||
|
||||
result = runner.invoke(app, ["chat", argument, name])
|
||||
assert result.exit_code == 0
|
||||
assert chat_fake.args["name"] == name
|
||||
|
||||
|
||||
def test_load(monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
|
||||
# Create a conversation object to save
|
||||
conversation = Conversation(
|
||||
|
|
Loading…
Reference in New Issue