Implement command line interface

This commit is contained in:
Paul Harrison 2023-08-19 15:58:54 +01:00
parent 0015ae4bff
commit f8aa9d4676
4 changed files with 221 additions and 4 deletions

139
poetry.lock generated
View File

@ -313,7 +313,7 @@ files = [
name = "click" name = "click"
version = "8.1.6" version = "8.1.6"
description = "Composable command line interface toolkit" description = "Composable command line interface toolkit"
category = "dev" category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
@ -449,6 +449,43 @@ pipfile-deprecated-finder = ["pip-shims (>=0.5.2)", "pipreqs", "requirementslib"
plugins = ["setuptools"] plugins = ["setuptools"]
requirements-deprecated-finder = ["pip-api", "pipreqs"] requirements-deprecated-finder = ["pip-api", "pipreqs"]
[[package]]
name = "markdown-it-py"
version = "3.0.0"
description = "Python port of markdown-it. Markdown parsing, done right!"
category = "main"
optional = false
python-versions = ">=3.8"
files = [
{file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"},
{file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"},
]
[package.dependencies]
mdurl = ">=0.1,<1.0"
[package.extras]
benchmarking = ["psutil", "pytest", "pytest-benchmark"]
code-style = ["pre-commit (>=3.0,<4.0)"]
compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"]
linkify = ["linkify-it-py (>=1,<3)"]
plugins = ["mdit-py-plugins"]
profiling = ["gprof2dot"]
rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"]
testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
[[package]]
name = "mdurl"
version = "0.1.2"
description = "Markdown URL utilities"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
{file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"},
{file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
]
[[package]] [[package]]
name = "multidict" name = "multidict"
version = "6.0.4" version = "6.0.4"
@ -665,6 +702,21 @@ files = [
dev = ["pre-commit", "tox"] dev = ["pre-commit", "tox"]
testing = ["pytest", "pytest-benchmark"] testing = ["pytest", "pytest-benchmark"]
[[package]]
name = "prompt-toolkit"
version = "3.0.39"
description = "Library for building powerful interactive command lines in Python"
category = "main"
optional = false
python-versions = ">=3.7.0"
files = [
{file = "prompt_toolkit-3.0.39-py3-none-any.whl", hash = "sha256:9dffbe1d8acf91e3de75f3b544e4842382fc06c6babe903ac9acb74dc6e08d88"},
{file = "prompt_toolkit-3.0.39.tar.gz", hash = "sha256:04505ade687dc26dc4284b1ad19a83be2f2afe83e7a828ace0c72f3a1df72aac"},
]
[package.dependencies]
wcwidth = "*"
[[package]] [[package]]
name = "pydantic" name = "pydantic"
version = "2.1.1" version = "2.1.1"
@ -833,6 +885,21 @@ snowballstemmer = ">=2.2.0"
[package.extras] [package.extras]
toml = ["tomli (>=1.2.3)"] toml = ["tomli (>=1.2.3)"]
[[package]]
name = "pygments"
version = "2.16.1"
description = "Pygments is a syntax highlighting package written in Python."
category = "main"
optional = false
python-versions = ">=3.7"
files = [
{file = "Pygments-2.16.1-py3-none-any.whl", hash = "sha256:13fc09fa63bc8d8671a6d247e1eb303c4b343eaee81d861f3404db2935653692"},
{file = "Pygments-2.16.1.tar.gz", hash = "sha256:1daff0494820c69bc8941e407aa20f577374ee88364ee10a98fdbe0aece96e29"},
]
[package.extras]
plugins = ["importlib-metadata"]
[[package]] [[package]]
name = "pytest" name = "pytest"
version = "7.4.0" version = "7.4.0"
@ -891,6 +958,25 @@ urllib3 = ">=1.21.1,<3"
socks = ["PySocks (>=1.5.6,!=1.5.7)"] socks = ["PySocks (>=1.5.6,!=1.5.7)"]
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "rich"
version = "13.5.2"
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
category = "main"
optional = false
python-versions = ">=3.7.0"
files = [
{file = "rich-13.5.2-py3-none-any.whl", hash = "sha256:146a90b3b6b47cac4a73c12866a499e9817426423f57c5a66949c086191a8808"},
{file = "rich-13.5.2.tar.gz", hash = "sha256:fb9d6c0a0f643c99eed3875b5377a184132ba9be4d61516a55273d3554d75a39"},
]
[package.dependencies]
markdown-it-py = ">=2.2.0"
pygments = ">=2.13.0,<3.0.0"
[package.extras]
jupyter = ["ipywidgets (>=7.5.1,<9)"]
[[package]] [[package]]
name = "ruff" name = "ruff"
version = "0.0.284" version = "0.0.284"
@ -918,6 +1004,18 @@ files = [
{file = "ruff-0.0.284.tar.gz", hash = "sha256:ebd3cc55cd499d326aac17a331deaea29bea206e01c08862f9b5c6e93d77a491"}, {file = "ruff-0.0.284.tar.gz", hash = "sha256:ebd3cc55cd499d326aac17a331deaea29bea206e01c08862f9b5c6e93d77a491"},
] ]
[[package]]
name = "shellingham"
version = "1.5.3"
description = "Tool to Detect Surrounding Shell"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
{file = "shellingham-1.5.3-py2.py3-none-any.whl", hash = "sha256:419c6a164770c9c7cfcaeddfacb3d31ac7a8db0b0f3e9c1287679359734107e9"},
{file = "shellingham-1.5.3.tar.gz", hash = "sha256:cb4a6fec583535bc6da17b647dd2330cf7ef30239e05d547d99ae3705fd0f7f8"},
]
[[package]] [[package]]
name = "snowballstemmer" name = "snowballstemmer"
version = "2.2.0" version = "2.2.0"
@ -951,6 +1049,31 @@ notebook = ["ipywidgets (>=6)"]
slack = ["slack-sdk"] slack = ["slack-sdk"]
telegram = ["requests"] telegram = ["requests"]
[[package]]
name = "typer"
version = "0.9.0"
description = "Typer, build great CLIs. Easy to code. Based on Python type hints."
category = "main"
optional = false
python-versions = ">=3.6"
files = [
{file = "typer-0.9.0-py3-none-any.whl", hash = "sha256:5d96d986a21493606a358cae4461bd8cdf83cbf33a5aa950ae629ca3b51467ee"},
{file = "typer-0.9.0.tar.gz", hash = "sha256:50922fd79aea2f4751a8e0408ff10d2662bd0c8bbfa84755a699f3bada2978b2"},
]
[package.dependencies]
click = ">=7.1.1,<9.0.0"
colorama = {version = ">=0.4.3,<0.5.0", optional = true, markers = "extra == \"all\""}
rich = {version = ">=10.11.0,<14.0.0", optional = true, markers = "extra == \"all\""}
shellingham = {version = ">=1.3.0,<2.0.0", optional = true, markers = "extra == \"all\""}
typing-extensions = ">=3.7.4.3"
[package.extras]
all = ["colorama (>=0.4.3,<0.5.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"]
dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "pre-commit (>=2.17.0,<3.0.0)"]
doc = ["cairosvg (>=2.5.2,<3.0.0)", "mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pillow (>=9.3.0,<10.0.0)"]
test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.910)", "pytest (>=4.4.0,<8.0.0)", "pytest-cov (>=2.10.0,<5.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<4.0.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"]
[[package]] [[package]]
name = "typing-extensions" name = "typing-extensions"
version = "4.7.1" version = "4.7.1"
@ -981,6 +1104,18 @@ secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
zstd = ["zstandard (>=0.18.0)"] zstd = ["zstandard (>=0.18.0)"]
[[package]]
name = "wcwidth"
version = "0.2.6"
description = "Measures the displayed width of unicode strings in a terminal"
category = "main"
optional = false
python-versions = "*"
files = [
{file = "wcwidth-0.2.6-py2.py3-none-any.whl", hash = "sha256:795b138f6875577cd91bba52baf9e445cd5118fd32723b460e30a0af30ea230e"},
{file = "wcwidth-0.2.6.tar.gz", hash = "sha256:a5220780a404dbe3353789870978e472cfe477761f06ee55077256e509b156d0"},
]
[[package]] [[package]]
name = "yarl" name = "yarl"
version = "1.9.2" version = "1.9.2"
@ -1072,4 +1207,4 @@ multidict = ">=4.0"
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.11,<3.12" python-versions = ">=3.11,<3.12"
content-hash = "544019c132db9541be699bb3aac24f07cd85ee90773412b581f509e273f9df51" content-hash = "5b81fe2c2ca1756c2c446ab727fd30fd976ed7d8c82ffb8f7d3870c38755bce6"

View File

@ -11,6 +11,8 @@ python = ">=3.11,<3.12"
openai = "^0.27.8" openai = "^0.27.8"
pydantic = "^2.1.1" pydantic = "^2.1.1"
pydantic-settings = "^2.0.2" pydantic-settings = "^2.0.2"
typer = {extras = ["all"], version = "^0.9.0"}
prompt-toolkit = "^3.0.39"
[tool.poetry.group.test.dependencies] [tool.poetry.group.test.dependencies]
pytest = "^7.4.0" pytest = "^7.4.0"
@ -20,6 +22,9 @@ ruff = "^0.0.284"
mypy = "^1.5.0" mypy = "^1.5.0"
pydocstyle = "^6.3.0" pydocstyle = "^6.3.0"
[tool.poetry.scripts]
chat = "llm_chat.cli:app"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"

73
src/llm_chat/cli.py Normal file
View File

@ -0,0 +1,73 @@
from typing import Annotated, Any, Optional
import typer
from prompt_toolkit import PromptSession
from rich.console import Console
from rich.markdown import Markdown
from llm_chat.chat import Chat
from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings
app = typer.Typer()
def prompt_continuation(width: int, *args: Any) -> str:
"""Prompt continuation for multiline input."""
return "." * (width - 1) + " "
@app.command()
def main(
api_key: Annotated[
Optional[str],
typer.Option(
...,
"--api-key",
"-k",
help=(
"API key. Will read from the environment variable OPENAI_API_KEY "
"if not provided."
),
),
] = None,
model: Annotated[
Model,
typer.Option(..., "--model", "-m", help="Model to use.", show_choices=True),
] = DEFAULT_MODEL,
temperature: Annotated[
float,
typer.Option(
..., "--temperature", "-t", help="Model temperature (i.e. creativeness)."
),
] = DEFAULT_TEMPERATURE,
) -> None:
"""Start a chat session."""
# TODO: Add option to load context from file.
# TODO: Add option to provide context string as an argument.
# TODO: Function to create settings and chat object to allow fakes in tests.
if api_key is not None:
settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature)
else:
settings = OpenAISettings(model=model, temperature=temperature)
chat = Chat(settings=settings)
console = Console()
# TODO: Can we properly type hint this class?
session = PromptSession() # type: ignore[var-annotated]
finished = False
console.print(f"[bold green]Model:[/bold green] {settings.model}")
console.print(f"[bold green]Temperature:[/bold green] {settings.temperature}")
while not finished:
# TODO: Can we style the prompt initial message and continuation?
prompt = session.prompt(
">>> ", multiline=True, prompt_continuation=prompt_continuation
)
if prompt.strip() == "/q":
finished = True
else:
response = chat.send_message(prompt.strip())
console.print(Markdown(response))

View File

@ -10,12 +10,16 @@ class Model(StrEnum):
GPT4 = "gpt-4" GPT4 = "gpt-4"
DEFAULT_MODEL = Model.GPT3
DEFAULT_TEMPERATURE = 0.7
class OpenAISettings(BaseSettings): class OpenAISettings(BaseSettings):
"""Settings for the LLM Chat application.""" """Settings for the LLM Chat application."""
api_key: str = "" api_key: str = ""
model: Model = Model.GPT3 model: Model = DEFAULT_MODEL
temperature: float = 0.7 temperature: float = DEFAULT_TEMPERATURE
model_config: SettingsConfigDict = SettingsConfigDict( # type: ignore[misc] model_config: SettingsConfigDict = SettingsConfigDict( # type: ignore[misc]
env_file=".env", env_file=".env",