Generalise base path for config and chat history
This commit is contained in:
parent
76bee1aed9
commit
2868c78184
|
@ -119,7 +119,7 @@ class Chat:
|
|||
|
||||
@classmethod
|
||||
def load(
|
||||
cls, path: Path, api_key: str | None = None, history_dir: Path | None = None
|
||||
cls, path: Path, api_key: str | None = None, base_dir: Path | None = None
|
||||
) -> ChatProtocol:
|
||||
"""Load a chat from a file."""
|
||||
with path.open() as f:
|
||||
|
@ -130,8 +130,8 @@ class Chat:
|
|||
}
|
||||
if api_key is not None:
|
||||
args["api_key"] = api_key
|
||||
if history_dir is not None:
|
||||
args["history_dir"] = history_dir
|
||||
if base_dir is not None:
|
||||
args["base_dir"] = base_dir
|
||||
|
||||
settings = OpenAISettings(**args)
|
||||
return cls(
|
||||
|
|
|
@ -125,13 +125,16 @@ def chat(
|
|||
readable=True,
|
||||
),
|
||||
] = [],
|
||||
history_dir: Annotated[
|
||||
base_dir: Annotated[
|
||||
Optional[Path],
|
||||
typer.Option(
|
||||
...,
|
||||
"--history-dir",
|
||||
"--base-dir",
|
||||
"-d",
|
||||
help="Path to the directory where conversation history will be saved.",
|
||||
help=(
|
||||
"Path to the base directory in which conversation "
|
||||
"configuration and history will be saved."
|
||||
),
|
||||
),
|
||||
] = None,
|
||||
name: Annotated[
|
||||
|
@ -153,8 +156,8 @@ def chat(
|
|||
args |= {"model": model}
|
||||
if temperature is not None:
|
||||
args |= {"temperature": temperature}
|
||||
if history_dir is not None:
|
||||
args |= {"history_dir": history_dir}
|
||||
if base_dir is not None:
|
||||
args |= {"history_dir": base_dir}
|
||||
settings = OpenAISettings(**args)
|
||||
|
||||
context_messages = [load_context(path) for path in context]
|
||||
|
|
|
@ -15,7 +15,9 @@ class Model(StrEnum):
|
|||
|
||||
DEFAULT_MODEL = Model.GPT3
|
||||
DEFAULT_TEMPERATURE = 0.7
|
||||
DEFAULT_HISTORY_DIR = Path.home() / ".llm_chat" / "history"
|
||||
DEFAULT_BASE_DIR = Path.home() / ".llm_chat"
|
||||
DEFAULT_BOT_PATH = "bots"
|
||||
DEFAULT_HISTORY_PATH = "history"
|
||||
|
||||
|
||||
class OpenAISettings(BaseSettings):
|
||||
|
@ -24,7 +26,9 @@ class OpenAISettings(BaseSettings):
|
|||
api_key: str = ""
|
||||
model: Model = DEFAULT_MODEL
|
||||
temperature: float = DEFAULT_TEMPERATURE
|
||||
history_dir: Path = DEFAULT_HISTORY_DIR
|
||||
base_dir: Path = DEFAULT_BASE_DIR
|
||||
bot_path: str = DEFAULT_BOT_PATH
|
||||
history_path: str = DEFAULT_HISTORY_PATH
|
||||
|
||||
model_config: SettingsConfigDict = SettingsConfigDict( # type: ignore[misc]
|
||||
env_file=".env",
|
||||
|
@ -34,9 +38,25 @@ class OpenAISettings(BaseSettings):
|
|||
use_enum_values=True,
|
||||
)
|
||||
|
||||
@field_validator("history_dir")
|
||||
def history_dir_must_exist(cls, history_dir: Path) -> Path:
|
||||
"""Ensure that the history directory exists."""
|
||||
if not history_dir.exists():
|
||||
history_dir.mkdir(parents=True)
|
||||
return history_dir
|
||||
@field_validator("base_dir")
|
||||
def base_dir_must_exist(cls, base_dir: Path) -> Path:
|
||||
"""Ensure that the base directory exists."""
|
||||
if not base_dir.exists():
|
||||
base_dir.mkdir(parents=True)
|
||||
return base_dir
|
||||
|
||||
@property
|
||||
def bot_dir(self) -> Path:
|
||||
"""Return bot directory Path object, creating if required."""
|
||||
path = self.base_dir / self.bot_path
|
||||
if not path.exists():
|
||||
path.mkdir()
|
||||
return path
|
||||
|
||||
@property
|
||||
def history_dir(self) -> Path:
|
||||
"""Return history directory Path object, creating if required."""
|
||||
path = self.base_dir / self.history_path
|
||||
if not path.exists():
|
||||
path.mkdir()
|
||||
return path
|
||||
|
|
|
@ -70,19 +70,21 @@ def test_load(tmp_path: Path) -> None:
|
|||
)
|
||||
|
||||
# Save the conversation to a file
|
||||
file_path = tmp_path / "conversation.json"
|
||||
history_dir = tmp_path / "history"
|
||||
history_dir.mkdir()
|
||||
file_path = history_dir / "conversation.json"
|
||||
with file_path.open("w") as f:
|
||||
f.write(conversation.model_dump_json())
|
||||
|
||||
# Load the conversation from the file
|
||||
loaded_chat = Chat.load(file_path, api_key="foo", history_dir=tmp_path)
|
||||
loaded_chat = Chat.load(file_path, api_key="foo", base_dir=tmp_path)
|
||||
|
||||
# Check that the loaded conversation matches the original conversation
|
||||
assert loaded_chat.settings.model == conversation.model
|
||||
assert loaded_chat.settings.temperature == conversation.temperature
|
||||
assert loaded_chat.conversation.messages == conversation.messages
|
||||
assert loaded_chat.settings.api_key == "foo"
|
||||
assert loaded_chat.settings.history_dir == tmp_path
|
||||
assert loaded_chat.settings.base_dir == tmp_path
|
||||
|
||||
# We don't want to load the tokens or cost from the previous session
|
||||
assert loaded_chat.conversation.completion_tokens == 0
|
||||
|
|
Loading…
Reference in New Issue