diff --git a/src/llm_chat/bot.py b/src/llm_chat/bot.py index 6d9decd..147684b 100644 --- a/src/llm_chat/bot.py +++ b/src/llm_chat/bot.py @@ -15,6 +15,28 @@ def _bot_id_from_name(name: str) -> str: return kebab_case(name) +def _load_context(config: BotConfig, bot_dir: Path) -> list[Message]: + """Load text from context files. + + Args: + config: Bot configuration. + + Returns: + List of system messages to provide as context. + """ + context: list[Message] = [] + for context_file in config.context_files: + path = bot_dir / config.bot_id / "context" / context_file + if not path.exists(): + raise ValueError(f"{path} does not exist.") + if not path.is_file(): + raise ValueError(f"{path} is not a file") + with path.open("r") as f: + content = f.read() + context.append(Message(role=Role.SYSTEM, content=content)) + return context + + class BotConfig(BaseModel): """Bot configuration class.""" @@ -30,26 +52,22 @@ class BotExists(Exception): class BotDoesNotExists(Exception): - """Bot already exists error.""" + """Bot does not exist error.""" pass class Bot: - """Custom bot interface.""" + """Custom bot interface. + + Args: + config: Bot configuration instance. + bot_dir: Path to directory of bot configurations. + """ def __init__(self, config: BotConfig, bot_dir: Path) -> None: self.config = config - self.context: list[Message] = [] - for context_file in config.context_files: - path = bot_dir / "context" / context_file - if not path.exists(): - raise ValueError(f"{path} does not exist.") - if not path.is_file(): - raise ValueError(f"{path} is not a file") - with path.open("r") as f: - content = f.read() - self.context.append(Message(role=Role.SYSTEM, content=content)) + self.context = _load_context(config, bot_dir) @property def id(self) -> str: @@ -75,6 +93,9 @@ class Bot: name: Name of the custom bot. bot_dir: Path to where custom bot contexts are stored. context_files: Paths to context files. + + Returns: + Instantiated Bot instance. """ bot_id = _bot_id_from_name(name) path = bot_dir / bot_id @@ -96,7 +117,15 @@ class Bot: @classmethod def load(cls, name: str, bot_dir: Path) -> Bot: - """Load existing bot.""" + """Load existing bot. + + Args: + name: Name of the custom bot. + bot_dir: Path to where custom bot contexts are stored. + + Returns: + Instantiated Bot instance. + """ bot_id = _bot_id_from_name(name) bot_path = bot_dir / bot_id if not bot_path.exists(): diff --git a/src/llm_chat/chat.py b/src/llm_chat/chat.py index a76d3db..1b684c8 100644 --- a/src/llm_chat/chat.py +++ b/src/llm_chat/chat.py @@ -9,6 +9,7 @@ from openai import OpenAI from openai.types.chat import ChatCompletion from openai.types.completion_usage import CompletionUsage +from llm_chat.bot import Bot from llm_chat.models import Conversation, Message, Role from llm_chat.settings import Model, OpenAISettings @@ -53,6 +54,10 @@ class ChatProtocol(Protocol): conversation: Conversation + @property + def bot(self) -> str: + """Get the name of the bot the conversation is with.""" + @property def cost(self) -> float: """Get the cost of the conversation.""" @@ -75,10 +80,14 @@ class ChatProtocol(Protocol): class Chat: """Interface class for OpenAI's ChatGPT chat API. - Arguments: - settings (optional): Settings for the chat. Defaults to reading from - environment variables. - context (optional): Context for the chat. Defaults to an empty list. + Args: + settings: Settings for the chat. Defaults to reading from environment + variables. + context: Context for the chat. Defaults to an empty list. + name: Name of the chat. + bot: Name of bot to chat with. + initial_system_messages: Whether to include the standard initial system + messages. """ _pricing: dict[Model, dict[Token, float]] = { @@ -101,16 +110,23 @@ class Chat: settings: OpenAISettings | None = None, context: list[Message] = [], name: str = "", + bot: str = "", initial_system_messages: bool = True, ) -> None: self._settings = settings + + if bot: + context = Bot.load(bot, self.settings.bot_dir).context + context + + if initial_system_messages: + context = INITIAL_SYSTEM_MESSAGES + context + self.conversation = Conversation( - messages=INITIAL_SYSTEM_MESSAGES + context - if initial_system_messages - else context, + messages=context, model=self.settings.model, temperature=self.settings.temperature, name=name, + bot=bot, ) self._start_time = datetime.now(tz=ZoneInfo("UTC")) self._client = OpenAI( @@ -147,6 +163,11 @@ class Chat: self._settings = OpenAISettings() return self._settings + @property + def bot(self) -> str: + """Get the name of the bot the conversation is with.""" + return self.conversation.bot + @property def cost(self) -> float: """Get the cost of the conversation.""" @@ -216,10 +237,10 @@ class Chat: def get_chat( - settings: OpenAISettings | None = None, context: list[Message] = [], name: str = "" + settings: OpenAISettings | None = None, context: list[Message] = [], name: str = "", bot: str = "" ) -> ChatProtocol: """Get a chat object.""" - return Chat(settings=settings, context=context, name=name) + return Chat(settings=settings, context=context, name=name, bot=bot) def get_chat_class() -> Type[Chat]: diff --git a/src/llm_chat/cli/main.py b/src/llm_chat/cli/main.py index 9f7b375..df8f37c 100644 --- a/src/llm_chat/cli/main.py +++ b/src/llm_chat/cli/main.py @@ -75,6 +75,8 @@ def run_conversation(current_chat: ChatProtocol) -> None: ) if current_chat.name: console.print(f"[bold green]Name:[/bold green] {current_chat.name}") + if current_chat.bot: + console.print(f"[bold green]Bot:[/bold green] {current_chat.bot}") while not finished: prompt = read_user_input(session) @@ -148,6 +150,15 @@ def chat( help="Name of the chat.", ), ] = "", + bot: Annotated[ + str, + typer.Option( + ..., + "--bot", + "-b", + help="Name of bot with whom you want to chat." + ) + ] = "" ) -> None: """Start a chat session.""" # TODO: Add option to provide context string as an argument. @@ -164,7 +175,7 @@ def chat( context_messages = [load_context(path) for path in context] - current_chat = get_chat(settings=settings, context=context_messages, name=name) + current_chat = get_chat(settings=settings, context=context_messages, name=name, bot=bot) run_conversation(current_chat) diff --git a/src/llm_chat/models.py b/src/llm_chat/models.py index d3262ca..b820241 100644 --- a/src/llm_chat/models.py +++ b/src/llm_chat/models.py @@ -32,6 +32,7 @@ class Conversation(BaseModel): model: Model temperature: float = DEFAULT_TEMPERATURE name: str = "" + bot: str = "" completion_tokens: int = 0 prompt_tokens: int = 0 cost: float = 0.0