From b859c8bb953952b788bc72531da75616d0e24aff Mon Sep 17 00:00:00 2001 From: Paul Harrison Date: Fri, 23 Feb 2024 11:12:53 +0000 Subject: [PATCH] Fix failing cost calculation test When updating the model pricing in commit 50fa0cc5ae I forgot to update the associated test. As well as fixing this test, this commit also updates the calculation to use `math.floor` instead of `round` to round to six decimal places. This is because the `round` function appeared to round incorrectly. For example, when running the test, 0.0000275 was rounded to 0.000028 instead of the expected 0.000028. --- pyproject.toml | 2 +- src/llm_chat/chat.py | 18 ++++++++++++------ tests/test_chat.py | 4 +++- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6f4146a..03769be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "llm-chat" -version = "1.1.2" +version = "1.1.3" description = "A general CLI interface for large language models." authors = ["Paul Harrison "] readme = "README.md" diff --git a/src/llm_chat/chat.py b/src/llm_chat/chat.py index 5247b90..8cc2cc3 100644 --- a/src/llm_chat/chat.py +++ b/src/llm_chat/chat.py @@ -1,3 +1,4 @@ +import math from datetime import datetime from enum import StrEnum, auto from pathlib import Path @@ -177,12 +178,17 @@ class Chat: """Calculate the cost of a request.""" self.conversation.completion_tokens += usage.completion_tokens self.conversation.prompt_tokens += usage.prompt_tokens - self.conversation.cost = round( - (self.conversation.completion_tokens / 1000) - * self._pricing[self.settings.model][Token.COMPLETION] - + (self.conversation.prompt_tokens / 1000) - * self._pricing[self.settings.model][Token.PROMPT], - 6, + self.conversation.cost = ( + math.floor( + 1000000 + * ( + (self.conversation.completion_tokens / 1000) + * self._pricing[self.settings.model][Token.COMPLETION] + + (self.conversation.prompt_tokens / 1000) + * self._pricing[self.settings.model][Token.PROMPT] + ) + ) + / 1000000 ) def save(self) -> None: diff --git a/tests/test_chat.py b/tests/test_chat.py index 2424134..ff17fb7 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -116,7 +116,9 @@ def test_send_message() -> None: assert response == "Hello!" -@pytest.mark.parametrize("model,cost", [(Model.GPT3, 0.000043), (Model.GPT4, 0.00105)]) +@pytest.mark.parametrize( + "model,cost", [(Model.GPT3, round(0.000027, 6)), (Model.GPT4, 0.00105)] +) def test_calculate_cost(model: Model, cost: float) -> None: with patch("llm_chat.chat.Chat._make_request") as mock_make_request: mock_make_request.return_value = ChatCompletion(