mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-02 02:40:21 +00:00
[feat] token-limited message context (#6087)
This commit is contained in:
parent
29485ef85b
commit
7487687cdc
@ -15,6 +15,7 @@ from autogen_core.model_context import (
|
||||
BufferedChatCompletionContext,
|
||||
HeadAndTailChatCompletionContext,
|
||||
UnboundedChatCompletionContext,
|
||||
TokenLimitedChatCompletionContext,
|
||||
)
|
||||
|
||||
|
||||
@ -104,6 +105,7 @@ async def test_chat_completion_context_declarative() -> None:
|
||||
unbounded_context = UnboundedChatCompletionContext()
|
||||
buffered_context = BufferedChatCompletionContext(buffer_size=5)
|
||||
head_tail_context = HeadAndTailChatCompletionContext(head_size=3, tail_size=2)
|
||||
token_limited_context = TokenLimitedChatCompletionContext(token_limit=5, model="gpt-4o")
|
||||
|
||||
# Test serialization
|
||||
unbounded_config = unbounded_context.dump_component()
|
||||
@ -118,6 +120,11 @@ async def test_chat_completion_context_declarative() -> None:
|
||||
assert head_tail_config.config["head_size"] == 3
|
||||
assert head_tail_config.config["tail_size"] == 2
|
||||
|
||||
token_limited_config = token_limited_context.dump_component()
|
||||
assert token_limited_config.provider == "autogen_core.model_context.TokenLimitedChatCompletionContext"
|
||||
assert token_limited_config.config["token_limit"] == 5
|
||||
assert token_limited_config.config["model"] == "gpt-4o"
|
||||
|
||||
# Test deserialization
|
||||
loaded_unbounded = ComponentLoader.load_component(unbounded_config, UnboundedChatCompletionContext)
|
||||
assert isinstance(loaded_unbounded, UnboundedChatCompletionContext)
|
||||
@ -129,3 +136,6 @@ async def test_chat_completion_context_declarative() -> None:
|
||||
loaded_head_tail = ComponentLoader.load_component(head_tail_config, HeadAndTailChatCompletionContext)
|
||||
|
||||
assert isinstance(loaded_head_tail, HeadAndTailChatCompletionContext)
|
||||
|
||||
loaded_token_limited = ComponentLoader.load_component(token_limited_config, TokenLimitedChatCompletionContext)
|
||||
assert isinstance(loaded_token_limited, TokenLimitedChatCompletionContext)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from ._buffered_chat_completion_context import BufferedChatCompletionContext
|
||||
from ._token_limited_chat_completion_context import TokenLimitedChatCompletionContext
|
||||
from ._chat_completion_context import ChatCompletionContext, ChatCompletionContextState
|
||||
from ._head_and_tail_chat_completion_context import HeadAndTailChatCompletionContext
|
||||
from ._unbounded_chat_completion_context import (
|
||||
@ -10,5 +11,6 @@ __all__ = [
|
||||
"ChatCompletionContextState",
|
||||
"UnboundedChatCompletionContext",
|
||||
"BufferedChatCompletionContext",
|
||||
"TokenLimitedChatCompletionContext",
|
||||
"HeadAndTailChatCompletionContext",
|
||||
]
|
||||
|
||||
@ -0,0 +1,83 @@
|
||||
from typing import List, Sequence
|
||||
from autogen_core.tools import Tool, ToolSchema
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
import tiktoken
|
||||
|
||||
from .._component_config import Component
|
||||
from ..models import FunctionExecutionResultMessage, LLMMessage
|
||||
from ._chat_completion_context import ChatCompletionContext
|
||||
|
||||
from autogen_ext.models.ollama._ollama_client import count_tokens_ollama
|
||||
from autogen_ext.models.openai._openai_client import count_tokens_openai
|
||||
|
||||
|
||||
class TokenLimitedChatCompletionContextConfig(BaseModel):
|
||||
token_limit: int
|
||||
model: str
|
||||
initial_messages: List[LLMMessage] | None = None
|
||||
|
||||
|
||||
class TokenLimitedChatCompletionContext(ChatCompletionContext, Component[TokenLimitedChatCompletionContextConfig]):
|
||||
"""A token based chat completion context maintains a view of the context up to a token limit,
|
||||
where n is the token limit. The token limit is set at initialization.
|
||||
|
||||
Args:
|
||||
token_limit (int): Max tokens for context.
|
||||
initial_messages (List[LLMMessage] | None): The initial messages.
|
||||
"""
|
||||
|
||||
component_config_schema = TokenLimitedChatCompletionContextConfig
|
||||
component_provider_override = "autogen_core.model_context.TokenLimitedChatCompletionContext"
|
||||
|
||||
def __init__(self, token_limit: int, model: str, initial_messages: List[LLMMessage] | None = None) -> None:
|
||||
super().__init__(initial_messages)
|
||||
if token_limit <= 0:
|
||||
raise ValueError("token_limit must be greater than 0.")
|
||||
self._token_limit = token_limit
|
||||
self._model = model
|
||||
|
||||
async def get_messages(self) -> List[LLMMessage]:
|
||||
"""Get at most `token_limit` tokens in recent messages."""
|
||||
token_count = count_chat_tokens(self._messages, self._model)
|
||||
while token_count > self._token_limit:
|
||||
middle_index = len(self._messages) // 2
|
||||
self._messages.pop(middle_index)
|
||||
token_count = count_chat_tokens(self._messages, self._model)
|
||||
messages = self._messages
|
||||
# Handle the first message is a function call result message.
|
||||
if messages and isinstance(messages[0], FunctionExecutionResultMessage):
|
||||
# Remove the first message from the list.
|
||||
messages = messages[1:]
|
||||
return messages
|
||||
|
||||
def _to_config(self) -> TokenLimitedChatCompletionContextConfig:
|
||||
return TokenLimitedChatCompletionContextConfig(
|
||||
token_limit=self._token_limit, model=self._model, initial_messages=self._messages
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: TokenLimitedChatCompletionContextConfig) -> Self:
|
||||
return cls(**config.model_dump())
|
||||
|
||||
|
||||
def count_chat_tokens(
|
||||
messages: Sequence[LLMMessage], model: str = "gpt-4o", *, tools: Sequence[Tool | ToolSchema] = []
|
||||
) -> int:
|
||||
"""Count tokens for a list of messages using the appropriate client based on the model."""
|
||||
# Check if the model is an OpenAI model
|
||||
if "openai" in model.lower():
|
||||
return count_tokens_openai(messages, model)
|
||||
|
||||
# Check if the model is an Ollama model
|
||||
elif "llama" in model.lower():
|
||||
return count_tokens_ollama(messages, model)
|
||||
|
||||
# Fallback to cl100k_base encoding if the model is unrecognized
|
||||
else:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
total_tokens = 0
|
||||
for message in messages:
|
||||
total_tokens += len(encoding.encode(str(message.content)))
|
||||
return total_tokens
|
||||
@ -5,8 +5,9 @@ from autogen_core.model_context import (
|
||||
BufferedChatCompletionContext,
|
||||
HeadAndTailChatCompletionContext,
|
||||
UnboundedChatCompletionContext,
|
||||
TokenLimitedChatCompletionContext,
|
||||
)
|
||||
from autogen_core.models import AssistantMessage, LLMMessage, UserMessage
|
||||
from autogen_core.models import AssistantMessage, LLMMessage, UserMessage, FunctionExecutionResultMessage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -104,3 +105,82 @@ async def test_unbounded_model_context() -> None:
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 3
|
||||
assert retrieved == messages
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_limited_model_context_openai() -> None:
|
||||
model_context = TokenLimitedChatCompletionContext(token_limit=20, model="gpt-4o")
|
||||
messages: List[LLMMessage] = [
|
||||
UserMessage(content="Hello!", source="user"),
|
||||
AssistantMessage(content="What can I do for you?", source="assistant"),
|
||||
UserMessage(content="Tell what are some fun things to do in seattle.", source="user"),
|
||||
]
|
||||
for msg in messages:
|
||||
await model_context.add_message(msg)
|
||||
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 2 # Token limit set very low, will remove 1 of the messages
|
||||
assert retrieved != messages # Will not be equal to the original messages
|
||||
|
||||
await model_context.clear()
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 0
|
||||
|
||||
# Test saving and loading state.
|
||||
for msg in messages:
|
||||
await model_context.add_message(msg)
|
||||
state = await model_context.save_state()
|
||||
await model_context.clear()
|
||||
await model_context.load_state(state)
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 2
|
||||
assert retrieved != messages
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_limited_model_context_llama() -> None:
|
||||
model_context = TokenLimitedChatCompletionContext(token_limit=20, model="llama2-7b")
|
||||
messages: List[LLMMessage] = [
|
||||
UserMessage(content="Hello!", source="user"),
|
||||
AssistantMessage(content="What can I do for you?", source="assistant"),
|
||||
UserMessage(content="Tell what are some fun things to do in seattle.", source="user"),
|
||||
]
|
||||
for msg in messages:
|
||||
await model_context.add_message(msg)
|
||||
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 1 # Token limit set very low, will remove two of the messages
|
||||
assert retrieved != messages # Will not be equal to the original messages
|
||||
|
||||
await model_context.clear()
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 0
|
||||
|
||||
# Test saving and loading state.
|
||||
for msg in messages:
|
||||
await model_context.add_message(msg)
|
||||
state = await model_context.save_state()
|
||||
await model_context.clear()
|
||||
await model_context.load_state(state)
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 1
|
||||
assert retrieved != messages
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_limited_model_context_openai_with_function_result() -> None:
|
||||
model_context = TokenLimitedChatCompletionContext(token_limit=1000, model="gpt-4o")
|
||||
messages: List[LLMMessage] = [
|
||||
FunctionExecutionResultMessage(content=[]),
|
||||
UserMessage(content="Hello!", source="user"),
|
||||
AssistantMessage(content="What can I do for you?", source="assistant"),
|
||||
UserMessage(content="Tell what are some fun things to do in seattle.", source="user"),
|
||||
]
|
||||
for msg in messages:
|
||||
await model_context.add_message(msg)
|
||||
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 3 # Token limit set very low, will remove 1 of the messages
|
||||
assert type(retrieved[0]) == UserMessage # Function result should be removed
|
||||
assert type(retrieved[1]) == AssistantMessage
|
||||
assert type(retrieved[2]) == UserMessage
|
||||
|
||||
@ -378,6 +378,66 @@ def normalize_stop_reason(stop_reason: str | None) -> FinishReasons:
|
||||
return KNOWN_STOP_MAPPINGS.get(stop_reason, "unknown")
|
||||
|
||||
|
||||
# TODO: probably needs work
|
||||
def count_tokens_ollama(messages: Sequence[LLMMessage], model: str, *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
trace_logger.warning(f"Model {model} not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
tokens_per_message = 3
|
||||
num_tokens = 0
|
||||
|
||||
# Message tokens.
|
||||
for message in messages:
|
||||
num_tokens += tokens_per_message
|
||||
ollama_message = to_ollama_type(message)
|
||||
for ollama_message_part in ollama_message:
|
||||
if isinstance(message.content, Image):
|
||||
num_tokens += calculate_vision_tokens(message.content)
|
||||
elif ollama_message_part.content is not None:
|
||||
num_tokens += len(encoding.encode(ollama_message_part.content))
|
||||
# TODO: every model family has its own message sequence.
|
||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||
|
||||
# Tool tokens.
|
||||
ollama_tools = convert_tools(tools)
|
||||
for tool in ollama_tools:
|
||||
function = tool["function"]
|
||||
tool_tokens = len(encoding.encode(function["name"]))
|
||||
if "description" in function:
|
||||
tool_tokens += len(encoding.encode(function["description"]))
|
||||
tool_tokens -= 2
|
||||
if "parameters" in function:
|
||||
parameters = function["parameters"]
|
||||
if "properties" in parameters:
|
||||
assert isinstance(parameters["properties"], dict)
|
||||
for propertiesKey in parameters["properties"]: # pyright: ignore
|
||||
assert isinstance(propertiesKey, str)
|
||||
tool_tokens += len(encoding.encode(propertiesKey))
|
||||
v = parameters["properties"][propertiesKey] # pyright: ignore
|
||||
for field in v: # pyright: ignore
|
||||
if field == "type":
|
||||
tool_tokens += 2
|
||||
tool_tokens += len(encoding.encode(v["type"])) # pyright: ignore
|
||||
elif field == "description":
|
||||
tool_tokens += 2
|
||||
tool_tokens += len(encoding.encode(v["description"])) # pyright: ignore
|
||||
elif field == "enum":
|
||||
tool_tokens -= 3
|
||||
for o in v["enum"]: # pyright: ignore
|
||||
tool_tokens += 3
|
||||
tool_tokens += len(encoding.encode(o)) # pyright: ignore
|
||||
else:
|
||||
trace_logger.warning(f"Not supported field {field}")
|
||||
tool_tokens += 11
|
||||
if len(parameters["properties"]) == 0: # pyright: ignore
|
||||
tool_tokens -= 2
|
||||
num_tokens += tool_tokens
|
||||
num_tokens += 12
|
||||
return num_tokens
|
||||
|
||||
|
||||
@dataclass
|
||||
class CreateParams:
|
||||
messages: Sequence[Message]
|
||||
@ -796,65 +856,8 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
||||
def total_usage(self) -> RequestUsage:
|
||||
return self._total_usage
|
||||
|
||||
# TODO: probably needs work
|
||||
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||
model = self._create_args["model"]
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
trace_logger.warning(f"Model {model} not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
tokens_per_message = 3
|
||||
num_tokens = 0
|
||||
|
||||
# Message tokens.
|
||||
for message in messages:
|
||||
num_tokens += tokens_per_message
|
||||
ollama_message = to_ollama_type(message)
|
||||
for ollama_message_part in ollama_message:
|
||||
if isinstance(message.content, Image):
|
||||
num_tokens += calculate_vision_tokens(message.content)
|
||||
elif ollama_message_part.content is not None:
|
||||
num_tokens += len(encoding.encode(ollama_message_part.content))
|
||||
# TODO: every model family has its own message sequence.
|
||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||
|
||||
# Tool tokens.
|
||||
ollama_tools = convert_tools(tools)
|
||||
for tool in ollama_tools:
|
||||
function = tool["function"]
|
||||
tool_tokens = len(encoding.encode(function["name"]))
|
||||
if "description" in function:
|
||||
tool_tokens += len(encoding.encode(function["description"]))
|
||||
tool_tokens -= 2
|
||||
if "parameters" in function:
|
||||
parameters = function["parameters"]
|
||||
if "properties" in parameters:
|
||||
assert isinstance(parameters["properties"], dict)
|
||||
for propertiesKey in parameters["properties"]: # pyright: ignore
|
||||
assert isinstance(propertiesKey, str)
|
||||
tool_tokens += len(encoding.encode(propertiesKey))
|
||||
v = parameters["properties"][propertiesKey] # pyright: ignore
|
||||
for field in v: # pyright: ignore
|
||||
if field == "type":
|
||||
tool_tokens += 2
|
||||
tool_tokens += len(encoding.encode(v["type"])) # pyright: ignore
|
||||
elif field == "description":
|
||||
tool_tokens += 2
|
||||
tool_tokens += len(encoding.encode(v["description"])) # pyright: ignore
|
||||
elif field == "enum":
|
||||
tool_tokens -= 3
|
||||
for o in v["enum"]: # pyright: ignore
|
||||
tool_tokens += 3
|
||||
tool_tokens += len(encoding.encode(o)) # pyright: ignore
|
||||
else:
|
||||
trace_logger.warning(f"Not supported field {field}")
|
||||
tool_tokens += 11
|
||||
if len(parameters["properties"]) == 0: # pyright: ignore
|
||||
tool_tokens -= 2
|
||||
num_tokens += tool_tokens
|
||||
num_tokens += 12
|
||||
return num_tokens
|
||||
return count_tokens_ollama(messages, self._create_args["model"], tools=tools)
|
||||
|
||||
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||
token_limit = _model_info.get_token_limit(self._create_args["model"])
|
||||
|
||||
@ -373,6 +373,101 @@ def assert_valid_name(name: str) -> str:
|
||||
return name
|
||||
|
||||
|
||||
def count_tokens_openai(
|
||||
messages: Sequence[LLMMessage],
|
||||
model: str,
|
||||
*,
|
||||
add_name_prefixes: bool = False,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
) -> int:
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
trace_logger.warning(f"Model {model} not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
num_tokens = 0
|
||||
|
||||
# Message tokens.
|
||||
for message in messages:
|
||||
num_tokens += tokens_per_message
|
||||
oai_message = to_oai_type(message, prepend_name=add_name_prefixes)
|
||||
for oai_message_part in oai_message:
|
||||
for key, value in oai_message_part.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
if isinstance(message, UserMessage) and isinstance(value, list):
|
||||
typed_message_value = cast(List[ChatCompletionContentPartParam], value)
|
||||
|
||||
assert len(typed_message_value) == len(
|
||||
message.content
|
||||
), "Mismatch in message content and typed message value"
|
||||
|
||||
# We need image properties that are only in the original message
|
||||
for part, content_part in zip(typed_message_value, message.content, strict=False):
|
||||
if isinstance(content_part, Image):
|
||||
# TODO: add detail parameter
|
||||
num_tokens += calculate_vision_tokens(content_part)
|
||||
elif isinstance(part, str):
|
||||
num_tokens += len(encoding.encode(part))
|
||||
else:
|
||||
try:
|
||||
serialized_part = json.dumps(part)
|
||||
num_tokens += len(encoding.encode(serialized_part))
|
||||
except TypeError:
|
||||
trace_logger.warning(f"Could not convert {part} to string, skipping.")
|
||||
else:
|
||||
if not isinstance(value, str):
|
||||
try:
|
||||
value = json.dumps(value)
|
||||
except TypeError:
|
||||
trace_logger.warning(f"Could not convert {value} to string, skipping.")
|
||||
continue
|
||||
num_tokens += len(encoding.encode(value))
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||
|
||||
# Tool tokens.
|
||||
oai_tools = convert_tools(tools)
|
||||
for tool in oai_tools:
|
||||
function = tool["function"]
|
||||
tool_tokens = len(encoding.encode(function["name"]))
|
||||
if "description" in function:
|
||||
tool_tokens += len(encoding.encode(function["description"]))
|
||||
tool_tokens -= 2
|
||||
if "parameters" in function:
|
||||
parameters = function["parameters"]
|
||||
if "properties" in parameters:
|
||||
assert isinstance(parameters["properties"], dict)
|
||||
for propertiesKey in parameters["properties"]: # pyright: ignore
|
||||
assert isinstance(propertiesKey, str)
|
||||
tool_tokens += len(encoding.encode(propertiesKey))
|
||||
v = parameters["properties"][propertiesKey] # pyright: ignore
|
||||
for field in v: # pyright: ignore
|
||||
if field == "type":
|
||||
tool_tokens += 2
|
||||
tool_tokens += len(encoding.encode(v["type"])) # pyright: ignore
|
||||
elif field == "description":
|
||||
tool_tokens += 2
|
||||
tool_tokens += len(encoding.encode(v["description"])) # pyright: ignore
|
||||
elif field == "enum":
|
||||
tool_tokens -= 3
|
||||
for o in v["enum"]: # pyright: ignore
|
||||
tool_tokens += 3
|
||||
tool_tokens += len(encoding.encode(o)) # pyright: ignore
|
||||
else:
|
||||
trace_logger.warning(f"Not supported field {field}")
|
||||
tool_tokens += 11
|
||||
if len(parameters["properties"]) == 0: # pyright: ignore
|
||||
tool_tokens -= 2
|
||||
num_tokens += tool_tokens
|
||||
num_tokens += 12
|
||||
return num_tokens
|
||||
|
||||
|
||||
@dataclass
|
||||
class CreateParams:
|
||||
messages: List[ChatCompletionMessageParam]
|
||||
@ -1002,93 +1097,12 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
return self._total_usage
|
||||
|
||||
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||
model = self._create_args["model"]
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
trace_logger.warning(f"Model {model} not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
num_tokens = 0
|
||||
|
||||
# Message tokens.
|
||||
for message in messages:
|
||||
num_tokens += tokens_per_message
|
||||
oai_message = to_oai_type(message, prepend_name=self._add_name_prefixes)
|
||||
for oai_message_part in oai_message:
|
||||
for key, value in oai_message_part.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
if isinstance(message, UserMessage) and isinstance(value, list):
|
||||
typed_message_value = cast(List[ChatCompletionContentPartParam], value)
|
||||
|
||||
assert len(typed_message_value) == len(
|
||||
message.content
|
||||
), "Mismatch in message content and typed message value"
|
||||
|
||||
# We need image properties that are only in the original message
|
||||
for part, content_part in zip(typed_message_value, message.content, strict=False):
|
||||
if isinstance(content_part, Image):
|
||||
# TODO: add detail parameter
|
||||
num_tokens += calculate_vision_tokens(content_part)
|
||||
elif isinstance(part, str):
|
||||
num_tokens += len(encoding.encode(part))
|
||||
else:
|
||||
try:
|
||||
serialized_part = json.dumps(part)
|
||||
num_tokens += len(encoding.encode(serialized_part))
|
||||
except TypeError:
|
||||
trace_logger.warning(f"Could not convert {part} to string, skipping.")
|
||||
else:
|
||||
if not isinstance(value, str):
|
||||
try:
|
||||
value = json.dumps(value)
|
||||
except TypeError:
|
||||
trace_logger.warning(f"Could not convert {value} to string, skipping.")
|
||||
continue
|
||||
num_tokens += len(encoding.encode(value))
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||
|
||||
# Tool tokens.
|
||||
oai_tools = convert_tools(tools)
|
||||
for tool in oai_tools:
|
||||
function = tool["function"]
|
||||
tool_tokens = len(encoding.encode(function["name"]))
|
||||
if "description" in function:
|
||||
tool_tokens += len(encoding.encode(function["description"]))
|
||||
tool_tokens -= 2
|
||||
if "parameters" in function:
|
||||
parameters = function["parameters"]
|
||||
if "properties" in parameters:
|
||||
assert isinstance(parameters["properties"], dict)
|
||||
for propertiesKey in parameters["properties"]: # pyright: ignore
|
||||
assert isinstance(propertiesKey, str)
|
||||
tool_tokens += len(encoding.encode(propertiesKey))
|
||||
v = parameters["properties"][propertiesKey] # pyright: ignore
|
||||
for field in v: # pyright: ignore
|
||||
if field == "type":
|
||||
tool_tokens += 2
|
||||
tool_tokens += len(encoding.encode(v["type"])) # pyright: ignore
|
||||
elif field == "description":
|
||||
tool_tokens += 2
|
||||
tool_tokens += len(encoding.encode(v["description"])) # pyright: ignore
|
||||
elif field == "enum":
|
||||
tool_tokens -= 3
|
||||
for o in v["enum"]: # pyright: ignore
|
||||
tool_tokens += 3
|
||||
tool_tokens += len(encoding.encode(o)) # pyright: ignore
|
||||
else:
|
||||
trace_logger.warning(f"Not supported field {field}")
|
||||
tool_tokens += 11
|
||||
if len(parameters["properties"]) == 0: # pyright: ignore
|
||||
tool_tokens -= 2
|
||||
num_tokens += tool_tokens
|
||||
num_tokens += 12
|
||||
return num_tokens
|
||||
return count_tokens_openai(
|
||||
messages,
|
||||
self._create_args["model"],
|
||||
add_name_prefixes=self._add_name_prefixes,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||
token_limit = _model_info.get_token_limit(self._create_args["model"])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user