add model client component (#2)

This commit is contained in:
Jack Gerrits 2024-05-17 11:09:59 -04:00 committed by GitHub
parent 02675126be
commit ebee8b856f
14 changed files with 867 additions and 10 deletions

View File

@ -1,5 +1,11 @@
# AutoGenNext
## Package layering
- `core` are the the foundational generic interfaces upon which all else is built. This module must not depend on any other module.
- `agent_components` are the building blocks for creating agents
- `core_components` are implementations of core components that are used to compose an application
## Development
### Setup

View File

@ -1,11 +1,11 @@
import asyncio
from dataclasses import dataclass
from agnext.agent_components.type_routed_agent import TypeRoutedAgent, event_handler
from agnext.application_components.single_threaded_agent_runtime import SingleThreadedAgentRuntime
from agnext.core.agent import Agent
from agnext.core.agent_runtime import AgentRuntime
from agnext.core.message import Message
from agnext.single_threaded_agent_runtime import SingleThreadedAgentRuntime
from agnext.type_routed_agent import TypeRoutedAgent, event_handler
@dataclass

View File

@ -13,9 +13,15 @@ classifiers = [
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
dependencies = [
"openai>=1.3",
"pillow",
"aiohttp",
"typing-extensions"
]
[project.optional-dependencies]
dev = ["ruff", "pyright", "mypy", "pytest"]
dev = ["ruff", "pyright", "mypy", "pytest", "types-Pillow"]
[tool.setuptools.package-data]
agnext = ["py.typed"]

View File

View File

@ -0,0 +1,78 @@
from __future__ import annotations
import base64
import re
from io import BytesIO
from pathlib import Path
import aiohttp
from openai.types.chat import ChatCompletionContentPartImageParam
from PIL import Image as PILImage
from typing_extensions import Literal
class Image:
def __init__(self, image: PILImage.Image):
self.image: PILImage.Image = image.convert("RGB")
@classmethod
def from_pil(cls, pil_image: PILImage.Image) -> Image:
return cls(pil_image)
@classmethod
def from_uri(cls, uri: str) -> Image:
if not re.match(r"data:image/(?:png|jpeg);base64,", uri):
raise ValueError("Invalid URI format. It should be a base64 encoded image URI.")
# A URI. Remove the prefix and decode the base64 string.
base64_data = re.sub(r"data:image/(?:png|jpeg);base64,", "", uri)
return cls.from_base64(base64_data)
@classmethod
async def from_url(cls, url: str) -> Image:
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
content = await response.read()
return cls(PILImage.open(content))
@classmethod
def from_base64(cls, base64_str: str) -> Image:
return cls(PILImage.open(BytesIO(base64.b64decode(base64_str))))
@classmethod
def from_file(cls, file_path: Path) -> Image:
return cls(PILImage.open(file_path))
def _repr_html_(self) -> str:
# Show the image in Jupyter notebook
return f'<img src="{self.data_uri}"/>'
@property
def data_uri(self) -> str:
buffered = BytesIO()
self.image.save(buffered, format="PNG")
content = buffered.getvalue()
return _convert_base64_to_data_uri(base64.b64encode(content).decode("utf-8"))
def to_openai_format(self, detail: Literal["auto", "low", "high"] = "auto") -> ChatCompletionContentPartImageParam:
return {"type": "image_url", "image_url": {"url": self.data_uri, "detail": detail}}
def _convert_base64_to_data_uri(base64_image: str) -> str:
def _get_mime_type_from_data_uri(base64_image: str) -> str:
# Decode the base64 string
image_data = base64.b64decode(base64_image)
# Check the first few bytes for known signatures
if image_data.startswith(b"\xff\xd8\xff"):
return "image/jpeg"
elif image_data.startswith(b"\x89PNG\r\n\x1a\n"):
return "image/png"
elif image_data.startswith(b"GIF87a") or image_data.startswith(b"GIF89a"):
return "image/gif"
elif image_data.startswith(b"RIFF") and image_data[8:12] == b"WEBP":
return "image/webp"
return "image/jpeg" # use jpeg for unknown formats, best guess.
mime_type = _get_mime_type_from_data_uri(base64_image)
data_uri = f"data:{mime_type};base64,{base64_image}"
return data_uri

View File

@ -0,0 +1,44 @@
from __future__ import annotations
from typing import Mapping, Optional, Sequence, runtime_checkable
from typing_extensions import Any, AsyncGenerator, List, Protocol, Required, TypedDict, Union
from .types import CreateResult, FunctionDefinition, LLMMessage, RequestUsage
class ModelCapabilities(TypedDict, total=False):
vision: Required[bool]
function_calling: Required[bool]
json_output: Required[bool]
@runtime_checkable
class ModelClient(Protocol):
# Caching has to be handled internally as they can depend on the create args that were stored in the constructor
async def create(
self,
messages: List[LLMMessage],
functions: Sequence[FunctionDefinition] = [],
# None means do not override the default
# A value means to override the client default - often specified in the constructor
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
) -> CreateResult: ...
def create_stream(
self,
messages: List[LLMMessage],
functions: Sequence[FunctionDefinition] = [],
# None means do not override the default
# A value means to override the client default - often specified in the constructor
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
) -> AsyncGenerator[Union[str, CreateResult], None]: ...
def actual_usage(self) -> RequestUsage: ...
def total_usage(self) -> RequestUsage: ...
@property
def capabilities(self) -> ModelCapabilities: ...

View File

@ -0,0 +1,3 @@
from .openai_client import AzureOpenAI, OpenAI
__all__ = ("OpenAI", "AzureOpenAI")

View File

@ -0,0 +1,83 @@
from typing import Dict
from ..model_client import ModelCapabilities
# Based on: https://platform.openai.com/docs/models/continuous-model-upgrades
# This is a moving target, so correctness is checked by the model value returned by openai against expected values at runtime``
_MODEL_POINTERS = {
"gpt-4-turbo": "gpt-4-turbo-2024-04-09",
"gpt-4-turbo-preview": "gpt-4-0125-preview",
"gpt-4": "gpt-4-0613",
"gpt-4-32k": "gpt-4-32k-0613",
"gpt-3.5-turbo": "gpt-3.5-turbo-0125",
"gpt-3.5-turbo-16k": "gpt-3.5-turbo-16k-0613",
}
_MODEL_CAPABILITIES: Dict[str, ModelCapabilities] = {
"gpt-4-turbo-2024-04-09": {
"vision": True,
"function_calling": True,
"json_output": True,
},
"gpt-4-0125-preview": {
"vision": False,
"function_calling": True,
"json_output": True,
},
"gpt-4-1106-preview": {
"vision": False,
"function_calling": True,
"json_output": True,
},
"gpt-4-1106-vision-preview": {
"vision": True,
"function_calling": False,
"json_output": False,
},
"gpt-4-0613": {
"vision": False,
"function_calling": True,
"json_output": True,
},
"gpt-4-32k-0613": {
"vision": False,
"function_calling": True,
"json_output": True,
},
"gpt-3.5-turbo-0125": {
"vision": False,
"function_calling": True,
"json_output": True,
},
"gpt-3.5-turbo-1106": {
"vision": False,
"function_calling": True,
"json_output": True,
},
"gpt-3.5-turbo-instruct": {
"vision": False,
"function_calling": True,
"json_output": True,
},
"gpt-3.5-turbo-0613": {
"vision": False,
"function_calling": True,
"json_output": True,
},
"gpt-3.5-turbo-16k-0613": {
"vision": False,
"function_calling": True,
"json_output": True,
},
}
def resolve_model(model: str) -> str:
if model in _MODEL_POINTERS:
return _MODEL_POINTERS[model]
return model
def get_capabilties(model: str) -> ModelCapabilities:
resolved_model = resolve_model(model)
return _MODEL_CAPABILITIES[resolved_model]

View File

@ -0,0 +1,562 @@
import inspect
import warnings
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
Dict,
List,
Literal,
Mapping,
Optional,
Sequence,
Set,
Union,
)
from openai import AsyncAzureOpenAI, AsyncOpenAI
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionContentPartParam,
ChatCompletionContentPartTextParam,
ChatCompletionMessageParam,
ChatCompletionMessageToolCallParam,
ChatCompletionRole,
ChatCompletionSystemMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionToolParam,
ChatCompletionUserMessageParam,
completion_create_params,
)
from typing_extensions import Required, TypedDict, Unpack
# from ..._pydantic import type2schema
from ..image import Image
from ..model_client import ModelCapabilities, ModelClient
from ..types import (
AssistantMessage,
CreateResult,
FunctionCall,
FunctionDefinition,
FunctionExecutionResultMessage,
LLMMessage,
RequestUsage,
SystemMessage,
UserMessage,
)
from . import model_info
openai_init_kwargs = set(inspect.getfullargspec(AsyncOpenAI.__init__).kwonlyargs)
aopenai_init_kwargs = set(inspect.getfullargspec(AsyncAzureOpenAI.__init__).kwonlyargs)
create_kwargs = set(completion_create_params.CompletionCreateParamsBase.__annotations__.keys()) | set(
("timeout", "stream")
)
# Only single choice allowed
disallowed_create_args = set(["stream", "messages", "function_call", "functions", "n"])
required_create_args: Set[str] = set(["model"])
def _azure_openai_client_from_config(config: Mapping[str, Any]) -> AsyncAzureOpenAI:
# Take a copy
copied_config = dict(config).copy()
# Do some fixups
copied_config["azure_deployment"] = copied_config.get("azure_deployment", config.get("model"))
if copied_config["azure_deployment"] is not None:
copied_config["azure_deployment"] = copied_config["azure_deployment"].replace(".", "")
copied_config["azure_endpoint"] = copied_config.get("azure_endpoint", copied_config.pop("base_url", None))
# Shave down the config to just the AzureOpenAI kwargs
azure_config = {k: v for k, v in copied_config.items() if k in aopenai_init_kwargs}
return AsyncAzureOpenAI(**azure_config)
def _openai_client_from_config(config: Mapping[str, Any]) -> AsyncOpenAI:
# Shave down the config to just the OpenAI kwargs
openai_config = {k: v for k, v in config.items() if k in openai_init_kwargs}
return AsyncOpenAI(**openai_config)
def _create_args_from_config(config: Mapping[str, Any]) -> Dict[str, Any]:
create_args = {k: v for k, v in config.items() if k in create_kwargs}
create_args_keys = set(create_args.keys())
if not required_create_args.issubset(create_args_keys):
raise ValueError(f"Required create args are missing: {required_create_args - create_args_keys}")
if disallowed_create_args.intersection(create_args_keys):
raise ValueError(f"Disallowed create args are present: {disallowed_create_args.intersection(create_args_keys)}")
return create_args
# TODO check types
# oai_system_message_schema = type2schema(ChatCompletionSystemMessageParam)
# oai_user_message_schema = type2schema(ChatCompletionUserMessageParam)
# oai_assistant_message_schema = type2schema(ChatCompletionAssistantMessageParam)
# oai_tool_message_schema = type2schema(ChatCompletionToolMessageParam)
def type_to_role(message: LLMMessage) -> ChatCompletionRole:
if isinstance(message, SystemMessage):
return "system"
elif isinstance(message, UserMessage):
return "user"
elif isinstance(message, AssistantMessage):
return "assistant"
else:
return "tool"
def user_message_to_oai(message: UserMessage) -> ChatCompletionUserMessageParam:
if isinstance(message.content, str):
return ChatCompletionUserMessageParam(
content=message.content,
role="user",
name=message.source,
)
else:
parts: List[ChatCompletionContentPartParam] = []
for part in message.content:
if isinstance(part, str):
oai_part = ChatCompletionContentPartTextParam(
text=part,
type="text",
)
parts.append(oai_part)
elif isinstance(part, Image):
# TODO: support url based images
# TODO: support specifying details
parts.append(part.to_openai_format())
else:
raise ValueError(f"Unknown content type: {part}")
return ChatCompletionUserMessageParam(
content=parts,
role="user",
name=message.source,
)
def system_message_to_oai(message: SystemMessage) -> ChatCompletionSystemMessageParam:
return ChatCompletionSystemMessageParam(
content=message.content,
role="system",
)
def func_call_to_oai(message: FunctionCall) -> ChatCompletionMessageToolCallParam:
return ChatCompletionMessageToolCallParam(
id=message.id,
function={
"arguments": message.arguments,
"name": message.name,
},
type="function",
)
def tool_message_to_oai(message: FunctionExecutionResultMessage) -> Sequence[ChatCompletionToolMessageParam]:
return [
ChatCompletionToolMessageParam(content=x.content, role="tool", tool_call_id=x.call_id) for x in message.content
]
def assistant_message_to_oai(message: AssistantMessage) -> ChatCompletionAssistantMessageParam:
if isinstance(message.content, list):
return ChatCompletionAssistantMessageParam(
tool_calls=[func_call_to_oai(x) for x in message.content],
role="assistant",
name=message.source,
)
else:
return ChatCompletionAssistantMessageParam(
content=message.content,
role="assistant",
name=message.source,
)
def to_oai_type(message: LLMMessage) -> Sequence[ChatCompletionMessageParam]:
if isinstance(message, SystemMessage):
return [system_message_to_oai(message)]
elif isinstance(message, UserMessage):
return [user_message_to_oai(message)]
elif isinstance(message, AssistantMessage):
return [assistant_message_to_oai(message)]
else:
return tool_message_to_oai(message)
def _add_usage(usage1: RequestUsage, usage2: RequestUsage) -> RequestUsage:
return RequestUsage(
prompt_tokens=usage1.prompt_tokens + usage2.prompt_tokens,
completion_tokens=usage1.completion_tokens + usage2.completion_tokens,
)
class ResponseFormat(TypedDict):
type: Literal["text", "json_object"]
class CreateArguments(TypedDict, total=False):
frequency_penalty: Optional[float]
logit_bias: Optional[Dict[str, int]]
max_tokens: Optional[int]
n: Optional[int]
presence_penalty: Optional[float]
response_format: ResponseFormat
seed: Optional[int]
stop: Union[Optional[str], List[str]]
temperature: Optional[float]
top_p: Optional[float]
user: str
AsyncAzureADTokenProvider = Callable[[], Union[str, Awaitable[str]]]
class BaseOpenAIClientConfiguration(CreateArguments, total=False):
model: str
api_key: str
timeout: Union[float, None]
max_retries: int
# See OpenAI docs for explanation of these parameters
class OpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False):
organization: str
base_url: str
# Not required
model_capabilities: ModelCapabilities
class AzureOpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False):
# Azure specific
azure_endpoint: Required[str]
azure_deployment: str
api_version: Required[str]
azure_ad_token: str
azure_ad_token_provider: AsyncAzureADTokenProvider
# Must be provided
model_capabilities: Required[ModelCapabilities]
def convert_functions(functions: Sequence[FunctionDefinition]) -> List[ChatCompletionToolParam]:
result: List[ChatCompletionToolParam] = []
for func in functions:
result.append(
{
"type": "function",
"function": {
"name": func.name,
"parameters": func.parameters,
"description": func.description,
},
}
)
return result
class BaseOpenAI(ModelClient):
def __init__(
self,
client: Union[AsyncOpenAI, AsyncAzureOpenAI],
create_args: Dict[str, Any],
model_capabilities: Optional[ModelCapabilities] = None,
):
self._client = client
if model_capabilities is None and isinstance(client, AsyncAzureOpenAI):
raise ValueError("AzureOpenAI requires explicit model capabilities")
elif model_capabilities is None:
self._model_capabilities = model_info.get_capabilties(create_args["model"])
else:
self._model_capabilities = model_capabilities
self._resolved_model: Optional[str] = None
if "model" in create_args:
self._resolved_model = model_info.resolve_model(create_args["model"])
if (
"response_format" in create_args
and create_args["response_format"]["type"] == "json_object"
and not self._model_capabilities["json_output"]
):
raise ValueError("Model does not support JSON output")
self._create_args = create_args
self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
self._actual_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
@classmethod
def create_from_config(cls, config: Dict[str, Any]) -> ModelClient:
return OpenAI(**config)
async def create(
self,
messages: List[LLMMessage],
functions: Sequence[FunctionDefinition] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
) -> CreateResult:
# Make sure all extra_create_args are valid
extra_create_args_keys = set(extra_create_args.keys())
if not create_kwargs.issuperset(extra_create_args_keys):
raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}")
# Copy the create args and overwrite anything in extra_create_args
create_args = self._create_args.copy()
create_args.update(extra_create_args)
# TODO: allow custom handling.
# For now we raise an error if images are present and vision is not supported
if self.capabilities["vision"] is False:
for message in messages:
if isinstance(message, UserMessage):
if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content):
raise ValueError("Model does not support vision and image was provided")
if json_output is not None:
if self.capabilities["json_output"] is False and json_output is True:
raise ValueError("Model does not support JSON output")
if json_output is True:
create_args["response_format"] = {"type": "json_object"}
else:
create_args["response_format"] = {"type": "text"}
if self.capabilities["json_output"] is False and json_output is True:
raise ValueError("Model does not support JSON output")
oai_messages_nested = [to_oai_type(m) for m in messages]
oai_messages = [item for sublist in oai_messages_nested for item in sublist]
if self.capabilities["function_calling"] is False and len(functions) > 0:
raise ValueError("Model does not support function calling")
if len(functions) > 0:
tools = convert_functions(functions)
result = await self._client.chat.completions.create(
messages=oai_messages, stream=False, tools=tools, **create_args
)
else:
result = await self._client.chat.completions.create(messages=oai_messages, stream=False, **create_args)
usage = RequestUsage(
# TODO backup token counting
prompt_tokens=result.usage.prompt_tokens if result.usage is not None else 0,
completion_tokens=result.usage.completion_tokens if result.usage is not None else 0,
)
if self._resolved_model is not None:
if self._resolved_model != result.model:
warnings.warn(
f"Resolved model mismatch: {self._resolved_model} != {result.model}. AutoGen model mapping may be incorrect.",
stacklevel=2,
)
# Limited to a single choice currently.
choice = result.choices[0]
if choice.finish_reason == "function_call":
raise ValueError("Function calls are not supported in this context")
content: Union[str, List[FunctionCall]]
if choice.finish_reason == "tool_calls":
assert choice.message.tool_calls is not None
assert choice.message.function_call is None
# NOTE: If OAI response type changes, this will need to be updated
content = [
FunctionCall(id=x.id, arguments=x.function.arguments, name=x.function.name)
for x in choice.message.tool_calls
]
finish_reason = "function_calls"
else:
finish_reason = choice.finish_reason
content = choice.message.content or ""
response = CreateResult(finish_reason=finish_reason, content=content, usage=usage, cached=False) # type: ignore
_add_usage(self._actual_usage, usage)
_add_usage(self._total_usage, usage)
# TODO - why is this cast needed?
return response
async def create_stream(
self,
messages: List[LLMMessage],
functions: Sequence[FunctionDefinition] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
) -> AsyncGenerator[Union[str, CreateResult], None]:
# Make sure all extra_create_args are valid
extra_create_args_keys = set(extra_create_args.keys())
if not create_kwargs.issuperset(extra_create_args_keys):
raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}")
# Copy the create args and overwrite anything in extra_create_args
create_args = self._create_args.copy()
create_args.update(extra_create_args)
oai_messages_nested = [to_oai_type(m) for m in messages]
oai_messages = [item for sublist in oai_messages_nested for item in sublist]
# TODO: allow custom handling.
# For now we raise an error if images are present and vision is not supported
if self.capabilities["vision"] is False:
for message in messages:
if isinstance(message, UserMessage):
if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content):
raise ValueError("Model does not support vision and image was provided")
if json_output is not None:
if self.capabilities["json_output"] is False and json_output is True:
raise ValueError("Model does not support JSON output")
if json_output is True:
create_args["response_format"] = {"type": "json_object"}
else:
create_args["response_format"] = {"type": "text"}
if len(functions) > 0:
tools = convert_functions(functions)
stream = await self._client.chat.completions.create(
messages=oai_messages, stream=True, tools=tools, **create_args
)
else:
stream = await self._client.chat.completions.create(messages=oai_messages, stream=True, **create_args)
stop_reason = None
maybe_model = None
content_deltas: List[str] = []
full_tool_calls: Dict[int, FunctionCall] = {}
completion_tokens = 0
async for chunk in stream:
choice = chunk.choices[0]
stop_reason = choice.finish_reason
maybe_model = chunk.model
# First try get content
if choice.delta.content is not None:
content_deltas.append(choice.delta.content)
if len(choice.delta.content) > 0:
yield choice.delta.content
continue
# Otherwise, get tool calls
if choice.delta.tool_calls is not None:
for tool_call_chunk in choice.delta.tool_calls:
idx = tool_call_chunk.index
if idx not in full_tool_calls:
# We ignore the type hint here because we want to fill in type when the delta provides it
full_tool_calls[idx] = FunctionCall(id="", arguments="", name="")
if tool_call_chunk.id is not None:
full_tool_calls[idx].id += tool_call_chunk.id
if tool_call_chunk.function is not None:
if tool_call_chunk.function.name is not None:
full_tool_calls[idx].name += tool_call_chunk.function.name
if tool_call_chunk.function.arguments is not None:
full_tool_calls[idx].arguments += tool_call_chunk.function.arguments
model = maybe_model or create_args["model"]
model = model.replace("gpt-35", "gpt-3.5") # hack for Azure API
# TODO fix count token
prompt_tokens = 0
# prompt_tokens = count_token(messages, model=model)
if stop_reason is None:
raise ValueError("No stop reason found")
content: Union[str, List[FunctionCall]]
if len(content_deltas) > 1:
content = "".join(content_deltas)
completion_tokens = 0
# completion_tokens = count_token(content, model=model)
else:
completion_tokens = 0
# TODO: fix assumption that dict values were added in order and actually order by int index
# for tool_call in full_tool_calls.values():
# # value = json.dumps(tool_call)
# # completion_tokens += count_token(value, model=model)
# completion_tokens += 0
content = list(full_tool_calls.values())
usage = RequestUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
if stop_reason == "function_call":
raise ValueError("Function calls are not supported in this context")
if stop_reason == "tool_calls":
stop_reason = "function_calls"
result = CreateResult(finish_reason=stop_reason, content=content, usage=usage, cached=False)
_add_usage(self._actual_usage, usage)
_add_usage(self._total_usage, usage)
yield result
def actual_usage(self) -> RequestUsage:
return self._actual_usage
def total_usage(self) -> RequestUsage:
return self._total_usage
@property
def capabilities(self) -> ModelCapabilities:
return self._model_capabilities
class OpenAI(BaseOpenAI):
def __init__(self, **kwargs: Unpack[OpenAIClientConfiguration]):
if "model" not in kwargs:
raise ValueError("model is required for OpenAI")
model_capabilities: Optional[ModelCapabilities] = None
copied_args = dict(kwargs).copy()
if "model_capabilities" in kwargs:
model_capabilities = kwargs["model_capabilities"]
del copied_args["model_capabilities"]
client = _openai_client_from_config(copied_args)
create_args = _create_args_from_config(copied_args)
self._raw_config = copied_args
super().__init__(client, create_args, model_capabilities)
def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
state["_client"] = None
return state
def __setstate__(self, state: Dict[str, Any]) -> None:
self.__dict__.update(state)
self._client = _openai_client_from_config(state["_raw_config"])
class AzureOpenAI(BaseOpenAI):
def __init__(self, **kwargs: Unpack[AzureOpenAIClientConfiguration]):
if "model" not in kwargs:
raise ValueError("model is required for OpenAI")
model_capabilities: Optional[ModelCapabilities] = None
copied_args = dict(kwargs).copy()
if "model_capabilities" in kwargs:
model_capabilities = kwargs["model_capabilities"]
del copied_args["model_capabilities"]
client = _azure_openai_client_from_config(copied_args)
create_args = _create_args_from_config(copied_args)
self._raw_config = copied_args
super().__init__(client, create_args, model_capabilities)
def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
state["_client"] = None
return state
def __setstate__(self, state: Dict[str, Any]) -> None:
self.__dict__.update(state)
self._client = _azure_openai_client_from_config(state["_raw_config"])

View File

@ -4,7 +4,7 @@ from agnext.core.agent_runtime import AgentRuntime
from agnext.core.base_agent import BaseAgent
from agnext.core.exceptions import CantHandleException
from .core.message import Message
from ..core.message import Message
T = TypeVar("T", bound=Message)

View File

@ -0,0 +1,76 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from typing_extensions import Literal
from .image import Image
@dataclass
class FunctionCall:
id: str
# JSON args
arguments: str
# Function to call
name: str
@dataclass
class FunctionDefinition:
name: str
parameters: Dict[str, Any]
description: str
@dataclass
class RequestUsage:
prompt_tokens: int
completion_tokens: int
@dataclass
class SystemMessage:
content: str
@dataclass
class UserMessage:
content: Union[str, List[Union[str, Image]]]
# Name of the agent that sent this message
source: str
@dataclass
class AssistantMessage:
content: Union[str, List[FunctionCall]]
# Name of the agent that sent this message
source: str
@dataclass
class FunctionExecutionResult:
content: str
call_id: str
@dataclass
class FunctionExecutionResultMessage:
content: List[FunctionExecutionResult]
LLMMessage = Union[SystemMessage, UserMessage, AssistantMessage, FunctionExecutionResultMessage]
FinishReasons = Literal["stop", "length", "function_calls", "content_filter"]
@dataclass
class CreateResult:
finish_reason: FinishReasons
content: Union[str, List[FunctionCall]]
usage: RequestUsage
cached: bool

View File

@ -3,10 +3,9 @@ from asyncio import Future
from dataclasses import dataclass
from typing import Dict, Generic, List, Set, Type, TypeVar
from agnext.core.agent import Agent
from .core.agent_runtime import AgentRuntime
from .core.message import Message
from ..core.agent import Agent
from ..core.agent_runtime import AgentRuntime
from ..core.message import Message
T = TypeVar("T", bound=Message)

View File

@ -15,7 +15,7 @@ echo "--- Running ruff format ---"
ruff format
echo "--- Running ruff check ---"
ruff check
echo "--- Running ruff lint ---"
echo "--- Running pyright ---"
pyright
echo "--- Running ruff test ---"
echo "--- Running mypy ---"
mypy