llm-chat/tests/test_chat.py

165 lines
5.2 KiB
Python

from datetime import datetime
from pathlib import Path
from unittest.mock import patch
from zoneinfo import ZoneInfo
import pytest
from openai.types.chat import ChatCompletion, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice
from openai.types.completion_usage import CompletionUsage
from llm_chat.bot import Bot
from llm_chat.chat import Chat, save_conversation
from llm_chat.models import Conversation, Message, Role
from llm_chat.settings import Model, OpenAISettings
@pytest.mark.parametrize(
"name,expected_filename",
[("", "20210101120000.json"), ("foo", "20210101120000-foo.json")],
)
def test_save_conversation(name: str, expected_filename: str, tmp_path: Path) -> None:
conversation = Conversation(
messages=[
Message(role=Role.SYSTEM, content="Hello!"),
Message(role=Role.USER, content="Hi!"),
Message(role=Role.ASSISTANT, content="How are you?"),
],
model=Model.GPT3,
temperature=0.5,
completion_tokens=10,
prompt_tokens=15,
cost=0.000043,
name=name,
)
path = tmp_path / ".history"
expected_file_path = path / expected_filename
dt = datetime(2021, 1, 1, 12, 0, 0, tzinfo=ZoneInfo("UTC"))
assert not path.exists()
save_conversation(
conversation=conversation,
history_dir=path,
dt=dt,
)
assert path.exists()
assert path.is_dir()
assert expected_file_path in path.iterdir()
with expected_file_path.open() as f:
conversation_from_file = Conversation.model_validate_json(f.read())
assert conversation == conversation_from_file
def test_load(tmp_path: Path) -> None:
# Create a conversation object to save
conversation = Conversation(
messages=[
Message(role=Role.SYSTEM, content="Hello!"),
Message(role=Role.USER, content="Hi!"),
Message(role=Role.ASSISTANT, content="How are you?"),
],
model=Model.GPT3,
temperature=0.5,
completion_tokens=10,
prompt_tokens=15,
cost=0.000043,
)
# Save the conversation to a file
history_dir = tmp_path / "history"
history_dir.mkdir()
file_path = history_dir / "conversation.json"
with file_path.open("w") as f:
f.write(conversation.model_dump_json())
# Load the conversation from the file
loaded_chat = Chat.load(file_path, api_key="foo", base_dir=tmp_path)
# Check that the loaded conversation matches the original conversation
assert loaded_chat.settings.model == conversation.model
assert loaded_chat.settings.temperature == conversation.temperature
assert loaded_chat.conversation.messages == conversation.messages
assert loaded_chat.settings.api_key == "foo"
assert loaded_chat.settings.base_dir == tmp_path
# We don't want to load the tokens or cost from the previous session
assert loaded_chat.conversation.completion_tokens == 0
assert loaded_chat.conversation.prompt_tokens == 0
assert loaded_chat.cost == 0
def test_send_message() -> None:
with patch("llm_chat.chat.Chat._make_request") as mock_make_request:
mock_make_request.return_value = ChatCompletion(
choices=[
Choice(
message=ChatCompletionMessage(content="Hello!", role="assistant"),
finish_reason="stop",
index=0,
),
],
id="foo",
created=0,
model="gpt-3.5-turbo-0613",
object="chat.completion",
usage=CompletionUsage(
completion_tokens=1,
prompt_tokens=1,
total_tokens=2,
),
)
conversation = Chat()
response = conversation.send_message("Hello")
assert isinstance(response, str)
assert response == "Hello!"
@pytest.mark.parametrize(
"model,cost", [(Model.GPT3, round(0.000027, 6)), (Model.GPT4, 0.00105)]
)
def test_calculate_cost(model: Model, cost: float) -> None:
with patch("llm_chat.chat.Chat._make_request") as mock_make_request:
mock_make_request.return_value = ChatCompletion(
choices=[
Choice(
message=ChatCompletionMessage(content="Hello!", role="assistant"),
finish_reason="stop",
index=0,
),
],
id="foo",
created=0,
model="gpt-3.5-turbo-0613",
object="chat.completion",
usage=CompletionUsage(
completion_tokens=10,
prompt_tokens=15,
total_tokens=25,
),
)
settings = OpenAISettings(model=model)
conversation = Chat(settings=settings)
_ = conversation.send_message("Hello")
assert conversation.cost == cost
def test_chat_with_bot(tmp_path: Path) -> None:
settings = OpenAISettings()
bot_name = "Test Bot"
context = "Hello, world!"
with (tmp_path / "context.md").open("w") as f:
f.write(context)
Bot.create(
name=bot_name, bot_dir=settings.bot_dir, context_files=[tmp_path / "context.md"]
)
chat = Chat(settings=settings, bot=bot_name)
assert chat.conversation.messages[-1].content == context