Fix failing cost calculation test

When updating the model pricing in commit 50fa0cc5ae I forgot to update
the associated test. As well as fixing this test, this commit also
updates the calculation to use `math.floor` instead of `round` to round
to six decimal places. This is because the `round` function appeared to
round incorrectly. For example, when running the test, 0.0000275 was
rounded to 0.000028 instead of the expected 0.000028.
This commit is contained in:
Paul Harrison 2024-02-23 11:12:53 +00:00
parent aade152486
commit b859c8bb95
3 changed files with 16 additions and 8 deletions

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "llm-chat"
version = "1.1.2"
version = "1.1.3"
description = "A general CLI interface for large language models."
authors = ["Paul Harrison <paul@harrison.sh>"]
readme = "README.md"

View File

@ -1,3 +1,4 @@
import math
from datetime import datetime
from enum import StrEnum, auto
from pathlib import Path
@ -177,12 +178,17 @@ class Chat:
"""Calculate the cost of a request."""
self.conversation.completion_tokens += usage.completion_tokens
self.conversation.prompt_tokens += usage.prompt_tokens
self.conversation.cost = round(
(self.conversation.completion_tokens / 1000)
* self._pricing[self.settings.model][Token.COMPLETION]
+ (self.conversation.prompt_tokens / 1000)
* self._pricing[self.settings.model][Token.PROMPT],
6,
self.conversation.cost = (
math.floor(
1000000
* (
(self.conversation.completion_tokens / 1000)
* self._pricing[self.settings.model][Token.COMPLETION]
+ (self.conversation.prompt_tokens / 1000)
* self._pricing[self.settings.model][Token.PROMPT]
)
)
/ 1000000
)
def save(self) -> None:

View File

@ -116,7 +116,9 @@ def test_send_message() -> None:
assert response == "Hello!"
@pytest.mark.parametrize("model,cost", [(Model.GPT3, 0.000043), (Model.GPT4, 0.00105)])
@pytest.mark.parametrize(
"model,cost", [(Model.GPT3, round(0.000027, 6)), (Model.GPT4, 0.00105)]
)
def test_calculate_cost(model: Model, cost: float) -> None:
with patch("llm_chat.chat.Chat._make_request") as mock_make_request:
mock_make_request.return_value = ChatCompletion(