mirror of
https://github.com/microsoft/autogen.git
synced 2025-09-01 20:37:33 +00:00
Aegis structure message (#6289)
Added support for structured message component using the Json to Pydantic utility functions. Note: also adding the ability to use a format string for structured messages. Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
8bd162f8fc
commit
a4a16fd2f8
@ -46,6 +46,7 @@ from ..messages import (
|
||||
MemoryQueryEvent,
|
||||
ModelClientStreamingChunkEvent,
|
||||
StructuredMessage,
|
||||
StructuredMessageFactory,
|
||||
TextMessage,
|
||||
ThoughtEvent,
|
||||
ToolCallExecutionEvent,
|
||||
@ -74,6 +75,7 @@ class AssistantAgentConfig(BaseModel):
|
||||
reflect_on_tool_use: bool
|
||||
tool_call_summary_format: str
|
||||
metadata: Dict[str, str] | None = None
|
||||
structured_message_factory: ComponentModel | None = None
|
||||
|
||||
|
||||
class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
@ -183,6 +185,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
This will be used with the model client to generate structured output.
|
||||
If this is set, the agent will respond with a :class:`~autogen_agentchat.messages.StructuredMessage` instead of a :class:`~autogen_agentchat.messages.TextMessage`
|
||||
in the final response, unless `reflect_on_tool_use` is `False` and a tool call is made.
|
||||
output_content_type_format (str | None, optional): (Experimental) The format string used for the content of a :class:`~autogen_agentchat.messages.StructuredMessage` response.
|
||||
tool_call_summary_format (str, optional): The format string used to create the content for a :class:`~autogen_agentchat.messages.ToolCallSummaryMessage` response.
|
||||
The format string is used to format the tool call summary for every tool call result.
|
||||
Defaults to "{result}".
|
||||
@ -635,6 +638,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
reflect_on_tool_use: bool | None = None,
|
||||
tool_call_summary_format: str = "{result}",
|
||||
output_content_type: type[BaseModel] | None = None,
|
||||
output_content_type_format: str | None = None,
|
||||
memory: Sequence[Memory] | None = None,
|
||||
metadata: Dict[str, str] | None = None,
|
||||
):
|
||||
@ -643,6 +647,13 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
self._model_client = model_client
|
||||
self._model_client_stream = model_client_stream
|
||||
self._output_content_type: type[BaseModel] | None = output_content_type
|
||||
self._output_content_type_format = output_content_type_format
|
||||
self._structured_message_factory: StructuredMessageFactory | None = None
|
||||
if output_content_type is not None:
|
||||
self._structured_message_factory = StructuredMessageFactory(
|
||||
input_model=output_content_type, format_string=output_content_type_format
|
||||
)
|
||||
|
||||
self._memory = None
|
||||
if memory is not None:
|
||||
if isinstance(memory, list):
|
||||
@ -771,6 +782,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
reflect_on_tool_use = self._reflect_on_tool_use
|
||||
tool_call_summary_format = self._tool_call_summary_format
|
||||
output_content_type = self._output_content_type
|
||||
format_string = self._output_content_type_format
|
||||
|
||||
# STEP 1: Add new user/handoff messages to the model context
|
||||
await self._add_messages_to_context(
|
||||
@ -840,6 +852,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
reflect_on_tool_use=reflect_on_tool_use,
|
||||
tool_call_summary_format=tool_call_summary_format,
|
||||
output_content_type=output_content_type,
|
||||
format_string=format_string,
|
||||
):
|
||||
yield output_event
|
||||
|
||||
@ -942,6 +955,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
reflect_on_tool_use: bool,
|
||||
tool_call_summary_format: str,
|
||||
output_content_type: type[BaseModel] | None,
|
||||
format_string: str | None = None,
|
||||
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
|
||||
"""
|
||||
Handle final or partial responses from model_result, including tool calls, handoffs,
|
||||
@ -957,6 +971,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
content=content,
|
||||
source=agent_name,
|
||||
models_usage=model_result.usage,
|
||||
format_string=format_string,
|
||||
),
|
||||
inner_messages=inner_messages,
|
||||
)
|
||||
@ -1277,9 +1292,6 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
def _to_config(self) -> AssistantAgentConfig:
|
||||
"""Convert the assistant agent to a declarative config."""
|
||||
|
||||
if self._output_content_type:
|
||||
raise ValueError("AssistantAgent with output_content_type does not support declarative config.")
|
||||
|
||||
return AssistantAgentConfig(
|
||||
name=self.name,
|
||||
model_client=self._model_client.dump_component(),
|
||||
@ -1294,12 +1306,23 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
model_client_stream=self._model_client_stream,
|
||||
reflect_on_tool_use=self._reflect_on_tool_use,
|
||||
tool_call_summary_format=self._tool_call_summary_format,
|
||||
structured_message_factory=self._structured_message_factory.dump_component()
|
||||
if self._structured_message_factory
|
||||
else None,
|
||||
metadata=self._metadata,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: AssistantAgentConfig) -> Self:
|
||||
"""Create an assistant agent from a declarative config."""
|
||||
if config.structured_message_factory:
|
||||
structured_message_factory = StructuredMessageFactory.load_component(config.structured_message_factory)
|
||||
format_string = structured_message_factory.format_string
|
||||
output_content_type = structured_message_factory.ContentModel
|
||||
|
||||
else:
|
||||
format_string = None
|
||||
output_content_type = None
|
||||
|
||||
return cls(
|
||||
name=config.name,
|
||||
@ -1313,5 +1336,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
model_client_stream=config.model_client_stream,
|
||||
reflect_on_tool_use=config.reflect_on_tool_use,
|
||||
tool_call_summary_format=config.tool_call_summary_format,
|
||||
output_content_type=output_content_type,
|
||||
output_content_type_format=format_string,
|
||||
metadata=config.metadata,
|
||||
)
|
||||
|
@ -5,9 +5,9 @@ class and includes specific fields relevant to the type of message being sent.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generic, List, Literal, Mapping, TypeVar
|
||||
from typing import Any, Dict, Generic, List, Literal, Mapping, Optional, Type, TypeVar
|
||||
|
||||
from autogen_core import FunctionCall, Image
|
||||
from autogen_core import Component, ComponentBase, FunctionCall, Image
|
||||
from autogen_core.code_executor import CodeBlock, CodeResult
|
||||
from autogen_core.memory import MemoryContent
|
||||
from autogen_core.models import (
|
||||
@ -16,6 +16,7 @@ from autogen_core.models import (
|
||||
RequestUsage,
|
||||
UserMessage,
|
||||
)
|
||||
from autogen_core.utils import schema_to_pydantic_model
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
from typing_extensions import Annotated, Self
|
||||
|
||||
@ -182,21 +183,56 @@ class StructuredMessage(BaseChatMessage, Generic[StructuredContentType]):
|
||||
|
||||
print(message.to_text()) # {"text": "Hello", "number": 42}
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pydantic import BaseModel
|
||||
from autogen_agentchat.messages import StructuredMessage
|
||||
|
||||
|
||||
class MyMessageContent(BaseModel):
|
||||
text: str
|
||||
number: int
|
||||
|
||||
|
||||
message = StructuredMessage[MyMessageContent](
|
||||
content=MyMessageContent(text="Hello", number=42),
|
||||
source="agent",
|
||||
format_string="Hello, {text} {number}!",
|
||||
)
|
||||
|
||||
print(message.to_text()) # Hello, agent 42!
|
||||
|
||||
"""
|
||||
|
||||
content: StructuredContentType
|
||||
"""The content of the message. Must be a subclass of
|
||||
`Pydantic BaseModel <https://docs.pydantic.dev/latest/concepts/models/>`_."""
|
||||
|
||||
format_string: Optional[str] = None
|
||||
"""(Experimental) An optional format string to render the content into a human-readable format.
|
||||
The format string can use the fields of the content model as placeholders.
|
||||
For example, if the content model has a field `name`, you can use
|
||||
`{name}` in the format string to include the value of that field.
|
||||
The format string is used in the :meth:`to_text` method to create a
|
||||
human-readable representation of the message.
|
||||
This setting is experimental and will change in the future.
|
||||
"""
|
||||
|
||||
@computed_field
|
||||
def type(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
def to_text(self) -> str:
|
||||
return self.content.model_dump_json(indent=2)
|
||||
if self.format_string is not None:
|
||||
return self.format_string.format(**self.content.model_dump())
|
||||
else:
|
||||
return self.content.model_dump_json()
|
||||
|
||||
def to_model_text(self) -> str:
|
||||
return self.content.model_dump_json()
|
||||
if self.format_string is not None:
|
||||
return self.format_string.format(**self.content.model_dump())
|
||||
else:
|
||||
return self.content.model_dump_json()
|
||||
|
||||
def to_model_message(self) -> UserMessage:
|
||||
return UserMessage(
|
||||
@ -205,6 +241,113 @@ class StructuredMessage(BaseChatMessage, Generic[StructuredContentType]):
|
||||
)
|
||||
|
||||
|
||||
class StructureMessageConfig(BaseModel):
|
||||
"""The declarative configuration for the structured output."""
|
||||
|
||||
json_schema: Dict[str, Any]
|
||||
format_string: Optional[str] = None
|
||||
content_model_name: str
|
||||
|
||||
|
||||
class StructuredMessageFactory(ComponentBase[StructureMessageConfig], Component[StructureMessageConfig]):
|
||||
""":meta private:
|
||||
|
||||
A component that creates structured chat messages from Pydantic models or JSON schemas.
|
||||
|
||||
This component helps you generate strongly-typed chat messages with content defined using a Pydantic model.
|
||||
It can be used in declarative workflows where message structure must be validated, formatted, and serialized.
|
||||
|
||||
You can initialize the component directly using a `BaseModel` subclass, or dynamically from a configuration
|
||||
object (e.g., loaded from disk or a database).
|
||||
|
||||
### Example 1: Create from a Pydantic Model
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pydantic import BaseModel
|
||||
from autogen_agentchat.messages import StructuredMessageFactory
|
||||
|
||||
|
||||
class TestContent(BaseModel):
|
||||
field1: str
|
||||
field2: int
|
||||
|
||||
|
||||
format_string = "This is a string {field1} and this is an int {field2}"
|
||||
sm_component = StructuredMessageFactory(input_model=TestContent, format_string=format_string)
|
||||
|
||||
message = sm_component.StructuredMessage(
|
||||
source="test_agent", content=TestContent(field1="Hello", field2=42), format_string=format_string
|
||||
)
|
||||
|
||||
print(message.to_model_text()) # Output: This is a string Hello and this is an int 42
|
||||
|
||||
config = sm_component.dump_component()
|
||||
|
||||
s_m_dyn = StructuredMessageFactory.load_component(config)
|
||||
message = s_m_dyn.StructuredMessage(
|
||||
source="test_agent",
|
||||
content=s_m_dyn.ContentModel(field1="dyn agent", field2=43),
|
||||
format_string=s_m_dyn.format_string,
|
||||
)
|
||||
print(type(message)) # StructuredMessage[GeneratedModel]
|
||||
print(message.to_model_text()) # Output: This is a string dyn agent and this is an int 43
|
||||
|
||||
Attributes:
|
||||
component_config_schema (StructureMessageConfig): Defines the configuration structure for this component.
|
||||
component_provider_override (str): Path used to reference this component in external tooling.
|
||||
component_type (str): Identifier used for categorization (e.g., "structured_message").
|
||||
|
||||
Raises:
|
||||
ValueError: If neither `json_schema` nor `input_model` is provided.
|
||||
|
||||
Args:
|
||||
json_schema (Optional[str]): JSON schema to dynamically create a Pydantic model.
|
||||
input_model (Optional[Type[BaseModel]]): A subclass of `BaseModel` that defines the expected message structure.
|
||||
format_string (Optional[str]): Optional string to render content into a human-readable format.
|
||||
content_model_name (Optional[str]): Optional name for the generated Pydantic model.
|
||||
"""
|
||||
|
||||
component_config_schema = StructureMessageConfig
|
||||
component_provider_override = "autogen_agentchat.messages.StructuredMessageFactory"
|
||||
component_type = "structured_message"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
json_schema: Optional[Dict[str, Any]] = None,
|
||||
input_model: Optional[Type[BaseModel]] = None,
|
||||
format_string: Optional[str] = None,
|
||||
content_model_name: Optional[str] = None,
|
||||
) -> None:
|
||||
self.format_string = format_string
|
||||
|
||||
if json_schema:
|
||||
self.ContentModel = schema_to_pydantic_model(
|
||||
json_schema, model_name=content_model_name or "GeneratedContentModel"
|
||||
)
|
||||
elif input_model:
|
||||
self.ContentModel = input_model
|
||||
else:
|
||||
raise ValueError("Either `json_schema` or `input_model` must be provided.")
|
||||
|
||||
self.StructuredMessage = StructuredMessage[self.ContentModel] # type: ignore[name-defined]
|
||||
|
||||
def _to_config(self) -> StructureMessageConfig:
|
||||
return StructureMessageConfig(
|
||||
json_schema=self.ContentModel.model_json_schema(),
|
||||
format_string=self.format_string,
|
||||
content_model_name=self.ContentModel.__name__,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: StructureMessageConfig) -> "StructuredMessageFactory":
|
||||
return cls(
|
||||
json_schema=config.json_schema,
|
||||
format_string=config.format_string,
|
||||
content_model_name=config.content_model_name,
|
||||
)
|
||||
|
||||
|
||||
class TextMessage(BaseTextChatMessage):
|
||||
"""A text message with string-only content."""
|
||||
|
||||
@ -468,6 +611,7 @@ __all__ = [
|
||||
"BaseTextChatMessage",
|
||||
"StructuredContentType",
|
||||
"StructuredMessage",
|
||||
"StructuredMessageFactory",
|
||||
"HandoffMessage",
|
||||
"MultiModalMessage",
|
||||
"StopMessage",
|
||||
|
@ -21,6 +21,7 @@ from ...messages import (
|
||||
MessageFactory,
|
||||
ModelClientStreamingChunkEvent,
|
||||
StopMessage,
|
||||
StructuredMessage,
|
||||
TextMessage,
|
||||
)
|
||||
from ...state import TeamState
|
||||
@ -68,6 +69,15 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
||||
for message_type in custom_message_types:
|
||||
self._message_factory.register(message_type)
|
||||
|
||||
for agent in participants:
|
||||
for message_type in agent.produced_message_types:
|
||||
try:
|
||||
if issubclass(message_type, StructuredMessage):
|
||||
self._message_factory.register(message_type) # type: ignore[reportUnknownArgumentType]
|
||||
except TypeError:
|
||||
# Not a class or not a valid subclassable type (skip)
|
||||
pass
|
||||
|
||||
# The team ID is a UUID that is used to identify the team and its participants
|
||||
# in the agent runtime. It is used to create unique topic types for each participant.
|
||||
# Currently, team ID is binded to an object instance of the group chat class.
|
||||
|
@ -36,7 +36,7 @@ from autogen_core.models._model_client import ModelFamily
|
||||
from autogen_core.tools import BaseTool, FunctionTool
|
||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||
from autogen_ext.models.replay import ReplayChatCompletionClient
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from utils import FileLogHandler
|
||||
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
@ -1104,3 +1104,103 @@ async def test_model_client_stream_with_tool_calls() -> None:
|
||||
elif isinstance(message, ModelClientStreamingChunkEvent):
|
||||
chunks.append(message.content)
|
||||
assert "".join(chunks) == "Example response 2 to task"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_structured_output_format() -> None:
|
||||
class AgentResponse(BaseModel):
|
||||
response: str
|
||||
status: str
|
||||
|
||||
model_client = ReplayChatCompletionClient(
|
||||
[
|
||||
CreateResult(
|
||||
finish_reason="stop",
|
||||
content='{"response": "Hello"}',
|
||||
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
||||
cached=False,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
agent = AssistantAgent(
|
||||
name="assistant",
|
||||
model_client=model_client,
|
||||
output_content_type=AgentResponse,
|
||||
)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
await agent.run()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_message_factory_serialization() -> None:
|
||||
class AgentResponse(BaseModel):
|
||||
result: str
|
||||
status: str
|
||||
|
||||
model_client = ReplayChatCompletionClient(
|
||||
[
|
||||
CreateResult(
|
||||
finish_reason="stop",
|
||||
content=AgentResponse(result="All good", status="ok").model_dump_json(),
|
||||
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
||||
cached=False,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
agent = AssistantAgent(
|
||||
name="structured_agent",
|
||||
model_client=model_client,
|
||||
output_content_type=AgentResponse,
|
||||
output_content_type_format="{result} - {status}",
|
||||
)
|
||||
|
||||
dumped = agent.dump_component()
|
||||
restored_agent = AssistantAgent.load_component(dumped)
|
||||
result = await restored_agent.run()
|
||||
|
||||
assert isinstance(result.messages[0], StructuredMessage)
|
||||
assert result.messages[0].content.result == "All good" # type: ignore[reportUnknownMemberType]
|
||||
assert result.messages[0].content.status == "ok" # type: ignore[reportUnknownMemberType]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_message_format_string() -> None:
|
||||
class AgentResponse(BaseModel):
|
||||
field1: str
|
||||
field2: str
|
||||
|
||||
expected = AgentResponse(field1="foo", field2="bar")
|
||||
|
||||
model_client = ReplayChatCompletionClient(
|
||||
[
|
||||
CreateResult(
|
||||
finish_reason="stop",
|
||||
content=expected.model_dump_json(),
|
||||
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
||||
cached=False,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
agent = AssistantAgent(
|
||||
name="formatted_agent",
|
||||
model_client=model_client,
|
||||
output_content_type=AgentResponse,
|
||||
output_content_type_format="{field1} - {field2}",
|
||||
)
|
||||
|
||||
result = await agent.run()
|
||||
|
||||
assert len(result.messages) == 1
|
||||
message = result.messages[0]
|
||||
|
||||
# Check that it's a StructuredMessage with the correct content model
|
||||
assert isinstance(message, StructuredMessage)
|
||||
assert isinstance(message.content, AgentResponse) # type: ignore[reportUnknownMemberType]
|
||||
assert message.content == expected
|
||||
|
||||
# Check that the format_string was applied correctly
|
||||
assert message.to_model_text() == "foo - bar"
|
||||
|
@ -1441,3 +1441,87 @@ async def test_declarative_groupchats_with_config(runtime: AgentRuntime | None)
|
||||
assert selector.dump_component().provider == "autogen_agentchat.teams.SelectorGroupChat"
|
||||
assert swarm.dump_component().provider == "autogen_agentchat.teams.Swarm"
|
||||
assert magentic.dump_component().provider == "autogen_agentchat.teams.MagenticOneGroupChat"
|
||||
|
||||
|
||||
class _StructuredContent(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class _StructuredAgent(BaseChatAgent):
|
||||
def __init__(self, name: str, description: str) -> None:
|
||||
super().__init__(name, description)
|
||||
self._message = _StructuredContent(message="Structured hello")
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
||||
return (StructuredMessage[_StructuredContent],)
|
||||
|
||||
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
return Response(
|
||||
chat_message=StructuredMessage[_StructuredContent](
|
||||
source=self.name,
|
||||
content=self._message,
|
||||
format_string="Structured says: {message}",
|
||||
)
|
||||
)
|
||||
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_type_auto_registration(runtime: AgentRuntime | None) -> None:
|
||||
agent1 = _StructuredAgent("structured", description="emits structured messages")
|
||||
agent2 = _EchoAgent("echo", description="echoes input")
|
||||
|
||||
team = RoundRobinGroupChat(participants=[agent1, agent2], max_turns=2, runtime=runtime)
|
||||
|
||||
result = await team.run(task="Say something structured")
|
||||
|
||||
assert len(result.messages) == 3
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert isinstance(result.messages[1], StructuredMessage)
|
||||
assert isinstance(result.messages[2], TextMessage)
|
||||
assert result.messages[1].to_text() == "Structured says: Structured hello"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_message_state_roundtrip(runtime: AgentRuntime | None) -> None:
|
||||
agent1 = _StructuredAgent("structured", description="sends structured")
|
||||
agent2 = _EchoAgent("echo", description="echoes")
|
||||
|
||||
team1 = RoundRobinGroupChat(
|
||||
participants=[agent1, agent2],
|
||||
termination_condition=MaxMessageTermination(2),
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
await team1.run(task="Say something structured")
|
||||
state1 = await team1.save_state()
|
||||
|
||||
# Recreate team without needing custom_message_types
|
||||
agent3 = _StructuredAgent("structured", description="sends structured")
|
||||
agent4 = _EchoAgent("echo", description="echoes")
|
||||
team2 = RoundRobinGroupChat(
|
||||
participants=[agent3, agent4],
|
||||
termination_condition=MaxMessageTermination(2),
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
await team2.load_state(state1)
|
||||
state2 = await team2.save_state()
|
||||
|
||||
# Assert full state equality
|
||||
assert state1 == state2
|
||||
|
||||
# Assert message thread content match
|
||||
manager1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||
AgentId(f"{team1._group_chat_manager_name}_{team1._team_id}", team1._team_id), # pyright: ignore
|
||||
RoundRobinGroupChatManager,
|
||||
)
|
||||
manager2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||
AgentId(f"{team2._group_chat_manager_name}_{team2._team_id}", team2._team_id), # pyright: ignore
|
||||
RoundRobinGroupChatManager,
|
||||
)
|
||||
|
||||
assert manager1._message_thread == manager2._message_thread # pyright: ignore
|
||||
|
@ -11,6 +11,7 @@ from autogen_agentchat.messages import (
|
||||
MultiModalMessage,
|
||||
StopMessage,
|
||||
StructuredMessage,
|
||||
StructuredMessageFactory,
|
||||
TextMessage,
|
||||
ToolCallExecutionEvent,
|
||||
ToolCallRequestEvent,
|
||||
@ -52,6 +53,28 @@ def test_structured_message() -> None:
|
||||
assert dumped_message["type"] == "StructuredMessage[TestContent]"
|
||||
|
||||
|
||||
def test_structured_message_component() -> None:
|
||||
# Create a structured message with the test contentformat_string="this is a string {field1} and this is an int {field2}"
|
||||
format_string = "this is a string {field1} and this is an int {field2}"
|
||||
s_m = StructuredMessageFactory(input_model=TestContent, format_string=format_string)
|
||||
config = s_m.dump_component()
|
||||
s_m_dyn = StructuredMessageFactory.load_component(config)
|
||||
message = s_m_dyn.StructuredMessage(
|
||||
source="test_agent", content=s_m_dyn.ContentModel(field1="test", field2=42), format_string=s_m_dyn.format_string
|
||||
)
|
||||
|
||||
assert isinstance(message.content, s_m_dyn.ContentModel)
|
||||
assert not isinstance(message.content, TestContent)
|
||||
assert message.content.field1 == "test" # type: ignore[attr-defined]
|
||||
assert message.content.field2 == 42 # type: ignore[attr-defined]
|
||||
|
||||
dumped_message = message.model_dump()
|
||||
assert dumped_message["source"] == "test_agent"
|
||||
assert dumped_message["content"]["field1"] == "test"
|
||||
assert dumped_message["content"]["field2"] == 42
|
||||
assert message.to_model_text() == format_string.format(field1="test", field2=42)
|
||||
|
||||
|
||||
def test_message_factory() -> None:
|
||||
factory = MessageFactory()
|
||||
|
||||
@ -109,6 +132,22 @@ def test_message_factory() -> None:
|
||||
assert structured_message.content.field2 == 42
|
||||
assert structured_message.type == "StructuredMessage[TestContent]" # type: ignore[comparison-overlap]
|
||||
|
||||
sm_factory = StructuredMessageFactory(input_model=TestContent, format_string=None, content_model_name="TestContent")
|
||||
config = sm_factory.dump_component()
|
||||
config.config["content_model_name"] = "DynamicTestContent"
|
||||
sm_factory_dynamic = StructuredMessageFactory.load_component(config)
|
||||
|
||||
factory.register(sm_factory_dynamic.StructuredMessage)
|
||||
msg = sm_factory_dynamic.StructuredMessage(
|
||||
content=sm_factory_dynamic.ContentModel(field1="static", field2=123), source="static_agent"
|
||||
)
|
||||
restored = factory.create(msg.dump())
|
||||
assert isinstance(restored, StructuredMessage)
|
||||
assert isinstance(restored.content, sm_factory_dynamic.ContentModel) # type: ignore[reportUnkownMemberType]
|
||||
assert restored.source == "static_agent"
|
||||
assert restored.content.field1 == "static" # type: ignore[attr-defined]
|
||||
assert restored.content.field2 == 123 # type: ignore[attr-defined]
|
||||
|
||||
|
||||
class TestContainer(BaseModel):
|
||||
chat_messages: List[ChatMessage]
|
||||
|
@ -0,0 +1,3 @@
|
||||
from ._json_to_pydantic import schema_to_pydantic_model
|
||||
|
||||
__all__ = ["schema_to_pydantic_model"]
|
@ -0,0 +1,531 @@
|
||||
import datetime
|
||||
from ipaddress import IPv4Address, IPv6Address
|
||||
from typing import Annotated, Any, Dict, ForwardRef, List, Literal, Optional, Type, Union, cast
|
||||
|
||||
from pydantic import (
|
||||
UUID1,
|
||||
UUID3,
|
||||
UUID4,
|
||||
UUID5,
|
||||
AnyUrl,
|
||||
BaseModel,
|
||||
EmailStr,
|
||||
Field,
|
||||
conbytes,
|
||||
confloat,
|
||||
conint,
|
||||
conlist,
|
||||
constr,
|
||||
create_model,
|
||||
)
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
|
||||
class SchemaConversionError(Exception):
|
||||
"""Base class for schema conversion exceptions."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ReferenceNotFoundError(SchemaConversionError):
|
||||
"""Raised when a $ref cannot be resolved."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class FormatNotSupportedError(SchemaConversionError):
|
||||
"""Raised when a format is not supported."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class UnsupportedKeywordError(SchemaConversionError):
|
||||
"""Raised when an unsupported JSON Schema keyword is encountered."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
TYPE_MAPPING: Dict[str, Type[Any]] = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"boolean": bool,
|
||||
"number": float,
|
||||
"array": List,
|
||||
"object": dict,
|
||||
"null": type(None),
|
||||
}
|
||||
|
||||
FORMAT_MAPPING: Dict[str, Any] = {
|
||||
"uuid": UUID4,
|
||||
"uuid1": UUID1,
|
||||
"uuid2": UUID4,
|
||||
"uuid3": UUID3,
|
||||
"uuid4": UUID4,
|
||||
"uuid5": UUID5,
|
||||
"email": EmailStr,
|
||||
"uri": AnyUrl,
|
||||
"hostname": constr(strict=True),
|
||||
"ipv4": IPv4Address,
|
||||
"ipv6": IPv6Address,
|
||||
"ipv4-network": IPv4Address,
|
||||
"ipv6-network": IPv6Address,
|
||||
"date-time": datetime.datetime,
|
||||
"date": datetime.date,
|
||||
"time": datetime.time,
|
||||
"duration": datetime.timedelta,
|
||||
"int32": conint(strict=True, ge=-(2**31), le=2**31 - 1),
|
||||
"int64": conint(strict=True, ge=-(2**63), le=2**63 - 1),
|
||||
"float": confloat(strict=True),
|
||||
"double": float,
|
||||
"decimal": float,
|
||||
"byte": conbytes(strict=True),
|
||||
"binary": conbytes(strict=True),
|
||||
"password": str,
|
||||
"path": str,
|
||||
}
|
||||
|
||||
|
||||
def _make_field(
|
||||
default: Any,
|
||||
*,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""Construct a Pydantic Field with proper typing."""
|
||||
field_kwargs: Dict[str, Any] = {}
|
||||
if title is not None:
|
||||
field_kwargs["title"] = title
|
||||
if description is not None:
|
||||
field_kwargs["description"] = description
|
||||
return Field(default, **field_kwargs)
|
||||
|
||||
|
||||
class _JSONSchemaToPydantic:
|
||||
def __init__(self) -> None:
|
||||
self._model_cache: Dict[str, Optional[Union[Type[BaseModel], ForwardRef]]] = {}
|
||||
|
||||
def _resolve_ref(self, ref: str, schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
ref_key = ref.split("/")[-1]
|
||||
definitions = cast(dict[str, dict[str, Any]], schema.get("$defs", {}))
|
||||
|
||||
if ref_key not in definitions:
|
||||
raise ReferenceNotFoundError(
|
||||
f"Reference `{ref}` not found in `$defs`. Available keys: {list(definitions.keys())}"
|
||||
)
|
||||
|
||||
return definitions[ref_key]
|
||||
|
||||
def get_ref(self, ref_name: str) -> Any:
|
||||
if ref_name not in self._model_cache:
|
||||
raise ReferenceNotFoundError(
|
||||
f"Reference `{ref_name}` not found in cache. Available: {list(self._model_cache.keys())}"
|
||||
)
|
||||
|
||||
if self._model_cache[ref_name] is None:
|
||||
return ForwardRef(ref_name)
|
||||
|
||||
return self._model_cache[ref_name]
|
||||
|
||||
def _process_definitions(self, root_schema: Dict[str, Any]) -> None:
|
||||
if "$defs" in root_schema:
|
||||
for model_name in root_schema["$defs"]:
|
||||
if model_name not in self._model_cache:
|
||||
self._model_cache[model_name] = None
|
||||
|
||||
for model_name, model_schema in root_schema["$defs"].items():
|
||||
if self._model_cache[model_name] is None:
|
||||
self._model_cache[model_name] = self.json_schema_to_pydantic(model_schema, model_name, root_schema)
|
||||
|
||||
def json_schema_to_pydantic(
|
||||
self, schema: Dict[str, Any], model_name: str = "GeneratedModel", root_schema: Optional[Dict[str, Any]] = None
|
||||
) -> Type[BaseModel]:
|
||||
if root_schema is None:
|
||||
root_schema = schema
|
||||
self._process_definitions(root_schema)
|
||||
|
||||
if "$ref" in schema:
|
||||
resolved = self._resolve_ref(schema["$ref"], root_schema)
|
||||
schema = {**resolved, **{k: v for k, v in schema.items() if k != "$ref"}}
|
||||
|
||||
if "allOf" in schema:
|
||||
merged: Dict[str, Any] = {"type": "object", "properties": {}, "required": []}
|
||||
for s in schema["allOf"]:
|
||||
part = self._resolve_ref(s["$ref"], root_schema) if "$ref" in s else s
|
||||
merged["properties"].update(part.get("properties", {}))
|
||||
merged["required"].extend(part.get("required", []))
|
||||
for k, v in schema.items():
|
||||
if k not in {"allOf", "properties", "required"}:
|
||||
merged[k] = v
|
||||
merged["required"] = list(set(merged["required"]))
|
||||
schema = merged
|
||||
|
||||
return self._json_schema_to_model(schema, model_name, root_schema)
|
||||
|
||||
def _resolve_union_types(self, schemas: List[Dict[str, Any]]) -> List[Any]:
|
||||
types: List[Any] = []
|
||||
for s in schemas:
|
||||
if "$ref" in s:
|
||||
types.append(self.get_ref(s["$ref"].split("/")[-1]))
|
||||
elif "enum" in s:
|
||||
types.append(Literal[tuple(s["enum"])] if len(s["enum"]) > 0 else Any)
|
||||
else:
|
||||
json_type = s.get("type")
|
||||
if json_type not in TYPE_MAPPING:
|
||||
raise UnsupportedKeywordError(f"Unsupported or missing type `{json_type}` in union")
|
||||
types.append(TYPE_MAPPING[json_type])
|
||||
return types
|
||||
|
||||
def _extract_field_type(self, key: str, value: Dict[str, Any], model_name: str, root_schema: Dict[str, Any]) -> Any:
|
||||
json_type = value.get("type")
|
||||
if json_type not in TYPE_MAPPING:
|
||||
raise UnsupportedKeywordError(
|
||||
f"Unsupported or missing type `{json_type}` for field `{key}` in `{model_name}`"
|
||||
)
|
||||
|
||||
base_type = TYPE_MAPPING[json_type]
|
||||
constraints: Dict[str, Any] = {}
|
||||
|
||||
if json_type == "string":
|
||||
if "minLength" in value:
|
||||
constraints["min_length"] = value["minLength"]
|
||||
if "maxLength" in value:
|
||||
constraints["max_length"] = value["maxLength"]
|
||||
if "pattern" in value:
|
||||
constraints["pattern"] = value["pattern"]
|
||||
if constraints:
|
||||
base_type = constr(**constraints)
|
||||
|
||||
elif json_type == "integer":
|
||||
if "minimum" in value:
|
||||
constraints["ge"] = value["minimum"]
|
||||
if "maximum" in value:
|
||||
constraints["le"] = value["maximum"]
|
||||
if "exclusiveMinimum" in value:
|
||||
constraints["gt"] = value["exclusiveMinimum"]
|
||||
if "exclusiveMaximum" in value:
|
||||
constraints["lt"] = value["exclusiveMaximum"]
|
||||
if constraints:
|
||||
base_type = conint(**constraints)
|
||||
|
||||
elif json_type == "number":
|
||||
if "minimum" in value:
|
||||
constraints["ge"] = value["minimum"]
|
||||
if "maximum" in value:
|
||||
constraints["le"] = value["maximum"]
|
||||
if "exclusiveMinimum" in value:
|
||||
constraints["gt"] = value["exclusiveMinimum"]
|
||||
if "exclusiveMaximum" in value:
|
||||
constraints["lt"] = value["exclusiveMaximum"]
|
||||
if constraints:
|
||||
base_type = confloat(**constraints)
|
||||
|
||||
elif json_type == "array":
|
||||
if "minItems" in value:
|
||||
constraints["min_length"] = value["minItems"]
|
||||
if "maxItems" in value:
|
||||
constraints["max_length"] = value["maxItems"]
|
||||
item_schema = value.get("items", {"type": "string"})
|
||||
if "$ref" in item_schema:
|
||||
item_type = self.get_ref(item_schema["$ref"].split("/")[-1])
|
||||
else:
|
||||
item_type_name = item_schema.get("type")
|
||||
if item_type_name not in TYPE_MAPPING:
|
||||
raise UnsupportedKeywordError(
|
||||
f"Unsupported or missing item type `{item_type_name}` for array field `{key}` in `{model_name}`"
|
||||
)
|
||||
item_type = TYPE_MAPPING[item_type_name]
|
||||
|
||||
base_type = conlist(item_type, **constraints) if constraints else List[item_type] # type: ignore[valid-type]
|
||||
|
||||
if "format" in value:
|
||||
format_type = FORMAT_MAPPING.get(value["format"])
|
||||
if format_type is None:
|
||||
raise FormatNotSupportedError(f"Unknown format `{value['format']}` for `{key}` in `{model_name}`")
|
||||
if not isinstance(format_type, type):
|
||||
return format_type
|
||||
if not issubclass(format_type, str):
|
||||
return format_type
|
||||
return format_type
|
||||
|
||||
return base_type
|
||||
|
||||
def _json_schema_to_model(
|
||||
self, schema: Dict[str, Any], model_name: str, root_schema: Dict[str, Any]
|
||||
) -> Type[BaseModel]:
|
||||
if "allOf" in schema:
|
||||
merged: Dict[str, Any] = {"type": "object", "properties": {}, "required": []}
|
||||
for s in schema["allOf"]:
|
||||
part = self._resolve_ref(s["$ref"], root_schema) if "$ref" in s else s
|
||||
merged["properties"].update(part.get("properties", {}))
|
||||
merged["required"].extend(part.get("required", []))
|
||||
for k, v in schema.items():
|
||||
if k not in {"allOf", "properties", "required"}:
|
||||
merged[k] = v
|
||||
merged["required"] = list(set(merged["required"]))
|
||||
schema = merged
|
||||
|
||||
fields: Dict[str, tuple[Any, FieldInfo]] = {}
|
||||
required_fields = set(schema.get("required", []))
|
||||
|
||||
for key, value in schema.get("properties", {}).items():
|
||||
if "$ref" in value:
|
||||
ref_name = value["$ref"].split("/")[-1]
|
||||
field_type = self.get_ref(ref_name)
|
||||
elif "anyOf" in value:
|
||||
sub_models = self._resolve_union_types(value["anyOf"])
|
||||
field_type = Union[tuple(sub_models)]
|
||||
elif "oneOf" in value:
|
||||
sub_models = self._resolve_union_types(value["oneOf"])
|
||||
field_type = Union[tuple(sub_models)]
|
||||
if "discriminator" in value:
|
||||
discriminator = value["discriminator"]["propertyName"]
|
||||
field_type = Annotated[field_type, Field(discriminator=discriminator)]
|
||||
elif "enum" in value:
|
||||
field_type = Literal[tuple(value["enum"])]
|
||||
elif "allOf" in value:
|
||||
merged = {"type": "object", "properties": {}, "required": []}
|
||||
for s in value["allOf"]:
|
||||
part = self._resolve_ref(s["$ref"], root_schema) if "$ref" in s else s
|
||||
merged["properties"].update(part.get("properties", {}))
|
||||
merged["required"].extend(part.get("required", []))
|
||||
for k, v in value.items():
|
||||
if k not in {"allOf", "properties", "required"}:
|
||||
merged[k] = v
|
||||
merged["required"] = list(set(merged["required"]))
|
||||
field_type = self._json_schema_to_model(merged, f"{model_name}_{key}", root_schema)
|
||||
elif value.get("type") == "object" and "properties" in value:
|
||||
field_type = self._json_schema_to_model(value, f"{model_name}_{key}", root_schema)
|
||||
else:
|
||||
field_type = self._extract_field_type(key, value, model_name, root_schema)
|
||||
|
||||
if field_type is None:
|
||||
raise UnsupportedKeywordError(f"Unsupported or missing type for field `{key}` in `{model_name}`")
|
||||
|
||||
default_value = value.get("default")
|
||||
is_required = key in required_fields
|
||||
|
||||
if not is_required and default_value is None:
|
||||
field_type = Optional[field_type]
|
||||
|
||||
field_args = {
|
||||
"default": default_value if not is_required else ...,
|
||||
}
|
||||
if "title" in value:
|
||||
field_args["title"] = value["title"]
|
||||
if "description" in value:
|
||||
field_args["description"] = value["description"]
|
||||
|
||||
fields[key] = (
|
||||
field_type,
|
||||
_make_field(
|
||||
default_value if not is_required else ...,
|
||||
title=value.get("title"),
|
||||
description=value.get("description"),
|
||||
),
|
||||
)
|
||||
|
||||
model: Type[BaseModel] = create_model(model_name, **cast(dict[str, Any], fields))
|
||||
model.model_rebuild()
|
||||
return model
|
||||
|
||||
|
||||
def schema_to_pydantic_model(schema: Dict[str, Any], model_name: str = "GeneratedModel") -> Type[BaseModel]:
|
||||
"""
|
||||
Convert a JSON Schema dictionary to a fully-typed Pydantic model.
|
||||
|
||||
This function handles schema translation and validation logic to produce
|
||||
a Pydantic model.
|
||||
|
||||
**Supported JSON Schema Features**
|
||||
|
||||
- **Primitive types**: `string`, `integer`, `number`, `boolean`, `object`, `array`, `null`
|
||||
- **String formats**:
|
||||
- `email`, `uri`, `uuid`, `uuid1`, `uuid3`, `uuid4`, `uuid5`
|
||||
- `hostname`, `ipv4`, `ipv6`, `ipv4-network`, `ipv6-network`
|
||||
- `date`, `time`, `date-time`, `duration`
|
||||
- `byte`, `binary`, `password`, `path`
|
||||
- **String constraints**:
|
||||
- `minLength`, `maxLength`, `pattern`
|
||||
- **Numeric constraints**:
|
||||
- `minimum`, `maximum`, `exclusiveMinimum`, `exclusiveMaximum`
|
||||
- **Array constraints**:
|
||||
- `minItems`, `maxItems`, `items`
|
||||
- **Object schema support**:
|
||||
- `properties`, `required`, `title`, `description`, `default`
|
||||
- **Enums**:
|
||||
- Converted to Python `Literal` type
|
||||
- **Union types**:
|
||||
- `anyOf`, `oneOf` supported with optional `discriminator`
|
||||
- **Inheritance and composition**:
|
||||
- `allOf` merges multiple schemas into one model
|
||||
- **$ref and $defs resolution**:
|
||||
- Supports references to sibling definitions and self-referencing schemas
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from json_schema_to_pydantic import schema_to_pydantic_model
|
||||
|
||||
# Example 1: Simple user model
|
||||
schema = {
|
||||
"title": "User",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"email": {"type": "string", "format": "email"},
|
||||
"age": {"type": "integer", "minimum": 0},
|
||||
},
|
||||
"required": ["name", "email"],
|
||||
}
|
||||
|
||||
UserModel = schema_to_pydantic_model(schema)
|
||||
user = UserModel(name="Alice", email="alice@example.com", age=30)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Example 2: Nested model
|
||||
schema = {
|
||||
"title": "BlogPost",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"tags": {"type": "array", "items": {"type": "string"}},
|
||||
"author": {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}, "email": {"type": "string", "format": "email"}},
|
||||
"required": ["name"],
|
||||
},
|
||||
},
|
||||
"required": ["title", "author"],
|
||||
}
|
||||
|
||||
BlogPost = schema_to_pydantic_model(schema)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Example 3: allOf merging with $refs
|
||||
schema = {
|
||||
"title": "EmployeeWithDepartment",
|
||||
"allOf": [{"$ref": "#/$defs/Employee"}, {"$ref": "#/$defs/Department"}],
|
||||
"$defs": {
|
||||
"Employee": {
|
||||
"type": "object",
|
||||
"properties": {"id": {"type": "string"}, "name": {"type": "string"}},
|
||||
"required": ["id", "name"],
|
||||
},
|
||||
"Department": {
|
||||
"type": "object",
|
||||
"properties": {"department": {"type": "string"}},
|
||||
"required": ["department"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
Model = schema_to_pydantic_model(schema)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Example 4: Self-referencing (recursive) model
|
||||
schema = {
|
||||
"title": "Category",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"subcategories": {"type": "array", "items": {"$ref": "#/$defs/Category"}},
|
||||
},
|
||||
"required": ["name"],
|
||||
"$defs": {
|
||||
"Category": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"subcategories": {"type": "array", "items": {"$ref": "#/$defs/Category"}},
|
||||
},
|
||||
"required": ["name"],
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
Category = schema_to_pydantic_model(schema)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Example 5: Serializing and deserializing with Pydantic
|
||||
|
||||
from uuid import uuid4
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
from typing import Optional, List, Dict, Any
|
||||
from autogen_core.utils import schema_to_pydantic_model
|
||||
|
||||
|
||||
class Address(BaseModel):
|
||||
street: str
|
||||
city: str
|
||||
zipcode: str
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
email: EmailStr
|
||||
age: int = Field(..., ge=18)
|
||||
address: Address
|
||||
|
||||
|
||||
class Employee(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
manager: Optional["Employee"] = None
|
||||
|
||||
|
||||
class Department(BaseModel):
|
||||
name: str
|
||||
employees: List[Employee]
|
||||
|
||||
|
||||
class ComplexModel(BaseModel):
|
||||
user: User
|
||||
extra_info: Optional[Dict[str, Any]] = None
|
||||
sub_items: List[Employee]
|
||||
|
||||
|
||||
# Convert ComplexModel to JSON schema
|
||||
complex_schema = ComplexModel.model_json_schema()
|
||||
|
||||
# Rebuild a new Pydantic model from JSON schema
|
||||
ReconstructedModel = schema_to_pydantic_model(complex_schema, "ComplexModel")
|
||||
|
||||
# Instantiate reconstructed model
|
||||
reconstructed = ReconstructedModel(
|
||||
user={
|
||||
"id": str(uuid4()),
|
||||
"name": "Alice",
|
||||
"email": "alice@example.com",
|
||||
"age": 30,
|
||||
"address": {"street": "123 Main St", "city": "Wonderland", "zipcode": "12345"},
|
||||
},
|
||||
sub_items=[{"id": str(uuid4()), "name": "Bob", "manager": {"id": str(uuid4()), "name": "Eve"}}],
|
||||
)
|
||||
|
||||
print(reconstructed.model_dump())
|
||||
|
||||
|
||||
Args:
|
||||
schema (Dict[str, Any]): A valid JSON Schema dictionary.
|
||||
model_name (str, optional): The name of the root model. Defaults to "GeneratedModel".
|
||||
|
||||
Returns:
|
||||
Type[BaseModel]: A dynamically generated Pydantic model class.
|
||||
|
||||
Raises:
|
||||
ReferenceNotFoundError: If a `$ref` key references a missing entry.
|
||||
FormatNotSupportedError: If a `format` keyword is unknown or unsupported.
|
||||
UnsupportedKeywordError: If the schema contains an unsupported `type`.
|
||||
|
||||
See Also:
|
||||
- :class:`pydantic.BaseModel`
|
||||
- :func:`pydantic.create_model`
|
||||
- https://json-schema.org/
|
||||
"""
|
||||
...
|
||||
|
||||
return _JSONSchemaToPydantic().json_schema_to_pydantic(schema, model_name)
|
757
python/packages/autogen-core/tests/test_json_to_pydantic.py
Normal file
757
python/packages/autogen-core/tests/test_json_to_pydantic.py
Normal file
@ -0,0 +1,757 @@
|
||||
import types
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, get_args, get_origin
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from autogen_core.utils._json_to_pydantic import (
|
||||
FORMAT_MAPPING,
|
||||
TYPE_MAPPING,
|
||||
FormatNotSupportedError,
|
||||
ReferenceNotFoundError,
|
||||
UnsupportedKeywordError,
|
||||
_JSONSchemaToPydantic, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
from pydantic import BaseModel, EmailStr, Field, ValidationError
|
||||
|
||||
|
||||
# ✅ Define Pydantic models for testing
|
||||
class Address(BaseModel):
|
||||
street: str
|
||||
city: str
|
||||
zipcode: str
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
email: EmailStr
|
||||
age: int = Field(..., ge=18) # Minimum age = 18
|
||||
address: Address
|
||||
|
||||
|
||||
class Employee(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
manager: Optional["Employee"] = None # Recursive self-reference
|
||||
|
||||
|
||||
class Department(BaseModel):
|
||||
name: str
|
||||
employees: List[Employee] # Array of objects
|
||||
|
||||
|
||||
class ComplexModel(BaseModel):
|
||||
user: User
|
||||
extra_info: Optional[Dict[str, Any]] = None # Optional dictionary
|
||||
sub_items: List[Employee] # List of Employees
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def converter() -> _JSONSchemaToPydantic:
|
||||
"""Fixture to create a fresh instance of JSONSchemaToPydantic for every test."""
|
||||
return _JSONSchemaToPydantic()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_json_schema() -> Dict[str, Any]:
|
||||
"""Fixture that returns a JSON schema dynamically using model_json_schema()."""
|
||||
return User.model_json_schema()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_json_schema_recursive() -> Dict[str, Any]:
|
||||
"""Fixture that returns a self-referencing JSON schema."""
|
||||
return Employee.model_json_schema()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_json_schema_nested() -> Dict[str, Any]:
|
||||
"""Fixture that returns a nested schema with arrays of objects."""
|
||||
return Department.model_json_schema()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_json_schema_complex() -> Dict[str, Any]:
|
||||
"""Fixture that returns a complex schema with multiple structures."""
|
||||
return ComplexModel.model_json_schema()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"schema_fixture, model_name, expected_fields",
|
||||
[
|
||||
(sample_json_schema, "User", ["id", "name", "email", "age", "address"]),
|
||||
(sample_json_schema_recursive, "Employee", ["id", "name", "manager"]),
|
||||
(sample_json_schema_nested, "Department", ["name", "employees"]),
|
||||
(sample_json_schema_complex, "ComplexModel", ["user", "extra_info", "sub_items"]),
|
||||
],
|
||||
)
|
||||
def test_json_schema_to_pydantic(
|
||||
converter: _JSONSchemaToPydantic,
|
||||
schema_fixture: Any,
|
||||
model_name: str,
|
||||
expected_fields: List[str],
|
||||
request: Any,
|
||||
) -> None:
|
||||
"""Test conversion of JSON Schema to Pydantic model using the class instance."""
|
||||
schema = request.getfixturevalue(schema_fixture.__name__)
|
||||
Model = converter.json_schema_to_pydantic(schema, model_name)
|
||||
|
||||
for field in expected_fields:
|
||||
assert field in Model.__annotations__, f"Expected '{field}' missing in {model_name}Model"
|
||||
|
||||
|
||||
# ✅ **Valid Data Tests**
|
||||
@pytest.mark.parametrize(
|
||||
"schema_fixture, model_name, valid_data",
|
||||
[
|
||||
(
|
||||
sample_json_schema,
|
||||
"User",
|
||||
{
|
||||
"id": str(uuid4()),
|
||||
"name": "Alice",
|
||||
"email": "alice@example.com",
|
||||
"age": 25,
|
||||
"address": {"street": "123 Main St", "city": "Metropolis", "zipcode": "12345"},
|
||||
},
|
||||
),
|
||||
(
|
||||
sample_json_schema_recursive,
|
||||
"Employee",
|
||||
{
|
||||
"id": str(uuid4()),
|
||||
"name": "Alice",
|
||||
"manager": {
|
||||
"id": str(uuid4()),
|
||||
"name": "Bob",
|
||||
},
|
||||
},
|
||||
),
|
||||
(
|
||||
sample_json_schema_nested,
|
||||
"Department",
|
||||
{
|
||||
"name": "Engineering",
|
||||
"employees": [
|
||||
{
|
||||
"id": str(uuid4()),
|
||||
"name": "Alice",
|
||||
"manager": {
|
||||
"id": str(uuid4()),
|
||||
"name": "Bob",
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
),
|
||||
(
|
||||
sample_json_schema_complex,
|
||||
"ComplexModel",
|
||||
{
|
||||
"user": {
|
||||
"id": str(uuid4()),
|
||||
"name": "Charlie",
|
||||
"email": "charlie@example.com",
|
||||
"age": 30,
|
||||
"address": {"street": "456 Side St", "city": "Gotham", "zipcode": "67890"},
|
||||
},
|
||||
"extra_info": {"hobby": "Chess", "level": "Advanced"},
|
||||
"sub_items": [
|
||||
{"id": str(uuid4()), "name": "Eve"},
|
||||
{"id": str(uuid4()), "name": "David", "manager": {"id": str(uuid4()), "name": "Frank"}},
|
||||
],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_valid_data_model(
|
||||
converter: _JSONSchemaToPydantic,
|
||||
schema_fixture: Any,
|
||||
model_name: str,
|
||||
valid_data: Dict[str, Any],
|
||||
request: Any,
|
||||
) -> None:
|
||||
"""Test that valid data is accepted by the generated model."""
|
||||
schema = request.getfixturevalue(schema_fixture.__name__)
|
||||
Model = converter.json_schema_to_pydantic(schema, model_name)
|
||||
|
||||
instance = Model(**valid_data)
|
||||
assert instance
|
||||
dumped = instance.model_dump(mode="json", exclude_none=True)
|
||||
assert dumped == valid_data, f"Model output mismatch.\nExpected: {valid_data}\nGot: {dumped}"
|
||||
|
||||
|
||||
# ✅ **Invalid Data Tests**
|
||||
@pytest.mark.parametrize(
|
||||
"schema_fixture, model_name, invalid_data",
|
||||
[
|
||||
(
|
||||
sample_json_schema,
|
||||
"User",
|
||||
{
|
||||
"id": "not-a-uuid", # Invalid UUID
|
||||
"name": "Alice",
|
||||
"email": "not-an-email", # Invalid email
|
||||
"age": 17, # Below minimum
|
||||
"address": {"street": "123 Main St", "city": "Metropolis"},
|
||||
},
|
||||
),
|
||||
(
|
||||
sample_json_schema_recursive,
|
||||
"Employee",
|
||||
{
|
||||
"id": str(uuid4()),
|
||||
"name": "Alice",
|
||||
"manager": {
|
||||
"id": "not-a-uuid", # Invalid UUID
|
||||
"name": "Bob",
|
||||
},
|
||||
},
|
||||
),
|
||||
(
|
||||
sample_json_schema_nested,
|
||||
"Department",
|
||||
{
|
||||
"name": "Engineering",
|
||||
"employees": [
|
||||
{
|
||||
"id": "not-a-uuid", # Invalid UUID
|
||||
"name": "Alice",
|
||||
"manager": {
|
||||
"id": str(uuid4()),
|
||||
"name": "Bob",
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
),
|
||||
(
|
||||
sample_json_schema_complex,
|
||||
"ComplexModel",
|
||||
{
|
||||
"user": {
|
||||
"id": str(uuid4()),
|
||||
"name": "Charlie",
|
||||
"email": "charlie@example.com",
|
||||
"age": "thirty", # Invalid: Should be an int
|
||||
"address": {"street": "456 Side St", "city": "Gotham", "zipcode": "67890"},
|
||||
},
|
||||
"extra_info": "should-be-dictionary", # Invalid type
|
||||
"sub_items": [
|
||||
{"id": "invalid-uuid", "name": "Eve"}, # Invalid UUID
|
||||
{"id": str(uuid4()), "name": 123}, # Invalid name type
|
||||
],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_data_model(
|
||||
converter: _JSONSchemaToPydantic,
|
||||
schema_fixture: Any,
|
||||
model_name: str,
|
||||
invalid_data: Dict[str, Any],
|
||||
request: Any,
|
||||
) -> None:
|
||||
"""Test that invalid data raises ValidationError."""
|
||||
schema = request.getfixturevalue(schema_fixture.__name__)
|
||||
Model = converter.json_schema_to_pydantic(schema, model_name)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
Model(**invalid_data)
|
||||
|
||||
|
||||
class ListDictModel(BaseModel):
|
||||
"""Example for `List[Dict[str, Any]]`"""
|
||||
|
||||
data: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class DictListModel(BaseModel):
|
||||
"""Example for `Dict[str, List[Any]]`"""
|
||||
|
||||
mapping: Dict[str, List[Any]]
|
||||
|
||||
|
||||
class NestedListModel(BaseModel):
|
||||
"""Example for `List[List[str]]`"""
|
||||
|
||||
matrix: List[List[str]]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_json_schema_list_dict() -> Dict[str, Any]:
|
||||
"""Fixture for `List[Dict[str, Any]]`"""
|
||||
return ListDictModel.model_json_schema()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_json_schema_dict_list() -> Dict[str, Any]:
|
||||
"""Fixture for `Dict[str, List[Any]]`"""
|
||||
return DictListModel.model_json_schema()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_json_schema_nested_list() -> Dict[str, Any]:
|
||||
"""Fixture for `List[List[str]]`"""
|
||||
return NestedListModel.model_json_schema()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"schema_fixture, model_name, expected_fields",
|
||||
[
|
||||
(sample_json_schema_list_dict, "ListDictModel", ["data"]),
|
||||
(sample_json_schema_dict_list, "DictListModel", ["mapping"]),
|
||||
(sample_json_schema_nested_list, "NestedListModel", ["matrix"]),
|
||||
],
|
||||
)
|
||||
def test_json_schema_to_pydantic_nested(
|
||||
converter: _JSONSchemaToPydantic,
|
||||
schema_fixture: Any,
|
||||
model_name: str,
|
||||
expected_fields: list[str],
|
||||
request: Any,
|
||||
) -> None:
|
||||
"""Test conversion of JSON Schema to Pydantic model using the class instance."""
|
||||
schema = request.getfixturevalue(schema_fixture.__name__)
|
||||
Model = converter.json_schema_to_pydantic(schema, model_name)
|
||||
|
||||
for field in expected_fields:
|
||||
assert field in Model.__annotations__, f"Expected '{field}' missing in {model_name}Model"
|
||||
|
||||
|
||||
# ✅ **Valid Data Tests**
|
||||
@pytest.mark.parametrize(
|
||||
"schema_fixture, model_name, valid_data",
|
||||
[
|
||||
(
|
||||
sample_json_schema_list_dict,
|
||||
"ListDictModel",
|
||||
{
|
||||
"data": [
|
||||
{"key1": "value1", "key2": 10},
|
||||
{"another_key": False, "nested": {"subkey": "data"}},
|
||||
]
|
||||
},
|
||||
),
|
||||
(
|
||||
sample_json_schema_dict_list,
|
||||
"DictListModel",
|
||||
{
|
||||
"mapping": {
|
||||
"first": ["a", "b", "c"],
|
||||
"second": [1, 2, 3, 4],
|
||||
"third": [True, False, True],
|
||||
}
|
||||
},
|
||||
),
|
||||
(
|
||||
sample_json_schema_nested_list,
|
||||
"NestedListModel",
|
||||
{"matrix": [["A", "B"], ["C", "D"], ["E", "F"]]},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_valid_data_model_nested(
|
||||
converter: _JSONSchemaToPydantic,
|
||||
schema_fixture: Any,
|
||||
model_name: str,
|
||||
valid_data: Dict[str, Any],
|
||||
request: Any,
|
||||
) -> None:
|
||||
"""Test that valid data is accepted by the generated model."""
|
||||
schema = request.getfixturevalue(schema_fixture.__name__)
|
||||
Model = converter.json_schema_to_pydantic(schema, model_name)
|
||||
|
||||
instance = Model(**valid_data)
|
||||
assert instance
|
||||
for field, value in valid_data.items():
|
||||
assert (
|
||||
getattr(instance, field) == value
|
||||
), f"Mismatch in field `{field}`: expected `{value}`, got `{getattr(instance, field)}`"
|
||||
|
||||
|
||||
# ✅ **Invalid Data Tests**
|
||||
@pytest.mark.parametrize(
|
||||
"schema_fixture, model_name, invalid_data",
|
||||
[
|
||||
(
|
||||
sample_json_schema_list_dict,
|
||||
"ListDictModel",
|
||||
{
|
||||
"data": "should-be-a-list", # ❌ Should be a list of dicts
|
||||
},
|
||||
),
|
||||
(
|
||||
sample_json_schema_dict_list,
|
||||
"DictListModel",
|
||||
{
|
||||
"mapping": [
|
||||
"should-be-a-dictionary", # ❌ Should be a dict of lists
|
||||
]
|
||||
},
|
||||
),
|
||||
(
|
||||
sample_json_schema_nested_list,
|
||||
"NestedListModel",
|
||||
{"matrix": [["A", "B"], "C", ["D", "E"]]}, # ❌ "C" is not a list
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_data_model_nested(
|
||||
converter: _JSONSchemaToPydantic,
|
||||
schema_fixture: Any,
|
||||
model_name: str,
|
||||
invalid_data: Dict[str, Any],
|
||||
request: Any,
|
||||
) -> None:
|
||||
"""Test that invalid data raises ValidationError."""
|
||||
schema = request.getfixturevalue(schema_fixture.__name__)
|
||||
Model = converter.json_schema_to_pydantic(schema, model_name)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
Model(**invalid_data)
|
||||
|
||||
|
||||
def test_reference_not_found(converter: _JSONSchemaToPydantic) -> None:
|
||||
schema = {"type": "object", "properties": {"manager": {"$ref": "#/$defs/MissingRef"}}}
|
||||
with pytest.raises(ReferenceNotFoundError):
|
||||
converter.json_schema_to_pydantic(schema, "MissingRefModel")
|
||||
|
||||
|
||||
def test_format_not_supported(converter: _JSONSchemaToPydantic) -> None:
|
||||
schema = {"type": "object", "properties": {"custom_field": {"type": "string", "format": "unsupported-format"}}}
|
||||
with pytest.raises(FormatNotSupportedError):
|
||||
converter.json_schema_to_pydantic(schema, "UnsupportedFormatModel")
|
||||
|
||||
|
||||
def test_unsupported_keyword(converter: _JSONSchemaToPydantic) -> None:
|
||||
schema = {"type": "object", "properties": {"broken_field": {"title": "Missing type"}}}
|
||||
with pytest.raises(UnsupportedKeywordError):
|
||||
converter.json_schema_to_pydantic(schema, "MissingTypeModel")
|
||||
|
||||
|
||||
def test_enum_field_schema() -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"type": "string", "enum": ["pending", "approved", "rejected"]},
|
||||
"priority": {"type": "integer", "enum": [1, 2, 3]},
|
||||
},
|
||||
"required": ["status"],
|
||||
}
|
||||
|
||||
converter: _JSONSchemaToPydantic = _JSONSchemaToPydantic()
|
||||
Model = converter.json_schema_to_pydantic(schema, "Task")
|
||||
|
||||
status_ann = Model.model_fields["status"].annotation
|
||||
assert get_origin(status_ann) is Literal
|
||||
assert set(get_args(status_ann)) == {"pending", "approved", "rejected"}
|
||||
|
||||
priority_ann = Model.model_fields["priority"].annotation
|
||||
args = get_args(priority_ann)
|
||||
assert type(None) in args
|
||||
assert Literal[1, 2, 3] in args
|
||||
|
||||
instance = Model(status="approved", priority=2)
|
||||
assert instance.status == "approved" # type: ignore[attr-defined]
|
||||
assert instance.priority == 2 # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_metadata_title_description(converter: _JSONSchemaToPydantic) -> None:
|
||||
schema = {
|
||||
"title": "CustomerProfile",
|
||||
"description": "A profile containing personal and contact info",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"first_name": {"type": "string", "title": "First Name", "description": "Given name of the user"},
|
||||
"age": {"type": "integer", "title": "Age", "description": "Age in years"},
|
||||
"contact": {
|
||||
"type": "object",
|
||||
"title": "Contact Information",
|
||||
"description": "How to reach the user",
|
||||
"properties": {
|
||||
"email": {
|
||||
"type": "string",
|
||||
"format": "email",
|
||||
"title": "Email Address",
|
||||
"description": "Primary email",
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": ["first_name"],
|
||||
}
|
||||
|
||||
Model: Type[BaseModel] = converter.json_schema_to_pydantic(schema, "CustomerProfile")
|
||||
generated_schema = Model.model_json_schema()
|
||||
|
||||
assert generated_schema["title"] == "CustomerProfile"
|
||||
|
||||
props = generated_schema["properties"]
|
||||
assert props["first_name"]["title"] == "First Name"
|
||||
assert props["first_name"]["description"] == "Given name of the user"
|
||||
assert props["age"]["title"] == "Age"
|
||||
assert props["age"]["description"] == "Age in years"
|
||||
|
||||
contact = props["contact"]
|
||||
assert contact["title"] == "Contact Information"
|
||||
assert contact["description"] == "How to reach the user"
|
||||
|
||||
# Follow the $ref
|
||||
ref_key = contact["anyOf"][0]["$ref"].split("/")[-1]
|
||||
contact_def = generated_schema["$defs"][ref_key]
|
||||
email = contact_def["properties"]["email"]
|
||||
assert email["title"] == "Email Address"
|
||||
assert email["description"] == "Primary email"
|
||||
|
||||
|
||||
def test_oneof_with_discriminator(converter: _JSONSchemaToPydantic) -> None:
|
||||
schema = {
|
||||
"title": "PetWrapper",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pet": {
|
||||
"oneOf": [{"$ref": "#/$defs/Cat"}, {"$ref": "#/$defs/Dog"}],
|
||||
"discriminator": {"propertyName": "pet_type"},
|
||||
}
|
||||
},
|
||||
"required": ["pet"],
|
||||
"$defs": {
|
||||
"Cat": {
|
||||
"type": "object",
|
||||
"properties": {"pet_type": {"type": "string", "enum": ["cat"]}, "hunting_skill": {"type": "string"}},
|
||||
"required": ["pet_type", "hunting_skill"],
|
||||
"title": "Cat",
|
||||
},
|
||||
"Dog": {
|
||||
"type": "object",
|
||||
"properties": {"pet_type": {"type": "string", "enum": ["dog"]}, "pack_size": {"type": "integer"}},
|
||||
"required": ["pet_type", "pack_size"],
|
||||
"title": "Dog",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
Model = converter.json_schema_to_pydantic(schema, "PetWrapper")
|
||||
|
||||
# Instantiate with a Cat
|
||||
cat = Model(pet={"pet_type": "cat", "hunting_skill": "expert"})
|
||||
assert cat.pet.pet_type == "cat" # type: ignore[attr-defined]
|
||||
|
||||
# Instantiate with a Dog
|
||||
dog = Model(pet={"pet_type": "dog", "pack_size": 4})
|
||||
assert dog.pet.pet_type == "dog" # type: ignore[attr-defined]
|
||||
|
||||
# Check round-trip schema includes discriminator
|
||||
model_schema = Model.model_json_schema()
|
||||
assert "discriminator" in model_schema["properties"]["pet"]
|
||||
assert model_schema["properties"]["pet"]["discriminator"]["propertyName"] == "pet_type"
|
||||
|
||||
|
||||
def test_allof_merging_with_refs(converter: _JSONSchemaToPydantic) -> None:
|
||||
schema = {
|
||||
"title": "EmployeeWithDepartment",
|
||||
"allOf": [{"$ref": "#/$defs/Employee"}, {"$ref": "#/$defs/Department"}],
|
||||
"$defs": {
|
||||
"Employee": {
|
||||
"type": "object",
|
||||
"properties": {"id": {"type": "string"}, "name": {"type": "string"}},
|
||||
"required": ["id", "name"],
|
||||
"title": "Employee",
|
||||
},
|
||||
"Department": {
|
||||
"type": "object",
|
||||
"properties": {"department": {"type": "string"}},
|
||||
"required": ["department"],
|
||||
"title": "Department",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
Model = converter.json_schema_to_pydantic(schema, "EmployeeWithDepartment")
|
||||
instance = Model(id="123", name="Alice", department="Engineering")
|
||||
assert instance.id == "123" # type: ignore[attr-defined]
|
||||
assert instance.name == "Alice" # type: ignore[attr-defined]
|
||||
assert instance.department == "Engineering" # type: ignore[attr-defined]
|
||||
|
||||
dumped = instance.model_dump()
|
||||
assert dumped == {"id": "123", "name": "Alice", "department": "Engineering"}
|
||||
|
||||
|
||||
def test_nested_allof_merging(converter: _JSONSchemaToPydantic) -> None:
|
||||
schema = {
|
||||
"title": "ContainerModel",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"nested": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {
|
||||
"allOf": [
|
||||
{"$ref": "#/$defs/Base"},
|
||||
{"type": "object", "properties": {"extra": {"type": "string"}}, "required": ["extra"]},
|
||||
]
|
||||
}
|
||||
},
|
||||
"required": ["data"],
|
||||
}
|
||||
},
|
||||
"required": ["nested"],
|
||||
"$defs": {
|
||||
"Base": {
|
||||
"type": "object",
|
||||
"properties": {"base_field": {"type": "string"}},
|
||||
"required": ["base_field"],
|
||||
"title": "Base",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
Model = converter.json_schema_to_pydantic(schema, "ContainerModel")
|
||||
instance = Model(nested={"data": {"base_field": "abc", "extra": "xyz"}})
|
||||
|
||||
assert instance.nested.data.base_field == "abc" # type: ignore[attr-defined]
|
||||
assert instance.nested.data.extra == "xyz" # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"schema, field_name, valid_values, invalid_values",
|
||||
[
|
||||
# String constraints
|
||||
(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"username": {"type": "string", "minLength": 3, "maxLength": 10, "pattern": "^[a-zA-Z0-9_]+$"}
|
||||
},
|
||||
"required": ["username"],
|
||||
},
|
||||
"username",
|
||||
["user_123", "abc", "Name2023"],
|
||||
["", "ab", "toolongusername123", "invalid!char"],
|
||||
),
|
||||
# Integer constraints
|
||||
(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"age": {"type": "integer", "minimum": 18, "maximum": 99}},
|
||||
"required": ["age"],
|
||||
},
|
||||
"age",
|
||||
[18, 25, 99],
|
||||
[17, 100, -1],
|
||||
),
|
||||
# Float constraints
|
||||
(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"score": {"type": "number", "minimum": 0.0, "exclusiveMaximum": 1.0}},
|
||||
"required": ["score"],
|
||||
},
|
||||
"score",
|
||||
[0.0, 0.5, 0.999],
|
||||
[-0.1, 1.0, 2.5],
|
||||
),
|
||||
# Array constraints
|
||||
(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"tags": {"type": "array", "items": {"type": "string"}, "minItems": 1, "maxItems": 3}},
|
||||
"required": ["tags"],
|
||||
},
|
||||
"tags",
|
||||
[["a"], ["a", "b"], ["x", "y", "z"]],
|
||||
[[], ["one", "two", "three", "four"]],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_field_constraints(
|
||||
schema: Dict[str, Any],
|
||||
field_name: str,
|
||||
valid_values: List[Any],
|
||||
invalid_values: List[Any],
|
||||
) -> None:
|
||||
converter = _JSONSchemaToPydantic()
|
||||
Model = converter.json_schema_to_pydantic(schema, "ConstraintModel")
|
||||
|
||||
for value in valid_values:
|
||||
instance = Model(**{field_name: value})
|
||||
assert getattr(instance, field_name) == value
|
||||
|
||||
for value in invalid_values:
|
||||
with pytest.raises(ValidationError):
|
||||
Model(**{field_name: value})
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"schema",
|
||||
[
|
||||
# Top-level field
|
||||
{"type": "object", "properties": {"weird": {"type": "abc"}}, "required": ["weird"]},
|
||||
# Inside array items
|
||||
{"type": "object", "properties": {"items": {"type": "array", "items": {"type": "abc"}}}, "required": ["items"]},
|
||||
# Inside anyOf
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"choice": {"anyOf": [{"type": "string"}, {"type": "abc"}]}},
|
||||
"required": ["choice"],
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_unknown_type_raises(schema: Dict[str, Any]) -> None:
|
||||
converter = _JSONSchemaToPydantic()
|
||||
with pytest.raises(UnsupportedKeywordError):
|
||||
converter.json_schema_to_pydantic(schema, "UnknownTypeModel")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("json_type, expected_type", list(TYPE_MAPPING.items()))
|
||||
def test_basic_type_mapping(json_type: str, expected_type: type) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"field": {"type": json_type}},
|
||||
"required": ["field"],
|
||||
}
|
||||
converter = _JSONSchemaToPydantic()
|
||||
Model = converter.json_schema_to_pydantic(schema, f"{json_type.capitalize()}Model")
|
||||
|
||||
assert "field" in Model.__annotations__
|
||||
field_type = Model.__annotations__["field"]
|
||||
|
||||
# For array/object/null we check the outer type only
|
||||
if json_type == "null":
|
||||
assert field_type is type(None)
|
||||
elif json_type == "array":
|
||||
assert getattr(field_type, "__origin__", None) is list
|
||||
elif json_type == "object":
|
||||
assert field_type in (dict, Dict) or getattr(field_type, "__origin__", None) in (dict, Dict)
|
||||
|
||||
else:
|
||||
assert field_type == expected_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize("format_name, expected_type", list(FORMAT_MAPPING.items()))
|
||||
def test_format_mapping(format_name: str, expected_type: Any) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"field": {"type": "string", "format": format_name}},
|
||||
"required": ["field"],
|
||||
}
|
||||
converter = _JSONSchemaToPydantic()
|
||||
Model = converter.json_schema_to_pydantic(schema, f"{format_name.capitalize()}Model")
|
||||
|
||||
assert "field" in Model.__annotations__
|
||||
field_type = Model.__annotations__["field"]
|
||||
if isinstance(expected_type, types.FunctionType): # if it's a constrained constructor (e.g., conint)
|
||||
assert callable(field_type)
|
||||
else:
|
||||
assert field_type == expected_type
|
||||
|
||||
|
||||
def test_unknown_format_raises() -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"bad_field": {"type": "string", "format": "definitely-not-a-format"}},
|
||||
}
|
||||
converter = _JSONSchemaToPydantic()
|
||||
with pytest.raises(FormatNotSupportedError):
|
||||
converter.json_schema_to_pydantic(schema, "UnknownFormatModel")
|
Loading…
x
Reference in New Issue
Block a user