Display running cost in USD after each response

This commit is contained in:
Paul Harrison 2023-08-24 15:23:32 +01:00
parent 2cd623e960
commit 1670383fee
6 changed files with 83 additions and 10 deletions

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "llm-chat" name = "llm-chat"
version = "0.2.0" version = "0.3.0"
description = "A general CLI interface for large language models." description = "A general CLI interface for large language models."
authors = ["Paul Harrison <paul@harrison.sh>"] authors = ["Paul Harrison <paul@harrison.sh>"]
readme = "README.md" readme = "README.md"

View File

@ -1,10 +1,11 @@
from enum import StrEnum, auto
from typing import Any, Protocol from typing import Any, Protocol
from openai import ChatCompletion from openai import ChatCompletion
from openai.openai_object import OpenAIObject from openai.openai_object import OpenAIObject
from llm_chat.models import Conversation, Message, Role from llm_chat.models import Conversation, Message, Role
from llm_chat.settings import OpenAISettings from llm_chat.settings import Model, OpenAISettings
INITIAL_SYSTEM_MESSAGES = [ INITIAL_SYSTEM_MESSAGES = [
Message( Message(
@ -17,9 +18,20 @@ INITIAL_SYSTEM_MESSAGES = [
] ]
class Token(StrEnum):
"""Token type for the OpenAI chat."""
COMPLETION = auto()
PROMPT = auto()
class ChatProtocol(Protocol): class ChatProtocol(Protocol):
"""Protocol for chat classes.""" """Protocol for chat classes."""
@property
def cost(self) -> float:
"""Get the cost of the conversation."""
def send_message(self, message: str) -> str: def send_message(self, message: str) -> str:
"""Send a message to the assistant.""" """Send a message to the assistant."""
@ -33,6 +45,17 @@ class Chat:
context (optional): Context for the chat. Defaults to an empty list. context (optional): Context for the chat. Defaults to an empty list.
""" """
_pricing: dict[Model, dict[Token, float]] = {
Model.GPT3: {
Token.COMPLETION: 0.002,
Token.PROMPT: 0.0015,
},
Model.GPT4: {
Token.COMPLETION: 0.06,
Token.PROMPT: 0.03,
},
}
def __init__( def __init__(
self, settings: OpenAISettings | None = None, context: list[Message] = [] self, settings: OpenAISettings | None = None, context: list[Message] = []
) -> None: ) -> None:
@ -48,6 +71,11 @@ class Chat:
self._settings = OpenAISettings() self._settings = OpenAISettings()
return self._settings return self._settings
@property
def cost(self) -> float:
"""Get the cost of the conversation."""
return self.conversation.cost
def _make_request(self, message: str) -> dict[str, Any]: def _make_request(self, message: str) -> dict[str, Any]:
"""Send a request to the OpenAI API. """Send a request to the OpenAI API.
@ -66,15 +94,30 @@ class Chat:
out: dict[str, Any] = response.to_dict() out: dict[str, Any] = response.to_dict()
return out return out
def send_message(self, message: str) -> str: def _calculate_cost(self, usage: dict[str, int]) -> None:
"""Calculate the cost of a request."""
self.conversation.completion_tokens += usage["completion_tokens"]
self.conversation.prompt_tokens += usage["prompt_tokens"]
self.conversation.cost = round(
(self.conversation.completion_tokens / 1000)
* self._pricing[self.settings.model][Token.COMPLETION]
+ (self.conversation.prompt_tokens / 1000)
* self._pricing[self.settings.model][Token.PROMPT],
6,
)
def send_message(self, prompt: str) -> str:
"""Send a message to the assistant. """Send a message to the assistant.
TODO: Add error handling. TODO: Add error handling.
""" """
response = self._make_request(message) request_response = self._make_request(prompt)
message = response["choices"][0]["message"]["content"] self._calculate_cost(request_response["usage"])
self.conversation.messages.append(Message(role=Role.ASSISTANT, content=message)) response: str = request_response["choices"][0]["message"]["content"]
return message self.conversation.messages.append(
Message(role=Role.ASSISTANT, content=response)
)
return response
def get_chat( def get_chat(

View File

@ -6,7 +6,7 @@ from prompt_toolkit import PromptSession
from rich.console import Console from rich.console import Console
from rich.markdown import Markdown from rich.markdown import Markdown
from llm_chat.chat import get_chat from llm_chat.chat import ChatProtocol, get_chat
from llm_chat.models import Message, Role from llm_chat.models import Message, Role
from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings
@ -55,6 +55,11 @@ def load_context(path: Path) -> Message:
return Message(role=Role.SYSTEM, content=content) return Message(role=Role.SYSTEM, content=content)
def display_cost(console: Console, chat: ChatProtocol) -> None:
"""Display the cost of the conversation so far."""
console.print(f"\n[bold green]Cost:[/bold green] ${chat.cost}\n")
@app.command() @app.command()
def chat( def chat(
api_key: Annotated[ api_key: Annotated[
@ -107,7 +112,6 @@ def chat(
context_messages = [load_context(path) for path in context] context_messages = [load_context(path) for path in context]
current_chat = get_chat(settings=settings, context=context_messages) current_chat = get_chat(settings=settings, context=context_messages)
# current_chat = get_chat(settings=settings)
console = get_console() console = get_console()
session = get_session() session = get_session()
@ -123,3 +127,4 @@ def chat(
else: else:
response = current_chat.send_message(prompt.strip()) response = current_chat.send_message(prompt.strip())
console.print(Markdown(response)) console.print(Markdown(response))
display_cost(console, current_chat)

View File

@ -27,6 +27,9 @@ class Conversation(BaseModel):
"""Conversation in the chat.""" """Conversation in the chat."""
messages: list[Message] messages: list[Message]
completion_tokens: int = 0
prompt_tokens: int = 0
cost: float = 0.0
model_config: ConfigDict = ConfigDict( # type: ignore[misc] model_config: ConfigDict = ConfigDict( # type: ignore[misc]
frozen=False, frozen=False,

View File

@ -1,14 +1,31 @@
from unittest.mock import patch from unittest.mock import patch
import pytest
from llm_chat.chat import Chat from llm_chat.chat import Chat
from llm_chat.settings import Model, OpenAISettings
def test_send_message() -> None: def test_send_message() -> None:
with patch("llm_chat.chat.Chat._make_request") as mock_make_request: with patch("llm_chat.chat.Chat._make_request") as mock_make_request:
mock_make_request.return_value = { mock_make_request.return_value = {
"choices": [{"message": {"content": "Hello!"}}] "choices": [{"message": {"content": "Hello!"}}],
"usage": {"completion_tokens": 1, "prompt_tokens": 1},
} }
conversation = Chat() conversation = Chat()
response = conversation.send_message("Hello") response = conversation.send_message("Hello")
assert isinstance(response, str) assert isinstance(response, str)
assert response == "Hello!" assert response == "Hello!"
@pytest.mark.parametrize("model,cost", [(Model.GPT3, 0.000043), (Model.GPT4, 0.00105)])
def test_calculate_cost(model: Model, cost: float) -> None:
with patch("llm_chat.chat.Chat._make_request") as mock_make_request:
mock_make_request.return_value = {
"choices": [{"message": {"content": "Hello!"}}],
"usage": {"completion_tokens": 10, "prompt_tokens": 15},
}
settings = OpenAISettings(model=model)
conversation = Chat(settings=settings)
_ = conversation.send_message("Hello")
assert conversation.cost == cost

View File

@ -29,6 +29,11 @@ class ChatFake:
def _set_args(self, **kwargs: Any) -> None: def _set_args(self, **kwargs: Any) -> None:
self.args = kwargs self.args = kwargs
@property
def cost(self) -> float:
"""Get the cost of the conversation."""
return 0.0
def send_message(self, message: str) -> str: def send_message(self, message: str) -> str:
"""Echo the received message.""" """Echo the received message."""
self.received_messages.append(message) self.received_messages.append(message)