from pathlib import Path from typing import Annotated, Any, Optional import typer from prompt_toolkit import PromptSession from rich.console import Console from rich.markdown import Markdown from llm_chat.chat import get_chat from llm_chat.models import Message, Role from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings app = typer.Typer() def prompt_continuation(width: int, *args: Any) -> str: """Prompt continuation for multiline input.""" return "." * (width - 1) + " " def get_console() -> Console: """Get a console.""" return Console() def get_session() -> PromptSession[Any]: """Get a prompt session.""" return PromptSession() def read_user_input(session: PromptSession[Any]) -> str: """Read user input. The main purpose of this function is to enable the injection of user input during tests, since trying to inject a fake PromptSession into the chat function led to the test hanging. """ prompt: str = session.prompt( ">>> ", multiline=True, prompt_continuation=prompt_continuation ) return prompt def load_context(path: Path) -> Message: """Load context text from file.""" if not path.exists(): raise typer.BadParameter(f"File {path} does not exist.") if not path.is_file(): raise typer.BadParameter(f"Path {path} is not a file.") with path.open() as f: content = f.read() return Message(role=Role.SYSTEM, content=content) @app.command() def chat( api_key: Annotated[ Optional[str], typer.Option( ..., "--api-key", "-k", help=( "API key. Will read from the environment variable OPENAI_API_KEY " "if not provided." ), ), ] = None, model: Annotated[ Model, typer.Option(..., "--model", "-m", help="Model to use.", show_choices=True), ] = DEFAULT_MODEL, temperature: Annotated[ float, typer.Option( ..., "--temperature", "-t", help="Model temperature (i.e. creativeness)." ), ] = DEFAULT_TEMPERATURE, context: Annotated[ list[Path], typer.Option( ..., "--context", "-c", help=( "Path to a file containing context text. " "Can provide multiple time for multiple files." ), exists=True, file_okay=True, dir_okay=False, readable=True, ), ] = [], ) -> None: """Start a chat session.""" # TODO: Add option to load context from file. # TODO: Add option to provide context string as an argument. if api_key is not None: settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature) else: settings = OpenAISettings(model=model, temperature=temperature) context_messages = [load_context(path) for path in context] current_chat = get_chat(settings=settings, context=context_messages) # current_chat = get_chat(settings=settings) console = get_console() session = get_session() finished = False console.print(f"[bold green]Model:[/bold green] {settings.model}") console.print(f"[bold green]Temperature:[/bold green] {settings.temperature}") while not finished: prompt = read_user_input(session) if prompt.strip() == "/q": finished = True else: response = current_chat.send_message(prompt.strip()) console.print(Markdown(response))