mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-14 09:03:53 +00:00
Add * before keyword args for ChatCompletionClient (#4822)
add * before keyword args Co-authored-by: Leonardo Pinheiro <lpinheiro@microsoft.com> Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
parent
edad1b6065
commit
9a2dbb4fba
@ -29,6 +29,7 @@ class ChatCompletionClient(ABC, ComponentLoader):
|
|||||||
async def create(
|
async def create(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[LLMMessage],
|
messages: Sequence[LLMMessage],
|
||||||
|
*,
|
||||||
tools: Sequence[Tool | ToolSchema] = [],
|
tools: Sequence[Tool | ToolSchema] = [],
|
||||||
# None means do not override the default
|
# None means do not override the default
|
||||||
# A value means to override the client default - often specified in the constructor
|
# A value means to override the client default - often specified in the constructor
|
||||||
@ -41,6 +42,7 @@ class ChatCompletionClient(ABC, ComponentLoader):
|
|||||||
def create_stream(
|
def create_stream(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[LLMMessage],
|
messages: Sequence[LLMMessage],
|
||||||
|
*,
|
||||||
tools: Sequence[Tool | ToolSchema] = [],
|
tools: Sequence[Tool | ToolSchema] = [],
|
||||||
# None means do not override the default
|
# None means do not override the default
|
||||||
# A value means to override the client default - often specified in the constructor
|
# A value means to override the client default - often specified in the constructor
|
||||||
@ -56,10 +58,10 @@ class ChatCompletionClient(ABC, ComponentLoader):
|
|||||||
def total_usage(self) -> RequestUsage: ...
|
def total_usage(self) -> RequestUsage: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: ...
|
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: ...
|
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: ...
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@ -92,6 +92,7 @@ async def test_caller_loop() -> None:
|
|||||||
async def create(
|
async def create(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[LLMMessage],
|
messages: Sequence[LLMMessage],
|
||||||
|
*,
|
||||||
tools: Sequence[Tool | ToolSchema] = [],
|
tools: Sequence[Tool | ToolSchema] = [],
|
||||||
json_output: Optional[bool] = None,
|
json_output: Optional[bool] = None,
|
||||||
extra_create_args: Mapping[str, Any] = {},
|
extra_create_args: Mapping[str, Any] = {},
|
||||||
@ -116,6 +117,7 @@ async def test_caller_loop() -> None:
|
|||||||
def create_stream(
|
def create_stream(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[LLMMessage],
|
messages: Sequence[LLMMessage],
|
||||||
|
*,
|
||||||
tools: Sequence[Tool | ToolSchema] = [],
|
tools: Sequence[Tool | ToolSchema] = [],
|
||||||
json_output: Optional[bool] = None,
|
json_output: Optional[bool] = None,
|
||||||
extra_create_args: Mapping[str, Any] = {},
|
extra_create_args: Mapping[str, Any] = {},
|
||||||
@ -129,10 +131,10 @@ async def test_caller_loop() -> None:
|
|||||||
def total_usage(self) -> RequestUsage:
|
def total_usage(self) -> RequestUsage:
|
||||||
return RequestUsage(prompt_tokens=0, completion_tokens=0)
|
return RequestUsage(prompt_tokens=0, completion_tokens=0)
|
||||||
|
|
||||||
def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
|
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
|
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@ -355,6 +355,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
client: Union[AsyncOpenAI, AsyncAzureOpenAI],
|
client: Union[AsyncOpenAI, AsyncAzureOpenAI],
|
||||||
|
*,
|
||||||
create_args: Dict[str, Any],
|
create_args: Dict[str, Any],
|
||||||
model_capabilities: Optional[ModelCapabilities] = None,
|
model_capabilities: Optional[ModelCapabilities] = None,
|
||||||
):
|
):
|
||||||
@ -389,6 +390,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||||||
async def create(
|
async def create(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[LLMMessage],
|
messages: Sequence[LLMMessage],
|
||||||
|
*,
|
||||||
tools: Sequence[Tool | ToolSchema] = [],
|
tools: Sequence[Tool | ToolSchema] = [],
|
||||||
json_output: Optional[bool] = None,
|
json_output: Optional[bool] = None,
|
||||||
extra_create_args: Mapping[str, Any] = {},
|
extra_create_args: Mapping[str, Any] = {},
|
||||||
@ -581,11 +583,11 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||||||
async def create_stream(
|
async def create_stream(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[LLMMessage],
|
messages: Sequence[LLMMessage],
|
||||||
|
*,
|
||||||
tools: Sequence[Tool | ToolSchema] = [],
|
tools: Sequence[Tool | ToolSchema] = [],
|
||||||
json_output: Optional[bool] = None,
|
json_output: Optional[bool] = None,
|
||||||
extra_create_args: Mapping[str, Any] = {},
|
extra_create_args: Mapping[str, Any] = {},
|
||||||
cancellation_token: Optional[CancellationToken] = None,
|
cancellation_token: Optional[CancellationToken] = None,
|
||||||
*,
|
|
||||||
max_consecutive_empty_chunk_tolerance: int = 0,
|
max_consecutive_empty_chunk_tolerance: int = 0,
|
||||||
) -> AsyncGenerator[Union[str, CreateResult], None]:
|
) -> AsyncGenerator[Union[str, CreateResult], None]:
|
||||||
"""
|
"""
|
||||||
@ -800,7 +802,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||||||
def total_usage(self) -> RequestUsage:
|
def total_usage(self) -> RequestUsage:
|
||||||
return self._total_usage
|
return self._total_usage
|
||||||
|
|
||||||
def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
|
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||||
model = self._create_args["model"]
|
model = self._create_args["model"]
|
||||||
try:
|
try:
|
||||||
encoding = tiktoken.encoding_for_model(model)
|
encoding = tiktoken.encoding_for_model(model)
|
||||||
@ -889,9 +891,9 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||||||
num_tokens += 12
|
num_tokens += 12
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
|
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||||
token_limit = _model_info.get_token_limit(self._create_args["model"])
|
token_limit = _model_info.get_token_limit(self._create_args["model"])
|
||||||
return token_limit - self.count_tokens(messages, tools)
|
return token_limit - self.count_tokens(messages, tools=tools)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def capabilities(self) -> ModelCapabilities:
|
def capabilities(self) -> ModelCapabilities:
|
||||||
@ -974,7 +976,7 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA
|
|||||||
client = _openai_client_from_config(copied_args)
|
client = _openai_client_from_config(copied_args)
|
||||||
create_args = _create_args_from_config(copied_args)
|
create_args = _create_args_from_config(copied_args)
|
||||||
self._raw_config: Dict[str, Any] = copied_args
|
self._raw_config: Dict[str, Any] = copied_args
|
||||||
super().__init__(client, create_args, model_capabilities)
|
super().__init__(client=client, create_args=create_args, model_capabilities=model_capabilities)
|
||||||
|
|
||||||
def __getstate__(self) -> Dict[str, Any]:
|
def __getstate__(self) -> Dict[str, Any]:
|
||||||
state = self.__dict__.copy()
|
state = self.__dict__.copy()
|
||||||
@ -1059,7 +1061,7 @@ class AzureOpenAIChatCompletionClient(
|
|||||||
client = _azure_openai_client_from_config(copied_args)
|
client = _azure_openai_client_from_config(copied_args)
|
||||||
create_args = _create_args_from_config(copied_args)
|
create_args = _create_args_from_config(copied_args)
|
||||||
self._raw_config: Dict[str, Any] = copied_args
|
self._raw_config: Dict[str, Any] = copied_args
|
||||||
super().__init__(client, create_args, model_capabilities)
|
super().__init__(client=client, create_args=create_args, model_capabilities=model_capabilities)
|
||||||
|
|
||||||
def __getstate__(self) -> Dict[str, Any]:
|
def __getstate__(self) -> Dict[str, Any]:
|
||||||
state = self.__dict__.copy()
|
state = self.__dict__.copy()
|
||||||
|
|||||||
@ -128,6 +128,7 @@ class ReplayChatCompletionClient(ChatCompletionClient):
|
|||||||
async def create(
|
async def create(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[LLMMessage],
|
messages: Sequence[LLMMessage],
|
||||||
|
*,
|
||||||
tools: Sequence[Tool | ToolSchema] = [],
|
tools: Sequence[Tool | ToolSchema] = [],
|
||||||
json_output: Optional[bool] = None,
|
json_output: Optional[bool] = None,
|
||||||
extra_create_args: Mapping[str, Any] = {},
|
extra_create_args: Mapping[str, Any] = {},
|
||||||
@ -155,6 +156,7 @@ class ReplayChatCompletionClient(ChatCompletionClient):
|
|||||||
async def create_stream(
|
async def create_stream(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[LLMMessage],
|
messages: Sequence[LLMMessage],
|
||||||
|
*,
|
||||||
tools: Sequence[Tool | ToolSchema] = [],
|
tools: Sequence[Tool | ToolSchema] = [],
|
||||||
json_output: Optional[bool] = None,
|
json_output: Optional[bool] = None,
|
||||||
extra_create_args: Mapping[str, Any] = {},
|
extra_create_args: Mapping[str, Any] = {},
|
||||||
@ -191,11 +193,11 @@ class ReplayChatCompletionClient(ChatCompletionClient):
|
|||||||
def total_usage(self) -> RequestUsage:
|
def total_usage(self) -> RequestUsage:
|
||||||
return self._total_usage
|
return self._total_usage
|
||||||
|
|
||||||
def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
|
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||||
_, token_count = self._tokenize(messages)
|
_, token_count = self._tokenize(messages)
|
||||||
return token_count
|
return token_count
|
||||||
|
|
||||||
def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
|
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||||
return max(
|
return max(
|
||||||
0, self._total_available_tokens - self._total_usage.prompt_tokens - self._total_usage.completion_tokens
|
0, self._total_available_tokens - self._total_usage.prompt_tokens - self._total_usage.completion_tokens
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user