mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-27 23:18:59 +00:00
Split apart component infra to allow for abstract class integration (#5017)
* Split apart component infra to allow for abstract class integration * fix is_component_class check * make is_ functions type guards * Simplify component creation * undo changes * Format
This commit is contained in:
parent
70f7e998d2
commit
404522bd6b
@ -20,7 +20,7 @@
|
||||
"\n",
|
||||
"## Usage\n",
|
||||
"\n",
|
||||
"If you have a component in Python and want to get the config for it, simply call {py:meth}`~autogen_core.ComponentConfig.dump_component` on it. The resulting object can be passed back into {py:meth}`~autogen_core.ComponentLoader.load_component` to get the component back.\n",
|
||||
"If you have a component in Python and want to get the config for it, simply call {py:meth}`~autogen_core.ComponentToConfig.dump_component` on it. The resulting object can be passed back into {py:meth}`~autogen_core.ComponentLoader.load_component` to get the component back.\n",
|
||||
"\n",
|
||||
"### Loading a component from a config\n",
|
||||
"\n",
|
||||
@ -52,7 +52,7 @@
|
||||
"To add component functionality to a given class:\n",
|
||||
"\n",
|
||||
"1. Add a call to {py:meth}`~autogen_core.Component` in the class inheritance list.\n",
|
||||
"2. Implment the {py:meth}`~autogen_core.ComponentConfigImpl._to_config` and {py:meth}`~autogen_core.ComponentConfigImpl._from_config` methods\n",
|
||||
"2. Implment the {py:meth}`~autogen_core.ComponentToConfig._to_config` and {py:meth}`~autogen_core.ComponentFromConfig._from_config` methods\n",
|
||||
"\n",
|
||||
"For example:"
|
||||
]
|
||||
@ -63,7 +63,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from autogen_core import Component\n",
|
||||
"from autogen_core import Component, ComponentBase\n",
|
||||
"from pydantic import BaseModel\n",
|
||||
"\n",
|
||||
"\n",
|
||||
@ -71,7 +71,7 @@
|
||||
" value: str\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class MyComponent(Component[Config]):\n",
|
||||
"class MyComponent(ComponentBase[Config], Component[Config]):\n",
|
||||
" component_type = \"custom\"\n",
|
||||
" component_config_schema = Config\n",
|
||||
"\n",
|
||||
@ -129,7 +129,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.7"
|
||||
"version": "3.12.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@ -14,10 +14,15 @@ from ._cancellation_token import CancellationToken
|
||||
from ._closure_agent import ClosureAgent, ClosureContext
|
||||
from ._component_config import (
|
||||
Component,
|
||||
ComponentConfigImpl,
|
||||
ComponentBase,
|
||||
ComponentFromConfig,
|
||||
ComponentLoader,
|
||||
ComponentModel,
|
||||
ComponentSchemaType,
|
||||
ComponentToConfig,
|
||||
ComponentType,
|
||||
is_component_class,
|
||||
is_component_instance,
|
||||
)
|
||||
from ._constants import (
|
||||
EVENT_LOGGER_NAME as EVENT_LOGGER_NAME_ALIAS,
|
||||
@ -112,10 +117,15 @@ __all__ = [
|
||||
"EVENT_LOGGER_NAME",
|
||||
"TRACE_LOGGER_NAME",
|
||||
"Component",
|
||||
"ComponentBase",
|
||||
"ComponentFromConfig",
|
||||
"ComponentLoader",
|
||||
"ComponentConfigImpl",
|
||||
"ComponentModel",
|
||||
"ComponentSchemaType",
|
||||
"ComponentToConfig",
|
||||
"ComponentType",
|
||||
"is_component_class",
|
||||
"is_component_instance",
|
||||
"DropMessage",
|
||||
"InterventionHandler",
|
||||
"DefaultInterventionHandler",
|
||||
|
||||
@ -2,13 +2,15 @@ from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import warnings
|
||||
from typing import Any, ClassVar, Dict, Generic, List, Literal, Protocol, Type, cast, overload, runtime_checkable
|
||||
from typing import Any, ClassVar, Dict, Generic, Literal, Type, TypeGuard, cast, overload
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self, TypeVar
|
||||
|
||||
ComponentType = Literal["model", "agent", "tool", "termination", "token_provider"] | str
|
||||
ConfigT = TypeVar("ConfigT", bound=BaseModel)
|
||||
FromConfigT = TypeVar("FromConfigT", bound=BaseModel, contravariant=True)
|
||||
ToConfigT = TypeVar("ToConfigT", bound=BaseModel, covariant=True)
|
||||
|
||||
T = TypeVar("T", bound=BaseModel, covariant=True)
|
||||
|
||||
@ -47,36 +49,9 @@ WELL_KNOWN_PROVIDERS = {
|
||||
}
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ComponentConfigImpl(Protocol[ConfigT]):
|
||||
# Ideally would be ClassVar[Type[ConfigT]], but this is disallowed https://github.com/python/typing/discussions/1424 (despite being valid in this context)
|
||||
component_config_schema: Type[ConfigT]
|
||||
"""The Pydantic model class which represents the configuration of the component."""
|
||||
component_type: ClassVar[ComponentType]
|
||||
"""The logical type of the component."""
|
||||
component_version: ClassVar[int] = 1
|
||||
"""The version of the component, if schema incompatibilities are introduced this should be updated."""
|
||||
component_provider_override: ClassVar[str | None] = None
|
||||
"""Override the provider string for the component. This should be used to prevent internal module names being a part of the module name."""
|
||||
|
||||
"""The two methods a class must implement to be a component.
|
||||
|
||||
Args:
|
||||
Protocol (ConfigT): Type which derives from :py:class:`pydantic.BaseModel`.
|
||||
"""
|
||||
|
||||
def _to_config(self) -> ConfigT:
|
||||
"""Dump the configuration that would be requite to create a new instance of a component matching the configuration of this instance.
|
||||
|
||||
Returns:
|
||||
T: The configuration of the component.
|
||||
|
||||
:meta public:
|
||||
"""
|
||||
...
|
||||
|
||||
class ComponentFromConfig(Generic[FromConfigT]):
|
||||
@classmethod
|
||||
def _from_config(cls, config: ConfigT) -> Self:
|
||||
def _from_config(cls, config: FromConfigT) -> Self:
|
||||
"""Create a new instance of the component from a configuration object.
|
||||
|
||||
Args:
|
||||
@ -87,7 +62,7 @@ class ComponentConfigImpl(Protocol[ConfigT]):
|
||||
|
||||
:meta public:
|
||||
"""
|
||||
...
|
||||
raise NotImplementedError("This component does not support dumping to config")
|
||||
|
||||
@classmethod
|
||||
def _from_config_past_version(cls, config: Dict[str, Any], version: int) -> Self:
|
||||
@ -104,7 +79,69 @@ class ComponentConfigImpl(Protocol[ConfigT]):
|
||||
|
||||
:meta public:
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError("This component does not support loading from past versions")
|
||||
|
||||
|
||||
class ComponentToConfig(Generic[ToConfigT]):
|
||||
"""The two methods a class must implement to be a component.
|
||||
|
||||
Args:
|
||||
Protocol (ConfigT): Type which derives from :py:class:`pydantic.BaseModel`.
|
||||
"""
|
||||
|
||||
component_type: ClassVar[ComponentType]
|
||||
"""The logical type of the component."""
|
||||
component_version: ClassVar[int] = 1
|
||||
"""The version of the component, if schema incompatibilities are introduced this should be updated."""
|
||||
component_provider_override: ClassVar[str | None] = None
|
||||
"""Override the provider string for the component. This should be used to prevent internal module names being a part of the module name."""
|
||||
|
||||
def _to_config(self) -> ToConfigT:
|
||||
"""Dump the configuration that would be requite to create a new instance of a component matching the configuration of this instance.
|
||||
|
||||
Returns:
|
||||
T: The configuration of the component.
|
||||
|
||||
:meta public:
|
||||
"""
|
||||
raise NotImplementedError("This component does not support dumping to config")
|
||||
|
||||
def dump_component(self) -> ComponentModel:
|
||||
"""Dump the component to a model that can be loaded back in.
|
||||
|
||||
Raises:
|
||||
TypeError: If the component is a local class.
|
||||
|
||||
Returns:
|
||||
ComponentModel: The model representing the component.
|
||||
"""
|
||||
if self.component_provider_override is not None:
|
||||
provider = self.component_provider_override
|
||||
else:
|
||||
provider = _type_to_provider_str(self.__class__)
|
||||
# Warn if internal module name is used,
|
||||
if "._" in provider:
|
||||
warnings.warn(
|
||||
"Internal module name used in provider string. This is not recommended and may cause issues in the future. Silence this warning by setting component_provider_override to this value.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if "<locals>" in provider:
|
||||
raise TypeError("Cannot dump component with local class")
|
||||
|
||||
if not hasattr(self, "component_type"):
|
||||
raise AttributeError("component_type not defined")
|
||||
|
||||
obj_config = self._to_config().model_dump(exclude_none=True)
|
||||
model = ComponentModel(
|
||||
provider=provider,
|
||||
component_type=self.component_type,
|
||||
version=self.component_version,
|
||||
component_version=self.component_version,
|
||||
description=None,
|
||||
config=obj_config,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
ExpectedType = TypeVar("ExpectedType")
|
||||
@ -171,9 +208,9 @@ class ComponentLoader:
|
||||
|
||||
module_path, class_name = output
|
||||
module = importlib.import_module(module_path)
|
||||
component_class = cast(ComponentConfigImpl[BaseModel], module.__getattribute__(class_name))
|
||||
component_class = module.__getattribute__(class_name)
|
||||
|
||||
if not isinstance(component_class, ComponentConfigImpl):
|
||||
if not is_component_class(component_class):
|
||||
raise TypeError("Invalid component class")
|
||||
|
||||
# We need to check the schema is valid
|
||||
@ -192,7 +229,7 @@ class ComponentLoader:
|
||||
f"Tried to load component {component_class} which is on version {component_class.component_version} with a config on version {loaded_config_version} but _from_config_past_version is not implemented"
|
||||
) from e
|
||||
else:
|
||||
schema = component_class.component_config_schema
|
||||
schema = component_class.component_config_schema # type: ignore
|
||||
validated_config = schema.model_validate(loaded_model.config)
|
||||
|
||||
# We're allowed to use the private method here
|
||||
@ -208,8 +245,35 @@ class ComponentLoader:
|
||||
return cast(ExpectedType, instance)
|
||||
|
||||
|
||||
class Component(ComponentConfigImpl[ConfigT], ComponentLoader, Generic[ConfigT]):
|
||||
"""To create a component class, inherit from this class. Then implement two class variables:
|
||||
class ComponentSchemaType(Generic[ConfigT]):
|
||||
# Ideally would be ClassVar[Type[ConfigT]], but this is disallowed https://github.com/python/typing/discussions/1424 (despite being valid in this context)
|
||||
component_config_schema: Type[ConfigT]
|
||||
"""The Pydantic model class which represents the configuration of the component."""
|
||||
|
||||
required_class_vars = ["component_config_schema", "component_type"]
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any):
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
if cls.__name__ != "Component" and not cls.__name__ == "_ConcreteComponent":
|
||||
# TODO: validate provider is loadable
|
||||
for var in cls.required_class_vars:
|
||||
if not hasattr(cls, var):
|
||||
warnings.warn(
|
||||
f"Class variable '{var}' must be defined in {cls.__name__} to be a valid component",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
class ComponentBase(ComponentToConfig[ConfigT], ComponentLoader, Generic[ConfigT]): ...
|
||||
|
||||
|
||||
class Component(
|
||||
ComponentFromConfig[ConfigT],
|
||||
ComponentSchemaType[ConfigT],
|
||||
Generic[ConfigT],
|
||||
):
|
||||
"""To create a component class, inherit from this class for the concrete class and ComponentBase on the interface. Then implement two class variables:
|
||||
|
||||
- :py:attr:`component_config_schema` - A Pydantic model class which represents the configuration of the component. This is also the type parameter of Component.
|
||||
- :py:attr:`component_type` - What is the logical type of the component.
|
||||
@ -243,55 +307,39 @@ class Component(ComponentConfigImpl[ConfigT], ComponentLoader, Generic[ConfigT])
|
||||
return cls(value=config.value)
|
||||
"""
|
||||
|
||||
required_class_vars: ClassVar[List[str]] = ["component_config_schema", "component_type"]
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
def __init_subclass__(cls, **kwargs: Any):
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# TODO: validate provider is loadable
|
||||
for var in cls.required_class_vars:
|
||||
if not hasattr(cls, var):
|
||||
warnings.warn(
|
||||
f"Class variable '{var}' must be defined in {cls.__name__} to be a valid component", stacklevel=2
|
||||
)
|
||||
if not is_component_class(cls):
|
||||
warnings.warn(
|
||||
f"Component class '{cls.__name__}' must subclass the following: ComponentFromConfig, ComponentToConfig, ComponentSchemaType, ComponentLoader, individually or with ComponentBase and Component. Look at the component config documentation or how OpenAIChatCompletionClient does it.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def dump_component(self) -> ComponentModel:
|
||||
"""Dump the component to a model that can be loaded back in.
|
||||
|
||||
Raises:
|
||||
TypeError: If the component is a local class.
|
||||
# Should never be used directly, only for type checking
|
||||
class _ConcreteComponent(
|
||||
ComponentFromConfig[ConfigT],
|
||||
ComponentSchemaType[ConfigT],
|
||||
ComponentToConfig[ConfigT],
|
||||
ComponentLoader,
|
||||
Generic[ConfigT],
|
||||
): ...
|
||||
|
||||
Returns:
|
||||
ComponentModel: The model representing the component.
|
||||
"""
|
||||
if self.component_provider_override is not None:
|
||||
provider = self.component_provider_override
|
||||
else:
|
||||
provider = _type_to_provider_str(self.__class__)
|
||||
# Warn if internal module name is used,
|
||||
if "._" in provider:
|
||||
warnings.warn(
|
||||
"Internal module name used in provider string. This is not recommended and may cause issues in the future. Silence this warning by setting component_provider_override to this value.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if "<locals>" in provider:
|
||||
raise TypeError("Cannot dump component with local class")
|
||||
def is_component_instance(cls: Any) -> TypeGuard[_ConcreteComponent[BaseModel]]:
|
||||
return (
|
||||
isinstance(cls, ComponentFromConfig)
|
||||
and isinstance(cls, ComponentToConfig)
|
||||
and isinstance(cls, ComponentSchemaType)
|
||||
and isinstance(cls, ComponentLoader)
|
||||
)
|
||||
|
||||
if not hasattr(self, "component_type"):
|
||||
raise AttributeError("component_type not defined")
|
||||
|
||||
obj_config = self._to_config().model_dump(exclude_none=True)
|
||||
model = ComponentModel(
|
||||
provider=provider,
|
||||
component_type=self.component_type,
|
||||
version=self.component_version,
|
||||
component_version=self.component_version,
|
||||
description=None,
|
||||
config=obj_config,
|
||||
)
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def _from_config_past_version(cls, config: Dict[str, Any], version: int) -> Self:
|
||||
raise NotImplementedError()
|
||||
def is_component_class(cls: type) -> TypeGuard[Type[_ConcreteComponent[BaseModel]]]:
|
||||
return (
|
||||
issubclass(cls, ComponentFromConfig)
|
||||
and issubclass(cls, ComponentToConfig)
|
||||
and issubclass(cls, ComponentSchemaType)
|
||||
and issubclass(cls, ComponentLoader)
|
||||
)
|
||||
|
||||
@ -4,10 +4,11 @@ import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Literal, Mapping, Optional, Sequence, TypeAlias
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Any, AsyncGenerator, Required, TypedDict, Union, deprecated
|
||||
|
||||
from .. import CancellationToken
|
||||
from .._component_config import ComponentLoader
|
||||
from .._component_config import ComponentBase
|
||||
from ..tools import Tool, ToolSchema
|
||||
from ._types import CreateResult, LLMMessage, RequestUsage
|
||||
|
||||
@ -47,7 +48,7 @@ class ModelInfo(TypedDict, total=False):
|
||||
"""Model family should be one of the constants from :py:class:`ModelFamily` or a string representing an unknown model family."""
|
||||
|
||||
|
||||
class ChatCompletionClient(ABC, ComponentLoader):
|
||||
class ChatCompletionClient(ComponentBase[BaseModel], ABC):
|
||||
# Caching has to be handled internally as they can depend on the create args that were stored in the constructor
|
||||
@abstractmethod
|
||||
async def create(
|
||||
|
||||
@ -4,7 +4,7 @@ import json
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
from autogen_core import Component, ComponentLoader, ComponentModel
|
||||
from autogen_core import Component, ComponentBase, ComponentLoader, ComponentModel
|
||||
from autogen_core._component_config import _type_to_provider_str # type: ignore
|
||||
from autogen_core.models import ChatCompletionClient
|
||||
from autogen_test_utils import MyInnerComponent, MyOuterComponent
|
||||
@ -16,7 +16,7 @@ class MyConfig(BaseModel):
|
||||
info: str
|
||||
|
||||
|
||||
class MyComponent(Component[MyConfig]):
|
||||
class MyComponent(ComponentBase[MyConfig], Component[MyConfig]):
|
||||
component_config_schema = MyConfig
|
||||
component_type = "custom"
|
||||
|
||||
@ -95,7 +95,7 @@ def test_cannot_import_locals() -> None:
|
||||
class InvalidModelClientConfig(BaseModel):
|
||||
info: str
|
||||
|
||||
class MyInvalidModelClient(Component[InvalidModelClientConfig]):
|
||||
class MyInvalidModelClient(ComponentBase[InvalidModelClientConfig], Component[InvalidModelClientConfig]):
|
||||
component_config_schema = InvalidModelClientConfig
|
||||
component_type = "model"
|
||||
|
||||
@ -119,7 +119,7 @@ class InvalidModelClientConfig(BaseModel):
|
||||
info: str
|
||||
|
||||
|
||||
class MyInvalidModelClient(Component[InvalidModelClientConfig]):
|
||||
class MyInvalidModelClient(ComponentBase[InvalidModelClientConfig], Component[InvalidModelClientConfig]):
|
||||
component_config_schema = InvalidModelClientConfig
|
||||
component_type = "model"
|
||||
|
||||
@ -143,7 +143,7 @@ def test_type_error_on_creation() -> None:
|
||||
|
||||
with pytest.warns(UserWarning):
|
||||
|
||||
class MyInvalidMissingAttrs(Component[InvalidModelClientConfig]):
|
||||
class MyInvalidMissingAttrs(ComponentBase[InvalidModelClientConfig], Component[InvalidModelClientConfig]):
|
||||
def __init__(self, info: str):
|
||||
self.info = info
|
||||
|
||||
@ -189,7 +189,7 @@ def test_config_optional_values() -> None:
|
||||
assert component.__class__ == MyComponent
|
||||
|
||||
|
||||
class ConfigProviderOverrided(Component[MyConfig]):
|
||||
class ConfigProviderOverrided(ComponentBase[MyConfig], Component[MyConfig]):
|
||||
component_provider_override = "InvalidButStillOverridden"
|
||||
component_config_schema = MyConfig
|
||||
component_type = "custom"
|
||||
@ -215,7 +215,7 @@ class MyConfig2(BaseModel):
|
||||
info2: str
|
||||
|
||||
|
||||
class ComponentNonOneVersion(Component[MyConfig2]):
|
||||
class ComponentNonOneVersion(ComponentBase[MyConfig2], Component[MyConfig2]):
|
||||
component_config_schema = MyConfig2
|
||||
component_version = 2
|
||||
component_type = "custom"
|
||||
@ -231,7 +231,7 @@ class ComponentNonOneVersion(Component[MyConfig2]):
|
||||
return cls(info=config.info2)
|
||||
|
||||
|
||||
class ComponentNonOneVersionWithUpgrade(Component[MyConfig2]):
|
||||
class ComponentNonOneVersionWithUpgrade(ComponentBase[MyConfig2], Component[MyConfig2]):
|
||||
component_config_schema = MyConfig2
|
||||
component_version = 2
|
||||
component_type = "custom"
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import List
|
||||
|
||||
from autogen_core import Component
|
||||
from autogen_core import Component, ComponentBase
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
@ -13,7 +13,7 @@ class TokenProviderConfig(BaseModel):
|
||||
scopes: List[str]
|
||||
|
||||
|
||||
class AzureTokenProvider(Component[TokenProviderConfig]):
|
||||
class AzureTokenProvider(ComponentBase[TokenProviderConfig], Component[TokenProviderConfig]):
|
||||
component_type = "token_provider"
|
||||
component_config_schema = TokenProviderConfig
|
||||
component_provider_override = "autogen_ext.auth.azure.AzureTokenProvider"
|
||||
|
||||
@ -6,13 +6,14 @@ from typing import Any
|
||||
from autogen_core import (
|
||||
BaseAgent,
|
||||
Component,
|
||||
ComponentBase,
|
||||
ComponentModel,
|
||||
DefaultTopicId,
|
||||
MessageContext,
|
||||
RoutedAgent,
|
||||
default_subscription,
|
||||
message_handler,
|
||||
)
|
||||
from autogen_core._component_config import ComponentModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@ -76,7 +77,7 @@ class MyInnerConfig(BaseModel):
|
||||
inner_message: str
|
||||
|
||||
|
||||
class MyInnerComponent(Component[MyInnerConfig]):
|
||||
class MyInnerComponent(ComponentBase[MyInnerConfig], Component[MyInnerConfig]):
|
||||
component_config_schema = MyInnerConfig
|
||||
component_type = "custom"
|
||||
|
||||
@ -96,7 +97,7 @@ class MyOuterConfig(BaseModel):
|
||||
inner_class: ComponentModel
|
||||
|
||||
|
||||
class MyOuterComponent(Component[MyOuterConfig]):
|
||||
class MyOuterComponent(ComponentBase[MyOuterConfig], Component[MyOuterConfig]):
|
||||
component_config_schema = MyOuterConfig
|
||||
component_type = "custom"
|
||||
|
||||
|
||||
@ -5,7 +5,8 @@ from typing import Any, DefaultDict, Dict, List, TypeVar
|
||||
from autogen_core import ComponentModel
|
||||
from autogen_core._component_config import (
|
||||
WELL_KNOWN_PROVIDERS,
|
||||
ComponentConfigImpl,
|
||||
ComponentSchemaType,
|
||||
ComponentToConfig,
|
||||
_type_to_provider_str, # type: ignore
|
||||
)
|
||||
from autogen_ext.auth.azure import AzureTokenProvider
|
||||
@ -17,10 +18,13 @@ all_defs: Dict[str, Any] = {}
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
def build_specific_component_schema(component: type[ComponentConfigImpl[T]], provider_str: str) -> Dict[str, Any]:
|
||||
def build_specific_component_schema(component: type[ComponentSchemaType[T]], provider_str: str) -> Dict[str, Any]:
|
||||
model = component.component_config_schema # type: ignore
|
||||
model_schema = model.model_json_schema()
|
||||
|
||||
# We can't specify component to be the union of two types, so we assert it here
|
||||
assert issubclass(component, ComponentToConfig)
|
||||
|
||||
component_model_schema = ComponentModel.model_json_schema()
|
||||
if "$defs" not in component_model_schema:
|
||||
component_model_schema["$defs"] = {}
|
||||
@ -70,7 +74,9 @@ def main() -> None:
|
||||
for key, value in WELL_KNOWN_PROVIDERS.items():
|
||||
reverse_provider_lookup_table[value].append(key)
|
||||
|
||||
def add_type(type: type[ComponentConfigImpl[T]]) -> None:
|
||||
def add_type(type: type[ComponentSchemaType[T]]) -> None:
|
||||
# We can't specify component to be the union of two types, so we assert it here
|
||||
assert issubclass(type, ComponentToConfig)
|
||||
canonical = type.component_provider_override or _type_to_provider_str(type)
|
||||
reverse_provider_lookup_table[canonical].append(canonical)
|
||||
for provider_str in reverse_provider_lookup_table[canonical]:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user