Incorporate custom bots into chat CLI

This commit will add the ability to create custom bots in the form of
automatically prepending additional bot-specific context to your chat
session. A new `chat` CLI argument will be added allowing users to
provide the name of an existing bot. For example:

```
llm chat -b "My Bot"
```

An additional `bot` CLI is added for creating and removing bots.

Closes #13
This commit is contained in:
Paul Harrison 2024-02-23 11:49:48 +00:00
parent 76bee1aed9
commit 0698919cb6
15 changed files with 471 additions and 48 deletions

2
poetry.lock generated
View File

@ -764,4 +764,4 @@ files = [
[metadata]
lock-version = "2.0"
python-versions = ">=3.11,<3.12"
content-hash = "e8c5d78c5c95eaadb03e603c5b4ceada8aa27aaa049e6c0d72129c1f2dc53ed9"
content-hash = "8d76898eeb53fd3848f3be2f6aa1662517f9dbd80146db8dfd6f2932021ace48"

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "llm-chat"
version = "1.1.5"
version = "2.0.0"
description = "A general CLI interface for large language models."
authors = ["Paul Harrison <paul@harrison.sh>"]
readme = "README.md"

162
src/llm_chat/bot.py Normal file
View File

@ -0,0 +1,162 @@
from __future__ import annotations
import json
import shutil
from pathlib import Path
from typing import Iterable
from pydantic import BaseModel
from llm_chat.models import Message, Role
from llm_chat.utils import kebab_case
class BotExists(Exception):
"""Bot already exists error."""
pass
class BotDoesNotExist(Exception):
"""Bot does not exist error."""
pass
class BotConfig(BaseModel):
"""Bot configuration class."""
bot_id: str
name: str
context_files: list[str]
def _bot_id_from_name(name: str) -> str:
"""Create bot ID from name.
Args:
name: Bot name in full prose (e.g. My Amazing Bot).
"""
return kebab_case(name)
def _load_context(config: BotConfig, bot_dir: Path) -> list[Message]:
"""Load text from context files.
Args:
config: Bot configuration.
Returns:
List of system messages to provide as context.
"""
context: list[Message] = []
for context_file in config.context_files:
path = bot_dir / config.bot_id / "context" / context_file
if not path.exists():
raise ValueError(f"{path} does not exist.")
if not path.is_file():
raise ValueError(f"{path} is not a file")
with path.open("r") as f:
content = f.read()
context.append(Message(role=Role.SYSTEM, content=content))
return context
class Bot:
"""Custom bot interface.
Args:
config: Bot configuration instance.
bot_dir: Path to directory of bot configurations.
"""
def __init__(self, config: BotConfig, bot_dir: Path) -> None:
self.config = config
self.context = _load_context(config, bot_dir)
@property
def id(self) -> str:
"""Return the bot ID."""
return self.config.bot_id
@property
def name(self) -> str:
"""Return the bot name."""
return self.config.name
@classmethod
def create(
cls,
name: str,
bot_dir: Path,
context_files: Iterable[Path] = tuple(),
) -> None:
"""Create a custom bot.
This command creates the directory structure for the custom bot and copies
the context files. The bot directory is stored within the base bot directory
(e.g. `~/.llm_chat/bots/<name>`), which is stored as the snake case version of
the name argument. the directory contains a settings file `<name>.json` and a
directory of context files.
Args:
name: Name of the custom bot.
bot_dir: Path to where custom bot contexts are stored.
context_files: Paths to context files.
Returns:
Instantiated Bot instance.
"""
bot_id = _bot_id_from_name(name)
path = bot_dir / bot_id
if path.exists():
raise BotExists(f"The bot {name} already exists.")
(path / "context").mkdir(parents=True)
config = BotConfig(
bot_id=bot_id,
name=name,
context_files=[context.name for context in context_files],
)
with (path / f"{bot_id}.json").open("w") as f:
f.write(config.model_dump_json() + "\n")
for context in context_files:
shutil.copy(context, path / "context" / context.name)
@classmethod
def load(cls, name: str, bot_dir: Path) -> Bot:
"""Load an existing bot.
Args:
name: Name of the custom bot.
bot_dir: Path to where custom bot contexts are stored.
Returns:
Instantiated Bot instance.
"""
bot_id = _bot_id_from_name(name)
bot_path = bot_dir / bot_id
if not bot_path.exists():
raise BotDoesNotExist(f"Bot {name} does not exist.")
with (bot_path / f"{bot_id}.json").open("r") as f:
config = BotConfig(**json.load(f))
return cls(config, bot_dir)
@classmethod
def remove(cls, name: str, bot_dir: Path) -> None:
"""Remove an existing bot.
Args:
name: Name of the custom bot.
bot_dir: Path to where custom bot contexts are stored.
"""
bot_id = _bot_id_from_name(name)
bot_path = bot_dir / bot_id
if not bot_path.exists():
raise BotDoesNotExist(f"Bot {name} does not exist.")
shutil.rmtree(bot_path)

View File

@ -9,6 +9,7 @@ from openai import OpenAI
from openai.types.chat import ChatCompletion
from openai.types.completion_usage import CompletionUsage
from llm_chat.bot import Bot
from llm_chat.models import Conversation, Message, Role
from llm_chat.settings import Model, OpenAISettings
@ -53,6 +54,10 @@ class ChatProtocol(Protocol):
conversation: Conversation
@property
def bot(self) -> str:
"""Get the name of the bot the conversation is with."""
@property
def cost(self) -> float:
"""Get the cost of the conversation."""
@ -75,10 +80,14 @@ class ChatProtocol(Protocol):
class Chat:
"""Interface class for OpenAI's ChatGPT chat API.
Arguments:
settings (optional): Settings for the chat. Defaults to reading from
environment variables.
context (optional): Context for the chat. Defaults to an empty list.
Args:
settings: Settings for the chat. Defaults to reading from environment
variables.
context: Context for the chat. Defaults to an empty list.
name: Name of the chat.
bot: Name of bot to chat with.
initial_system_messages: Whether to include the standard initial system
messages.
"""
_pricing: dict[Model, dict[Token, float]] = {
@ -101,16 +110,23 @@ class Chat:
settings: OpenAISettings | None = None,
context: list[Message] = [],
name: str = "",
bot: str = "",
initial_system_messages: bool = True,
) -> None:
self._settings = settings
if bot:
context = Bot.load(bot, self.settings.bot_dir).context + context
if initial_system_messages:
context = INITIAL_SYSTEM_MESSAGES + context
self.conversation = Conversation(
messages=INITIAL_SYSTEM_MESSAGES + context
if initial_system_messages
else context,
messages=context,
model=self.settings.model,
temperature=self.settings.temperature,
name=name,
bot=bot,
)
self._start_time = datetime.now(tz=ZoneInfo("UTC"))
self._client = OpenAI(
@ -119,7 +135,7 @@ class Chat:
@classmethod
def load(
cls, path: Path, api_key: str | None = None, history_dir: Path | None = None
cls, path: Path, api_key: str | None = None, base_dir: Path | None = None
) -> ChatProtocol:
"""Load a chat from a file."""
with path.open() as f:
@ -130,8 +146,8 @@ class Chat:
}
if api_key is not None:
args["api_key"] = api_key
if history_dir is not None:
args["history_dir"] = history_dir
if base_dir is not None:
args["base_dir"] = base_dir
settings = OpenAISettings(**args)
return cls(
@ -147,6 +163,11 @@ class Chat:
self._settings = OpenAISettings()
return self._settings
@property
def bot(self) -> str:
"""Get the name of the bot the conversation is with."""
return self.conversation.bot
@property
def cost(self) -> float:
"""Get the cost of the conversation."""
@ -216,10 +237,13 @@ class Chat:
def get_chat(
settings: OpenAISettings | None = None, context: list[Message] = [], name: str = ""
settings: OpenAISettings | None = None,
context: list[Message] = [],
name: str = "",
bot: str = "",
) -> ChatProtocol:
"""Get a chat object."""
return Chat(settings=settings, context=context, name=name)
return Chat(settings=settings, context=context, name=name, bot=bot)
def get_chat_class() -> Type[Chat]:

View File

@ -0,0 +1,3 @@
from llm_chat.cli.main import app
__all__ = ["app"]

79
src/llm_chat/cli/bot.py Normal file
View File

@ -0,0 +1,79 @@
from pathlib import Path
from typing import Annotated, Any, Optional
import typer
from llm_chat.bot import Bot
from llm_chat.settings import OpenAISettings
app = typer.Typer()
@app.command("create")
def create(
name: Annotated[
str,
typer.Argument(help="Name of bot to create."),
],
base_dir: Annotated[
Optional[Path],
typer.Option(
...,
"--base-dir",
"-d",
help=(
"Path to the base directory in which conversation "
"configuration and history will be saved."
),
),
] = None,
context_files: Annotated[
list[Path],
typer.Option(
...,
"--context",
"-c",
help=(
"Path to a file containing context text. "
"Can provide multiple times for multiple files."
),
exists=True,
file_okay=True,
dir_okay=False,
readable=True,
),
] = [],
) -> None:
"""Create a new bot."""
args: dict[str, Any] = {}
if base_dir is not None:
args |= {"base_dir": base_dir}
settings = OpenAISettings(**args)
Bot.create(name, settings.bot_dir, context_files=context_files)
@app.command("remove")
def remove(
name: Annotated[
str,
typer.Argument(help="Name of bot to remove."),
],
base_dir: Annotated[
Optional[Path],
typer.Option(
...,
"--base-dir",
"-d",
help=(
"Path to the base directory in which conversation "
"configuration and history will be saved."
),
),
] = None,
) -> None:
"""Remove an existing bot."""
args: dict[str, Any] = {}
if base_dir is not None:
args |= {"base_dir": base_dir}
settings = OpenAISettings(**args)
Bot.remove(name, settings.bot_dir)

View File

@ -7,10 +7,12 @@ from rich.console import Console
from rich.markdown import Markdown
from llm_chat.chat import ChatProtocol, get_chat, get_chat_class
from llm_chat.cli import bot
from llm_chat.models import Message, Role
from llm_chat.settings import Model, OpenAISettings
app = typer.Typer()
app.add_typer(bot.app, name="bot", help="Manage custom bots.")
def prompt_continuation(width: int, *args: Any) -> str:
@ -73,6 +75,8 @@ def run_conversation(current_chat: ChatProtocol) -> None:
)
if current_chat.name:
console.print(f"[bold green]Name:[/bold green] {current_chat.name}")
if current_chat.bot:
console.print(f"[bold green]Bot:[/bold green] {current_chat.bot}")
while not finished:
prompt = read_user_input(session)
@ -117,7 +121,7 @@ def chat(
"-c",
help=(
"Path to a file containing context text. "
"Can provide multiple time for multiple files."
"Can provide multiple times for multiple files."
),
exists=True,
file_okay=True,
@ -125,13 +129,16 @@ def chat(
readable=True,
),
] = [],
history_dir: Annotated[
base_dir: Annotated[
Optional[Path],
typer.Option(
...,
"--history-dir",
"--base-dir",
"-d",
help="Path to the directory where conversation history will be saved.",
help=(
"Path to the base directory in which conversation "
"configuration and history will be saved."
),
),
] = None,
name: Annotated[
@ -143,6 +150,12 @@ def chat(
help="Name of the chat.",
),
] = "",
bot: Annotated[
str,
typer.Option(
..., "--bot", "-b", help="Name of bot with whom you want to chat."
),
] = "",
) -> None:
"""Start a chat session."""
# TODO: Add option to provide context string as an argument.
@ -153,13 +166,15 @@ def chat(
args |= {"model": model}
if temperature is not None:
args |= {"temperature": temperature}
if history_dir is not None:
args |= {"history_dir": history_dir}
if base_dir is not None:
args |= {"base_dir": base_dir}
settings = OpenAISettings(**args)
context_messages = [load_context(path) for path in context]
current_chat = get_chat(settings=settings, context=context_messages, name=name)
current_chat = get_chat(
settings=settings, context=context_messages, name=name, bot=bot
)
run_conversation(current_chat)

View File

@ -32,6 +32,7 @@ class Conversation(BaseModel):
model: Model
temperature: float = DEFAULT_TEMPERATURE
name: str = ""
bot: str = ""
completion_tokens: int = 0
prompt_tokens: int = 0
cost: float = 0.0

View File

@ -15,7 +15,9 @@ class Model(StrEnum):
DEFAULT_MODEL = Model.GPT3
DEFAULT_TEMPERATURE = 0.7
DEFAULT_HISTORY_DIR = Path.home() / ".llm_chat" / "history"
DEFAULT_BASE_DIR = Path.home() / ".llm-chat"
DEFAULT_BOT_PATH = "bots"
DEFAULT_HISTORY_PATH = "history"
class OpenAISettings(BaseSettings):
@ -24,7 +26,9 @@ class OpenAISettings(BaseSettings):
api_key: str = ""
model: Model = DEFAULT_MODEL
temperature: float = DEFAULT_TEMPERATURE
history_dir: Path = DEFAULT_HISTORY_DIR
base_dir: Path = DEFAULT_BASE_DIR
bot_path: str = DEFAULT_BOT_PATH
history_path: str = DEFAULT_HISTORY_PATH
model_config: SettingsConfigDict = SettingsConfigDict( # type: ignore[misc]
env_file=".env",
@ -34,9 +38,25 @@ class OpenAISettings(BaseSettings):
use_enum_values=True,
)
@field_validator("history_dir")
def history_dir_must_exist(cls, history_dir: Path) -> Path:
"""Ensure that the history directory exists."""
if not history_dir.exists():
history_dir.mkdir(parents=True)
return history_dir
@field_validator("base_dir")
def base_dir_must_exist(cls, base_dir: Path) -> Path:
"""Ensure that the base directory exists."""
if not base_dir.exists():
base_dir.mkdir(parents=True)
return base_dir
@property
def bot_dir(self) -> Path:
"""Return bot directory Path object, creating if required."""
path = self.base_dir / self.bot_path
if not path.exists():
path.mkdir()
return path
@property
def history_dir(self) -> Path:
"""Return history directory Path object, creating if required."""
path = self.base_dir / self.history_path
if not path.exists():
path.mkdir()
return path

9
src/llm_chat/utils.py Normal file
View File

@ -0,0 +1,9 @@
import re
def kebab_case(string: str) -> str:
"""Convert a string to kebab case."""
string = string.replace("-", " ")
string = re.sub("([A-Z][a-z]+)", r" \1", string)
string = re.sub("([A-Z]+)", r" \1", string)
return "-".join(string.split()).lower()

View File

@ -11,6 +11,6 @@ def mock_openai_api_key() -> None:
@pytest.fixture(autouse=True)
def mock_history_dir(tmp_path: Path) -> None:
def mock_base_dir(tmp_path: Path) -> None:
"""Set a fake history directory."""
os.environ["OPENAI_HISTORY_DIR"] = str(tmp_path / ".llm_chat")
os.environ["OPENAI_BASE_DIR"] = str(tmp_path / ".llm_chat")

64
tests/test_bot.py Normal file
View File

@ -0,0 +1,64 @@
from pathlib import Path
import pytest
from llm_chat.bot import Bot, BotConfig, BotDoesNotExist, BotExists
def test_create_load_remove_bot(tmp_path: Path) -> None:
bot_name = "Test Bot"
bot_id = "test-bot"
with (tmp_path / "context.md").open("w") as f:
f.write("Hello, world!")
assert not (tmp_path / bot_id).exists()
Bot.create(
name=bot_name,
bot_dir=tmp_path,
context_files=[tmp_path / "context.md"],
)
assert (tmp_path / bot_id).exists()
assert (tmp_path / bot_id / "context").exists()
assert (tmp_path / bot_id / "context" / "context.md").exists()
assert (tmp_path / bot_id / f"{bot_id}.json").exists()
with (tmp_path / bot_id / f"{bot_id}.json").open() as f:
config = BotConfig.model_validate_json(f.read(), strict=True)
assert config.name == bot_name
assert config.bot_id == bot_id
assert config.context_files == ["context.md"]
with (tmp_path / bot_id / "context" / "context.md").open() as f:
assert f.read() == "Hello, world!"
bot = Bot.load(name=bot_name, bot_dir=tmp_path)
assert bot.config == config
assert bot.id == bot_id
assert bot.name == bot_name
Bot.remove(name=bot_name, bot_dir=tmp_path)
assert not (tmp_path / bot_id).exists()
def test_bot_does_not_exist(tmp_path: Path) -> None:
with pytest.raises(BotDoesNotExist):
Bot.load(name="Test Bot", bot_dir=tmp_path)
def test_bot_already_exists(tmp_path: Path) -> None:
bot_name = "Test Bot"
Bot.create(
name=bot_name,
bot_dir=tmp_path,
)
with pytest.raises(BotExists):
Bot.create(
name="Test Bot",
bot_dir=tmp_path,
)

View File

@ -8,6 +8,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice
from openai.types.completion_usage import CompletionUsage
from llm_chat.bot import Bot
from llm_chat.chat import Chat, save_conversation
from llm_chat.models import Conversation, Message, Role
from llm_chat.settings import Model, OpenAISettings
@ -70,19 +71,21 @@ def test_load(tmp_path: Path) -> None:
)
# Save the conversation to a file
file_path = tmp_path / "conversation.json"
history_dir = tmp_path / "history"
history_dir.mkdir()
file_path = history_dir / "conversation.json"
with file_path.open("w") as f:
f.write(conversation.model_dump_json())
# Load the conversation from the file
loaded_chat = Chat.load(file_path, api_key="foo", history_dir=tmp_path)
loaded_chat = Chat.load(file_path, api_key="foo", base_dir=tmp_path)
# Check that the loaded conversation matches the original conversation
assert loaded_chat.settings.model == conversation.model
assert loaded_chat.settings.temperature == conversation.temperature
assert loaded_chat.conversation.messages == conversation.messages
assert loaded_chat.settings.api_key == "foo"
assert loaded_chat.settings.history_dir == tmp_path
assert loaded_chat.settings.base_dir == tmp_path
# We don't want to load the tokens or cost from the previous session
assert loaded_chat.conversation.completion_tokens == 0
@ -143,3 +146,19 @@ def test_calculate_cost(model: Model, cost: float) -> None:
conversation = Chat(settings=settings)
_ = conversation.send_message("Hello")
assert conversation.cost == cost
def test_chat_with_bot(tmp_path: Path) -> None:
settings = OpenAISettings()
bot_name = "Test Bot"
context = "Hello, world!"
with (tmp_path / "context.md").open("w") as f:
f.write(context)
Bot.create(
name=bot_name, bot_dir=settings.bot_dir, context_files=[tmp_path / "context.md"]
)
chat = Chat(settings=settings, bot=bot_name)
assert chat.conversation.messages[-1].content == context

View File

@ -9,9 +9,9 @@ from pytest import MonkeyPatch
from rich.console import Console
from typer.testing import CliRunner
import llm_chat
import llm_chat.cli
from llm_chat.chat import ChatProtocol
from llm_chat.cli import app
from llm_chat.cli.main import app
from llm_chat.models import Conversation, Message, Role
from llm_chat.settings import Model, OpenAISettings
@ -37,6 +37,11 @@ class ChatFake:
def _set_args(self, **kwargs: Any) -> None:
self.args = kwargs
@property
def bot(self) -> str:
"""Get the name of the bot the conversation is with."""
return self.args.get("bot", "")
@property
def cost(self) -> float:
"""Get the cost of the conversation."""
@ -77,9 +82,9 @@ def test_chat(monkeypatch: MonkeyPatch) -> None:
mock_read_user_input = MagicMock(side_effect=["Hello", "/q"])
monkeypatch.setattr(llm_chat.cli, "get_chat", mock_get_chat)
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
monkeypatch.setattr(llm_chat.cli.main, "get_chat", mock_get_chat)
monkeypatch.setattr(llm_chat.cli.main, "get_console", mock_get_console)
monkeypatch.setattr(llm_chat.cli.main, "read_user_input", mock_read_user_input)
result = runner.invoke(app, ["chat"])
assert result.exit_code == 0
@ -107,9 +112,9 @@ def test_chat_with_context(
mock_read_user_input = MagicMock(side_effect=["Hello", "/q"])
monkeypatch.setattr(llm_chat.cli, "get_chat", mock_get_chat)
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
monkeypatch.setattr(llm_chat.cli.main, "get_chat", mock_get_chat)
monkeypatch.setattr(llm_chat.cli.main, "get_console", mock_get_console)
monkeypatch.setattr(llm_chat.cli.main, "read_user_input", mock_read_user_input)
result = runner.invoke(app, ["chat", argument, str(tmp_file)])
assert result.exit_code == 0
@ -139,9 +144,9 @@ def test_chat_with_name(
mock_read_user_input = MagicMock(side_effect=["Hello", "/q"])
monkeypatch.setattr(llm_chat.cli, "get_chat", mock_get_chat)
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
monkeypatch.setattr(llm_chat.cli.main, "get_chat", mock_get_chat)
monkeypatch.setattr(llm_chat.cli.main, "get_console", mock_get_console)
monkeypatch.setattr(llm_chat.cli.main, "read_user_input", mock_read_user_input)
result = runner.invoke(app, ["chat", argument, name])
assert result.exit_code == 0
@ -179,9 +184,9 @@ def test_load(monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
mock_read_user_input = MagicMock(side_effect=["Hello", "/q"])
monkeypatch.setattr(llm_chat.cli, "get_chat_class", mock_get_chat)
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
monkeypatch.setattr(llm_chat.cli.main, "get_chat_class", mock_get_chat)
monkeypatch.setattr(llm_chat.cli.main, "get_console", mock_get_console)
monkeypatch.setattr(llm_chat.cli.main, "read_user_input", mock_read_user_input)
# Load the conversation from the file
result = runner.invoke(app, ["load", str(file_path)])

22
tests/test_utils.py Normal file
View File

@ -0,0 +1,22 @@
import pytest
from llm_chat.utils import kebab_case
@pytest.mark.parametrize(
"string,expected",
[
("fooBar", "foo-bar"),
("FooBar", "foo-bar"),
("Foo Bar", "foo-bar"),
("1Foo2Bar3", "1-foo2-bar3"),
],
ids=[
"fooBar",
"FooBar",
"Foo Bar",
"1Foo2Bar3",
],
)
def test_kebab_case(string: str, expected: str) -> None:
assert kebab_case(string) == expected