diff --git a/pyproject.toml b/pyproject.toml index 92ad0c8..35dd17d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "llm-chat" -version = "0.2.0" +version = "0.3.0" description = "A general CLI interface for large language models." authors = ["Paul Harrison "] readme = "README.md" diff --git a/src/llm_chat/chat.py b/src/llm_chat/chat.py index 7ab0adf..328d3b6 100644 --- a/src/llm_chat/chat.py +++ b/src/llm_chat/chat.py @@ -1,10 +1,11 @@ +from enum import StrEnum, auto from typing import Any, Protocol from openai import ChatCompletion from openai.openai_object import OpenAIObject from llm_chat.models import Conversation, Message, Role -from llm_chat.settings import OpenAISettings +from llm_chat.settings import Model, OpenAISettings INITIAL_SYSTEM_MESSAGES = [ Message( @@ -17,9 +18,20 @@ INITIAL_SYSTEM_MESSAGES = [ ] +class Token(StrEnum): + """Token type for the OpenAI chat.""" + + COMPLETION = auto() + PROMPT = auto() + + class ChatProtocol(Protocol): """Protocol for chat classes.""" + @property + def cost(self) -> float: + """Get the cost of the conversation.""" + def send_message(self, message: str) -> str: """Send a message to the assistant.""" @@ -33,6 +45,17 @@ class Chat: context (optional): Context for the chat. Defaults to an empty list. """ + _pricing: dict[Model, dict[Token, float]] = { + Model.GPT3: { + Token.COMPLETION: 0.002, + Token.PROMPT: 0.0015, + }, + Model.GPT4: { + Token.COMPLETION: 0.06, + Token.PROMPT: 0.03, + }, + } + def __init__( self, settings: OpenAISettings | None = None, context: list[Message] = [] ) -> None: @@ -48,6 +71,11 @@ class Chat: self._settings = OpenAISettings() return self._settings + @property + def cost(self) -> float: + """Get the cost of the conversation.""" + return self.conversation.cost + def _make_request(self, message: str) -> dict[str, Any]: """Send a request to the OpenAI API. @@ -66,15 +94,30 @@ class Chat: out: dict[str, Any] = response.to_dict() return out - def send_message(self, message: str) -> str: + def _calculate_cost(self, usage: dict[str, int]) -> None: + """Calculate the cost of a request.""" + self.conversation.completion_tokens += usage["completion_tokens"] + self.conversation.prompt_tokens += usage["prompt_tokens"] + self.conversation.cost = round( + (self.conversation.completion_tokens / 1000) + * self._pricing[self.settings.model][Token.COMPLETION] + + (self.conversation.prompt_tokens / 1000) + * self._pricing[self.settings.model][Token.PROMPT], + 6, + ) + + def send_message(self, prompt: str) -> str: """Send a message to the assistant. TODO: Add error handling. """ - response = self._make_request(message) - message = response["choices"][0]["message"]["content"] - self.conversation.messages.append(Message(role=Role.ASSISTANT, content=message)) - return message + request_response = self._make_request(prompt) + self._calculate_cost(request_response["usage"]) + response: str = request_response["choices"][0]["message"]["content"] + self.conversation.messages.append( + Message(role=Role.ASSISTANT, content=response) + ) + return response def get_chat( diff --git a/src/llm_chat/cli.py b/src/llm_chat/cli.py index a6cc02f..5d5c1b5 100644 --- a/src/llm_chat/cli.py +++ b/src/llm_chat/cli.py @@ -6,7 +6,7 @@ from prompt_toolkit import PromptSession from rich.console import Console from rich.markdown import Markdown -from llm_chat.chat import get_chat +from llm_chat.chat import ChatProtocol, get_chat from llm_chat.models import Message, Role from llm_chat.settings import DEFAULT_MODEL, DEFAULT_TEMPERATURE, Model, OpenAISettings @@ -55,6 +55,11 @@ def load_context(path: Path) -> Message: return Message(role=Role.SYSTEM, content=content) +def display_cost(console: Console, chat: ChatProtocol) -> None: + """Display the cost of the conversation so far.""" + console.print(f"\n[bold green]Cost:[/bold green] ${chat.cost}\n") + + @app.command() def chat( api_key: Annotated[ @@ -107,7 +112,6 @@ def chat( context_messages = [load_context(path) for path in context] current_chat = get_chat(settings=settings, context=context_messages) - # current_chat = get_chat(settings=settings) console = get_console() session = get_session() @@ -123,3 +127,4 @@ def chat( else: response = current_chat.send_message(prompt.strip()) console.print(Markdown(response)) + display_cost(console, current_chat) diff --git a/src/llm_chat/models.py b/src/llm_chat/models.py index bea6317..4b5a74c 100644 --- a/src/llm_chat/models.py +++ b/src/llm_chat/models.py @@ -27,6 +27,9 @@ class Conversation(BaseModel): """Conversation in the chat.""" messages: list[Message] + completion_tokens: int = 0 + prompt_tokens: int = 0 + cost: float = 0.0 model_config: ConfigDict = ConfigDict( # type: ignore[misc] frozen=False, diff --git a/tests/test_chat.py b/tests/test_chat.py index d0129d3..a5b6edf 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -1,14 +1,31 @@ from unittest.mock import patch +import pytest + from llm_chat.chat import Chat +from llm_chat.settings import Model, OpenAISettings def test_send_message() -> None: with patch("llm_chat.chat.Chat._make_request") as mock_make_request: mock_make_request.return_value = { - "choices": [{"message": {"content": "Hello!"}}] + "choices": [{"message": {"content": "Hello!"}}], + "usage": {"completion_tokens": 1, "prompt_tokens": 1}, } conversation = Chat() response = conversation.send_message("Hello") assert isinstance(response, str) assert response == "Hello!" + + +@pytest.mark.parametrize("model,cost", [(Model.GPT3, 0.000043), (Model.GPT4, 0.00105)]) +def test_calculate_cost(model: Model, cost: float) -> None: + with patch("llm_chat.chat.Chat._make_request") as mock_make_request: + mock_make_request.return_value = { + "choices": [{"message": {"content": "Hello!"}}], + "usage": {"completion_tokens": 10, "prompt_tokens": 15}, + } + settings = OpenAISettings(model=model) + conversation = Chat(settings=settings) + _ = conversation.send_message("Hello") + assert conversation.cost == cost diff --git a/tests/test_cli.py b/tests/test_cli.py index 5861cf6..675178d 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -29,6 +29,11 @@ class ChatFake: def _set_args(self, **kwargs: Any) -> None: self.args = kwargs + @property + def cost(self) -> float: + """Get the cost of the conversation.""" + return 0.0 + def send_message(self, message: str) -> str: """Echo the received message.""" self.received_messages.append(message)