Split apart component infra to allow for abstract class integration (#5017)

* Split apart component infra to allow for abstract class integration

* fix is_component_class check

* make is_ functions type guards

* Simplify component creation

* undo changes

* Format
This commit is contained in:
Jack Gerrits 2025-01-13 15:58:38 -05:00 committed by GitHub
parent 70f7e998d2
commit 404522bd6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 172 additions and 106 deletions

View File

@ -20,7 +20,7 @@
"\n",
"## Usage\n",
"\n",
"If you have a component in Python and want to get the config for it, simply call {py:meth}`~autogen_core.ComponentConfig.dump_component` on it. The resulting object can be passed back into {py:meth}`~autogen_core.ComponentLoader.load_component` to get the component back.\n",
"If you have a component in Python and want to get the config for it, simply call {py:meth}`~autogen_core.ComponentToConfig.dump_component` on it. The resulting object can be passed back into {py:meth}`~autogen_core.ComponentLoader.load_component` to get the component back.\n",
"\n",
"### Loading a component from a config\n",
"\n",
@ -52,7 +52,7 @@
"To add component functionality to a given class:\n",
"\n",
"1. Add a call to {py:meth}`~autogen_core.Component` in the class inheritance list.\n",
"2. Implment the {py:meth}`~autogen_core.ComponentConfigImpl._to_config` and {py:meth}`~autogen_core.ComponentConfigImpl._from_config` methods\n",
"2. Implment the {py:meth}`~autogen_core.ComponentToConfig._to_config` and {py:meth}`~autogen_core.ComponentFromConfig._from_config` methods\n",
"\n",
"For example:"
]
@ -63,7 +63,7 @@
"metadata": {},
"outputs": [],
"source": [
"from autogen_core import Component\n",
"from autogen_core import Component, ComponentBase\n",
"from pydantic import BaseModel\n",
"\n",
"\n",
@ -71,7 +71,7 @@
" value: str\n",
"\n",
"\n",
"class MyComponent(Component[Config]):\n",
"class MyComponent(ComponentBase[Config], Component[Config]):\n",
" component_type = \"custom\"\n",
" component_config_schema = Config\n",
"\n",
@ -129,7 +129,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
"version": "3.12.5"
}
},
"nbformat": 4,

View File

@ -14,10 +14,15 @@ from ._cancellation_token import CancellationToken
from ._closure_agent import ClosureAgent, ClosureContext
from ._component_config import (
Component,
ComponentConfigImpl,
ComponentBase,
ComponentFromConfig,
ComponentLoader,
ComponentModel,
ComponentSchemaType,
ComponentToConfig,
ComponentType,
is_component_class,
is_component_instance,
)
from ._constants import (
EVENT_LOGGER_NAME as EVENT_LOGGER_NAME_ALIAS,
@ -112,10 +117,15 @@ __all__ = [
"EVENT_LOGGER_NAME",
"TRACE_LOGGER_NAME",
"Component",
"ComponentBase",
"ComponentFromConfig",
"ComponentLoader",
"ComponentConfigImpl",
"ComponentModel",
"ComponentSchemaType",
"ComponentToConfig",
"ComponentType",
"is_component_class",
"is_component_instance",
"DropMessage",
"InterventionHandler",
"DefaultInterventionHandler",

View File

@ -2,13 +2,15 @@ from __future__ import annotations
import importlib
import warnings
from typing import Any, ClassVar, Dict, Generic, List, Literal, Protocol, Type, cast, overload, runtime_checkable
from typing import Any, ClassVar, Dict, Generic, Literal, Type, TypeGuard, cast, overload
from pydantic import BaseModel
from typing_extensions import Self, TypeVar
ComponentType = Literal["model", "agent", "tool", "termination", "token_provider"] | str
ConfigT = TypeVar("ConfigT", bound=BaseModel)
FromConfigT = TypeVar("FromConfigT", bound=BaseModel, contravariant=True)
ToConfigT = TypeVar("ToConfigT", bound=BaseModel, covariant=True)
T = TypeVar("T", bound=BaseModel, covariant=True)
@ -47,36 +49,9 @@ WELL_KNOWN_PROVIDERS = {
}
@runtime_checkable
class ComponentConfigImpl(Protocol[ConfigT]):
# Ideally would be ClassVar[Type[ConfigT]], but this is disallowed https://github.com/python/typing/discussions/1424 (despite being valid in this context)
component_config_schema: Type[ConfigT]
"""The Pydantic model class which represents the configuration of the component."""
component_type: ClassVar[ComponentType]
"""The logical type of the component."""
component_version: ClassVar[int] = 1
"""The version of the component, if schema incompatibilities are introduced this should be updated."""
component_provider_override: ClassVar[str | None] = None
"""Override the provider string for the component. This should be used to prevent internal module names being a part of the module name."""
"""The two methods a class must implement to be a component.
Args:
Protocol (ConfigT): Type which derives from :py:class:`pydantic.BaseModel`.
"""
def _to_config(self) -> ConfigT:
"""Dump the configuration that would be requite to create a new instance of a component matching the configuration of this instance.
Returns:
T: The configuration of the component.
:meta public:
"""
...
class ComponentFromConfig(Generic[FromConfigT]):
@classmethod
def _from_config(cls, config: ConfigT) -> Self:
def _from_config(cls, config: FromConfigT) -> Self:
"""Create a new instance of the component from a configuration object.
Args:
@ -87,7 +62,7 @@ class ComponentConfigImpl(Protocol[ConfigT]):
:meta public:
"""
...
raise NotImplementedError("This component does not support dumping to config")
@classmethod
def _from_config_past_version(cls, config: Dict[str, Any], version: int) -> Self:
@ -104,7 +79,69 @@ class ComponentConfigImpl(Protocol[ConfigT]):
:meta public:
"""
raise NotImplementedError()
raise NotImplementedError("This component does not support loading from past versions")
class ComponentToConfig(Generic[ToConfigT]):
"""The two methods a class must implement to be a component.
Args:
Protocol (ConfigT): Type which derives from :py:class:`pydantic.BaseModel`.
"""
component_type: ClassVar[ComponentType]
"""The logical type of the component."""
component_version: ClassVar[int] = 1
"""The version of the component, if schema incompatibilities are introduced this should be updated."""
component_provider_override: ClassVar[str | None] = None
"""Override the provider string for the component. This should be used to prevent internal module names being a part of the module name."""
def _to_config(self) -> ToConfigT:
"""Dump the configuration that would be requite to create a new instance of a component matching the configuration of this instance.
Returns:
T: The configuration of the component.
:meta public:
"""
raise NotImplementedError("This component does not support dumping to config")
def dump_component(self) -> ComponentModel:
"""Dump the component to a model that can be loaded back in.
Raises:
TypeError: If the component is a local class.
Returns:
ComponentModel: The model representing the component.
"""
if self.component_provider_override is not None:
provider = self.component_provider_override
else:
provider = _type_to_provider_str(self.__class__)
# Warn if internal module name is used,
if "._" in provider:
warnings.warn(
"Internal module name used in provider string. This is not recommended and may cause issues in the future. Silence this warning by setting component_provider_override to this value.",
stacklevel=2,
)
if "<locals>" in provider:
raise TypeError("Cannot dump component with local class")
if not hasattr(self, "component_type"):
raise AttributeError("component_type not defined")
obj_config = self._to_config().model_dump(exclude_none=True)
model = ComponentModel(
provider=provider,
component_type=self.component_type,
version=self.component_version,
component_version=self.component_version,
description=None,
config=obj_config,
)
return model
ExpectedType = TypeVar("ExpectedType")
@ -171,9 +208,9 @@ class ComponentLoader:
module_path, class_name = output
module = importlib.import_module(module_path)
component_class = cast(ComponentConfigImpl[BaseModel], module.__getattribute__(class_name))
component_class = module.__getattribute__(class_name)
if not isinstance(component_class, ComponentConfigImpl):
if not is_component_class(component_class):
raise TypeError("Invalid component class")
# We need to check the schema is valid
@ -192,7 +229,7 @@ class ComponentLoader:
f"Tried to load component {component_class} which is on version {component_class.component_version} with a config on version {loaded_config_version} but _from_config_past_version is not implemented"
) from e
else:
schema = component_class.component_config_schema
schema = component_class.component_config_schema # type: ignore
validated_config = schema.model_validate(loaded_model.config)
# We're allowed to use the private method here
@ -208,8 +245,35 @@ class ComponentLoader:
return cast(ExpectedType, instance)
class Component(ComponentConfigImpl[ConfigT], ComponentLoader, Generic[ConfigT]):
"""To create a component class, inherit from this class. Then implement two class variables:
class ComponentSchemaType(Generic[ConfigT]):
# Ideally would be ClassVar[Type[ConfigT]], but this is disallowed https://github.com/python/typing/discussions/1424 (despite being valid in this context)
component_config_schema: Type[ConfigT]
"""The Pydantic model class which represents the configuration of the component."""
required_class_vars = ["component_config_schema", "component_type"]
def __init_subclass__(cls, **kwargs: Any):
super().__init_subclass__(**kwargs)
if cls.__name__ != "Component" and not cls.__name__ == "_ConcreteComponent":
# TODO: validate provider is loadable
for var in cls.required_class_vars:
if not hasattr(cls, var):
warnings.warn(
f"Class variable '{var}' must be defined in {cls.__name__} to be a valid component",
stacklevel=2,
)
class ComponentBase(ComponentToConfig[ConfigT], ComponentLoader, Generic[ConfigT]): ...
class Component(
ComponentFromConfig[ConfigT],
ComponentSchemaType[ConfigT],
Generic[ConfigT],
):
"""To create a component class, inherit from this class for the concrete class and ComponentBase on the interface. Then implement two class variables:
- :py:attr:`component_config_schema` - A Pydantic model class which represents the configuration of the component. This is also the type parameter of Component.
- :py:attr:`component_type` - What is the logical type of the component.
@ -243,55 +307,39 @@ class Component(ComponentConfigImpl[ConfigT], ComponentLoader, Generic[ConfigT])
return cls(value=config.value)
"""
required_class_vars: ClassVar[List[str]] = ["component_config_schema", "component_type"]
def __init_subclass__(cls, **kwargs: Any) -> None:
def __init_subclass__(cls, **kwargs: Any):
super().__init_subclass__(**kwargs)
# TODO: validate provider is loadable
for var in cls.required_class_vars:
if not hasattr(cls, var):
warnings.warn(
f"Class variable '{var}' must be defined in {cls.__name__} to be a valid component", stacklevel=2
)
if not is_component_class(cls):
warnings.warn(
f"Component class '{cls.__name__}' must subclass the following: ComponentFromConfig, ComponentToConfig, ComponentSchemaType, ComponentLoader, individually or with ComponentBase and Component. Look at the component config documentation or how OpenAIChatCompletionClient does it.",
stacklevel=2,
)
def dump_component(self) -> ComponentModel:
"""Dump the component to a model that can be loaded back in.
Raises:
TypeError: If the component is a local class.
# Should never be used directly, only for type checking
class _ConcreteComponent(
ComponentFromConfig[ConfigT],
ComponentSchemaType[ConfigT],
ComponentToConfig[ConfigT],
ComponentLoader,
Generic[ConfigT],
): ...
Returns:
ComponentModel: The model representing the component.
"""
if self.component_provider_override is not None:
provider = self.component_provider_override
else:
provider = _type_to_provider_str(self.__class__)
# Warn if internal module name is used,
if "._" in provider:
warnings.warn(
"Internal module name used in provider string. This is not recommended and may cause issues in the future. Silence this warning by setting component_provider_override to this value.",
stacklevel=2,
)
if "<locals>" in provider:
raise TypeError("Cannot dump component with local class")
def is_component_instance(cls: Any) -> TypeGuard[_ConcreteComponent[BaseModel]]:
return (
isinstance(cls, ComponentFromConfig)
and isinstance(cls, ComponentToConfig)
and isinstance(cls, ComponentSchemaType)
and isinstance(cls, ComponentLoader)
)
if not hasattr(self, "component_type"):
raise AttributeError("component_type not defined")
obj_config = self._to_config().model_dump(exclude_none=True)
model = ComponentModel(
provider=provider,
component_type=self.component_type,
version=self.component_version,
component_version=self.component_version,
description=None,
config=obj_config,
)
return model
@classmethod
def _from_config_past_version(cls, config: Dict[str, Any], version: int) -> Self:
raise NotImplementedError()
def is_component_class(cls: type) -> TypeGuard[Type[_ConcreteComponent[BaseModel]]]:
return (
issubclass(cls, ComponentFromConfig)
and issubclass(cls, ComponentToConfig)
and issubclass(cls, ComponentSchemaType)
and issubclass(cls, ComponentLoader)
)

View File

@ -4,10 +4,11 @@ import warnings
from abc import ABC, abstractmethod
from typing import Literal, Mapping, Optional, Sequence, TypeAlias
from pydantic import BaseModel
from typing_extensions import Any, AsyncGenerator, Required, TypedDict, Union, deprecated
from .. import CancellationToken
from .._component_config import ComponentLoader
from .._component_config import ComponentBase
from ..tools import Tool, ToolSchema
from ._types import CreateResult, LLMMessage, RequestUsage
@ -47,7 +48,7 @@ class ModelInfo(TypedDict, total=False):
"""Model family should be one of the constants from :py:class:`ModelFamily` or a string representing an unknown model family."""
class ChatCompletionClient(ABC, ComponentLoader):
class ChatCompletionClient(ComponentBase[BaseModel], ABC):
# Caching has to be handled internally as they can depend on the create args that were stored in the constructor
@abstractmethod
async def create(

View File

@ -4,7 +4,7 @@ import json
from typing import Any, Dict
import pytest
from autogen_core import Component, ComponentLoader, ComponentModel
from autogen_core import Component, ComponentBase, ComponentLoader, ComponentModel
from autogen_core._component_config import _type_to_provider_str # type: ignore
from autogen_core.models import ChatCompletionClient
from autogen_test_utils import MyInnerComponent, MyOuterComponent
@ -16,7 +16,7 @@ class MyConfig(BaseModel):
info: str
class MyComponent(Component[MyConfig]):
class MyComponent(ComponentBase[MyConfig], Component[MyConfig]):
component_config_schema = MyConfig
component_type = "custom"
@ -95,7 +95,7 @@ def test_cannot_import_locals() -> None:
class InvalidModelClientConfig(BaseModel):
info: str
class MyInvalidModelClient(Component[InvalidModelClientConfig]):
class MyInvalidModelClient(ComponentBase[InvalidModelClientConfig], Component[InvalidModelClientConfig]):
component_config_schema = InvalidModelClientConfig
component_type = "model"
@ -119,7 +119,7 @@ class InvalidModelClientConfig(BaseModel):
info: str
class MyInvalidModelClient(Component[InvalidModelClientConfig]):
class MyInvalidModelClient(ComponentBase[InvalidModelClientConfig], Component[InvalidModelClientConfig]):
component_config_schema = InvalidModelClientConfig
component_type = "model"
@ -143,7 +143,7 @@ def test_type_error_on_creation() -> None:
with pytest.warns(UserWarning):
class MyInvalidMissingAttrs(Component[InvalidModelClientConfig]):
class MyInvalidMissingAttrs(ComponentBase[InvalidModelClientConfig], Component[InvalidModelClientConfig]):
def __init__(self, info: str):
self.info = info
@ -189,7 +189,7 @@ def test_config_optional_values() -> None:
assert component.__class__ == MyComponent
class ConfigProviderOverrided(Component[MyConfig]):
class ConfigProviderOverrided(ComponentBase[MyConfig], Component[MyConfig]):
component_provider_override = "InvalidButStillOverridden"
component_config_schema = MyConfig
component_type = "custom"
@ -215,7 +215,7 @@ class MyConfig2(BaseModel):
info2: str
class ComponentNonOneVersion(Component[MyConfig2]):
class ComponentNonOneVersion(ComponentBase[MyConfig2], Component[MyConfig2]):
component_config_schema = MyConfig2
component_version = 2
component_type = "custom"
@ -231,7 +231,7 @@ class ComponentNonOneVersion(Component[MyConfig2]):
return cls(info=config.info2)
class ComponentNonOneVersionWithUpgrade(Component[MyConfig2]):
class ComponentNonOneVersionWithUpgrade(ComponentBase[MyConfig2], Component[MyConfig2]):
component_config_schema = MyConfig2
component_version = 2
component_type = "custom"

View File

@ -1,6 +1,6 @@
from typing import List
from autogen_core import Component
from autogen_core import Component, ComponentBase
from pydantic import BaseModel
from typing_extensions import Self
@ -13,7 +13,7 @@ class TokenProviderConfig(BaseModel):
scopes: List[str]
class AzureTokenProvider(Component[TokenProviderConfig]):
class AzureTokenProvider(ComponentBase[TokenProviderConfig], Component[TokenProviderConfig]):
component_type = "token_provider"
component_config_schema = TokenProviderConfig
component_provider_override = "autogen_ext.auth.azure.AzureTokenProvider"

View File

@ -6,13 +6,14 @@ from typing import Any
from autogen_core import (
BaseAgent,
Component,
ComponentBase,
ComponentModel,
DefaultTopicId,
MessageContext,
RoutedAgent,
default_subscription,
message_handler,
)
from autogen_core._component_config import ComponentModel
from pydantic import BaseModel
@ -76,7 +77,7 @@ class MyInnerConfig(BaseModel):
inner_message: str
class MyInnerComponent(Component[MyInnerConfig]):
class MyInnerComponent(ComponentBase[MyInnerConfig], Component[MyInnerConfig]):
component_config_schema = MyInnerConfig
component_type = "custom"
@ -96,7 +97,7 @@ class MyOuterConfig(BaseModel):
inner_class: ComponentModel
class MyOuterComponent(Component[MyOuterConfig]):
class MyOuterComponent(ComponentBase[MyOuterConfig], Component[MyOuterConfig]):
component_config_schema = MyOuterConfig
component_type = "custom"

View File

@ -5,7 +5,8 @@ from typing import Any, DefaultDict, Dict, List, TypeVar
from autogen_core import ComponentModel
from autogen_core._component_config import (
WELL_KNOWN_PROVIDERS,
ComponentConfigImpl,
ComponentSchemaType,
ComponentToConfig,
_type_to_provider_str, # type: ignore
)
from autogen_ext.auth.azure import AzureTokenProvider
@ -17,10 +18,13 @@ all_defs: Dict[str, Any] = {}
T = TypeVar("T", bound=BaseModel)
def build_specific_component_schema(component: type[ComponentConfigImpl[T]], provider_str: str) -> Dict[str, Any]:
def build_specific_component_schema(component: type[ComponentSchemaType[T]], provider_str: str) -> Dict[str, Any]:
model = component.component_config_schema # type: ignore
model_schema = model.model_json_schema()
# We can't specify component to be the union of two types, so we assert it here
assert issubclass(component, ComponentToConfig)
component_model_schema = ComponentModel.model_json_schema()
if "$defs" not in component_model_schema:
component_model_schema["$defs"] = {}
@ -70,7 +74,9 @@ def main() -> None:
for key, value in WELL_KNOWN_PROVIDERS.items():
reverse_provider_lookup_table[value].append(key)
def add_type(type: type[ComponentConfigImpl[T]]) -> None:
def add_type(type: type[ComponentSchemaType[T]]) -> None:
# We can't specify component to be the union of two types, so we assert it here
assert issubclass(type, ComponentToConfig)
canonical = type.component_provider_override or _type_to_provider_str(type)
reverse_provider_lookup_table[canonical].append(canonical)
for provider_str in reverse_provider_lookup_table[canonical]: