Display running cost in USD after each response
This commit is contained in:
parent
2cd623e960
commit
1670383fee
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "llm-chat"
|
||||
version = "0.2.0"
|
||||
version = "0.3.0"
|
||||
description = "A general CLI interface for large language models."
|
||||
authors = ["Paul Harrison <paul@harrison.sh>"]
|
||||
readme = "README.md"
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
from enum import StrEnum, auto
|
||||
from typing import Any, Protocol
|
||||
|
||||
from openai import ChatCompletion
|
||||
from openai.openai_object import OpenAIObject
|
||||
|
||||
from llm_chat.models import Conversation, Message, Role
|
||||
from llm_chat.settings import OpenAISettings
|
||||
from llm_chat.settings import Model, OpenAISettings
|
||||
|
||||
INITIAL_SYSTEM_MESSAGES = [
|
||||
Message(
|
||||
|
@ -17,9 +18,20 @@ INITIAL_SYSTEM_MESSAGES = [
|
|||
]
|
||||
|
||||
|
||||
class Token(StrEnum):
|
||||
"""Token type for the OpenAI chat."""
|
||||
|
||||
COMPLETION = auto()
|
||||
PROMPT = auto()
|
||||
|
||||
|
||||
class ChatProtocol(Protocol):
|
||||
"""Protocol for chat classes."""
|
||||
|
||||
@property
|
||||
def cost(self) -> float:
|
||||
"""Get the cost of the conversation."""
|
||||
|
||||
def send_message(self, message: str) -> str:
|
||||
"""Send a message to the assistant."""
|
||||
|
||||
|
@ -33,6 +45,17 @@ class Chat:
|
|||
context (optional): Context for the chat. Defaults to an empty list.
|
||||
"""
|
||||
|
||||
_pricing: dict[Model, dict[Token, float]] = {
|
||||
Model.GPT3: {
|
||||
Token.COMPLETION: 0.002,
|
||||
Token.PROMPT: 0.0015,
|
||||
},
|
||||
Model.GPT4: {
|
||||
Token.COMPLETION: 0.06,
|
||||
Token.PROMPT: 0.03,
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, settings: OpenAISettings | None = None, context: list[Message] = []
|
||||
) -> None:
|
||||
|
@ -48,6 +71,11 @@ class Chat:
|
|||
self._settings = OpenAISettings()
|
||||
return self._settings
|
||||
|
||||
@property
|
||||
def cost(self) -> float:
|
||||
"""Get the cost of the conversation."""
|
||||
return self.conversation.cost
|
||||
|
||||
def _make_request(self, message: str) -> dict[str, Any]:
|
||||
"""Send a request to the OpenAI API.
|
||||
|
||||
|
@ -66,15 +94,30 @@ class Chat:
|
|||
out: dict[str, Any] = response.to_dict()
|
||||
return out
|
||||
|
||||
def send_message(self, message: str) -> str:
|
||||
def _calculate_cost(self, usage: dict[str, int]) -> None:
|
||||
"""Calculate the cost of a request."""
|
||||
self.conversation.completion_tokens += usage["completion_tokens"]
|
||||
self.conversation.prompt_tokens += usage["prompt_tokens"]
|
||||
self.conversation.cost = round(
|
||||
(self.conversation.completion_tokens / 1000)
|
||||
* self._pricing[self.settings.model][Token.COMPLETION]
|
||||
+ (self.conversation.prompt_tokens / 1000)
|
||||
* self._pricing[self.settings.model][Token.PROMPT],
|
||||
6,
|
||||
)
|
||||
|
||||
def send_message(self, prompt: str) -> str:
|
||||
"""Send a message to the assistant.
|
||||
|
||||
TODO: Add error handling.
|
||||
"""
|
||||
response = self._make_request(message)
|
||||
message = response["choices"][0]["message"]["content"]
|
||||
self.conversation.messages.append(Message(role=Role.ASSISTANT, content=message))
|
||||
return message
|
||||
request_response = self._make_request(prompt)
|
||||
self._calculate_cost(request_response["usage"])
|
||||
response: str = request_response["choices"][0]["message"]["content"]
|
||||
self.conversation.messages.append(
|
||||
Message(role=Role.ASSISTANT, content=response)
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
def get_chat(
|
||||
|
|
|
@ -6,7 +6,7 @@ from prompt_toolkit import PromptSession
|
|||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
|
||||
from llm_chat.chat import get_chat
|
||||
from llm_chat.chat import ChatProtocol, get_chat
|
||||
from llm_chat.models import Message, Role
|
||||
from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings
|
||||
|
||||
|
@ -55,6 +55,11 @@ def load_context(path: Path) -> Message:
|
|||
return Message(role=Role.SYSTEM, content=content)
|
||||
|
||||
|
||||
def display_cost(console: Console, chat: ChatProtocol) -> None:
|
||||
"""Display the cost of the conversation so far."""
|
||||
console.print(f"\n[bold green]Cost:[/bold green] ${chat.cost}\n")
|
||||
|
||||
|
||||
@app.command()
|
||||
def chat(
|
||||
api_key: Annotated[
|
||||
|
@ -107,7 +112,6 @@ def chat(
|
|||
context_messages = [load_context(path) for path in context]
|
||||
|
||||
current_chat = get_chat(settings=settings, context=context_messages)
|
||||
# current_chat = get_chat(settings=settings)
|
||||
console = get_console()
|
||||
session = get_session()
|
||||
|
||||
|
@ -123,3 +127,4 @@ def chat(
|
|||
else:
|
||||
response = current_chat.send_message(prompt.strip())
|
||||
console.print(Markdown(response))
|
||||
display_cost(console, current_chat)
|
||||
|
|
|
@ -27,6 +27,9 @@ class Conversation(BaseModel):
|
|||
"""Conversation in the chat."""
|
||||
|
||||
messages: list[Message]
|
||||
completion_tokens: int = 0
|
||||
prompt_tokens: int = 0
|
||||
cost: float = 0.0
|
||||
|
||||
model_config: ConfigDict = ConfigDict( # type: ignore[misc]
|
||||
frozen=False,
|
||||
|
|
|
@ -1,14 +1,31 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llm_chat.chat import Chat
|
||||
from llm_chat.settings import Model, OpenAISettings
|
||||
|
||||
|
||||
def test_send_message() -> None:
|
||||
with patch("llm_chat.chat.Chat._make_request") as mock_make_request:
|
||||
mock_make_request.return_value = {
|
||||
"choices": [{"message": {"content": "Hello!"}}]
|
||||
"choices": [{"message": {"content": "Hello!"}}],
|
||||
"usage": {"completion_tokens": 1, "prompt_tokens": 1},
|
||||
}
|
||||
conversation = Chat()
|
||||
response = conversation.send_message("Hello")
|
||||
assert isinstance(response, str)
|
||||
assert response == "Hello!"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model,cost", [(Model.GPT3, 0.000043), (Model.GPT4, 0.00105)])
|
||||
def test_calculate_cost(model: Model, cost: float) -> None:
|
||||
with patch("llm_chat.chat.Chat._make_request") as mock_make_request:
|
||||
mock_make_request.return_value = {
|
||||
"choices": [{"message": {"content": "Hello!"}}],
|
||||
"usage": {"completion_tokens": 10, "prompt_tokens": 15},
|
||||
}
|
||||
settings = OpenAISettings(model=model)
|
||||
conversation = Chat(settings=settings)
|
||||
_ = conversation.send_message("Hello")
|
||||
assert conversation.cost == cost
|
||||
|
|
|
@ -29,6 +29,11 @@ class ChatFake:
|
|||
def _set_args(self, **kwargs: Any) -> None:
|
||||
self.args = kwargs
|
||||
|
||||
@property
|
||||
def cost(self) -> float:
|
||||
"""Get the cost of the conversation."""
|
||||
return 0.0
|
||||
|
||||
def send_message(self, message: str) -> str:
|
||||
"""Echo the received message."""
|
||||
self.received_messages.append(message)
|
||||
|
|
Loading…
Reference in New Issue