Generalise base path for config and chat history

This commit is contained in:
Paul Harrison 2024-02-23 11:49:48 +00:00
parent 76bee1aed9
commit 2868c78184
4 changed files with 44 additions and 19 deletions

View File

@ -119,7 +119,7 @@ class Chat:
@classmethod @classmethod
def load( def load(
cls, path: Path, api_key: str | None = None, history_dir: Path | None = None cls, path: Path, api_key: str | None = None, base_dir: Path | None = None
) -> ChatProtocol: ) -> ChatProtocol:
"""Load a chat from a file.""" """Load a chat from a file."""
with path.open() as f: with path.open() as f:
@ -130,8 +130,8 @@ class Chat:
} }
if api_key is not None: if api_key is not None:
args["api_key"] = api_key args["api_key"] = api_key
if history_dir is not None: if base_dir is not None:
args["history_dir"] = history_dir args["base_dir"] = base_dir
settings = OpenAISettings(**args) settings = OpenAISettings(**args)
return cls( return cls(

View File

@ -125,13 +125,16 @@ def chat(
readable=True, readable=True,
), ),
] = [], ] = [],
history_dir: Annotated[ base_dir: Annotated[
Optional[Path], Optional[Path],
typer.Option( typer.Option(
..., ...,
"--history-dir", "--base-dir",
"-d", "-d",
help="Path to the directory where conversation history will be saved.", help=(
"Path to the base directory in which conversation "
"configuration and history will be saved."
),
), ),
] = None, ] = None,
name: Annotated[ name: Annotated[
@ -153,8 +156,8 @@ def chat(
args |= {"model": model} args |= {"model": model}
if temperature is not None: if temperature is not None:
args |= {"temperature": temperature} args |= {"temperature": temperature}
if history_dir is not None: if base_dir is not None:
args |= {"history_dir": history_dir} args |= {"history_dir": base_dir}
settings = OpenAISettings(**args) settings = OpenAISettings(**args)
context_messages = [load_context(path) for path in context] context_messages = [load_context(path) for path in context]

View File

@ -15,7 +15,9 @@ class Model(StrEnum):
DEFAULT_MODEL = Model.GPT3 DEFAULT_MODEL = Model.GPT3
DEFAULT_TEMPERATURE = 0.7 DEFAULT_TEMPERATURE = 0.7
DEFAULT_HISTORY_DIR = Path.home() / ".llm_chat" / "history" DEFAULT_BASE_DIR = Path.home() / ".llm_chat"
DEFAULT_BOT_PATH = "bots"
DEFAULT_HISTORY_PATH = "history"
class OpenAISettings(BaseSettings): class OpenAISettings(BaseSettings):
@ -24,7 +26,9 @@ class OpenAISettings(BaseSettings):
api_key: str = "" api_key: str = ""
model: Model = DEFAULT_MODEL model: Model = DEFAULT_MODEL
temperature: float = DEFAULT_TEMPERATURE temperature: float = DEFAULT_TEMPERATURE
history_dir: Path = DEFAULT_HISTORY_DIR base_dir: Path = DEFAULT_BASE_DIR
bot_path: str = DEFAULT_BOT_PATH
history_path: str = DEFAULT_HISTORY_PATH
model_config: SettingsConfigDict = SettingsConfigDict( # type: ignore[misc] model_config: SettingsConfigDict = SettingsConfigDict( # type: ignore[misc]
env_file=".env", env_file=".env",
@ -34,9 +38,25 @@ class OpenAISettings(BaseSettings):
use_enum_values=True, use_enum_values=True,
) )
@field_validator("history_dir") @field_validator("base_dir")
def history_dir_must_exist(cls, history_dir: Path) -> Path: def base_dir_must_exist(cls, base_dir: Path) -> Path:
"""Ensure that the history directory exists.""" """Ensure that the base directory exists."""
if not history_dir.exists(): if not base_dir.exists():
history_dir.mkdir(parents=True) base_dir.mkdir(parents=True)
return history_dir return base_dir
@property
def bot_dir(self) -> Path:
"""Return bot directory Path object, creating if required."""
path = self.base_dir / self.bot_path
if not path.exists():
path.mkdir()
return path
@property
def history_dir(self) -> Path:
"""Return history directory Path object, creating if required."""
path = self.base_dir / self.history_path
if not path.exists():
path.mkdir()
return path

View File

@ -70,19 +70,21 @@ def test_load(tmp_path: Path) -> None:
) )
# Save the conversation to a file # Save the conversation to a file
file_path = tmp_path / "conversation.json" history_dir = tmp_path / "history"
history_dir.mkdir()
file_path = history_dir / "conversation.json"
with file_path.open("w") as f: with file_path.open("w") as f:
f.write(conversation.model_dump_json()) f.write(conversation.model_dump_json())
# Load the conversation from the file # Load the conversation from the file
loaded_chat = Chat.load(file_path, api_key="foo", history_dir=tmp_path) loaded_chat = Chat.load(file_path, api_key="foo", base_dir=tmp_path)
# Check that the loaded conversation matches the original conversation # Check that the loaded conversation matches the original conversation
assert loaded_chat.settings.model == conversation.model assert loaded_chat.settings.model == conversation.model
assert loaded_chat.settings.temperature == conversation.temperature assert loaded_chat.settings.temperature == conversation.temperature
assert loaded_chat.conversation.messages == conversation.messages assert loaded_chat.conversation.messages == conversation.messages
assert loaded_chat.settings.api_key == "foo" assert loaded_chat.settings.api_key == "foo"
assert loaded_chat.settings.history_dir == tmp_path assert loaded_chat.settings.base_dir == tmp_path
# We don't want to load the tokens or cost from the previous session # We don't want to load the tokens or cost from the previous session
assert loaded_chat.conversation.completion_tokens == 0 assert loaded_chat.conversation.completion_tokens == 0