Add * before keyword args for ChatCompletionClient (#4822)

add * before keyword args

Co-authored-by: Leonardo Pinheiro <lpinheiro@microsoft.com>
Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
Leonardo Pinheiro 2024-12-27 23:41:16 +10:00 committed by GitHub
parent edad1b6065
commit 9a2dbb4fba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 20 additions and 12 deletions

View File

@ -29,6 +29,7 @@ class ChatCompletionClient(ABC, ComponentLoader):
async def create(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
# None means do not override the default
# A value means to override the client default - often specified in the constructor
@ -41,6 +42,7 @@ class ChatCompletionClient(ABC, ComponentLoader):
def create_stream(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
# None means do not override the default
# A value means to override the client default - often specified in the constructor
@ -56,10 +58,10 @@ class ChatCompletionClient(ABC, ComponentLoader):
def total_usage(self) -> RequestUsage: ...
@abstractmethod
def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: ...
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: ...
@abstractmethod
def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: ...
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: ...
@property
@abstractmethod

View File

@ -92,6 +92,7 @@ async def test_caller_loop() -> None:
async def create(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
@ -116,6 +117,7 @@ async def test_caller_loop() -> None:
def create_stream(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
@ -129,10 +131,10 @@ async def test_caller_loop() -> None:
def total_usage(self) -> RequestUsage:
return RequestUsage(prompt_tokens=0, completion_tokens=0)
def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
return 0
def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
return 0
@property

View File

@ -355,6 +355,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
def __init__(
self,
client: Union[AsyncOpenAI, AsyncAzureOpenAI],
*,
create_args: Dict[str, Any],
model_capabilities: Optional[ModelCapabilities] = None,
):
@ -389,6 +390,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
async def create(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
@ -581,11 +583,11 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
async def create_stream(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
*,
max_consecutive_empty_chunk_tolerance: int = 0,
) -> AsyncGenerator[Union[str, CreateResult], None]:
"""
@ -800,7 +802,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
def total_usage(self) -> RequestUsage:
return self._total_usage
def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
model = self._create_args["model"]
try:
encoding = tiktoken.encoding_for_model(model)
@ -889,9 +891,9 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
num_tokens += 12
return num_tokens
def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
token_limit = _model_info.get_token_limit(self._create_args["model"])
return token_limit - self.count_tokens(messages, tools)
return token_limit - self.count_tokens(messages, tools=tools)
@property
def capabilities(self) -> ModelCapabilities:
@ -974,7 +976,7 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA
client = _openai_client_from_config(copied_args)
create_args = _create_args_from_config(copied_args)
self._raw_config: Dict[str, Any] = copied_args
super().__init__(client, create_args, model_capabilities)
super().__init__(client=client, create_args=create_args, model_capabilities=model_capabilities)
def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
@ -1059,7 +1061,7 @@ class AzureOpenAIChatCompletionClient(
client = _azure_openai_client_from_config(copied_args)
create_args = _create_args_from_config(copied_args)
self._raw_config: Dict[str, Any] = copied_args
super().__init__(client, create_args, model_capabilities)
super().__init__(client=client, create_args=create_args, model_capabilities=model_capabilities)
def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()

View File

@ -128,6 +128,7 @@ class ReplayChatCompletionClient(ChatCompletionClient):
async def create(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
@ -155,6 +156,7 @@ class ReplayChatCompletionClient(ChatCompletionClient):
async def create_stream(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
@ -191,11 +193,11 @@ class ReplayChatCompletionClient(ChatCompletionClient):
def total_usage(self) -> RequestUsage:
return self._total_usage
def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
_, token_count = self._tokenize(messages)
return token_count
def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
return max(
0, self._total_available_tokens - self._total_usage.prompt_tokens - self._total_usage.completion_tokens
)