mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-16 18:14:30 +00:00
[feat] token-limited message context (#6087)
This commit is contained in:
parent
29485ef85b
commit
7487687cdc
@ -15,6 +15,7 @@ from autogen_core.model_context import (
|
|||||||
BufferedChatCompletionContext,
|
BufferedChatCompletionContext,
|
||||||
HeadAndTailChatCompletionContext,
|
HeadAndTailChatCompletionContext,
|
||||||
UnboundedChatCompletionContext,
|
UnboundedChatCompletionContext,
|
||||||
|
TokenLimitedChatCompletionContext,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -104,6 +105,7 @@ async def test_chat_completion_context_declarative() -> None:
|
|||||||
unbounded_context = UnboundedChatCompletionContext()
|
unbounded_context = UnboundedChatCompletionContext()
|
||||||
buffered_context = BufferedChatCompletionContext(buffer_size=5)
|
buffered_context = BufferedChatCompletionContext(buffer_size=5)
|
||||||
head_tail_context = HeadAndTailChatCompletionContext(head_size=3, tail_size=2)
|
head_tail_context = HeadAndTailChatCompletionContext(head_size=3, tail_size=2)
|
||||||
|
token_limited_context = TokenLimitedChatCompletionContext(token_limit=5, model="gpt-4o")
|
||||||
|
|
||||||
# Test serialization
|
# Test serialization
|
||||||
unbounded_config = unbounded_context.dump_component()
|
unbounded_config = unbounded_context.dump_component()
|
||||||
@ -118,6 +120,11 @@ async def test_chat_completion_context_declarative() -> None:
|
|||||||
assert head_tail_config.config["head_size"] == 3
|
assert head_tail_config.config["head_size"] == 3
|
||||||
assert head_tail_config.config["tail_size"] == 2
|
assert head_tail_config.config["tail_size"] == 2
|
||||||
|
|
||||||
|
token_limited_config = token_limited_context.dump_component()
|
||||||
|
assert token_limited_config.provider == "autogen_core.model_context.TokenLimitedChatCompletionContext"
|
||||||
|
assert token_limited_config.config["token_limit"] == 5
|
||||||
|
assert token_limited_config.config["model"] == "gpt-4o"
|
||||||
|
|
||||||
# Test deserialization
|
# Test deserialization
|
||||||
loaded_unbounded = ComponentLoader.load_component(unbounded_config, UnboundedChatCompletionContext)
|
loaded_unbounded = ComponentLoader.load_component(unbounded_config, UnboundedChatCompletionContext)
|
||||||
assert isinstance(loaded_unbounded, UnboundedChatCompletionContext)
|
assert isinstance(loaded_unbounded, UnboundedChatCompletionContext)
|
||||||
@ -129,3 +136,6 @@ async def test_chat_completion_context_declarative() -> None:
|
|||||||
loaded_head_tail = ComponentLoader.load_component(head_tail_config, HeadAndTailChatCompletionContext)
|
loaded_head_tail = ComponentLoader.load_component(head_tail_config, HeadAndTailChatCompletionContext)
|
||||||
|
|
||||||
assert isinstance(loaded_head_tail, HeadAndTailChatCompletionContext)
|
assert isinstance(loaded_head_tail, HeadAndTailChatCompletionContext)
|
||||||
|
|
||||||
|
loaded_token_limited = ComponentLoader.load_component(token_limited_config, TokenLimitedChatCompletionContext)
|
||||||
|
assert isinstance(loaded_token_limited, TokenLimitedChatCompletionContext)
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from ._buffered_chat_completion_context import BufferedChatCompletionContext
|
from ._buffered_chat_completion_context import BufferedChatCompletionContext
|
||||||
|
from ._token_limited_chat_completion_context import TokenLimitedChatCompletionContext
|
||||||
from ._chat_completion_context import ChatCompletionContext, ChatCompletionContextState
|
from ._chat_completion_context import ChatCompletionContext, ChatCompletionContextState
|
||||||
from ._head_and_tail_chat_completion_context import HeadAndTailChatCompletionContext
|
from ._head_and_tail_chat_completion_context import HeadAndTailChatCompletionContext
|
||||||
from ._unbounded_chat_completion_context import (
|
from ._unbounded_chat_completion_context import (
|
||||||
@ -10,5 +11,6 @@ __all__ = [
|
|||||||
"ChatCompletionContextState",
|
"ChatCompletionContextState",
|
||||||
"UnboundedChatCompletionContext",
|
"UnboundedChatCompletionContext",
|
||||||
"BufferedChatCompletionContext",
|
"BufferedChatCompletionContext",
|
||||||
|
"TokenLimitedChatCompletionContext",
|
||||||
"HeadAndTailChatCompletionContext",
|
"HeadAndTailChatCompletionContext",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -0,0 +1,83 @@
|
|||||||
|
from typing import List, Sequence
|
||||||
|
from autogen_core.tools import Tool, ToolSchema
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing_extensions import Self
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
from .._component_config import Component
|
||||||
|
from ..models import FunctionExecutionResultMessage, LLMMessage
|
||||||
|
from ._chat_completion_context import ChatCompletionContext
|
||||||
|
|
||||||
|
from autogen_ext.models.ollama._ollama_client import count_tokens_ollama
|
||||||
|
from autogen_ext.models.openai._openai_client import count_tokens_openai
|
||||||
|
|
||||||
|
|
||||||
|
class TokenLimitedChatCompletionContextConfig(BaseModel):
|
||||||
|
token_limit: int
|
||||||
|
model: str
|
||||||
|
initial_messages: List[LLMMessage] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TokenLimitedChatCompletionContext(ChatCompletionContext, Component[TokenLimitedChatCompletionContextConfig]):
|
||||||
|
"""A token based chat completion context maintains a view of the context up to a token limit,
|
||||||
|
where n is the token limit. The token limit is set at initialization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_limit (int): Max tokens for context.
|
||||||
|
initial_messages (List[LLMMessage] | None): The initial messages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
component_config_schema = TokenLimitedChatCompletionContextConfig
|
||||||
|
component_provider_override = "autogen_core.model_context.TokenLimitedChatCompletionContext"
|
||||||
|
|
||||||
|
def __init__(self, token_limit: int, model: str, initial_messages: List[LLMMessage] | None = None) -> None:
|
||||||
|
super().__init__(initial_messages)
|
||||||
|
if token_limit <= 0:
|
||||||
|
raise ValueError("token_limit must be greater than 0.")
|
||||||
|
self._token_limit = token_limit
|
||||||
|
self._model = model
|
||||||
|
|
||||||
|
async def get_messages(self) -> List[LLMMessage]:
|
||||||
|
"""Get at most `token_limit` tokens in recent messages."""
|
||||||
|
token_count = count_chat_tokens(self._messages, self._model)
|
||||||
|
while token_count > self._token_limit:
|
||||||
|
middle_index = len(self._messages) // 2
|
||||||
|
self._messages.pop(middle_index)
|
||||||
|
token_count = count_chat_tokens(self._messages, self._model)
|
||||||
|
messages = self._messages
|
||||||
|
# Handle the first message is a function call result message.
|
||||||
|
if messages and isinstance(messages[0], FunctionExecutionResultMessage):
|
||||||
|
# Remove the first message from the list.
|
||||||
|
messages = messages[1:]
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def _to_config(self) -> TokenLimitedChatCompletionContextConfig:
|
||||||
|
return TokenLimitedChatCompletionContextConfig(
|
||||||
|
token_limit=self._token_limit, model=self._model, initial_messages=self._messages
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _from_config(cls, config: TokenLimitedChatCompletionContextConfig) -> Self:
|
||||||
|
return cls(**config.model_dump())
|
||||||
|
|
||||||
|
|
||||||
|
def count_chat_tokens(
|
||||||
|
messages: Sequence[LLMMessage], model: str = "gpt-4o", *, tools: Sequence[Tool | ToolSchema] = []
|
||||||
|
) -> int:
|
||||||
|
"""Count tokens for a list of messages using the appropriate client based on the model."""
|
||||||
|
# Check if the model is an OpenAI model
|
||||||
|
if "openai" in model.lower():
|
||||||
|
return count_tokens_openai(messages, model)
|
||||||
|
|
||||||
|
# Check if the model is an Ollama model
|
||||||
|
elif "llama" in model.lower():
|
||||||
|
return count_tokens_ollama(messages, model)
|
||||||
|
|
||||||
|
# Fallback to cl100k_base encoding if the model is unrecognized
|
||||||
|
else:
|
||||||
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
total_tokens = 0
|
||||||
|
for message in messages:
|
||||||
|
total_tokens += len(encoding.encode(str(message.content)))
|
||||||
|
return total_tokens
|
||||||
@ -5,8 +5,9 @@ from autogen_core.model_context import (
|
|||||||
BufferedChatCompletionContext,
|
BufferedChatCompletionContext,
|
||||||
HeadAndTailChatCompletionContext,
|
HeadAndTailChatCompletionContext,
|
||||||
UnboundedChatCompletionContext,
|
UnboundedChatCompletionContext,
|
||||||
|
TokenLimitedChatCompletionContext,
|
||||||
)
|
)
|
||||||
from autogen_core.models import AssistantMessage, LLMMessage, UserMessage
|
from autogen_core.models import AssistantMessage, LLMMessage, UserMessage, FunctionExecutionResultMessage
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -104,3 +105,82 @@ async def test_unbounded_model_context() -> None:
|
|||||||
retrieved = await model_context.get_messages()
|
retrieved = await model_context.get_messages()
|
||||||
assert len(retrieved) == 3
|
assert len(retrieved) == 3
|
||||||
assert retrieved == messages
|
assert retrieved == messages
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_token_limited_model_context_openai() -> None:
|
||||||
|
model_context = TokenLimitedChatCompletionContext(token_limit=20, model="gpt-4o")
|
||||||
|
messages: List[LLMMessage] = [
|
||||||
|
UserMessage(content="Hello!", source="user"),
|
||||||
|
AssistantMessage(content="What can I do for you?", source="assistant"),
|
||||||
|
UserMessage(content="Tell what are some fun things to do in seattle.", source="user"),
|
||||||
|
]
|
||||||
|
for msg in messages:
|
||||||
|
await model_context.add_message(msg)
|
||||||
|
|
||||||
|
retrieved = await model_context.get_messages()
|
||||||
|
assert len(retrieved) == 2 # Token limit set very low, will remove 1 of the messages
|
||||||
|
assert retrieved != messages # Will not be equal to the original messages
|
||||||
|
|
||||||
|
await model_context.clear()
|
||||||
|
retrieved = await model_context.get_messages()
|
||||||
|
assert len(retrieved) == 0
|
||||||
|
|
||||||
|
# Test saving and loading state.
|
||||||
|
for msg in messages:
|
||||||
|
await model_context.add_message(msg)
|
||||||
|
state = await model_context.save_state()
|
||||||
|
await model_context.clear()
|
||||||
|
await model_context.load_state(state)
|
||||||
|
retrieved = await model_context.get_messages()
|
||||||
|
assert len(retrieved) == 2
|
||||||
|
assert retrieved != messages
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_token_limited_model_context_llama() -> None:
|
||||||
|
model_context = TokenLimitedChatCompletionContext(token_limit=20, model="llama2-7b")
|
||||||
|
messages: List[LLMMessage] = [
|
||||||
|
UserMessage(content="Hello!", source="user"),
|
||||||
|
AssistantMessage(content="What can I do for you?", source="assistant"),
|
||||||
|
UserMessage(content="Tell what are some fun things to do in seattle.", source="user"),
|
||||||
|
]
|
||||||
|
for msg in messages:
|
||||||
|
await model_context.add_message(msg)
|
||||||
|
|
||||||
|
retrieved = await model_context.get_messages()
|
||||||
|
assert len(retrieved) == 1 # Token limit set very low, will remove two of the messages
|
||||||
|
assert retrieved != messages # Will not be equal to the original messages
|
||||||
|
|
||||||
|
await model_context.clear()
|
||||||
|
retrieved = await model_context.get_messages()
|
||||||
|
assert len(retrieved) == 0
|
||||||
|
|
||||||
|
# Test saving and loading state.
|
||||||
|
for msg in messages:
|
||||||
|
await model_context.add_message(msg)
|
||||||
|
state = await model_context.save_state()
|
||||||
|
await model_context.clear()
|
||||||
|
await model_context.load_state(state)
|
||||||
|
retrieved = await model_context.get_messages()
|
||||||
|
assert len(retrieved) == 1
|
||||||
|
assert retrieved != messages
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_token_limited_model_context_openai_with_function_result() -> None:
|
||||||
|
model_context = TokenLimitedChatCompletionContext(token_limit=1000, model="gpt-4o")
|
||||||
|
messages: List[LLMMessage] = [
|
||||||
|
FunctionExecutionResultMessage(content=[]),
|
||||||
|
UserMessage(content="Hello!", source="user"),
|
||||||
|
AssistantMessage(content="What can I do for you?", source="assistant"),
|
||||||
|
UserMessage(content="Tell what are some fun things to do in seattle.", source="user"),
|
||||||
|
]
|
||||||
|
for msg in messages:
|
||||||
|
await model_context.add_message(msg)
|
||||||
|
|
||||||
|
retrieved = await model_context.get_messages()
|
||||||
|
assert len(retrieved) == 3 # Token limit set very low, will remove 1 of the messages
|
||||||
|
assert type(retrieved[0]) == UserMessage # Function result should be removed
|
||||||
|
assert type(retrieved[1]) == AssistantMessage
|
||||||
|
assert type(retrieved[2]) == UserMessage
|
||||||
|
|||||||
@ -378,6 +378,66 @@ def normalize_stop_reason(stop_reason: str | None) -> FinishReasons:
|
|||||||
return KNOWN_STOP_MAPPINGS.get(stop_reason, "unknown")
|
return KNOWN_STOP_MAPPINGS.get(stop_reason, "unknown")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: probably needs work
|
||||||
|
def count_tokens_ollama(messages: Sequence[LLMMessage], model: str, *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||||
|
try:
|
||||||
|
encoding = tiktoken.encoding_for_model(model)
|
||||||
|
except KeyError:
|
||||||
|
trace_logger.warning(f"Model {model} not found. Using cl100k_base encoding.")
|
||||||
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
tokens_per_message = 3
|
||||||
|
num_tokens = 0
|
||||||
|
|
||||||
|
# Message tokens.
|
||||||
|
for message in messages:
|
||||||
|
num_tokens += tokens_per_message
|
||||||
|
ollama_message = to_ollama_type(message)
|
||||||
|
for ollama_message_part in ollama_message:
|
||||||
|
if isinstance(message.content, Image):
|
||||||
|
num_tokens += calculate_vision_tokens(message.content)
|
||||||
|
elif ollama_message_part.content is not None:
|
||||||
|
num_tokens += len(encoding.encode(ollama_message_part.content))
|
||||||
|
# TODO: every model family has its own message sequence.
|
||||||
|
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||||
|
|
||||||
|
# Tool tokens.
|
||||||
|
ollama_tools = convert_tools(tools)
|
||||||
|
for tool in ollama_tools:
|
||||||
|
function = tool["function"]
|
||||||
|
tool_tokens = len(encoding.encode(function["name"]))
|
||||||
|
if "description" in function:
|
||||||
|
tool_tokens += len(encoding.encode(function["description"]))
|
||||||
|
tool_tokens -= 2
|
||||||
|
if "parameters" in function:
|
||||||
|
parameters = function["parameters"]
|
||||||
|
if "properties" in parameters:
|
||||||
|
assert isinstance(parameters["properties"], dict)
|
||||||
|
for propertiesKey in parameters["properties"]: # pyright: ignore
|
||||||
|
assert isinstance(propertiesKey, str)
|
||||||
|
tool_tokens += len(encoding.encode(propertiesKey))
|
||||||
|
v = parameters["properties"][propertiesKey] # pyright: ignore
|
||||||
|
for field in v: # pyright: ignore
|
||||||
|
if field == "type":
|
||||||
|
tool_tokens += 2
|
||||||
|
tool_tokens += len(encoding.encode(v["type"])) # pyright: ignore
|
||||||
|
elif field == "description":
|
||||||
|
tool_tokens += 2
|
||||||
|
tool_tokens += len(encoding.encode(v["description"])) # pyright: ignore
|
||||||
|
elif field == "enum":
|
||||||
|
tool_tokens -= 3
|
||||||
|
for o in v["enum"]: # pyright: ignore
|
||||||
|
tool_tokens += 3
|
||||||
|
tool_tokens += len(encoding.encode(o)) # pyright: ignore
|
||||||
|
else:
|
||||||
|
trace_logger.warning(f"Not supported field {field}")
|
||||||
|
tool_tokens += 11
|
||||||
|
if len(parameters["properties"]) == 0: # pyright: ignore
|
||||||
|
tool_tokens -= 2
|
||||||
|
num_tokens += tool_tokens
|
||||||
|
num_tokens += 12
|
||||||
|
return num_tokens
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CreateParams:
|
class CreateParams:
|
||||||
messages: Sequence[Message]
|
messages: Sequence[Message]
|
||||||
@ -796,65 +856,8 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
|||||||
def total_usage(self) -> RequestUsage:
|
def total_usage(self) -> RequestUsage:
|
||||||
return self._total_usage
|
return self._total_usage
|
||||||
|
|
||||||
# TODO: probably needs work
|
|
||||||
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||||
model = self._create_args["model"]
|
return count_tokens_ollama(messages, self._create_args["model"], tools=tools)
|
||||||
try:
|
|
||||||
encoding = tiktoken.encoding_for_model(model)
|
|
||||||
except KeyError:
|
|
||||||
trace_logger.warning(f"Model {model} not found. Using cl100k_base encoding.")
|
|
||||||
encoding = tiktoken.get_encoding("cl100k_base")
|
|
||||||
tokens_per_message = 3
|
|
||||||
num_tokens = 0
|
|
||||||
|
|
||||||
# Message tokens.
|
|
||||||
for message in messages:
|
|
||||||
num_tokens += tokens_per_message
|
|
||||||
ollama_message = to_ollama_type(message)
|
|
||||||
for ollama_message_part in ollama_message:
|
|
||||||
if isinstance(message.content, Image):
|
|
||||||
num_tokens += calculate_vision_tokens(message.content)
|
|
||||||
elif ollama_message_part.content is not None:
|
|
||||||
num_tokens += len(encoding.encode(ollama_message_part.content))
|
|
||||||
# TODO: every model family has its own message sequence.
|
|
||||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
|
||||||
|
|
||||||
# Tool tokens.
|
|
||||||
ollama_tools = convert_tools(tools)
|
|
||||||
for tool in ollama_tools:
|
|
||||||
function = tool["function"]
|
|
||||||
tool_tokens = len(encoding.encode(function["name"]))
|
|
||||||
if "description" in function:
|
|
||||||
tool_tokens += len(encoding.encode(function["description"]))
|
|
||||||
tool_tokens -= 2
|
|
||||||
if "parameters" in function:
|
|
||||||
parameters = function["parameters"]
|
|
||||||
if "properties" in parameters:
|
|
||||||
assert isinstance(parameters["properties"], dict)
|
|
||||||
for propertiesKey in parameters["properties"]: # pyright: ignore
|
|
||||||
assert isinstance(propertiesKey, str)
|
|
||||||
tool_tokens += len(encoding.encode(propertiesKey))
|
|
||||||
v = parameters["properties"][propertiesKey] # pyright: ignore
|
|
||||||
for field in v: # pyright: ignore
|
|
||||||
if field == "type":
|
|
||||||
tool_tokens += 2
|
|
||||||
tool_tokens += len(encoding.encode(v["type"])) # pyright: ignore
|
|
||||||
elif field == "description":
|
|
||||||
tool_tokens += 2
|
|
||||||
tool_tokens += len(encoding.encode(v["description"])) # pyright: ignore
|
|
||||||
elif field == "enum":
|
|
||||||
tool_tokens -= 3
|
|
||||||
for o in v["enum"]: # pyright: ignore
|
|
||||||
tool_tokens += 3
|
|
||||||
tool_tokens += len(encoding.encode(o)) # pyright: ignore
|
|
||||||
else:
|
|
||||||
trace_logger.warning(f"Not supported field {field}")
|
|
||||||
tool_tokens += 11
|
|
||||||
if len(parameters["properties"]) == 0: # pyright: ignore
|
|
||||||
tool_tokens -= 2
|
|
||||||
num_tokens += tool_tokens
|
|
||||||
num_tokens += 12
|
|
||||||
return num_tokens
|
|
||||||
|
|
||||||
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||||
token_limit = _model_info.get_token_limit(self._create_args["model"])
|
token_limit = _model_info.get_token_limit(self._create_args["model"])
|
||||||
|
|||||||
@ -373,6 +373,101 @@ def assert_valid_name(name: str) -> str:
|
|||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def count_tokens_openai(
|
||||||
|
messages: Sequence[LLMMessage],
|
||||||
|
model: str,
|
||||||
|
*,
|
||||||
|
add_name_prefixes: bool = False,
|
||||||
|
tools: Sequence[Tool | ToolSchema] = [],
|
||||||
|
) -> int:
|
||||||
|
try:
|
||||||
|
encoding = tiktoken.encoding_for_model(model)
|
||||||
|
except KeyError:
|
||||||
|
trace_logger.warning(f"Model {model} not found. Using cl100k_base encoding.")
|
||||||
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
tokens_per_message = 3
|
||||||
|
tokens_per_name = 1
|
||||||
|
num_tokens = 0
|
||||||
|
|
||||||
|
# Message tokens.
|
||||||
|
for message in messages:
|
||||||
|
num_tokens += tokens_per_message
|
||||||
|
oai_message = to_oai_type(message, prepend_name=add_name_prefixes)
|
||||||
|
for oai_message_part in oai_message:
|
||||||
|
for key, value in oai_message_part.items():
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(message, UserMessage) and isinstance(value, list):
|
||||||
|
typed_message_value = cast(List[ChatCompletionContentPartParam], value)
|
||||||
|
|
||||||
|
assert len(typed_message_value) == len(
|
||||||
|
message.content
|
||||||
|
), "Mismatch in message content and typed message value"
|
||||||
|
|
||||||
|
# We need image properties that are only in the original message
|
||||||
|
for part, content_part in zip(typed_message_value, message.content, strict=False):
|
||||||
|
if isinstance(content_part, Image):
|
||||||
|
# TODO: add detail parameter
|
||||||
|
num_tokens += calculate_vision_tokens(content_part)
|
||||||
|
elif isinstance(part, str):
|
||||||
|
num_tokens += len(encoding.encode(part))
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
serialized_part = json.dumps(part)
|
||||||
|
num_tokens += len(encoding.encode(serialized_part))
|
||||||
|
except TypeError:
|
||||||
|
trace_logger.warning(f"Could not convert {part} to string, skipping.")
|
||||||
|
else:
|
||||||
|
if not isinstance(value, str):
|
||||||
|
try:
|
||||||
|
value = json.dumps(value)
|
||||||
|
except TypeError:
|
||||||
|
trace_logger.warning(f"Could not convert {value} to string, skipping.")
|
||||||
|
continue
|
||||||
|
num_tokens += len(encoding.encode(value))
|
||||||
|
if key == "name":
|
||||||
|
num_tokens += tokens_per_name
|
||||||
|
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||||
|
|
||||||
|
# Tool tokens.
|
||||||
|
oai_tools = convert_tools(tools)
|
||||||
|
for tool in oai_tools:
|
||||||
|
function = tool["function"]
|
||||||
|
tool_tokens = len(encoding.encode(function["name"]))
|
||||||
|
if "description" in function:
|
||||||
|
tool_tokens += len(encoding.encode(function["description"]))
|
||||||
|
tool_tokens -= 2
|
||||||
|
if "parameters" in function:
|
||||||
|
parameters = function["parameters"]
|
||||||
|
if "properties" in parameters:
|
||||||
|
assert isinstance(parameters["properties"], dict)
|
||||||
|
for propertiesKey in parameters["properties"]: # pyright: ignore
|
||||||
|
assert isinstance(propertiesKey, str)
|
||||||
|
tool_tokens += len(encoding.encode(propertiesKey))
|
||||||
|
v = parameters["properties"][propertiesKey] # pyright: ignore
|
||||||
|
for field in v: # pyright: ignore
|
||||||
|
if field == "type":
|
||||||
|
tool_tokens += 2
|
||||||
|
tool_tokens += len(encoding.encode(v["type"])) # pyright: ignore
|
||||||
|
elif field == "description":
|
||||||
|
tool_tokens += 2
|
||||||
|
tool_tokens += len(encoding.encode(v["description"])) # pyright: ignore
|
||||||
|
elif field == "enum":
|
||||||
|
tool_tokens -= 3
|
||||||
|
for o in v["enum"]: # pyright: ignore
|
||||||
|
tool_tokens += 3
|
||||||
|
tool_tokens += len(encoding.encode(o)) # pyright: ignore
|
||||||
|
else:
|
||||||
|
trace_logger.warning(f"Not supported field {field}")
|
||||||
|
tool_tokens += 11
|
||||||
|
if len(parameters["properties"]) == 0: # pyright: ignore
|
||||||
|
tool_tokens -= 2
|
||||||
|
num_tokens += tool_tokens
|
||||||
|
num_tokens += 12
|
||||||
|
return num_tokens
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CreateParams:
|
class CreateParams:
|
||||||
messages: List[ChatCompletionMessageParam]
|
messages: List[ChatCompletionMessageParam]
|
||||||
@ -1002,93 +1097,12 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||||||
return self._total_usage
|
return self._total_usage
|
||||||
|
|
||||||
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||||
model = self._create_args["model"]
|
return count_tokens_openai(
|
||||||
try:
|
messages,
|
||||||
encoding = tiktoken.encoding_for_model(model)
|
self._create_args["model"],
|
||||||
except KeyError:
|
add_name_prefixes=self._add_name_prefixes,
|
||||||
trace_logger.warning(f"Model {model} not found. Using cl100k_base encoding.")
|
tools=tools,
|
||||||
encoding = tiktoken.get_encoding("cl100k_base")
|
)
|
||||||
tokens_per_message = 3
|
|
||||||
tokens_per_name = 1
|
|
||||||
num_tokens = 0
|
|
||||||
|
|
||||||
# Message tokens.
|
|
||||||
for message in messages:
|
|
||||||
num_tokens += tokens_per_message
|
|
||||||
oai_message = to_oai_type(message, prepend_name=self._add_name_prefixes)
|
|
||||||
for oai_message_part in oai_message:
|
|
||||||
for key, value in oai_message_part.items():
|
|
||||||
if value is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if isinstance(message, UserMessage) and isinstance(value, list):
|
|
||||||
typed_message_value = cast(List[ChatCompletionContentPartParam], value)
|
|
||||||
|
|
||||||
assert len(typed_message_value) == len(
|
|
||||||
message.content
|
|
||||||
), "Mismatch in message content and typed message value"
|
|
||||||
|
|
||||||
# We need image properties that are only in the original message
|
|
||||||
for part, content_part in zip(typed_message_value, message.content, strict=False):
|
|
||||||
if isinstance(content_part, Image):
|
|
||||||
# TODO: add detail parameter
|
|
||||||
num_tokens += calculate_vision_tokens(content_part)
|
|
||||||
elif isinstance(part, str):
|
|
||||||
num_tokens += len(encoding.encode(part))
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
serialized_part = json.dumps(part)
|
|
||||||
num_tokens += len(encoding.encode(serialized_part))
|
|
||||||
except TypeError:
|
|
||||||
trace_logger.warning(f"Could not convert {part} to string, skipping.")
|
|
||||||
else:
|
|
||||||
if not isinstance(value, str):
|
|
||||||
try:
|
|
||||||
value = json.dumps(value)
|
|
||||||
except TypeError:
|
|
||||||
trace_logger.warning(f"Could not convert {value} to string, skipping.")
|
|
||||||
continue
|
|
||||||
num_tokens += len(encoding.encode(value))
|
|
||||||
if key == "name":
|
|
||||||
num_tokens += tokens_per_name
|
|
||||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
|
||||||
|
|
||||||
# Tool tokens.
|
|
||||||
oai_tools = convert_tools(tools)
|
|
||||||
for tool in oai_tools:
|
|
||||||
function = tool["function"]
|
|
||||||
tool_tokens = len(encoding.encode(function["name"]))
|
|
||||||
if "description" in function:
|
|
||||||
tool_tokens += len(encoding.encode(function["description"]))
|
|
||||||
tool_tokens -= 2
|
|
||||||
if "parameters" in function:
|
|
||||||
parameters = function["parameters"]
|
|
||||||
if "properties" in parameters:
|
|
||||||
assert isinstance(parameters["properties"], dict)
|
|
||||||
for propertiesKey in parameters["properties"]: # pyright: ignore
|
|
||||||
assert isinstance(propertiesKey, str)
|
|
||||||
tool_tokens += len(encoding.encode(propertiesKey))
|
|
||||||
v = parameters["properties"][propertiesKey] # pyright: ignore
|
|
||||||
for field in v: # pyright: ignore
|
|
||||||
if field == "type":
|
|
||||||
tool_tokens += 2
|
|
||||||
tool_tokens += len(encoding.encode(v["type"])) # pyright: ignore
|
|
||||||
elif field == "description":
|
|
||||||
tool_tokens += 2
|
|
||||||
tool_tokens += len(encoding.encode(v["description"])) # pyright: ignore
|
|
||||||
elif field == "enum":
|
|
||||||
tool_tokens -= 3
|
|
||||||
for o in v["enum"]: # pyright: ignore
|
|
||||||
tool_tokens += 3
|
|
||||||
tool_tokens += len(encoding.encode(o)) # pyright: ignore
|
|
||||||
else:
|
|
||||||
trace_logger.warning(f"Not supported field {field}")
|
|
||||||
tool_tokens += 11
|
|
||||||
if len(parameters["properties"]) == 0: # pyright: ignore
|
|
||||||
tool_tokens -= 2
|
|
||||||
num_tokens += tool_tokens
|
|
||||||
num_tokens += 12
|
|
||||||
return num_tokens
|
|
||||||
|
|
||||||
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||||
token_limit = _model_info.get_token_limit(self._create_args["model"])
|
token_limit = _model_info.get_token_limit(self._create_args["model"])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user