Generalise base path for config and chat history

This commit is contained in:
Paul Harrison 2024-02-23 11:49:48 +00:00
parent 76bee1aed9
commit 2868c78184
4 changed files with 44 additions and 19 deletions

View File

@ -119,7 +119,7 @@ class Chat:
@classmethod
def load(
cls, path: Path, api_key: str | None = None, history_dir: Path | None = None
cls, path: Path, api_key: str | None = None, base_dir: Path | None = None
) -> ChatProtocol:
"""Load a chat from a file."""
with path.open() as f:
@ -130,8 +130,8 @@ class Chat:
}
if api_key is not None:
args["api_key"] = api_key
if history_dir is not None:
args["history_dir"] = history_dir
if base_dir is not None:
args["base_dir"] = base_dir
settings = OpenAISettings(**args)
return cls(

View File

@ -125,13 +125,16 @@ def chat(
readable=True,
),
] = [],
history_dir: Annotated[
base_dir: Annotated[
Optional[Path],
typer.Option(
...,
"--history-dir",
"--base-dir",
"-d",
help="Path to the directory where conversation history will be saved.",
help=(
"Path to the base directory in which conversation "
"configuration and history will be saved."
),
),
] = None,
name: Annotated[
@ -153,8 +156,8 @@ def chat(
args |= {"model": model}
if temperature is not None:
args |= {"temperature": temperature}
if history_dir is not None:
args |= {"history_dir": history_dir}
if base_dir is not None:
args |= {"history_dir": base_dir}
settings = OpenAISettings(**args)
context_messages = [load_context(path) for path in context]

View File

@ -15,7 +15,9 @@ class Model(StrEnum):
DEFAULT_MODEL = Model.GPT3
DEFAULT_TEMPERATURE = 0.7
DEFAULT_HISTORY_DIR = Path.home() / ".llm_chat" / "history"
DEFAULT_BASE_DIR = Path.home() / ".llm_chat"
DEFAULT_BOT_PATH = "bots"
DEFAULT_HISTORY_PATH = "history"
class OpenAISettings(BaseSettings):
@ -24,7 +26,9 @@ class OpenAISettings(BaseSettings):
api_key: str = ""
model: Model = DEFAULT_MODEL
temperature: float = DEFAULT_TEMPERATURE
history_dir: Path = DEFAULT_HISTORY_DIR
base_dir: Path = DEFAULT_BASE_DIR
bot_path: str = DEFAULT_BOT_PATH
history_path: str = DEFAULT_HISTORY_PATH
model_config: SettingsConfigDict = SettingsConfigDict( # type: ignore[misc]
env_file=".env",
@ -34,9 +38,25 @@ class OpenAISettings(BaseSettings):
use_enum_values=True,
)
@field_validator("history_dir")
def history_dir_must_exist(cls, history_dir: Path) -> Path:
"""Ensure that the history directory exists."""
if not history_dir.exists():
history_dir.mkdir(parents=True)
return history_dir
@field_validator("base_dir")
def base_dir_must_exist(cls, base_dir: Path) -> Path:
"""Ensure that the base directory exists."""
if not base_dir.exists():
base_dir.mkdir(parents=True)
return base_dir
@property
def bot_dir(self) -> Path:
"""Return bot directory Path object, creating if required."""
path = self.base_dir / self.bot_path
if not path.exists():
path.mkdir()
return path
@property
def history_dir(self) -> Path:
"""Return history directory Path object, creating if required."""
path = self.base_dir / self.history_path
if not path.exists():
path.mkdir()
return path

View File

@ -70,19 +70,21 @@ def test_load(tmp_path: Path) -> None:
)
# Save the conversation to a file
file_path = tmp_path / "conversation.json"
history_dir = tmp_path / "history"
history_dir.mkdir()
file_path = history_dir / "conversation.json"
with file_path.open("w") as f:
f.write(conversation.model_dump_json())
# Load the conversation from the file
loaded_chat = Chat.load(file_path, api_key="foo", history_dir=tmp_path)
loaded_chat = Chat.load(file_path, api_key="foo", base_dir=tmp_path)
# Check that the loaded conversation matches the original conversation
assert loaded_chat.settings.model == conversation.model
assert loaded_chat.settings.temperature == conversation.temperature
assert loaded_chat.conversation.messages == conversation.messages
assert loaded_chat.settings.api_key == "foo"
assert loaded_chat.settings.history_dir == tmp_path
assert loaded_chat.settings.base_dir == tmp_path
# We don't want to load the tokens or cost from the previous session
assert loaded_chat.conversation.completion_tokens == 0