llm-chat/src/llm_chat/cli.py

91 lines
2.5 KiB
Python
Raw Normal View History

from typing import Annotated, Any, Optional
import typer
from prompt_toolkit import PromptSession
from rich.console import Console
from rich.markdown import Markdown
from llm_chat.chat import get_chat
from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings
app = typer.Typer()
def prompt_continuation(width: int, *args: Any) -> str:
"""Prompt continuation for multiline input."""
return "." * (width - 1) + " "
def get_console() -> Console:
"""Get a console."""
return Console()
def get_session() -> PromptSession[Any]:
"""Get a prompt session."""
return PromptSession()
def read_user_input(session: PromptSession[Any]) -> str:
"""Read user input.
The main purpose of this function is to enable the injection of user input during
tests, since trying to inject a fake PromptSession into the chat function led to
the test hanging.
"""
prompt: str = session.prompt(
">>> ", multiline=True, prompt_continuation=prompt_continuation
)
return prompt
@app.command()
def chat(
api_key: Annotated[
Optional[str],
typer.Option(
...,
"--api-key",
"-k",
help=(
"API key. Will read from the environment variable OPENAI_API_KEY "
"if not provided."
),
),
] = None,
model: Annotated[
Model,
typer.Option(..., "--model", "-m", help="Model to use.", show_choices=True),
] = DEFAULT_MODEL,
temperature: Annotated[
float,
typer.Option(
..., "--temperature", "-t", help="Model temperature (i.e. creativeness)."
),
] = DEFAULT_TEMPERATURE,
) -> None:
"""Start a chat session."""
# TODO: Add option to load context from file.
# TODO: Add option to provide context string as an argument.
if api_key is not None:
settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature)
else:
settings = OpenAISettings(model=model, temperature=temperature)
current_chat = get_chat(settings)
console = get_console()
session = get_session()
finished = False
console.print(f"[bold green]Model:[/bold green] {settings.model}")
console.print(f"[bold green]Temperature:[/bold green] {settings.temperature}")
while not finished:
prompt = read_user_input(session)
if prompt.strip() == "/q":
finished = True
else:
response = current_chat.send_message(prompt.strip())
console.print(Markdown(response))