diff --git a/poetry.lock b/poetry.lock index ac6d7b4..ad0d2cf 100644 --- a/poetry.lock +++ b/poetry.lock @@ -313,7 +313,7 @@ files = [ name = "click" version = "8.1.6" description = "Composable command line interface toolkit" -category = "dev" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -449,6 +449,43 @@ pipfile-deprecated-finder = ["pip-shims (>=0.5.2)", "pipreqs", "requirementslib" plugins = ["setuptools"] requirements-deprecated-finder = ["pip-api", "pipreqs"] +[[package]] +name = "markdown-it-py" +version = "3.0.0" +description = "Python port of markdown-it. Markdown parsing, done right!" +category = "main" +optional = false +python-versions = ">=3.8" +files = [ + {file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"}, + {file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"}, +] + +[package.dependencies] +mdurl = ">=0.1,<1.0" + +[package.extras] +benchmarking = ["psutil", "pytest", "pytest-benchmark"] +code-style = ["pre-commit (>=3.0,<4.0)"] +compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"] +linkify = ["linkify-it-py (>=1,<3)"] +plugins = ["mdit-py-plugins"] +profiling = ["gprof2dot"] +rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"] +testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] + +[[package]] +name = "mdurl" +version = "0.1.2" +description = "Markdown URL utilities" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"}, + {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, +] + [[package]] name = "multidict" version = "6.0.4" @@ -665,6 +702,21 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "prompt-toolkit" +version = "3.0.39" +description = "Library for building powerful interactive command lines in Python" +category = "main" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "prompt_toolkit-3.0.39-py3-none-any.whl", hash = "sha256:9dffbe1d8acf91e3de75f3b544e4842382fc06c6babe903ac9acb74dc6e08d88"}, + {file = "prompt_toolkit-3.0.39.tar.gz", hash = "sha256:04505ade687dc26dc4284b1ad19a83be2f2afe83e7a828ace0c72f3a1df72aac"}, +] + +[package.dependencies] +wcwidth = "*" + [[package]] name = "pydantic" version = "2.1.1" @@ -833,6 +885,21 @@ snowballstemmer = ">=2.2.0" [package.extras] toml = ["tomli (>=1.2.3)"] +[[package]] +name = "pygments" +version = "2.16.1" +description = "Pygments is a syntax highlighting package written in Python." +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "Pygments-2.16.1-py3-none-any.whl", hash = "sha256:13fc09fa63bc8d8671a6d247e1eb303c4b343eaee81d861f3404db2935653692"}, + {file = "Pygments-2.16.1.tar.gz", hash = "sha256:1daff0494820c69bc8941e407aa20f577374ee88364ee10a98fdbe0aece96e29"}, +] + +[package.extras] +plugins = ["importlib-metadata"] + [[package]] name = "pytest" version = "7.4.0" @@ -891,6 +958,25 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "rich" +version = "13.5.2" +description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" +category = "main" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "rich-13.5.2-py3-none-any.whl", hash = "sha256:146a90b3b6b47cac4a73c12866a499e9817426423f57c5a66949c086191a8808"}, + {file = "rich-13.5.2.tar.gz", hash = "sha256:fb9d6c0a0f643c99eed3875b5377a184132ba9be4d61516a55273d3554d75a39"}, +] + +[package.dependencies] +markdown-it-py = ">=2.2.0" +pygments = ">=2.13.0,<3.0.0" + +[package.extras] +jupyter = ["ipywidgets (>=7.5.1,<9)"] + [[package]] name = "ruff" version = "0.0.284" @@ -918,6 +1004,18 @@ files = [ {file = "ruff-0.0.284.tar.gz", hash = "sha256:ebd3cc55cd499d326aac17a331deaea29bea206e01c08862f9b5c6e93d77a491"}, ] +[[package]] +name = "shellingham" +version = "1.5.3" +description = "Tool to Detect Surrounding Shell" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "shellingham-1.5.3-py2.py3-none-any.whl", hash = "sha256:419c6a164770c9c7cfcaeddfacb3d31ac7a8db0b0f3e9c1287679359734107e9"}, + {file = "shellingham-1.5.3.tar.gz", hash = "sha256:cb4a6fec583535bc6da17b647dd2330cf7ef30239e05d547d99ae3705fd0f7f8"}, +] + [[package]] name = "snowballstemmer" version = "2.2.0" @@ -951,6 +1049,31 @@ notebook = ["ipywidgets (>=6)"] slack = ["slack-sdk"] telegram = ["requests"] +[[package]] +name = "typer" +version = "0.9.0" +description = "Typer, build great CLIs. Easy to code. Based on Python type hints." +category = "main" +optional = false +python-versions = ">=3.6" +files = [ + {file = "typer-0.9.0-py3-none-any.whl", hash = "sha256:5d96d986a21493606a358cae4461bd8cdf83cbf33a5aa950ae629ca3b51467ee"}, + {file = "typer-0.9.0.tar.gz", hash = "sha256:50922fd79aea2f4751a8e0408ff10d2662bd0c8bbfa84755a699f3bada2978b2"}, +] + +[package.dependencies] +click = ">=7.1.1,<9.0.0" +colorama = {version = ">=0.4.3,<0.5.0", optional = true, markers = "extra == \"all\""} +rich = {version = ">=10.11.0,<14.0.0", optional = true, markers = "extra == \"all\""} +shellingham = {version = ">=1.3.0,<2.0.0", optional = true, markers = "extra == \"all\""} +typing-extensions = ">=3.7.4.3" + +[package.extras] +all = ["colorama (>=0.4.3,<0.5.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"] +dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "pre-commit (>=2.17.0,<3.0.0)"] +doc = ["cairosvg (>=2.5.2,<3.0.0)", "mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pillow (>=9.3.0,<10.0.0)"] +test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.910)", "pytest (>=4.4.0,<8.0.0)", "pytest-cov (>=2.10.0,<5.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<4.0.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"] + [[package]] name = "typing-extensions" version = "4.7.1" @@ -981,6 +1104,18 @@ secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17. socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "wcwidth" +version = "0.2.6" +description = "Measures the displayed width of unicode strings in a terminal" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "wcwidth-0.2.6-py2.py3-none-any.whl", hash = "sha256:795b138f6875577cd91bba52baf9e445cd5118fd32723b460e30a0af30ea230e"}, + {file = "wcwidth-0.2.6.tar.gz", hash = "sha256:a5220780a404dbe3353789870978e472cfe477761f06ee55077256e509b156d0"}, +] + [[package]] name = "yarl" version = "1.9.2" @@ -1072,4 +1207,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.12" -content-hash = "544019c132db9541be699bb3aac24f07cd85ee90773412b581f509e273f9df51" +content-hash = "5b81fe2c2ca1756c2c446ab727fd30fd976ed7d8c82ffb8f7d3870c38755bce6" diff --git a/pyproject.toml b/pyproject.toml index 93fcb9b..2849184 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,8 @@ python = ">=3.11,<3.12" openai = "^0.27.8" pydantic = "^2.1.1" pydantic-settings = "^2.0.2" +typer = {extras = ["all"], version = "^0.9.0"} +prompt-toolkit = "^3.0.39" [tool.poetry.group.test.dependencies] pytest = "^7.4.0" @@ -20,6 +22,9 @@ ruff = "^0.0.284" mypy = "^1.5.0" pydocstyle = "^6.3.0" +[tool.poetry.scripts] +chat = "llm_chat.cli:app" + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" diff --git a/src/llm_chat/chat.py b/src/llm_chat/chat.py index 3bc8c86..35cf640 100644 --- a/src/llm_chat/chat.py +++ b/src/llm_chat/chat.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Protocol from openai import ChatCompletion from openai.openai_object import OpenAIObject @@ -17,6 +17,13 @@ INITIAL_SYSTEM_MESSAGES = [ ] +class ChatProtocol(Protocol): + """Protocol for chat classes.""" + + def send_message(self, message: str) -> str: + """Send a message to the assistant.""" + + class Chat: """Interface class for OpenAI's ChatGPT chat API. @@ -68,3 +75,8 @@ class Chat: message = response["choices"][0]["message"]["content"] self.conversation.messages.append(Message(role=Role.ASSISTANT, content=message)) return message + + +def get_chat(settings: OpenAISettings | None = None) -> ChatProtocol: + """Get a chat object.""" + return Chat(settings=settings) diff --git a/src/llm_chat/cli.py b/src/llm_chat/cli.py new file mode 100644 index 0000000..7d8cff1 --- /dev/null +++ b/src/llm_chat/cli.py @@ -0,0 +1,90 @@ +from typing import Annotated, Any, Optional + +import typer +from prompt_toolkit import PromptSession +from rich.console import Console +from rich.markdown import Markdown + +from llm_chat.chat import get_chat +from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings + +app = typer.Typer() + + +def prompt_continuation(width: int, *args: Any) -> str: + """Prompt continuation for multiline input.""" + return "." * (width - 1) + " " + + +def get_console() -> Console: + """Get a console.""" + return Console() + + +def get_session() -> PromptSession[Any]: + """Get a prompt session.""" + return PromptSession() + + +def read_user_input(session: PromptSession[Any]) -> str: + """Read user input. + + The main purpose of this function is to enable the injection of user input during + tests, since trying to inject a fake PromptSession into the chat function led to + the test hanging. + """ + prompt: str = session.prompt( + ">>> ", multiline=True, prompt_continuation=prompt_continuation + ) + return prompt + + +@app.command() +def chat( + api_key: Annotated[ + Optional[str], + typer.Option( + ..., + "--api-key", + "-k", + help=( + "API key. Will read from the environment variable OPENAI_API_KEY " + "if not provided." + ), + ), + ] = None, + model: Annotated[ + Model, + typer.Option(..., "--model", "-m", help="Model to use.", show_choices=True), + ] = DEFAULT_MODEL, + temperature: Annotated[ + float, + typer.Option( + ..., "--temperature", "-t", help="Model temperature (i.e. creativeness)." + ), + ] = DEFAULT_TEMPERATURE, +) -> None: + """Start a chat session.""" + # TODO: Add option to load context from file. + # TODO: Add option to provide context string as an argument. + if api_key is not None: + settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature) + else: + settings = OpenAISettings(model=model, temperature=temperature) + + current_chat = get_chat(settings) + console = get_console() + session = get_session() + + finished = False + + console.print(f"[bold green]Model:[/bold green] {settings.model}") + console.print(f"[bold green]Temperature:[/bold green] {settings.temperature}") + + while not finished: + prompt = read_user_input(session) + if prompt.strip() == "/q": + finished = True + else: + response = current_chat.send_message(prompt.strip()) + console.print(Markdown(response)) diff --git a/src/llm_chat/settings.py b/src/llm_chat/settings.py index d01801a..f266183 100644 --- a/src/llm_chat/settings.py +++ b/src/llm_chat/settings.py @@ -10,12 +10,16 @@ class Model(StrEnum): GPT4 = "gpt-4" +DEFAULT_MODEL = Model.GPT3 +DEFAULT_TEMPERATURE = 0.7 + + class OpenAISettings(BaseSettings): """Settings for the LLM Chat application.""" api_key: str = "" - model: Model = Model.GPT3 - temperature: float = 0.7 + model: Model = DEFAULT_MODEL + temperature: float = DEFAULT_TEMPERATURE model_config: SettingsConfigDict = SettingsConfigDict( # type: ignore[misc] env_file=".env", diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..a64b4fc --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,46 @@ +from io import StringIO +from unittest.mock import MagicMock + +from pytest import MonkeyPatch +from rich.console import Console +from typer.testing import CliRunner + +import llm_chat +from llm_chat.chat import ChatProtocol +from llm_chat.cli import app +from llm_chat.settings import OpenAISettings + +runner = CliRunner() + + +class ChatFake: + """Fake chat class for testing.""" + + received_messages: list[str] = [] + + def send_message(self, message: str) -> str: + """Echo the received message.""" + self.received_messages.append(message) + return message + + +def test_chat(monkeypatch: MonkeyPatch) -> None: + chat_fake = ChatFake() + output = StringIO() + console = Console(file=output) + + def mock_get_chat(_: OpenAISettings) -> ChatProtocol: + return chat_fake + + def mock_get_console() -> Console: + return console + + mock_read_user_input = MagicMock(side_effect=["Hello", "/q"]) + + monkeypatch.setattr(llm_chat.cli, "get_chat", mock_get_chat) + monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console) + monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input) + + result = runner.invoke(app) + assert result.exit_code == 0 + assert chat_fake.received_messages == ["Hello"]