mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-02 02:40:21 +00:00
Add * before keyword args for ChatCompletionClient (#4822)
add * before keyword args Co-authored-by: Leonardo Pinheiro <lpinheiro@microsoft.com> Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
parent
edad1b6065
commit
9a2dbb4fba
@ -29,6 +29,7 @@ class ChatCompletionClient(ABC, ComponentLoader):
|
||||
async def create(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
# None means do not override the default
|
||||
# A value means to override the client default - often specified in the constructor
|
||||
@ -41,6 +42,7 @@ class ChatCompletionClient(ABC, ComponentLoader):
|
||||
def create_stream(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
# None means do not override the default
|
||||
# A value means to override the client default - often specified in the constructor
|
||||
@ -56,10 +58,10 @@ class ChatCompletionClient(ABC, ComponentLoader):
|
||||
def total_usage(self) -> RequestUsage: ...
|
||||
|
||||
@abstractmethod
|
||||
def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: ...
|
||||
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: ...
|
||||
|
||||
@abstractmethod
|
||||
def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: ...
|
||||
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: ...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
|
||||
@ -92,6 +92,7 @@ async def test_caller_loop() -> None:
|
||||
async def create(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
@ -116,6 +117,7 @@ async def test_caller_loop() -> None:
|
||||
def create_stream(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
@ -129,10 +131,10 @@ async def test_caller_loop() -> None:
|
||||
def total_usage(self) -> RequestUsage:
|
||||
return RequestUsage(prompt_tokens=0, completion_tokens=0)
|
||||
|
||||
def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||
return 0
|
||||
|
||||
def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||
return 0
|
||||
|
||||
@property
|
||||
|
||||
@ -355,6 +355,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
def __init__(
|
||||
self,
|
||||
client: Union[AsyncOpenAI, AsyncAzureOpenAI],
|
||||
*,
|
||||
create_args: Dict[str, Any],
|
||||
model_capabilities: Optional[ModelCapabilities] = None,
|
||||
):
|
||||
@ -389,6 +390,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
async def create(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
@ -581,11 +583,11 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
async def create_stream(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
*,
|
||||
max_consecutive_empty_chunk_tolerance: int = 0,
|
||||
) -> AsyncGenerator[Union[str, CreateResult], None]:
|
||||
"""
|
||||
@ -800,7 +802,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
def total_usage(self) -> RequestUsage:
|
||||
return self._total_usage
|
||||
|
||||
def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||
model = self._create_args["model"]
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
@ -889,9 +891,9 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
num_tokens += 12
|
||||
return num_tokens
|
||||
|
||||
def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||
token_limit = _model_info.get_token_limit(self._create_args["model"])
|
||||
return token_limit - self.count_tokens(messages, tools)
|
||||
return token_limit - self.count_tokens(messages, tools=tools)
|
||||
|
||||
@property
|
||||
def capabilities(self) -> ModelCapabilities:
|
||||
@ -974,7 +976,7 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA
|
||||
client = _openai_client_from_config(copied_args)
|
||||
create_args = _create_args_from_config(copied_args)
|
||||
self._raw_config: Dict[str, Any] = copied_args
|
||||
super().__init__(client, create_args, model_capabilities)
|
||||
super().__init__(client=client, create_args=create_args, model_capabilities=model_capabilities)
|
||||
|
||||
def __getstate__(self) -> Dict[str, Any]:
|
||||
state = self.__dict__.copy()
|
||||
@ -1059,7 +1061,7 @@ class AzureOpenAIChatCompletionClient(
|
||||
client = _azure_openai_client_from_config(copied_args)
|
||||
create_args = _create_args_from_config(copied_args)
|
||||
self._raw_config: Dict[str, Any] = copied_args
|
||||
super().__init__(client, create_args, model_capabilities)
|
||||
super().__init__(client=client, create_args=create_args, model_capabilities=model_capabilities)
|
||||
|
||||
def __getstate__(self) -> Dict[str, Any]:
|
||||
state = self.__dict__.copy()
|
||||
|
||||
@ -128,6 +128,7 @@ class ReplayChatCompletionClient(ChatCompletionClient):
|
||||
async def create(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
@ -155,6 +156,7 @@ class ReplayChatCompletionClient(ChatCompletionClient):
|
||||
async def create_stream(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
@ -191,11 +193,11 @@ class ReplayChatCompletionClient(ChatCompletionClient):
|
||||
def total_usage(self) -> RequestUsage:
|
||||
return self._total_usage
|
||||
|
||||
def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||
_, token_count = self._tokenize(messages)
|
||||
return token_count
|
||||
|
||||
def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||
return max(
|
||||
0, self._total_available_tokens - self._total_usage.prompt_tokens - self._total_usage.completion_tokens
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user