Implement command line interface
This commit is contained in:
parent
0015ae4bff
commit
f8aa9d4676
|
@ -313,7 +313,7 @@ files = [
|
||||||
name = "click"
|
name = "click"
|
||||||
version = "8.1.6"
|
version = "8.1.6"
|
||||||
description = "Composable command line interface toolkit"
|
description = "Composable command line interface toolkit"
|
||||||
category = "dev"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
files = [
|
files = [
|
||||||
|
@ -449,6 +449,43 @@ pipfile-deprecated-finder = ["pip-shims (>=0.5.2)", "pipreqs", "requirementslib"
|
||||||
plugins = ["setuptools"]
|
plugins = ["setuptools"]
|
||||||
requirements-deprecated-finder = ["pip-api", "pipreqs"]
|
requirements-deprecated-finder = ["pip-api", "pipreqs"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "markdown-it-py"
|
||||||
|
version = "3.0.0"
|
||||||
|
description = "Python port of markdown-it. Markdown parsing, done right!"
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"},
|
||||||
|
{file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
mdurl = ">=0.1,<1.0"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
benchmarking = ["psutil", "pytest", "pytest-benchmark"]
|
||||||
|
code-style = ["pre-commit (>=3.0,<4.0)"]
|
||||||
|
compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"]
|
||||||
|
linkify = ["linkify-it-py (>=1,<3)"]
|
||||||
|
plugins = ["mdit-py-plugins"]
|
||||||
|
profiling = ["gprof2dot"]
|
||||||
|
rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"]
|
||||||
|
testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mdurl"
|
||||||
|
version = "0.1.2"
|
||||||
|
description = "Markdown URL utilities"
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"},
|
||||||
|
{file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "multidict"
|
name = "multidict"
|
||||||
version = "6.0.4"
|
version = "6.0.4"
|
||||||
|
@ -665,6 +702,21 @@ files = [
|
||||||
dev = ["pre-commit", "tox"]
|
dev = ["pre-commit", "tox"]
|
||||||
testing = ["pytest", "pytest-benchmark"]
|
testing = ["pytest", "pytest-benchmark"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "prompt-toolkit"
|
||||||
|
version = "3.0.39"
|
||||||
|
description = "Library for building powerful interactive command lines in Python"
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7.0"
|
||||||
|
files = [
|
||||||
|
{file = "prompt_toolkit-3.0.39-py3-none-any.whl", hash = "sha256:9dffbe1d8acf91e3de75f3b544e4842382fc06c6babe903ac9acb74dc6e08d88"},
|
||||||
|
{file = "prompt_toolkit-3.0.39.tar.gz", hash = "sha256:04505ade687dc26dc4284b1ad19a83be2f2afe83e7a828ace0c72f3a1df72aac"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
wcwidth = "*"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pydantic"
|
name = "pydantic"
|
||||||
version = "2.1.1"
|
version = "2.1.1"
|
||||||
|
@ -833,6 +885,21 @@ snowballstemmer = ">=2.2.0"
|
||||||
[package.extras]
|
[package.extras]
|
||||||
toml = ["tomli (>=1.2.3)"]
|
toml = ["tomli (>=1.2.3)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pygments"
|
||||||
|
version = "2.16.1"
|
||||||
|
description = "Pygments is a syntax highlighting package written in Python."
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "Pygments-2.16.1-py3-none-any.whl", hash = "sha256:13fc09fa63bc8d8671a6d247e1eb303c4b343eaee81d861f3404db2935653692"},
|
||||||
|
{file = "Pygments-2.16.1.tar.gz", hash = "sha256:1daff0494820c69bc8941e407aa20f577374ee88364ee10a98fdbe0aece96e29"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
plugins = ["importlib-metadata"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pytest"
|
name = "pytest"
|
||||||
version = "7.4.0"
|
version = "7.4.0"
|
||||||
|
@ -891,6 +958,25 @@ urllib3 = ">=1.21.1,<3"
|
||||||
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
|
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
|
||||||
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rich"
|
||||||
|
version = "13.5.2"
|
||||||
|
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7.0"
|
||||||
|
files = [
|
||||||
|
{file = "rich-13.5.2-py3-none-any.whl", hash = "sha256:146a90b3b6b47cac4a73c12866a499e9817426423f57c5a66949c086191a8808"},
|
||||||
|
{file = "rich-13.5.2.tar.gz", hash = "sha256:fb9d6c0a0f643c99eed3875b5377a184132ba9be4d61516a55273d3554d75a39"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
markdown-it-py = ">=2.2.0"
|
||||||
|
pygments = ">=2.13.0,<3.0.0"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
jupyter = ["ipywidgets (>=7.5.1,<9)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ruff"
|
name = "ruff"
|
||||||
version = "0.0.284"
|
version = "0.0.284"
|
||||||
|
@ -918,6 +1004,18 @@ files = [
|
||||||
{file = "ruff-0.0.284.tar.gz", hash = "sha256:ebd3cc55cd499d326aac17a331deaea29bea206e01c08862f9b5c6e93d77a491"},
|
{file = "ruff-0.0.284.tar.gz", hash = "sha256:ebd3cc55cd499d326aac17a331deaea29bea206e01c08862f9b5c6e93d77a491"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "shellingham"
|
||||||
|
version = "1.5.3"
|
||||||
|
description = "Tool to Detect Surrounding Shell"
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "shellingham-1.5.3-py2.py3-none-any.whl", hash = "sha256:419c6a164770c9c7cfcaeddfacb3d31ac7a8db0b0f3e9c1287679359734107e9"},
|
||||||
|
{file = "shellingham-1.5.3.tar.gz", hash = "sha256:cb4a6fec583535bc6da17b647dd2330cf7ef30239e05d547d99ae3705fd0f7f8"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "snowballstemmer"
|
name = "snowballstemmer"
|
||||||
version = "2.2.0"
|
version = "2.2.0"
|
||||||
|
@ -951,6 +1049,31 @@ notebook = ["ipywidgets (>=6)"]
|
||||||
slack = ["slack-sdk"]
|
slack = ["slack-sdk"]
|
||||||
telegram = ["requests"]
|
telegram = ["requests"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "typer"
|
||||||
|
version = "0.9.0"
|
||||||
|
description = "Typer, build great CLIs. Easy to code. Based on Python type hints."
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.6"
|
||||||
|
files = [
|
||||||
|
{file = "typer-0.9.0-py3-none-any.whl", hash = "sha256:5d96d986a21493606a358cae4461bd8cdf83cbf33a5aa950ae629ca3b51467ee"},
|
||||||
|
{file = "typer-0.9.0.tar.gz", hash = "sha256:50922fd79aea2f4751a8e0408ff10d2662bd0c8bbfa84755a699f3bada2978b2"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
click = ">=7.1.1,<9.0.0"
|
||||||
|
colorama = {version = ">=0.4.3,<0.5.0", optional = true, markers = "extra == \"all\""}
|
||||||
|
rich = {version = ">=10.11.0,<14.0.0", optional = true, markers = "extra == \"all\""}
|
||||||
|
shellingham = {version = ">=1.3.0,<2.0.0", optional = true, markers = "extra == \"all\""}
|
||||||
|
typing-extensions = ">=3.7.4.3"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
all = ["colorama (>=0.4.3,<0.5.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"]
|
||||||
|
dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "pre-commit (>=2.17.0,<3.0.0)"]
|
||||||
|
doc = ["cairosvg (>=2.5.2,<3.0.0)", "mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pillow (>=9.3.0,<10.0.0)"]
|
||||||
|
test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.910)", "pytest (>=4.4.0,<8.0.0)", "pytest-cov (>=2.10.0,<5.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<4.0.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "typing-extensions"
|
name = "typing-extensions"
|
||||||
version = "4.7.1"
|
version = "4.7.1"
|
||||||
|
@ -981,6 +1104,18 @@ secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.
|
||||||
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
|
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
|
||||||
zstd = ["zstandard (>=0.18.0)"]
|
zstd = ["zstandard (>=0.18.0)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wcwidth"
|
||||||
|
version = "0.2.6"
|
||||||
|
description = "Measures the displayed width of unicode strings in a terminal"
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = "*"
|
||||||
|
files = [
|
||||||
|
{file = "wcwidth-0.2.6-py2.py3-none-any.whl", hash = "sha256:795b138f6875577cd91bba52baf9e445cd5118fd32723b460e30a0af30ea230e"},
|
||||||
|
{file = "wcwidth-0.2.6.tar.gz", hash = "sha256:a5220780a404dbe3353789870978e472cfe477761f06ee55077256e509b156d0"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "yarl"
|
name = "yarl"
|
||||||
version = "1.9.2"
|
version = "1.9.2"
|
||||||
|
@ -1072,4 +1207,4 @@ multidict = ">=4.0"
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.11,<3.12"
|
python-versions = ">=3.11,<3.12"
|
||||||
content-hash = "544019c132db9541be699bb3aac24f07cd85ee90773412b581f509e273f9df51"
|
content-hash = "5b81fe2c2ca1756c2c446ab727fd30fd976ed7d8c82ffb8f7d3870c38755bce6"
|
||||||
|
|
|
@ -11,6 +11,8 @@ python = ">=3.11,<3.12"
|
||||||
openai = "^0.27.8"
|
openai = "^0.27.8"
|
||||||
pydantic = "^2.1.1"
|
pydantic = "^2.1.1"
|
||||||
pydantic-settings = "^2.0.2"
|
pydantic-settings = "^2.0.2"
|
||||||
|
typer = {extras = ["all"], version = "^0.9.0"}
|
||||||
|
prompt-toolkit = "^3.0.39"
|
||||||
|
|
||||||
[tool.poetry.group.test.dependencies]
|
[tool.poetry.group.test.dependencies]
|
||||||
pytest = "^7.4.0"
|
pytest = "^7.4.0"
|
||||||
|
@ -20,6 +22,9 @@ ruff = "^0.0.284"
|
||||||
mypy = "^1.5.0"
|
mypy = "^1.5.0"
|
||||||
pydocstyle = "^6.3.0"
|
pydocstyle = "^6.3.0"
|
||||||
|
|
||||||
|
[tool.poetry.scripts]
|
||||||
|
chat = "llm_chat.cli:app"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
|
@ -0,0 +1,73 @@
|
||||||
|
from typing import Annotated, Any, Optional
|
||||||
|
|
||||||
|
import typer
|
||||||
|
from prompt_toolkit import PromptSession
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.markdown import Markdown
|
||||||
|
|
||||||
|
from llm_chat.chat import Chat
|
||||||
|
from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_continuation(width: int, *args: Any) -> str:
|
||||||
|
"""Prompt continuation for multiline input."""
|
||||||
|
return "." * (width - 1) + " "
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def main(
|
||||||
|
api_key: Annotated[
|
||||||
|
Optional[str],
|
||||||
|
typer.Option(
|
||||||
|
...,
|
||||||
|
"--api-key",
|
||||||
|
"-k",
|
||||||
|
help=(
|
||||||
|
"API key. Will read from the environment variable OPENAI_API_KEY "
|
||||||
|
"if not provided."
|
||||||
|
),
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
model: Annotated[
|
||||||
|
Model,
|
||||||
|
typer.Option(..., "--model", "-m", help="Model to use.", show_choices=True),
|
||||||
|
] = DEFAULT_MODEL,
|
||||||
|
temperature: Annotated[
|
||||||
|
float,
|
||||||
|
typer.Option(
|
||||||
|
..., "--temperature", "-t", help="Model temperature (i.e. creativeness)."
|
||||||
|
),
|
||||||
|
] = DEFAULT_TEMPERATURE,
|
||||||
|
) -> None:
|
||||||
|
"""Start a chat session."""
|
||||||
|
# TODO: Add option to load context from file.
|
||||||
|
# TODO: Add option to provide context string as an argument.
|
||||||
|
# TODO: Function to create settings and chat object to allow fakes in tests.
|
||||||
|
if api_key is not None:
|
||||||
|
settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature)
|
||||||
|
else:
|
||||||
|
settings = OpenAISettings(model=model, temperature=temperature)
|
||||||
|
chat = Chat(settings=settings)
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
# TODO: Can we properly type hint this class?
|
||||||
|
session = PromptSession() # type: ignore[var-annotated]
|
||||||
|
|
||||||
|
finished = False
|
||||||
|
|
||||||
|
console.print(f"[bold green]Model:[/bold green] {settings.model}")
|
||||||
|
console.print(f"[bold green]Temperature:[/bold green] {settings.temperature}")
|
||||||
|
|
||||||
|
while not finished:
|
||||||
|
# TODO: Can we style the prompt initial message and continuation?
|
||||||
|
prompt = session.prompt(
|
||||||
|
">>> ", multiline=True, prompt_continuation=prompt_continuation
|
||||||
|
)
|
||||||
|
if prompt.strip() == "/q":
|
||||||
|
finished = True
|
||||||
|
else:
|
||||||
|
response = chat.send_message(prompt.strip())
|
||||||
|
console.print(Markdown(response))
|
|
@ -10,12 +10,16 @@ class Model(StrEnum):
|
||||||
GPT4 = "gpt-4"
|
GPT4 = "gpt-4"
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_MODEL = Model.GPT3
|
||||||
|
DEFAULT_TEMPERATURE = 0.7
|
||||||
|
|
||||||
|
|
||||||
class OpenAISettings(BaseSettings):
|
class OpenAISettings(BaseSettings):
|
||||||
"""Settings for the LLM Chat application."""
|
"""Settings for the LLM Chat application."""
|
||||||
|
|
||||||
api_key: str = ""
|
api_key: str = ""
|
||||||
model: Model = Model.GPT3
|
model: Model = DEFAULT_MODEL
|
||||||
temperature: float = 0.7
|
temperature: float = DEFAULT_TEMPERATURE
|
||||||
|
|
||||||
model_config: SettingsConfigDict = SettingsConfigDict( # type: ignore[misc]
|
model_config: SettingsConfigDict = SettingsConfigDict( # type: ignore[misc]
|
||||||
env_file=".env",
|
env_file=".env",
|
||||||
|
|
Loading…
Reference in New Issue