Load previous conversation

This commit is contained in:
Paul Harrison 2023-08-24 18:35:27 +01:00
parent 8c17d4165a
commit 0e59dcff22
4 changed files with 138 additions and 20 deletions

View File

@ -80,15 +80,41 @@ class Chat:
self,
settings: OpenAISettings | None = None,
context: list[Message] = [],
initial_system_messages: bool = True,
) -> None:
self._settings = settings
self.conversation = Conversation(
messages=INITIAL_SYSTEM_MESSAGES + context,
messages=INITIAL_SYSTEM_MESSAGES + context
if initial_system_messages
else context,
model=self.settings.model,
temperature=self.settings.temperature,
)
self._start_time = datetime.now(tz=ZoneInfo("UTC"))
@classmethod
def load(
cls, path: Path, api_key: str | None = None, history_dir: Path | None = None
) -> ChatProtocol:
"""Load a chat from a file."""
with path.open() as f:
conversation = Conversation.model_validate_json(f.read())
args = {
"model": conversation.model,
"temperature": conversation.temperature,
}
if api_key is not None:
args["api_key"] = api_key
if history_dir is not None:
args["history_dir"] = history_dir
settings = OpenAISettings(**args)
return cls(
settings=settings,
context=conversation.messages,
initial_system_messages=False,
)
@property
def settings(self) -> OpenAISettings:
"""Get OpenAI chat settings."""

View File

@ -6,7 +6,7 @@ from prompt_toolkit import PromptSession
from rich.console import Console
from rich.markdown import Markdown
from llm_chat.chat import ChatProtocol, get_chat
from llm_chat.chat import Chat, ChatProtocol, get_chat
from llm_chat.models import Message, Role
from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings
@ -60,7 +60,28 @@ def display_cost(console: Console, chat: ChatProtocol) -> None:
console.print(f"\n[bold green]Cost:[/bold green] ${chat.cost}\n")
@app.command()
def run_conversation(current_chat: ChatProtocol) -> None:
"""Run a conversation."""
console = get_console()
session = get_session()
finished = False
console.print(f"[bold green]Model:[/bold green] {current_chat.settings.model}")
console.print(f"[bold green]Temperature:[/bold green] {current_chat.settings.temperature}")
while not finished:
prompt = read_user_input(session)
if prompt.strip() == "/q":
finished = True
current_chat.save()
else:
response = current_chat.send_message(prompt.strip())
console.print(Markdown(response))
display_cost(console, current_chat)
@app.command("chat")
def chat(
api_key: Annotated[
Optional[str],
@ -102,7 +123,6 @@ def chat(
] = [],
) -> None:
"""Start a chat session."""
# TODO: Add option to load context from file.
# TODO: Add option to provide context string as an argument.
if api_key is not None:
settings = OpenAISettings(api_key=api_key, model=model, temperature=temperature)
@ -112,20 +132,43 @@ def chat(
context_messages = [load_context(path) for path in context]
current_chat = get_chat(settings=settings, context=context_messages)
console = get_console()
session = get_session()
finished = False
run_conversation(current_chat)
console.print(f"[bold green]Model:[/bold green] {settings.model}")
console.print(f"[bold green]Temperature:[/bold green] {settings.temperature}")
while not finished:
prompt = read_user_input(session)
if prompt.strip() == "/q":
finished = True
current_chat.save()
@app.command("load")
def load(
path: Annotated[
Path,
typer.Argument(
help="Path to a conversation file.",
exists=True,
file_okay=True,
dir_okay=False,
readable=True,
),
],
api_key: Annotated[
Optional[str],
typer.Option(
...,
"--api-key",
"-k",
help=(
"API key. Will read from the environment variable OPENAI_API_KEY "
"if not provided."
),
),
] = None,
) -> None:
"""Load a conversation from a file."""
if api_key is not None:
current_chat = Chat.load(path, api_key=api_key)
else:
response = current_chat.send_message(prompt.strip())
console.print(Markdown(response))
display_cost(console, current_chat)
current_chat = Chat.load(path)
run_conversation(current_chat)
if __name__ == "__main__":
app()

View File

@ -46,6 +46,42 @@ def test_save_conversation(tmp_path: Path) -> None:
assert conversation == conversation_from_file
def test_load(tmp_path: Path) -> None:
# Create a conversation object to save
conversation = Conversation(
messages=[
Message(role=Role.SYSTEM, content="Hello!"),
Message(role=Role.USER, content="Hi!"),
Message(role=Role.ASSISTANT, content="How are you?"),
],
model=Model.GPT3,
temperature=0.5,
completion_tokens=10,
prompt_tokens=15,
cost=0.000043,
)
# Save the conversation to a file
file_path = tmp_path / "conversation.json"
with file_path.open("w") as f:
f.write(conversation.model_dump_json())
# Load the conversation from the file
loaded_chat = Chat.load(file_path, api_key="foo", history_dir=tmp_path)
# Check that the loaded conversation matches the original conversation
assert loaded_chat.settings.model == conversation.model
assert loaded_chat.settings.temperature == conversation.temperature
assert loaded_chat.conversation.messages == conversation.messages
assert loaded_chat.settings.api_key == "foo"
assert loaded_chat.settings.history_dir == tmp_path
# We don't want to load the tokens or cost from the previous session
assert loaded_chat.conversation.completion_tokens == 0
assert loaded_chat.conversation.prompt_tokens == 0
assert loaded_chat.cost == 0
def test_send_message() -> None:
with patch("llm_chat.chat.Chat._make_request") as mock_make_request:
mock_make_request.return_value = {

View File

@ -12,6 +12,7 @@ import llm_chat
from llm_chat.chat import ChatProtocol
from llm_chat.cli import app
from llm_chat.models import Message, Role
from llm_chat.settings import OpenAISettings
runner = CliRunner()
@ -21,8 +22,13 @@ class ChatFake:
args: dict[str, Any]
received_messages: list[str]
settings: OpenAISettings
def __init__(self) -> None:
def __init__(self, settings: OpenAISettings | None = None) -> None:
if settings is not None:
self.settings = settings
else:
self.settings = OpenAISettings()
self.args = {}
self.received_messages = []
@ -34,6 +40,13 @@ class ChatFake:
"""Get the cost of the conversation."""
return 0.0
@classmethod
def load(
cls, path: Path, api_key: str | None = None, history_dir: Path | None = None
) -> ChatProtocol:
"""Load a chat from a file."""
return cls()
def save(self) -> None:
"""Dummy save method."""
pass
@ -61,7 +74,7 @@ def test_chat(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(llm_chat.cli, "get_console", mock_get_console)
monkeypatch.setattr(llm_chat.cli, "read_user_input", mock_read_user_input)
result = runner.invoke(app)
result = runner.invoke(app, ["chat"])
assert result.exit_code == 0
assert chat_fake.received_messages == ["Hello"]