Implement basic command line interface

Implements a basic CLI using Typer, Rich, and Prompt Toolkit. I couldn't
work out how to propoerly mock/fake the Rich console or the prompt
session from Prompt Toolkit, so I created dummy functions to work around
them and just test the basic application flow rather than user inputs
and outputs to the terminal.
This commit is contained in:
Paul Harrison 2023-08-19 15:58:54 +01:00
parent 0015ae4bff
commit 7cbde55aac
6 changed files with 297 additions and 5 deletions

139
poetry.lock generated
View File

@ -313,7 +313,7 @@ files = [
name = "click"
version = "8.1.6"
description = "Composable command line interface toolkit"
category = "dev"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@ -449,6 +449,43 @@ pipfile-deprecated-finder = ["pip-shims (>=0.5.2)", "pipreqs", "requirementslib"
plugins = ["setuptools"]
requirements-deprecated-finder = ["pip-api", "pipreqs"]
[[package]]
name = "markdown-it-py"
version = "3.0.0"
description = "Python port of markdown-it. Markdown parsing, done right!"
category = "main"
optional = false
python-versions = ">=3.8"
files = [
{file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"},
{file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"},
]
[package.dependencies]
mdurl = ">=0.1,<1.0"
[package.extras]
benchmarking = ["psutil", "pytest", "pytest-benchmark"]
code-style = ["pre-commit (>=3.0,<4.0)"]
compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"]
linkify = ["linkify-it-py (>=1,<3)"]
plugins = ["mdit-py-plugins"]
profiling = ["gprof2dot"]
rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"]
testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
[[package]]
name = "mdurl"
version = "0.1.2"
description = "Markdown URL utilities"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
{file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"},
{file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
]
[[package]]
name = "multidict"
version = "6.0.4"
@ -665,6 +702,21 @@ files = [
dev = ["pre-commit", "tox"]
testing = ["pytest", "pytest-benchmark"]
[[package]]
name = "prompt-toolkit"
version = "3.0.39"
description = "Library for building powerful interactive command lines in Python"
category = "main"
optional = false
python-versions = ">=3.7.0"
files = [
{file = "prompt_toolkit-3.0.39-py3-none-any.whl", hash = "sha256:9dffbe1d8acf91e3de75f3b544e4842382fc06c6babe903ac9acb74dc6e08d88"},
{file = "prompt_toolkit-3.0.39.tar.gz", hash = "sha256:04505ade687dc26dc4284b1ad19a83be2f2afe83e7a828ace0c72f3a1df72aac"},
]
[package.dependencies]
wcwidth = "*"
[[package]]
name = "pydantic"
version = "2.1.1"
@ -833,6 +885,21 @@ snowballstemmer = ">=2.2.0"
[package.extras]
toml = ["tomli (>=1.2.3)"]
[[package]]
name = "pygments"
version = "2.16.1"
description = "Pygments is a syntax highlighting package written in Python."
category = "main"
optional = false
python-versions = ">=3.7"
files = [
{file = "Pygments-2.16.1-py3-none-any.whl", hash = "sha256:13fc09fa63bc8d8671a6d247e1eb303c4b343eaee81d861f3404db2935653692"},
{file = "Pygments-2.16.1.tar.gz", hash = "sha256:1daff0494820c69bc8941e407aa20f577374ee88364ee10a98fdbe0aece96e29"},
]
[package.extras]
plugins = ["importlib-metadata"]
[[package]]
name = "pytest"
version = "7.4.0"
@ -891,6 +958,25 @@ urllib3 = ">=1.21.1,<3"
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "rich"
version = "13.5.2"
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
category = "main"
optional = false
python-versions = ">=3.7.0"
files = [
{file = "rich-13.5.2-py3-none-any.whl", hash = "sha256:146a90b3b6b47cac4a73c12866a499e9817426423f57c5a66949c086191a8808"},
{file = "rich-13.5.2.tar.gz", hash = "sha256:fb9d6c0a0f643c99eed3875b5377a184132ba9be4d61516a55273d3554d75a39"},
]
[package.dependencies]
markdown-it-py = ">=2.2.0"
pygments = ">=2.13.0,<3.0.0"
[package.extras]
jupyter = ["ipywidgets (>=7.5.1,<9)"]
[[package]]
name = "ruff"
version = "0.0.284"
@ -918,6 +1004,18 @@ files = [
{file = "ruff-0.0.284.tar.gz", hash = "sha256:ebd3cc55cd499d326aac17a331deaea29bea206e01c08862f9b5c6e93d77a491"},
]
[[package]]
name = "shellingham"
version = "1.5.3"
description = "Tool to Detect Surrounding Shell"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
{file = "shellingham-1.5.3-py2.py3-none-any.whl", hash = "sha256:419c6a164770c9c7cfcaeddfacb3d31ac7a8db0b0f3e9c1287679359734107e9"},
{file = "shellingham-1.5.3.tar.gz", hash = "sha256:cb4a6fec583535bc6da17b647dd2330cf7ef30239e05d547d99ae3705fd0f7f8"},
]
[[package]]
name = "snowballstemmer"
version = "2.2.0"
@ -951,6 +1049,31 @@ notebook = ["ipywidgets (>=6)"]
slack = ["slack-sdk"]
telegram = ["requests"]
[[package]]
name = "typer"
version = "0.9.0"
description = "Typer, build great CLIs. Easy to code. Based on Python type hints."
category = "main"
optional = false
python-versions = ">=3.6"
files = [
{file = "typer-0.9.0-py3-none-any.whl", hash = "sha256:5d96d986a21493606a358cae4461bd8cdf83cbf33a5aa950ae629ca3b51467ee"},
{file = "typer-0.9.0.tar.gz", hash = "sha256:50922fd79aea2f4751a8e0408ff10d2662bd0c8bbfa84755a699f3bada2978b2"},
]
[package.dependencies]
click = ">=7.1.1,<9.0.0"
colorama = {version = ">=0.4.3,<0.5.0", optional = true, markers = "extra == \"all\""}
rich = {version = ">=10.11.0,<14.0.0", optional = true, markers = "extra == \"all\""}
shellingham = {version = ">=1.3.0,<2.0.0", optional = true, markers = "extra == \"all\""}
typing-extensions = ">=3.7.4.3"
[package.extras]
all = ["colorama (>=0.4.3,<0.5.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"]
dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "pre-commit (>=2.17.0,<3.0.0)"]
doc = ["cairosvg (>=2.5.2,<3.0.0)", "mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pillow (>=9.3.0,<10.0.0)"]
test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.910)", "pytest (>=4.4.0,<8.0.0)", "pytest-cov (>=2.10.0,<5.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<4.0.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"]
[[package]]
name = "typing-extensions"
version = "4.7.1"
@ -981,6 +1104,18 @@ secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
zstd = ["zstandard (>=0.18.0)"]
[[package]]
name = "wcwidth"
version = "0.2.6"
description = "Measures the displayed width of unicode strings in a terminal"
category = "main"
optional = false
python-versions = "*"
files = [
{file = "wcwidth-0.2.6-py2.py3-none-any.whl", hash = "sha256:795b138f6875577cd91bba52baf9e445cd5118fd32723b460e30a0af30ea230e"},
{file = "wcwidth-0.2.6.tar.gz", hash = "sha256:a5220780a404dbe3353789870978e472cfe477761f06ee55077256e509b156d0"},
]
[[package]]
name = "yarl"
version = "1.9.2"
@ -1072,4 +1207,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = ">=3.11,<3.12"
content-hash = "544019c132db9541be699bb3aac24f07cd85ee90773412b581f509e273f9df51"
content-hash = "5b81fe2c2ca1756c2c446ab727fd30fd976ed7d8c82ffb8f7d3870c38755bce6"

View File

@ -11,6 +11,8 @@ python = ">=3.11,<3.12"
openai = "^0.27.8"
pydantic = "^2.1.1"
pydantic-settings = "^2.0.2"
typer = {extras = ["all"], version = "^0.9.0"}
prompt-toolkit = "^3.0.39"
[tool.poetry.group.test.dependencies]
pytest = "^7.4.0"
@ -20,6 +22,9 @@ ruff = "^0.0.284"
mypy = "^1.5.0"
pydocstyle = "^6.3.0"
[tool.poetry.scripts]
chat = "llm_chat.cli:app"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

View File

@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Protocol
from openai import ChatCompletion
from openai.openai_object import OpenAIObject
@ -17,6 +17,13 @@ INITIAL_SYSTEM_MESSAGES = [
]
class ChatProtocol(Protocol):
"""Protocol for chat classes."""
def send_message(self, message: str) -> str:
"""Send a message to the assistant."""
class Chat:
"""Interface class for OpenAI's ChatGPT chat API.
@ -68,3 +75,8 @@ class Chat:
message = response["choices"][0]["message"]["content"]
self.conversation.messages.append(Message(role=Role.ASSISTANT, content=message))
return message
def get_chat(settings: OpenAISettings | None = None) -> ChatProtocol:
"""Get a chat object."""
return Chat(settings=settings)

90
src/llm_chat/cli.py Normal file
View File

@ -0,0 +1,90 @@
from typing import Annotated, Any, Optional
import typer
from prompt_toolkit import PromptSession
from rich.console import Console
from rich.markdown import Markdown
from llm_chat.chat import get_chat
from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings
app = typer.Typer()
def prompt_continuation(width: int, *args: Any) -> str:
"""Prompt continuation for multiline input."""
return "." * (width - 1) + " "
def get_console() -> Console:
"""Get a console."""
return Console()
def get_session() -> PromptSession[Any]:
"""Get a prompt session."""
return PromptSession()
def read_user_input(session: PromptSession[Any]) -> str:
"""Read user input.
The main purpose of this function is to enable the injection of user input during
tests, since trying to inject a fake PromptSession into the chat function led to
the test hanging.
"""
prompt: str = session.prompt(
">>> ", multiline=True, prompt_continuation=prompt_continuation
)
return prompt
@app.command()
def chat(
api_key: Annotated[
Optional[str],
typer.Option(
...,
"--api-key",
"-k",
help=(
"API key. Will read from the environment variable OPENAI_API_KEY "
"if not provided."
),
),
] = None,
model: Annotated[
Model,
typer.Option(..., "--model", "-m", help="Model to use.", show_choices=True),
] = DEFAULT_MODEL,
temperature: Annotated[
float,
typer.Option(
..., "--temperature", "-t", help="Model temperature (i.e. creativeness)."
),
] = DEFAULT_TEMPERATURE,
) -> None:
"""Start a chat session."""
# TODO: Add option to load context from file.
# TODO: Add option to provide context string as an argument.
if api_key is not None:
settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature)
else:
settings = OpenAISettings(model=model, temperature=temperature)
current_chat = get_chat(settings)
console = get_console()
session = get_session()
finished = False
console.print(f"[bold green]Model:[/bold green] {settings.model}")
console.print(f"[bold green]Temperature:[/bold green] {settings.temperature}")
while not finished:
prompt = read_user_input(session)
if prompt.strip() == "/q":
finished = True
else:
response = current_chat.send_message(prompt.strip())
console.print(Markdown(response))

View File

@ -10,12 +10,16 @@ class Model(StrEnum):
GPT4 = "gpt-4"
DEFAULT_MODEL = Model.GPT3
DEFAULT_TEMPERATURE = 0.7
class OpenAISettings(BaseSettings):
"""Settings for the LLM Chat application."""
api_key: str = ""
model: Model = Model.GPT3
temperature: float = 0.7
model: Model = DEFAULT_MODEL
temperature: float = DEFAULT_TEMPERATURE
model_config: SettingsConfigDict = SettingsConfigDict( # type: ignore[misc]
env_file=".env",

46
tests/test_cli.py Normal file
View File

@ -0,0 +1,46 @@
from io import StringIO
from unittest.mock import MagicMock
from pytest import MonkeyPatch
from rich.console import Console
from typer.testing import CliRunner
import llm_chat
from llm_chat.chat import ChatProtocol
from llm_chat.cli import app
from llm_chat.settings import OpenAISettings
runner = CliRunner()
class ChatFake:
"""Fake chat class for testing."""
received_messages: list[str] = []
def send_message(self, message: str) -> str:
"""Echo the received message."""
self.received_messages.append(message)
return message
def test_chat(monkeypatch: MonkeyPatch) -> None:
chat_fake = ChatFake()
output = StringIO()
console = Console(file=output)
def mock_get_chat(_: OpenAISettings) -> ChatProtocol:
return chat_fake
def mock_get_console() -> Console:
return console
mock_read_user_input = MagicMock(side_effect=["Hello", "/q"])
monkeypatch.setattr(llm_chat.cli, "get_chat", mock_get_chat)
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
result = runner.invoke(app)
assert result.exit_code == 0
assert chat_fake.received_messages == ["Hello"]