Display running cost in USD after each response
This commit is contained in:
parent
2cd623e960
commit
1670383fee
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "llm-chat"
|
name = "llm-chat"
|
||||||
version = "0.2.0"
|
version = "0.3.0"
|
||||||
description = "A general CLI interface for large language models."
|
description = "A general CLI interface for large language models."
|
||||||
authors = ["Paul Harrison <paul@harrison.sh>"]
|
authors = ["Paul Harrison <paul@harrison.sh>"]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
|
from enum import StrEnum, auto
|
||||||
from typing import Any, Protocol
|
from typing import Any, Protocol
|
||||||
|
|
||||||
from openai import ChatCompletion
|
from openai import ChatCompletion
|
||||||
from openai.openai_object import OpenAIObject
|
from openai.openai_object import OpenAIObject
|
||||||
|
|
||||||
from llm_chat.models import Conversation, Message, Role
|
from llm_chat.models import Conversation, Message, Role
|
||||||
from llm_chat.settings import OpenAISettings
|
from llm_chat.settings import Model, OpenAISettings
|
||||||
|
|
||||||
INITIAL_SYSTEM_MESSAGES = [
|
INITIAL_SYSTEM_MESSAGES = [
|
||||||
Message(
|
Message(
|
||||||
|
@ -17,9 +18,20 @@ INITIAL_SYSTEM_MESSAGES = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Token(StrEnum):
|
||||||
|
"""Token type for the OpenAI chat."""
|
||||||
|
|
||||||
|
COMPLETION = auto()
|
||||||
|
PROMPT = auto()
|
||||||
|
|
||||||
|
|
||||||
class ChatProtocol(Protocol):
|
class ChatProtocol(Protocol):
|
||||||
"""Protocol for chat classes."""
|
"""Protocol for chat classes."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cost(self) -> float:
|
||||||
|
"""Get the cost of the conversation."""
|
||||||
|
|
||||||
def send_message(self, message: str) -> str:
|
def send_message(self, message: str) -> str:
|
||||||
"""Send a message to the assistant."""
|
"""Send a message to the assistant."""
|
||||||
|
|
||||||
|
@ -33,6 +45,17 @@ class Chat:
|
||||||
context (optional): Context for the chat. Defaults to an empty list.
|
context (optional): Context for the chat. Defaults to an empty list.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_pricing: dict[Model, dict[Token, float]] = {
|
||||||
|
Model.GPT3: {
|
||||||
|
Token.COMPLETION: 0.002,
|
||||||
|
Token.PROMPT: 0.0015,
|
||||||
|
},
|
||||||
|
Model.GPT4: {
|
||||||
|
Token.COMPLETION: 0.06,
|
||||||
|
Token.PROMPT: 0.03,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, settings: OpenAISettings | None = None, context: list[Message] = []
|
self, settings: OpenAISettings | None = None, context: list[Message] = []
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -48,6 +71,11 @@ class Chat:
|
||||||
self._settings = OpenAISettings()
|
self._settings = OpenAISettings()
|
||||||
return self._settings
|
return self._settings
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cost(self) -> float:
|
||||||
|
"""Get the cost of the conversation."""
|
||||||
|
return self.conversation.cost
|
||||||
|
|
||||||
def _make_request(self, message: str) -> dict[str, Any]:
|
def _make_request(self, message: str) -> dict[str, Any]:
|
||||||
"""Send a request to the OpenAI API.
|
"""Send a request to the OpenAI API.
|
||||||
|
|
||||||
|
@ -66,15 +94,30 @@ class Chat:
|
||||||
out: dict[str, Any] = response.to_dict()
|
out: dict[str, Any] = response.to_dict()
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def send_message(self, message: str) -> str:
|
def _calculate_cost(self, usage: dict[str, int]) -> None:
|
||||||
|
"""Calculate the cost of a request."""
|
||||||
|
self.conversation.completion_tokens += usage["completion_tokens"]
|
||||||
|
self.conversation.prompt_tokens += usage["prompt_tokens"]
|
||||||
|
self.conversation.cost = round(
|
||||||
|
(self.conversation.completion_tokens / 1000)
|
||||||
|
* self._pricing[self.settings.model][Token.COMPLETION]
|
||||||
|
+ (self.conversation.prompt_tokens / 1000)
|
||||||
|
* self._pricing[self.settings.model][Token.PROMPT],
|
||||||
|
6,
|
||||||
|
)
|
||||||
|
|
||||||
|
def send_message(self, prompt: str) -> str:
|
||||||
"""Send a message to the assistant.
|
"""Send a message to the assistant.
|
||||||
|
|
||||||
TODO: Add error handling.
|
TODO: Add error handling.
|
||||||
"""
|
"""
|
||||||
response = self._make_request(message)
|
request_response = self._make_request(prompt)
|
||||||
message = response["choices"][0]["message"]["content"]
|
self._calculate_cost(request_response["usage"])
|
||||||
self.conversation.messages.append(Message(role=Role.ASSISTANT, content=message))
|
response: str = request_response["choices"][0]["message"]["content"]
|
||||||
return message
|
self.conversation.messages.append(
|
||||||
|
Message(role=Role.ASSISTANT, content=response)
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
def get_chat(
|
def get_chat(
|
||||||
|
|
|
@ -6,7 +6,7 @@ from prompt_toolkit import PromptSession
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
|
|
||||||
from llm_chat.chat import get_chat
|
from llm_chat.chat import ChatProtocol, get_chat
|
||||||
from llm_chat.models import Message, Role
|
from llm_chat.models import Message, Role
|
||||||
from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings
|
from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings
|
||||||
|
|
||||||
|
@ -55,6 +55,11 @@ def load_context(path: Path) -> Message:
|
||||||
return Message(role=Role.SYSTEM, content=content)
|
return Message(role=Role.SYSTEM, content=content)
|
||||||
|
|
||||||
|
|
||||||
|
def display_cost(console: Console, chat: ChatProtocol) -> None:
|
||||||
|
"""Display the cost of the conversation so far."""
|
||||||
|
console.print(f"\n[bold green]Cost:[/bold green] ${chat.cost}\n")
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def chat(
|
def chat(
|
||||||
api_key: Annotated[
|
api_key: Annotated[
|
||||||
|
@ -107,7 +112,6 @@ def chat(
|
||||||
context_messages = [load_context(path) for path in context]
|
context_messages = [load_context(path) for path in context]
|
||||||
|
|
||||||
current_chat = get_chat(settings=settings, context=context_messages)
|
current_chat = get_chat(settings=settings, context=context_messages)
|
||||||
# current_chat = get_chat(settings=settings)
|
|
||||||
console = get_console()
|
console = get_console()
|
||||||
session = get_session()
|
session = get_session()
|
||||||
|
|
||||||
|
@ -123,3 +127,4 @@ def chat(
|
||||||
else:
|
else:
|
||||||
response = current_chat.send_message(prompt.strip())
|
response = current_chat.send_message(prompt.strip())
|
||||||
console.print(Markdown(response))
|
console.print(Markdown(response))
|
||||||
|
display_cost(console, current_chat)
|
||||||
|
|
|
@ -27,6 +27,9 @@ class Conversation(BaseModel):
|
||||||
"""Conversation in the chat."""
|
"""Conversation in the chat."""
|
||||||
|
|
||||||
messages: list[Message]
|
messages: list[Message]
|
||||||
|
completion_tokens: int = 0
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
cost: float = 0.0
|
||||||
|
|
||||||
model_config: ConfigDict = ConfigDict( # type: ignore[misc]
|
model_config: ConfigDict = ConfigDict( # type: ignore[misc]
|
||||||
frozen=False,
|
frozen=False,
|
||||||
|
|
|
@ -1,14 +1,31 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from llm_chat.chat import Chat
|
from llm_chat.chat import Chat
|
||||||
|
from llm_chat.settings import Model, OpenAISettings
|
||||||
|
|
||||||
|
|
||||||
def test_send_message() -> None:
|
def test_send_message() -> None:
|
||||||
with patch("llm_chat.chat.Chat._make_request") as mock_make_request:
|
with patch("llm_chat.chat.Chat._make_request") as mock_make_request:
|
||||||
mock_make_request.return_value = {
|
mock_make_request.return_value = {
|
||||||
"choices": [{"message": {"content": "Hello!"}}]
|
"choices": [{"message": {"content": "Hello!"}}],
|
||||||
|
"usage": {"completion_tokens": 1, "prompt_tokens": 1},
|
||||||
}
|
}
|
||||||
conversation = Chat()
|
conversation = Chat()
|
||||||
response = conversation.send_message("Hello")
|
response = conversation.send_message("Hello")
|
||||||
assert isinstance(response, str)
|
assert isinstance(response, str)
|
||||||
assert response == "Hello!"
|
assert response == "Hello!"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model,cost", [(Model.GPT3, 0.000043), (Model.GPT4, 0.00105)])
|
||||||
|
def test_calculate_cost(model: Model, cost: float) -> None:
|
||||||
|
with patch("llm_chat.chat.Chat._make_request") as mock_make_request:
|
||||||
|
mock_make_request.return_value = {
|
||||||
|
"choices": [{"message": {"content": "Hello!"}}],
|
||||||
|
"usage": {"completion_tokens": 10, "prompt_tokens": 15},
|
||||||
|
}
|
||||||
|
settings = OpenAISettings(model=model)
|
||||||
|
conversation = Chat(settings=settings)
|
||||||
|
_ = conversation.send_message("Hello")
|
||||||
|
assert conversation.cost == cost
|
||||||
|
|
|
@ -29,6 +29,11 @@ class ChatFake:
|
||||||
def _set_args(self, **kwargs: Any) -> None:
|
def _set_args(self, **kwargs: Any) -> None:
|
||||||
self.args = kwargs
|
self.args = kwargs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cost(self) -> float:
|
||||||
|
"""Get the cost of the conversation."""
|
||||||
|
return 0.0
|
||||||
|
|
||||||
def send_message(self, message: str) -> str:
|
def send_message(self, message: str) -> str:
|
||||||
"""Echo the received message."""
|
"""Echo the received message."""
|
||||||
self.received_messages.append(message)
|
self.received_messages.append(message)
|
||||||
|
|
Loading…
Reference in New Issue