Generalise base path for config and chat history
This commit is contained in:
parent
76bee1aed9
commit
2868c78184
|
@ -119,7 +119,7 @@ class Chat:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
cls, path: Path, api_key: str | None = None, history_dir: Path | None = None
|
cls, path: Path, api_key: str | None = None, base_dir: Path | None = None
|
||||||
) -> ChatProtocol:
|
) -> ChatProtocol:
|
||||||
"""Load a chat from a file."""
|
"""Load a chat from a file."""
|
||||||
with path.open() as f:
|
with path.open() as f:
|
||||||
|
@ -130,8 +130,8 @@ class Chat:
|
||||||
}
|
}
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
args["api_key"] = api_key
|
args["api_key"] = api_key
|
||||||
if history_dir is not None:
|
if base_dir is not None:
|
||||||
args["history_dir"] = history_dir
|
args["base_dir"] = base_dir
|
||||||
|
|
||||||
settings = OpenAISettings(**args)
|
settings = OpenAISettings(**args)
|
||||||
return cls(
|
return cls(
|
||||||
|
|
|
@ -125,13 +125,16 @@ def chat(
|
||||||
readable=True,
|
readable=True,
|
||||||
),
|
),
|
||||||
] = [],
|
] = [],
|
||||||
history_dir: Annotated[
|
base_dir: Annotated[
|
||||||
Optional[Path],
|
Optional[Path],
|
||||||
typer.Option(
|
typer.Option(
|
||||||
...,
|
...,
|
||||||
"--history-dir",
|
"--base-dir",
|
||||||
"-d",
|
"-d",
|
||||||
help="Path to the directory where conversation history will be saved.",
|
help=(
|
||||||
|
"Path to the base directory in which conversation "
|
||||||
|
"configuration and history will be saved."
|
||||||
|
),
|
||||||
),
|
),
|
||||||
] = None,
|
] = None,
|
||||||
name: Annotated[
|
name: Annotated[
|
||||||
|
@ -153,8 +156,8 @@ def chat(
|
||||||
args |= {"model": model}
|
args |= {"model": model}
|
||||||
if temperature is not None:
|
if temperature is not None:
|
||||||
args |= {"temperature": temperature}
|
args |= {"temperature": temperature}
|
||||||
if history_dir is not None:
|
if base_dir is not None:
|
||||||
args |= {"history_dir": history_dir}
|
args |= {"history_dir": base_dir}
|
||||||
settings = OpenAISettings(**args)
|
settings = OpenAISettings(**args)
|
||||||
|
|
||||||
context_messages = [load_context(path) for path in context]
|
context_messages = [load_context(path) for path in context]
|
||||||
|
|
|
@ -15,7 +15,9 @@ class Model(StrEnum):
|
||||||
|
|
||||||
DEFAULT_MODEL = Model.GPT3
|
DEFAULT_MODEL = Model.GPT3
|
||||||
DEFAULT_TEMPERATURE = 0.7
|
DEFAULT_TEMPERATURE = 0.7
|
||||||
DEFAULT_HISTORY_DIR = Path.home() / ".llm_chat" / "history"
|
DEFAULT_BASE_DIR = Path.home() / ".llm_chat"
|
||||||
|
DEFAULT_BOT_PATH = "bots"
|
||||||
|
DEFAULT_HISTORY_PATH = "history"
|
||||||
|
|
||||||
|
|
||||||
class OpenAISettings(BaseSettings):
|
class OpenAISettings(BaseSettings):
|
||||||
|
@ -24,7 +26,9 @@ class OpenAISettings(BaseSettings):
|
||||||
api_key: str = ""
|
api_key: str = ""
|
||||||
model: Model = DEFAULT_MODEL
|
model: Model = DEFAULT_MODEL
|
||||||
temperature: float = DEFAULT_TEMPERATURE
|
temperature: float = DEFAULT_TEMPERATURE
|
||||||
history_dir: Path = DEFAULT_HISTORY_DIR
|
base_dir: Path = DEFAULT_BASE_DIR
|
||||||
|
bot_path: str = DEFAULT_BOT_PATH
|
||||||
|
history_path: str = DEFAULT_HISTORY_PATH
|
||||||
|
|
||||||
model_config: SettingsConfigDict = SettingsConfigDict( # type: ignore[misc]
|
model_config: SettingsConfigDict = SettingsConfigDict( # type: ignore[misc]
|
||||||
env_file=".env",
|
env_file=".env",
|
||||||
|
@ -34,9 +38,25 @@ class OpenAISettings(BaseSettings):
|
||||||
use_enum_values=True,
|
use_enum_values=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@field_validator("history_dir")
|
@field_validator("base_dir")
|
||||||
def history_dir_must_exist(cls, history_dir: Path) -> Path:
|
def base_dir_must_exist(cls, base_dir: Path) -> Path:
|
||||||
"""Ensure that the history directory exists."""
|
"""Ensure that the base directory exists."""
|
||||||
if not history_dir.exists():
|
if not base_dir.exists():
|
||||||
history_dir.mkdir(parents=True)
|
base_dir.mkdir(parents=True)
|
||||||
return history_dir
|
return base_dir
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bot_dir(self) -> Path:
|
||||||
|
"""Return bot directory Path object, creating if required."""
|
||||||
|
path = self.base_dir / self.bot_path
|
||||||
|
if not path.exists():
|
||||||
|
path.mkdir()
|
||||||
|
return path
|
||||||
|
|
||||||
|
@property
|
||||||
|
def history_dir(self) -> Path:
|
||||||
|
"""Return history directory Path object, creating if required."""
|
||||||
|
path = self.base_dir / self.history_path
|
||||||
|
if not path.exists():
|
||||||
|
path.mkdir()
|
||||||
|
return path
|
||||||
|
|
|
@ -70,19 +70,21 @@ def test_load(tmp_path: Path) -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save the conversation to a file
|
# Save the conversation to a file
|
||||||
file_path = tmp_path / "conversation.json"
|
history_dir = tmp_path / "history"
|
||||||
|
history_dir.mkdir()
|
||||||
|
file_path = history_dir / "conversation.json"
|
||||||
with file_path.open("w") as f:
|
with file_path.open("w") as f:
|
||||||
f.write(conversation.model_dump_json())
|
f.write(conversation.model_dump_json())
|
||||||
|
|
||||||
# Load the conversation from the file
|
# Load the conversation from the file
|
||||||
loaded_chat = Chat.load(file_path, api_key="foo", history_dir=tmp_path)
|
loaded_chat = Chat.load(file_path, api_key="foo", base_dir=tmp_path)
|
||||||
|
|
||||||
# Check that the loaded conversation matches the original conversation
|
# Check that the loaded conversation matches the original conversation
|
||||||
assert loaded_chat.settings.model == conversation.model
|
assert loaded_chat.settings.model == conversation.model
|
||||||
assert loaded_chat.settings.temperature == conversation.temperature
|
assert loaded_chat.settings.temperature == conversation.temperature
|
||||||
assert loaded_chat.conversation.messages == conversation.messages
|
assert loaded_chat.conversation.messages == conversation.messages
|
||||||
assert loaded_chat.settings.api_key == "foo"
|
assert loaded_chat.settings.api_key == "foo"
|
||||||
assert loaded_chat.settings.history_dir == tmp_path
|
assert loaded_chat.settings.base_dir == tmp_path
|
||||||
|
|
||||||
# We don't want to load the tokens or cost from the previous session
|
# We don't want to load the tokens or cost from the previous session
|
||||||
assert loaded_chat.conversation.completion_tokens == 0
|
assert loaded_chat.conversation.completion_tokens == 0
|
||||||
|
|
Loading…
Reference in New Issue