148 lines
4.8 KiB
Python
148 lines
4.8 KiB
Python
from datetime import datetime
|
|
from pathlib import Path
|
|
from unittest.mock import patch
|
|
from zoneinfo import ZoneInfo
|
|
|
|
import pytest
|
|
from openai.types.chat import ChatCompletion, ChatCompletionMessage
|
|
from openai.types.chat.chat_completion import Choice
|
|
from openai.types.completion_usage import CompletionUsage
|
|
|
|
from llm_chat.chat import Chat, save_conversation
|
|
from llm_chat.models import Conversation, Message, Role
|
|
from llm_chat.settings import Model, OpenAISettings
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"name,expected_filename",
|
|
[("", "20210101120000.json"), ("foo", "20210101120000-foo.json")],
|
|
)
|
|
def test_save_conversation(name: str, expected_filename: str, tmp_path: Path) -> None:
|
|
conversation = Conversation(
|
|
messages=[
|
|
Message(role=Role.SYSTEM, content="Hello!"),
|
|
Message(role=Role.USER, content="Hi!"),
|
|
Message(role=Role.ASSISTANT, content="How are you?"),
|
|
],
|
|
model=Model.GPT3,
|
|
temperature=0.5,
|
|
completion_tokens=10,
|
|
prompt_tokens=15,
|
|
cost=0.000043,
|
|
name=name,
|
|
)
|
|
|
|
path = tmp_path / ".history"
|
|
expected_file_path = path / expected_filename
|
|
dt = datetime(2021, 1, 1, 12, 0, 0, tzinfo=ZoneInfo("UTC"))
|
|
|
|
assert not path.exists()
|
|
|
|
save_conversation(
|
|
conversation=conversation,
|
|
history_dir=path,
|
|
dt=dt,
|
|
)
|
|
|
|
assert path.exists()
|
|
assert path.is_dir()
|
|
assert expected_file_path in path.iterdir()
|
|
|
|
with expected_file_path.open() as f:
|
|
conversation_from_file = Conversation.model_validate_json(f.read())
|
|
|
|
assert conversation == conversation_from_file
|
|
|
|
|
|
def test_load(tmp_path: Path) -> None:
|
|
# Create a conversation object to save
|
|
conversation = Conversation(
|
|
messages=[
|
|
Message(role=Role.SYSTEM, content="Hello!"),
|
|
Message(role=Role.USER, content="Hi!"),
|
|
Message(role=Role.ASSISTANT, content="How are you?"),
|
|
],
|
|
model=Model.GPT3,
|
|
temperature=0.5,
|
|
completion_tokens=10,
|
|
prompt_tokens=15,
|
|
cost=0.000043,
|
|
)
|
|
|
|
# Save the conversation to a file
|
|
history_dir = tmp_path / "history"
|
|
history_dir.mkdir()
|
|
file_path = history_dir / "conversation.json"
|
|
with file_path.open("w") as f:
|
|
f.write(conversation.model_dump_json())
|
|
|
|
# Load the conversation from the file
|
|
loaded_chat = Chat.load(file_path, api_key="foo", base_dir=tmp_path)
|
|
|
|
# Check that the loaded conversation matches the original conversation
|
|
assert loaded_chat.settings.model == conversation.model
|
|
assert loaded_chat.settings.temperature == conversation.temperature
|
|
assert loaded_chat.conversation.messages == conversation.messages
|
|
assert loaded_chat.settings.api_key == "foo"
|
|
assert loaded_chat.settings.base_dir == tmp_path
|
|
|
|
# We don't want to load the tokens or cost from the previous session
|
|
assert loaded_chat.conversation.completion_tokens == 0
|
|
assert loaded_chat.conversation.prompt_tokens == 0
|
|
assert loaded_chat.cost == 0
|
|
|
|
|
|
def test_send_message() -> None:
|
|
with patch("llm_chat.chat.Chat._make_request") as mock_make_request:
|
|
mock_make_request.return_value = ChatCompletion(
|
|
choices=[
|
|
Choice(
|
|
message=ChatCompletionMessage(content="Hello!", role="assistant"),
|
|
finish_reason="stop",
|
|
index=0,
|
|
),
|
|
],
|
|
id="foo",
|
|
created=0,
|
|
model="gpt-3.5-turbo-0613",
|
|
object="chat.completion",
|
|
usage=CompletionUsage(
|
|
completion_tokens=1,
|
|
prompt_tokens=1,
|
|
total_tokens=2,
|
|
),
|
|
)
|
|
conversation = Chat()
|
|
response = conversation.send_message("Hello")
|
|
assert isinstance(response, str)
|
|
assert response == "Hello!"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"model,cost", [(Model.GPT3, round(0.000027, 6)), (Model.GPT4, 0.00105)]
|
|
)
|
|
def test_calculate_cost(model: Model, cost: float) -> None:
|
|
with patch("llm_chat.chat.Chat._make_request") as mock_make_request:
|
|
mock_make_request.return_value = ChatCompletion(
|
|
choices=[
|
|
Choice(
|
|
message=ChatCompletionMessage(content="Hello!", role="assistant"),
|
|
finish_reason="stop",
|
|
index=0,
|
|
),
|
|
],
|
|
id="foo",
|
|
created=0,
|
|
model="gpt-3.5-turbo-0613",
|
|
object="chat.completion",
|
|
usage=CompletionUsage(
|
|
completion_tokens=10,
|
|
prompt_tokens=15,
|
|
total_tokens=25,
|
|
),
|
|
)
|
|
settings = OpenAISettings(model=model)
|
|
conversation = Chat(settings=settings)
|
|
_ = conversation.send_message("Hello")
|
|
assert conversation.cost == cost
|