mirror of
				https://github.com/microsoft/autogen.git
				synced 2025-11-03 19:29:52 +00:00 
			
		
		
		
	Add * before keyword args for ChatCompletionClient (#4822)
add * before keyword args Co-authored-by: Leonardo Pinheiro <lpinheiro@microsoft.com> Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									edad1b6065
								
							
						
					
					
						commit
						9a2dbb4fba
					
				@ -29,6 +29,7 @@ class ChatCompletionClient(ABC, ComponentLoader):
 | 
			
		||||
    async def create(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[LLMMessage],
 | 
			
		||||
        *,
 | 
			
		||||
        tools: Sequence[Tool | ToolSchema] = [],
 | 
			
		||||
        # None means do not override the default
 | 
			
		||||
        # A value means to override the client default - often specified in the constructor
 | 
			
		||||
@ -41,6 +42,7 @@ class ChatCompletionClient(ABC, ComponentLoader):
 | 
			
		||||
    def create_stream(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[LLMMessage],
 | 
			
		||||
        *,
 | 
			
		||||
        tools: Sequence[Tool | ToolSchema] = [],
 | 
			
		||||
        # None means do not override the default
 | 
			
		||||
        # A value means to override the client default - often specified in the constructor
 | 
			
		||||
@ -56,10 +58,10 @@ class ChatCompletionClient(ABC, ComponentLoader):
 | 
			
		||||
    def total_usage(self) -> RequestUsage: ...
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: ...
 | 
			
		||||
    def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: ...
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: ...
 | 
			
		||||
    def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: ...
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
 | 
			
		||||
@ -92,6 +92,7 @@ async def test_caller_loop() -> None:
 | 
			
		||||
        async def create(
 | 
			
		||||
            self,
 | 
			
		||||
            messages: Sequence[LLMMessage],
 | 
			
		||||
            *,
 | 
			
		||||
            tools: Sequence[Tool | ToolSchema] = [],
 | 
			
		||||
            json_output: Optional[bool] = None,
 | 
			
		||||
            extra_create_args: Mapping[str, Any] = {},
 | 
			
		||||
@ -116,6 +117,7 @@ async def test_caller_loop() -> None:
 | 
			
		||||
        def create_stream(
 | 
			
		||||
            self,
 | 
			
		||||
            messages: Sequence[LLMMessage],
 | 
			
		||||
            *,
 | 
			
		||||
            tools: Sequence[Tool | ToolSchema] = [],
 | 
			
		||||
            json_output: Optional[bool] = None,
 | 
			
		||||
            extra_create_args: Mapping[str, Any] = {},
 | 
			
		||||
@ -129,10 +131,10 @@ async def test_caller_loop() -> None:
 | 
			
		||||
        def total_usage(self) -> RequestUsage:
 | 
			
		||||
            return RequestUsage(prompt_tokens=0, completion_tokens=0)
 | 
			
		||||
 | 
			
		||||
        def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
 | 
			
		||||
        def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
 | 
			
		||||
            return 0
 | 
			
		||||
 | 
			
		||||
        def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
 | 
			
		||||
        def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
 | 
			
		||||
            return 0
 | 
			
		||||
 | 
			
		||||
        @property
 | 
			
		||||
 | 
			
		||||
@ -355,6 +355,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        client: Union[AsyncOpenAI, AsyncAzureOpenAI],
 | 
			
		||||
        *,
 | 
			
		||||
        create_args: Dict[str, Any],
 | 
			
		||||
        model_capabilities: Optional[ModelCapabilities] = None,
 | 
			
		||||
    ):
 | 
			
		||||
@ -389,6 +390,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
 | 
			
		||||
    async def create(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[LLMMessage],
 | 
			
		||||
        *,
 | 
			
		||||
        tools: Sequence[Tool | ToolSchema] = [],
 | 
			
		||||
        json_output: Optional[bool] = None,
 | 
			
		||||
        extra_create_args: Mapping[str, Any] = {},
 | 
			
		||||
@ -581,11 +583,11 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
 | 
			
		||||
    async def create_stream(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[LLMMessage],
 | 
			
		||||
        *,
 | 
			
		||||
        tools: Sequence[Tool | ToolSchema] = [],
 | 
			
		||||
        json_output: Optional[bool] = None,
 | 
			
		||||
        extra_create_args: Mapping[str, Any] = {},
 | 
			
		||||
        cancellation_token: Optional[CancellationToken] = None,
 | 
			
		||||
        *,
 | 
			
		||||
        max_consecutive_empty_chunk_tolerance: int = 0,
 | 
			
		||||
    ) -> AsyncGenerator[Union[str, CreateResult], None]:
 | 
			
		||||
        """
 | 
			
		||||
@ -800,7 +802,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
 | 
			
		||||
    def total_usage(self) -> RequestUsage:
 | 
			
		||||
        return self._total_usage
 | 
			
		||||
 | 
			
		||||
    def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
 | 
			
		||||
    def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
 | 
			
		||||
        model = self._create_args["model"]
 | 
			
		||||
        try:
 | 
			
		||||
            encoding = tiktoken.encoding_for_model(model)
 | 
			
		||||
@ -889,9 +891,9 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
 | 
			
		||||
        num_tokens += 12
 | 
			
		||||
        return num_tokens
 | 
			
		||||
 | 
			
		||||
    def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
 | 
			
		||||
    def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
 | 
			
		||||
        token_limit = _model_info.get_token_limit(self._create_args["model"])
 | 
			
		||||
        return token_limit - self.count_tokens(messages, tools)
 | 
			
		||||
        return token_limit - self.count_tokens(messages, tools=tools)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def capabilities(self) -> ModelCapabilities:
 | 
			
		||||
@ -974,7 +976,7 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA
 | 
			
		||||
        client = _openai_client_from_config(copied_args)
 | 
			
		||||
        create_args = _create_args_from_config(copied_args)
 | 
			
		||||
        self._raw_config: Dict[str, Any] = copied_args
 | 
			
		||||
        super().__init__(client, create_args, model_capabilities)
 | 
			
		||||
        super().__init__(client=client, create_args=create_args, model_capabilities=model_capabilities)
 | 
			
		||||
 | 
			
		||||
    def __getstate__(self) -> Dict[str, Any]:
 | 
			
		||||
        state = self.__dict__.copy()
 | 
			
		||||
@ -1059,7 +1061,7 @@ class AzureOpenAIChatCompletionClient(
 | 
			
		||||
        client = _azure_openai_client_from_config(copied_args)
 | 
			
		||||
        create_args = _create_args_from_config(copied_args)
 | 
			
		||||
        self._raw_config: Dict[str, Any] = copied_args
 | 
			
		||||
        super().__init__(client, create_args, model_capabilities)
 | 
			
		||||
        super().__init__(client=client, create_args=create_args, model_capabilities=model_capabilities)
 | 
			
		||||
 | 
			
		||||
    def __getstate__(self) -> Dict[str, Any]:
 | 
			
		||||
        state = self.__dict__.copy()
 | 
			
		||||
 | 
			
		||||
@ -128,6 +128,7 @@ class ReplayChatCompletionClient(ChatCompletionClient):
 | 
			
		||||
    async def create(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[LLMMessage],
 | 
			
		||||
        *,
 | 
			
		||||
        tools: Sequence[Tool | ToolSchema] = [],
 | 
			
		||||
        json_output: Optional[bool] = None,
 | 
			
		||||
        extra_create_args: Mapping[str, Any] = {},
 | 
			
		||||
@ -155,6 +156,7 @@ class ReplayChatCompletionClient(ChatCompletionClient):
 | 
			
		||||
    async def create_stream(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[LLMMessage],
 | 
			
		||||
        *,
 | 
			
		||||
        tools: Sequence[Tool | ToolSchema] = [],
 | 
			
		||||
        json_output: Optional[bool] = None,
 | 
			
		||||
        extra_create_args: Mapping[str, Any] = {},
 | 
			
		||||
@ -191,11 +193,11 @@ class ReplayChatCompletionClient(ChatCompletionClient):
 | 
			
		||||
    def total_usage(self) -> RequestUsage:
 | 
			
		||||
        return self._total_usage
 | 
			
		||||
 | 
			
		||||
    def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
 | 
			
		||||
    def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
 | 
			
		||||
        _, token_count = self._tokenize(messages)
 | 
			
		||||
        return token_count
 | 
			
		||||
 | 
			
		||||
    def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
 | 
			
		||||
    def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
 | 
			
		||||
        return max(
 | 
			
		||||
            0, self._total_available_tokens - self._total_usage.prompt_tokens - self._total_usage.completion_tokens
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user