126 lines
3.5 KiB
Python
126 lines
3.5 KiB
Python
from pathlib import Path
|
|
from typing import Annotated, Any, Optional
|
|
|
|
import typer
|
|
from prompt_toolkit import PromptSession
|
|
from rich.console import Console
|
|
from rich.markdown import Markdown
|
|
|
|
from llm_chat.chat import get_chat
|
|
from llm_chat.models import Message, Role
|
|
from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings
|
|
|
|
app = typer.Typer()
|
|
|
|
|
|
def prompt_continuation(width: int, *args: Any) -> str:
|
|
"""Prompt continuation for multiline input."""
|
|
return "." * (width - 1) + " "
|
|
|
|
|
|
def get_console() -> Console:
|
|
"""Get a console."""
|
|
return Console()
|
|
|
|
|
|
def get_session() -> PromptSession[Any]:
|
|
"""Get a prompt session."""
|
|
return PromptSession()
|
|
|
|
|
|
def read_user_input(session: PromptSession[Any]) -> str:
|
|
"""Read user input.
|
|
|
|
The main purpose of this function is to enable the injection of user input during
|
|
tests, since trying to inject a fake PromptSession into the chat function led to
|
|
the test hanging.
|
|
"""
|
|
prompt: str = session.prompt(
|
|
">>> ", multiline=True, prompt_continuation=prompt_continuation
|
|
)
|
|
return prompt
|
|
|
|
|
|
def load_context(path: Path) -> Message:
|
|
"""Load context text from file."""
|
|
if not path.exists():
|
|
raise typer.BadParameter(f"File {path} does not exist.")
|
|
|
|
if not path.is_file():
|
|
raise typer.BadParameter(f"Path {path} is not a file.")
|
|
|
|
with path.open() as f:
|
|
content = f.read()
|
|
|
|
return Message(role=Role.SYSTEM, content=content)
|
|
|
|
|
|
@app.command()
|
|
def chat(
|
|
api_key: Annotated[
|
|
Optional[str],
|
|
typer.Option(
|
|
...,
|
|
"--api-key",
|
|
"-k",
|
|
help=(
|
|
"API key. Will read from the environment variable OPENAI_API_KEY "
|
|
"if not provided."
|
|
),
|
|
),
|
|
] = None,
|
|
model: Annotated[
|
|
Model,
|
|
typer.Option(..., "--model", "-m", help="Model to use.", show_choices=True),
|
|
] = DEFAULT_MODEL,
|
|
temperature: Annotated[
|
|
float,
|
|
typer.Option(
|
|
..., "--temperature", "-t", help="Model temperature (i.e. creativeness)."
|
|
),
|
|
] = DEFAULT_TEMPERATURE,
|
|
context: Annotated[
|
|
list[Path],
|
|
typer.Option(
|
|
...,
|
|
"--context",
|
|
"-c",
|
|
help=(
|
|
"Path to a file containing context text. "
|
|
"Can provide multiple time for multiple files."
|
|
),
|
|
exists=True,
|
|
file_okay=True,
|
|
dir_okay=False,
|
|
readable=True,
|
|
),
|
|
] = [],
|
|
) -> None:
|
|
"""Start a chat session."""
|
|
# TODO: Add option to load context from file.
|
|
# TODO: Add option to provide context string as an argument.
|
|
if api_key is not None:
|
|
settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature)
|
|
else:
|
|
settings = OpenAISettings(model=model, temperature=temperature)
|
|
|
|
context_messages = [load_context(path) for path in context]
|
|
|
|
current_chat = get_chat(settings=settings, context=context_messages)
|
|
# current_chat = get_chat(settings=settings)
|
|
console = get_console()
|
|
session = get_session()
|
|
|
|
finished = False
|
|
|
|
console.print(f"[bold green]Model:[/bold green] {settings.model}")
|
|
console.print(f"[bold green]Temperature:[/bold green] {settings.temperature}")
|
|
|
|
while not finished:
|
|
prompt = read_user_input(session)
|
|
if prompt.strip() == "/q":
|
|
finished = True
|
|
else:
|
|
response = current_chat.send_message(prompt.strip())
|
|
console.print(Markdown(response))
|