Aegis structure message (#6289)

Added support for structured message component using the Json to
Pydantic utility functions. Note: also adding the ability to use a
format string for structured messages.

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
abhinav-aegis 2025-04-17 05:00:14 +10:00 committed by GitHub
parent 8bd162f8fc
commit a4a16fd2f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 1701 additions and 8 deletions

View File

@ -46,6 +46,7 @@ from ..messages import (
MemoryQueryEvent,
ModelClientStreamingChunkEvent,
StructuredMessage,
StructuredMessageFactory,
TextMessage,
ThoughtEvent,
ToolCallExecutionEvent,
@ -74,6 +75,7 @@ class AssistantAgentConfig(BaseModel):
reflect_on_tool_use: bool
tool_call_summary_format: str
metadata: Dict[str, str] | None = None
structured_message_factory: ComponentModel | None = None
class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
@ -183,6 +185,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
This will be used with the model client to generate structured output.
If this is set, the agent will respond with a :class:`~autogen_agentchat.messages.StructuredMessage` instead of a :class:`~autogen_agentchat.messages.TextMessage`
in the final response, unless `reflect_on_tool_use` is `False` and a tool call is made.
output_content_type_format (str | None, optional): (Experimental) The format string used for the content of a :class:`~autogen_agentchat.messages.StructuredMessage` response.
tool_call_summary_format (str, optional): The format string used to create the content for a :class:`~autogen_agentchat.messages.ToolCallSummaryMessage` response.
The format string is used to format the tool call summary for every tool call result.
Defaults to "{result}".
@ -635,6 +638,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
reflect_on_tool_use: bool | None = None,
tool_call_summary_format: str = "{result}",
output_content_type: type[BaseModel] | None = None,
output_content_type_format: str | None = None,
memory: Sequence[Memory] | None = None,
metadata: Dict[str, str] | None = None,
):
@ -643,6 +647,13 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
self._model_client = model_client
self._model_client_stream = model_client_stream
self._output_content_type: type[BaseModel] | None = output_content_type
self._output_content_type_format = output_content_type_format
self._structured_message_factory: StructuredMessageFactory | None = None
if output_content_type is not None:
self._structured_message_factory = StructuredMessageFactory(
input_model=output_content_type, format_string=output_content_type_format
)
self._memory = None
if memory is not None:
if isinstance(memory, list):
@ -771,6 +782,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
reflect_on_tool_use = self._reflect_on_tool_use
tool_call_summary_format = self._tool_call_summary_format
output_content_type = self._output_content_type
format_string = self._output_content_type_format
# STEP 1: Add new user/handoff messages to the model context
await self._add_messages_to_context(
@ -840,6 +852,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
reflect_on_tool_use=reflect_on_tool_use,
tool_call_summary_format=tool_call_summary_format,
output_content_type=output_content_type,
format_string=format_string,
):
yield output_event
@ -942,6 +955,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
reflect_on_tool_use: bool,
tool_call_summary_format: str,
output_content_type: type[BaseModel] | None,
format_string: str | None = None,
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
"""
Handle final or partial responses from model_result, including tool calls, handoffs,
@ -957,6 +971,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
content=content,
source=agent_name,
models_usage=model_result.usage,
format_string=format_string,
),
inner_messages=inner_messages,
)
@ -1277,9 +1292,6 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
def _to_config(self) -> AssistantAgentConfig:
"""Convert the assistant agent to a declarative config."""
if self._output_content_type:
raise ValueError("AssistantAgent with output_content_type does not support declarative config.")
return AssistantAgentConfig(
name=self.name,
model_client=self._model_client.dump_component(),
@ -1294,12 +1306,23 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
model_client_stream=self._model_client_stream,
reflect_on_tool_use=self._reflect_on_tool_use,
tool_call_summary_format=self._tool_call_summary_format,
structured_message_factory=self._structured_message_factory.dump_component()
if self._structured_message_factory
else None,
metadata=self._metadata,
)
@classmethod
def _from_config(cls, config: AssistantAgentConfig) -> Self:
"""Create an assistant agent from a declarative config."""
if config.structured_message_factory:
structured_message_factory = StructuredMessageFactory.load_component(config.structured_message_factory)
format_string = structured_message_factory.format_string
output_content_type = structured_message_factory.ContentModel
else:
format_string = None
output_content_type = None
return cls(
name=config.name,
@ -1313,5 +1336,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
model_client_stream=config.model_client_stream,
reflect_on_tool_use=config.reflect_on_tool_use,
tool_call_summary_format=config.tool_call_summary_format,
output_content_type=output_content_type,
output_content_type_format=format_string,
metadata=config.metadata,
)

View File

@ -5,9 +5,9 @@ class and includes specific fields relevant to the type of message being sent.
"""
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, Literal, Mapping, TypeVar
from typing import Any, Dict, Generic, List, Literal, Mapping, Optional, Type, TypeVar
from autogen_core import FunctionCall, Image
from autogen_core import Component, ComponentBase, FunctionCall, Image
from autogen_core.code_executor import CodeBlock, CodeResult
from autogen_core.memory import MemoryContent
from autogen_core.models import (
@ -16,6 +16,7 @@ from autogen_core.models import (
RequestUsage,
UserMessage,
)
from autogen_core.utils import schema_to_pydantic_model
from pydantic import BaseModel, Field, computed_field
from typing_extensions import Annotated, Self
@ -182,21 +183,56 @@ class StructuredMessage(BaseChatMessage, Generic[StructuredContentType]):
print(message.to_text()) # {"text": "Hello", "number": 42}
.. code-block:: python
from pydantic import BaseModel
from autogen_agentchat.messages import StructuredMessage
class MyMessageContent(BaseModel):
text: str
number: int
message = StructuredMessage[MyMessageContent](
content=MyMessageContent(text="Hello", number=42),
source="agent",
format_string="Hello, {text} {number}!",
)
print(message.to_text()) # Hello, agent 42!
"""
content: StructuredContentType
"""The content of the message. Must be a subclass of
`Pydantic BaseModel <https://docs.pydantic.dev/latest/concepts/models/>`_."""
format_string: Optional[str] = None
"""(Experimental) An optional format string to render the content into a human-readable format.
The format string can use the fields of the content model as placeholders.
For example, if the content model has a field `name`, you can use
`{name}` in the format string to include the value of that field.
The format string is used in the :meth:`to_text` method to create a
human-readable representation of the message.
This setting is experimental and will change in the future.
"""
@computed_field
def type(self) -> str:
return self.__class__.__name__
def to_text(self) -> str:
return self.content.model_dump_json(indent=2)
if self.format_string is not None:
return self.format_string.format(**self.content.model_dump())
else:
return self.content.model_dump_json()
def to_model_text(self) -> str:
return self.content.model_dump_json()
if self.format_string is not None:
return self.format_string.format(**self.content.model_dump())
else:
return self.content.model_dump_json()
def to_model_message(self) -> UserMessage:
return UserMessage(
@ -205,6 +241,113 @@ class StructuredMessage(BaseChatMessage, Generic[StructuredContentType]):
)
class StructureMessageConfig(BaseModel):
"""The declarative configuration for the structured output."""
json_schema: Dict[str, Any]
format_string: Optional[str] = None
content_model_name: str
class StructuredMessageFactory(ComponentBase[StructureMessageConfig], Component[StructureMessageConfig]):
""":meta private:
A component that creates structured chat messages from Pydantic models or JSON schemas.
This component helps you generate strongly-typed chat messages with content defined using a Pydantic model.
It can be used in declarative workflows where message structure must be validated, formatted, and serialized.
You can initialize the component directly using a `BaseModel` subclass, or dynamically from a configuration
object (e.g., loaded from disk or a database).
### Example 1: Create from a Pydantic Model
.. code-block:: python
from pydantic import BaseModel
from autogen_agentchat.messages import StructuredMessageFactory
class TestContent(BaseModel):
field1: str
field2: int
format_string = "This is a string {field1} and this is an int {field2}"
sm_component = StructuredMessageFactory(input_model=TestContent, format_string=format_string)
message = sm_component.StructuredMessage(
source="test_agent", content=TestContent(field1="Hello", field2=42), format_string=format_string
)
print(message.to_model_text()) # Output: This is a string Hello and this is an int 42
config = sm_component.dump_component()
s_m_dyn = StructuredMessageFactory.load_component(config)
message = s_m_dyn.StructuredMessage(
source="test_agent",
content=s_m_dyn.ContentModel(field1="dyn agent", field2=43),
format_string=s_m_dyn.format_string,
)
print(type(message)) # StructuredMessage[GeneratedModel]
print(message.to_model_text()) # Output: This is a string dyn agent and this is an int 43
Attributes:
component_config_schema (StructureMessageConfig): Defines the configuration structure for this component.
component_provider_override (str): Path used to reference this component in external tooling.
component_type (str): Identifier used for categorization (e.g., "structured_message").
Raises:
ValueError: If neither `json_schema` nor `input_model` is provided.
Args:
json_schema (Optional[str]): JSON schema to dynamically create a Pydantic model.
input_model (Optional[Type[BaseModel]]): A subclass of `BaseModel` that defines the expected message structure.
format_string (Optional[str]): Optional string to render content into a human-readable format.
content_model_name (Optional[str]): Optional name for the generated Pydantic model.
"""
component_config_schema = StructureMessageConfig
component_provider_override = "autogen_agentchat.messages.StructuredMessageFactory"
component_type = "structured_message"
def __init__(
self,
json_schema: Optional[Dict[str, Any]] = None,
input_model: Optional[Type[BaseModel]] = None,
format_string: Optional[str] = None,
content_model_name: Optional[str] = None,
) -> None:
self.format_string = format_string
if json_schema:
self.ContentModel = schema_to_pydantic_model(
json_schema, model_name=content_model_name or "GeneratedContentModel"
)
elif input_model:
self.ContentModel = input_model
else:
raise ValueError("Either `json_schema` or `input_model` must be provided.")
self.StructuredMessage = StructuredMessage[self.ContentModel] # type: ignore[name-defined]
def _to_config(self) -> StructureMessageConfig:
return StructureMessageConfig(
json_schema=self.ContentModel.model_json_schema(),
format_string=self.format_string,
content_model_name=self.ContentModel.__name__,
)
@classmethod
def _from_config(cls, config: StructureMessageConfig) -> "StructuredMessageFactory":
return cls(
json_schema=config.json_schema,
format_string=config.format_string,
content_model_name=config.content_model_name,
)
class TextMessage(BaseTextChatMessage):
"""A text message with string-only content."""
@ -468,6 +611,7 @@ __all__ = [
"BaseTextChatMessage",
"StructuredContentType",
"StructuredMessage",
"StructuredMessageFactory",
"HandoffMessage",
"MultiModalMessage",
"StopMessage",

View File

@ -21,6 +21,7 @@ from ...messages import (
MessageFactory,
ModelClientStreamingChunkEvent,
StopMessage,
StructuredMessage,
TextMessage,
)
from ...state import TeamState
@ -68,6 +69,15 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
for message_type in custom_message_types:
self._message_factory.register(message_type)
for agent in participants:
for message_type in agent.produced_message_types:
try:
if issubclass(message_type, StructuredMessage):
self._message_factory.register(message_type) # type: ignore[reportUnknownArgumentType]
except TypeError:
# Not a class or not a valid subclassable type (skip)
pass
# The team ID is a UUID that is used to identify the team and its participants
# in the agent runtime. It is used to create unique topic types for each participant.
# Currently, team ID is binded to an object instance of the group chat class.

View File

@ -36,7 +36,7 @@ from autogen_core.models._model_client import ModelFamily
from autogen_core.tools import BaseTool, FunctionTool
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.models.replay import ReplayChatCompletionClient
from pydantic import BaseModel
from pydantic import BaseModel, ValidationError
from utils import FileLogHandler
logger = logging.getLogger(EVENT_LOGGER_NAME)
@ -1104,3 +1104,103 @@ async def test_model_client_stream_with_tool_calls() -> None:
elif isinstance(message, ModelClientStreamingChunkEvent):
chunks.append(message.content)
assert "".join(chunks) == "Example response 2 to task"
@pytest.mark.asyncio
async def test_invalid_structured_output_format() -> None:
class AgentResponse(BaseModel):
response: str
status: str
model_client = ReplayChatCompletionClient(
[
CreateResult(
finish_reason="stop",
content='{"response": "Hello"}',
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
]
)
agent = AssistantAgent(
name="assistant",
model_client=model_client,
output_content_type=AgentResponse,
)
with pytest.raises(ValidationError):
await agent.run()
@pytest.mark.asyncio
async def test_structured_message_factory_serialization() -> None:
class AgentResponse(BaseModel):
result: str
status: str
model_client = ReplayChatCompletionClient(
[
CreateResult(
finish_reason="stop",
content=AgentResponse(result="All good", status="ok").model_dump_json(),
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
)
]
)
agent = AssistantAgent(
name="structured_agent",
model_client=model_client,
output_content_type=AgentResponse,
output_content_type_format="{result} - {status}",
)
dumped = agent.dump_component()
restored_agent = AssistantAgent.load_component(dumped)
result = await restored_agent.run()
assert isinstance(result.messages[0], StructuredMessage)
assert result.messages[0].content.result == "All good" # type: ignore[reportUnknownMemberType]
assert result.messages[0].content.status == "ok" # type: ignore[reportUnknownMemberType]
@pytest.mark.asyncio
async def test_structured_message_format_string() -> None:
class AgentResponse(BaseModel):
field1: str
field2: str
expected = AgentResponse(field1="foo", field2="bar")
model_client = ReplayChatCompletionClient(
[
CreateResult(
finish_reason="stop",
content=expected.model_dump_json(),
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
)
]
)
agent = AssistantAgent(
name="formatted_agent",
model_client=model_client,
output_content_type=AgentResponse,
output_content_type_format="{field1} - {field2}",
)
result = await agent.run()
assert len(result.messages) == 1
message = result.messages[0]
# Check that it's a StructuredMessage with the correct content model
assert isinstance(message, StructuredMessage)
assert isinstance(message.content, AgentResponse) # type: ignore[reportUnknownMemberType]
assert message.content == expected
# Check that the format_string was applied correctly
assert message.to_model_text() == "foo - bar"

View File

@ -1441,3 +1441,87 @@ async def test_declarative_groupchats_with_config(runtime: AgentRuntime | None)
assert selector.dump_component().provider == "autogen_agentchat.teams.SelectorGroupChat"
assert swarm.dump_component().provider == "autogen_agentchat.teams.Swarm"
assert magentic.dump_component().provider == "autogen_agentchat.teams.MagenticOneGroupChat"
class _StructuredContent(BaseModel):
message: str
class _StructuredAgent(BaseChatAgent):
def __init__(self, name: str, description: str) -> None:
super().__init__(name, description)
self._message = _StructuredContent(message="Structured hello")
@property
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
return (StructuredMessage[_StructuredContent],)
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
return Response(
chat_message=StructuredMessage[_StructuredContent](
source=self.name,
content=self._message,
format_string="Structured says: {message}",
)
)
async def on_reset(self, cancellation_token: CancellationToken) -> None:
pass
@pytest.mark.asyncio
async def test_message_type_auto_registration(runtime: AgentRuntime | None) -> None:
agent1 = _StructuredAgent("structured", description="emits structured messages")
agent2 = _EchoAgent("echo", description="echoes input")
team = RoundRobinGroupChat(participants=[agent1, agent2], max_turns=2, runtime=runtime)
result = await team.run(task="Say something structured")
assert len(result.messages) == 3
assert isinstance(result.messages[0], TextMessage)
assert isinstance(result.messages[1], StructuredMessage)
assert isinstance(result.messages[2], TextMessage)
assert result.messages[1].to_text() == "Structured says: Structured hello"
@pytest.mark.asyncio
async def test_structured_message_state_roundtrip(runtime: AgentRuntime | None) -> None:
agent1 = _StructuredAgent("structured", description="sends structured")
agent2 = _EchoAgent("echo", description="echoes")
team1 = RoundRobinGroupChat(
participants=[agent1, agent2],
termination_condition=MaxMessageTermination(2),
runtime=runtime,
)
await team1.run(task="Say something structured")
state1 = await team1.save_state()
# Recreate team without needing custom_message_types
agent3 = _StructuredAgent("structured", description="sends structured")
agent4 = _EchoAgent("echo", description="echoes")
team2 = RoundRobinGroupChat(
participants=[agent3, agent4],
termination_condition=MaxMessageTermination(2),
runtime=runtime,
)
await team2.load_state(state1)
state2 = await team2.save_state()
# Assert full state equality
assert state1 == state2
# Assert message thread content match
manager1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore
AgentId(f"{team1._group_chat_manager_name}_{team1._team_id}", team1._team_id), # pyright: ignore
RoundRobinGroupChatManager,
)
manager2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
AgentId(f"{team2._group_chat_manager_name}_{team2._team_id}", team2._team_id), # pyright: ignore
RoundRobinGroupChatManager,
)
assert manager1._message_thread == manager2._message_thread # pyright: ignore

View File

@ -11,6 +11,7 @@ from autogen_agentchat.messages import (
MultiModalMessage,
StopMessage,
StructuredMessage,
StructuredMessageFactory,
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
@ -52,6 +53,28 @@ def test_structured_message() -> None:
assert dumped_message["type"] == "StructuredMessage[TestContent]"
def test_structured_message_component() -> None:
# Create a structured message with the test contentformat_string="this is a string {field1} and this is an int {field2}"
format_string = "this is a string {field1} and this is an int {field2}"
s_m = StructuredMessageFactory(input_model=TestContent, format_string=format_string)
config = s_m.dump_component()
s_m_dyn = StructuredMessageFactory.load_component(config)
message = s_m_dyn.StructuredMessage(
source="test_agent", content=s_m_dyn.ContentModel(field1="test", field2=42), format_string=s_m_dyn.format_string
)
assert isinstance(message.content, s_m_dyn.ContentModel)
assert not isinstance(message.content, TestContent)
assert message.content.field1 == "test" # type: ignore[attr-defined]
assert message.content.field2 == 42 # type: ignore[attr-defined]
dumped_message = message.model_dump()
assert dumped_message["source"] == "test_agent"
assert dumped_message["content"]["field1"] == "test"
assert dumped_message["content"]["field2"] == 42
assert message.to_model_text() == format_string.format(field1="test", field2=42)
def test_message_factory() -> None:
factory = MessageFactory()
@ -109,6 +132,22 @@ def test_message_factory() -> None:
assert structured_message.content.field2 == 42
assert structured_message.type == "StructuredMessage[TestContent]" # type: ignore[comparison-overlap]
sm_factory = StructuredMessageFactory(input_model=TestContent, format_string=None, content_model_name="TestContent")
config = sm_factory.dump_component()
config.config["content_model_name"] = "DynamicTestContent"
sm_factory_dynamic = StructuredMessageFactory.load_component(config)
factory.register(sm_factory_dynamic.StructuredMessage)
msg = sm_factory_dynamic.StructuredMessage(
content=sm_factory_dynamic.ContentModel(field1="static", field2=123), source="static_agent"
)
restored = factory.create(msg.dump())
assert isinstance(restored, StructuredMessage)
assert isinstance(restored.content, sm_factory_dynamic.ContentModel) # type: ignore[reportUnkownMemberType]
assert restored.source == "static_agent"
assert restored.content.field1 == "static" # type: ignore[attr-defined]
assert restored.content.field2 == 123 # type: ignore[attr-defined]
class TestContainer(BaseModel):
chat_messages: List[ChatMessage]

View File

@ -0,0 +1,3 @@
from ._json_to_pydantic import schema_to_pydantic_model
__all__ = ["schema_to_pydantic_model"]

View File

@ -0,0 +1,531 @@
import datetime
from ipaddress import IPv4Address, IPv6Address
from typing import Annotated, Any, Dict, ForwardRef, List, Literal, Optional, Type, Union, cast
from pydantic import (
UUID1,
UUID3,
UUID4,
UUID5,
AnyUrl,
BaseModel,
EmailStr,
Field,
conbytes,
confloat,
conint,
conlist,
constr,
create_model,
)
from pydantic.fields import FieldInfo
class SchemaConversionError(Exception):
"""Base class for schema conversion exceptions."""
pass
class ReferenceNotFoundError(SchemaConversionError):
"""Raised when a $ref cannot be resolved."""
pass
class FormatNotSupportedError(SchemaConversionError):
"""Raised when a format is not supported."""
pass
class UnsupportedKeywordError(SchemaConversionError):
"""Raised when an unsupported JSON Schema keyword is encountered."""
pass
TYPE_MAPPING: Dict[str, Type[Any]] = {
"string": str,
"integer": int,
"boolean": bool,
"number": float,
"array": List,
"object": dict,
"null": type(None),
}
FORMAT_MAPPING: Dict[str, Any] = {
"uuid": UUID4,
"uuid1": UUID1,
"uuid2": UUID4,
"uuid3": UUID3,
"uuid4": UUID4,
"uuid5": UUID5,
"email": EmailStr,
"uri": AnyUrl,
"hostname": constr(strict=True),
"ipv4": IPv4Address,
"ipv6": IPv6Address,
"ipv4-network": IPv4Address,
"ipv6-network": IPv6Address,
"date-time": datetime.datetime,
"date": datetime.date,
"time": datetime.time,
"duration": datetime.timedelta,
"int32": conint(strict=True, ge=-(2**31), le=2**31 - 1),
"int64": conint(strict=True, ge=-(2**63), le=2**63 - 1),
"float": confloat(strict=True),
"double": float,
"decimal": float,
"byte": conbytes(strict=True),
"binary": conbytes(strict=True),
"password": str,
"path": str,
}
def _make_field(
default: Any,
*,
title: Optional[str] = None,
description: Optional[str] = None,
) -> Any:
"""Construct a Pydantic Field with proper typing."""
field_kwargs: Dict[str, Any] = {}
if title is not None:
field_kwargs["title"] = title
if description is not None:
field_kwargs["description"] = description
return Field(default, **field_kwargs)
class _JSONSchemaToPydantic:
def __init__(self) -> None:
self._model_cache: Dict[str, Optional[Union[Type[BaseModel], ForwardRef]]] = {}
def _resolve_ref(self, ref: str, schema: Dict[str, Any]) -> Dict[str, Any]:
ref_key = ref.split("/")[-1]
definitions = cast(dict[str, dict[str, Any]], schema.get("$defs", {}))
if ref_key not in definitions:
raise ReferenceNotFoundError(
f"Reference `{ref}` not found in `$defs`. Available keys: {list(definitions.keys())}"
)
return definitions[ref_key]
def get_ref(self, ref_name: str) -> Any:
if ref_name not in self._model_cache:
raise ReferenceNotFoundError(
f"Reference `{ref_name}` not found in cache. Available: {list(self._model_cache.keys())}"
)
if self._model_cache[ref_name] is None:
return ForwardRef(ref_name)
return self._model_cache[ref_name]
def _process_definitions(self, root_schema: Dict[str, Any]) -> None:
if "$defs" in root_schema:
for model_name in root_schema["$defs"]:
if model_name not in self._model_cache:
self._model_cache[model_name] = None
for model_name, model_schema in root_schema["$defs"].items():
if self._model_cache[model_name] is None:
self._model_cache[model_name] = self.json_schema_to_pydantic(model_schema, model_name, root_schema)
def json_schema_to_pydantic(
self, schema: Dict[str, Any], model_name: str = "GeneratedModel", root_schema: Optional[Dict[str, Any]] = None
) -> Type[BaseModel]:
if root_schema is None:
root_schema = schema
self._process_definitions(root_schema)
if "$ref" in schema:
resolved = self._resolve_ref(schema["$ref"], root_schema)
schema = {**resolved, **{k: v for k, v in schema.items() if k != "$ref"}}
if "allOf" in schema:
merged: Dict[str, Any] = {"type": "object", "properties": {}, "required": []}
for s in schema["allOf"]:
part = self._resolve_ref(s["$ref"], root_schema) if "$ref" in s else s
merged["properties"].update(part.get("properties", {}))
merged["required"].extend(part.get("required", []))
for k, v in schema.items():
if k not in {"allOf", "properties", "required"}:
merged[k] = v
merged["required"] = list(set(merged["required"]))
schema = merged
return self._json_schema_to_model(schema, model_name, root_schema)
def _resolve_union_types(self, schemas: List[Dict[str, Any]]) -> List[Any]:
types: List[Any] = []
for s in schemas:
if "$ref" in s:
types.append(self.get_ref(s["$ref"].split("/")[-1]))
elif "enum" in s:
types.append(Literal[tuple(s["enum"])] if len(s["enum"]) > 0 else Any)
else:
json_type = s.get("type")
if json_type not in TYPE_MAPPING:
raise UnsupportedKeywordError(f"Unsupported or missing type `{json_type}` in union")
types.append(TYPE_MAPPING[json_type])
return types
def _extract_field_type(self, key: str, value: Dict[str, Any], model_name: str, root_schema: Dict[str, Any]) -> Any:
json_type = value.get("type")
if json_type not in TYPE_MAPPING:
raise UnsupportedKeywordError(
f"Unsupported or missing type `{json_type}` for field `{key}` in `{model_name}`"
)
base_type = TYPE_MAPPING[json_type]
constraints: Dict[str, Any] = {}
if json_type == "string":
if "minLength" in value:
constraints["min_length"] = value["minLength"]
if "maxLength" in value:
constraints["max_length"] = value["maxLength"]
if "pattern" in value:
constraints["pattern"] = value["pattern"]
if constraints:
base_type = constr(**constraints)
elif json_type == "integer":
if "minimum" in value:
constraints["ge"] = value["minimum"]
if "maximum" in value:
constraints["le"] = value["maximum"]
if "exclusiveMinimum" in value:
constraints["gt"] = value["exclusiveMinimum"]
if "exclusiveMaximum" in value:
constraints["lt"] = value["exclusiveMaximum"]
if constraints:
base_type = conint(**constraints)
elif json_type == "number":
if "minimum" in value:
constraints["ge"] = value["minimum"]
if "maximum" in value:
constraints["le"] = value["maximum"]
if "exclusiveMinimum" in value:
constraints["gt"] = value["exclusiveMinimum"]
if "exclusiveMaximum" in value:
constraints["lt"] = value["exclusiveMaximum"]
if constraints:
base_type = confloat(**constraints)
elif json_type == "array":
if "minItems" in value:
constraints["min_length"] = value["minItems"]
if "maxItems" in value:
constraints["max_length"] = value["maxItems"]
item_schema = value.get("items", {"type": "string"})
if "$ref" in item_schema:
item_type = self.get_ref(item_schema["$ref"].split("/")[-1])
else:
item_type_name = item_schema.get("type")
if item_type_name not in TYPE_MAPPING:
raise UnsupportedKeywordError(
f"Unsupported or missing item type `{item_type_name}` for array field `{key}` in `{model_name}`"
)
item_type = TYPE_MAPPING[item_type_name]
base_type = conlist(item_type, **constraints) if constraints else List[item_type] # type: ignore[valid-type]
if "format" in value:
format_type = FORMAT_MAPPING.get(value["format"])
if format_type is None:
raise FormatNotSupportedError(f"Unknown format `{value['format']}` for `{key}` in `{model_name}`")
if not isinstance(format_type, type):
return format_type
if not issubclass(format_type, str):
return format_type
return format_type
return base_type
def _json_schema_to_model(
self, schema: Dict[str, Any], model_name: str, root_schema: Dict[str, Any]
) -> Type[BaseModel]:
if "allOf" in schema:
merged: Dict[str, Any] = {"type": "object", "properties": {}, "required": []}
for s in schema["allOf"]:
part = self._resolve_ref(s["$ref"], root_schema) if "$ref" in s else s
merged["properties"].update(part.get("properties", {}))
merged["required"].extend(part.get("required", []))
for k, v in schema.items():
if k not in {"allOf", "properties", "required"}:
merged[k] = v
merged["required"] = list(set(merged["required"]))
schema = merged
fields: Dict[str, tuple[Any, FieldInfo]] = {}
required_fields = set(schema.get("required", []))
for key, value in schema.get("properties", {}).items():
if "$ref" in value:
ref_name = value["$ref"].split("/")[-1]
field_type = self.get_ref(ref_name)
elif "anyOf" in value:
sub_models = self._resolve_union_types(value["anyOf"])
field_type = Union[tuple(sub_models)]
elif "oneOf" in value:
sub_models = self._resolve_union_types(value["oneOf"])
field_type = Union[tuple(sub_models)]
if "discriminator" in value:
discriminator = value["discriminator"]["propertyName"]
field_type = Annotated[field_type, Field(discriminator=discriminator)]
elif "enum" in value:
field_type = Literal[tuple(value["enum"])]
elif "allOf" in value:
merged = {"type": "object", "properties": {}, "required": []}
for s in value["allOf"]:
part = self._resolve_ref(s["$ref"], root_schema) if "$ref" in s else s
merged["properties"].update(part.get("properties", {}))
merged["required"].extend(part.get("required", []))
for k, v in value.items():
if k not in {"allOf", "properties", "required"}:
merged[k] = v
merged["required"] = list(set(merged["required"]))
field_type = self._json_schema_to_model(merged, f"{model_name}_{key}", root_schema)
elif value.get("type") == "object" and "properties" in value:
field_type = self._json_schema_to_model(value, f"{model_name}_{key}", root_schema)
else:
field_type = self._extract_field_type(key, value, model_name, root_schema)
if field_type is None:
raise UnsupportedKeywordError(f"Unsupported or missing type for field `{key}` in `{model_name}`")
default_value = value.get("default")
is_required = key in required_fields
if not is_required and default_value is None:
field_type = Optional[field_type]
field_args = {
"default": default_value if not is_required else ...,
}
if "title" in value:
field_args["title"] = value["title"]
if "description" in value:
field_args["description"] = value["description"]
fields[key] = (
field_type,
_make_field(
default_value if not is_required else ...,
title=value.get("title"),
description=value.get("description"),
),
)
model: Type[BaseModel] = create_model(model_name, **cast(dict[str, Any], fields))
model.model_rebuild()
return model
def schema_to_pydantic_model(schema: Dict[str, Any], model_name: str = "GeneratedModel") -> Type[BaseModel]:
"""
Convert a JSON Schema dictionary to a fully-typed Pydantic model.
This function handles schema translation and validation logic to produce
a Pydantic model.
**Supported JSON Schema Features**
- **Primitive types**: `string`, `integer`, `number`, `boolean`, `object`, `array`, `null`
- **String formats**:
- `email`, `uri`, `uuid`, `uuid1`, `uuid3`, `uuid4`, `uuid5`
- `hostname`, `ipv4`, `ipv6`, `ipv4-network`, `ipv6-network`
- `date`, `time`, `date-time`, `duration`
- `byte`, `binary`, `password`, `path`
- **String constraints**:
- `minLength`, `maxLength`, `pattern`
- **Numeric constraints**:
- `minimum`, `maximum`, `exclusiveMinimum`, `exclusiveMaximum`
- **Array constraints**:
- `minItems`, `maxItems`, `items`
- **Object schema support**:
- `properties`, `required`, `title`, `description`, `default`
- **Enums**:
- Converted to Python `Literal` type
- **Union types**:
- `anyOf`, `oneOf` supported with optional `discriminator`
- **Inheritance and composition**:
- `allOf` merges multiple schemas into one model
- **$ref and $defs resolution**:
- Supports references to sibling definitions and self-referencing schemas
.. code-block:: python
from json_schema_to_pydantic import schema_to_pydantic_model
# Example 1: Simple user model
schema = {
"title": "User",
"type": "object",
"properties": {
"name": {"type": "string"},
"email": {"type": "string", "format": "email"},
"age": {"type": "integer", "minimum": 0},
},
"required": ["name", "email"],
}
UserModel = schema_to_pydantic_model(schema)
user = UserModel(name="Alice", email="alice@example.com", age=30)
.. code-block:: python
# Example 2: Nested model
schema = {
"title": "BlogPost",
"type": "object",
"properties": {
"title": {"type": "string"},
"tags": {"type": "array", "items": {"type": "string"}},
"author": {
"type": "object",
"properties": {"name": {"type": "string"}, "email": {"type": "string", "format": "email"}},
"required": ["name"],
},
},
"required": ["title", "author"],
}
BlogPost = schema_to_pydantic_model(schema)
.. code-block:: python
# Example 3: allOf merging with $refs
schema = {
"title": "EmployeeWithDepartment",
"allOf": [{"$ref": "#/$defs/Employee"}, {"$ref": "#/$defs/Department"}],
"$defs": {
"Employee": {
"type": "object",
"properties": {"id": {"type": "string"}, "name": {"type": "string"}},
"required": ["id", "name"],
},
"Department": {
"type": "object",
"properties": {"department": {"type": "string"}},
"required": ["department"],
},
},
}
Model = schema_to_pydantic_model(schema)
.. code-block:: python
# Example 4: Self-referencing (recursive) model
schema = {
"title": "Category",
"type": "object",
"properties": {
"name": {"type": "string"},
"subcategories": {"type": "array", "items": {"$ref": "#/$defs/Category"}},
},
"required": ["name"],
"$defs": {
"Category": {
"type": "object",
"properties": {
"name": {"type": "string"},
"subcategories": {"type": "array", "items": {"$ref": "#/$defs/Category"}},
},
"required": ["name"],
}
},
}
Category = schema_to_pydantic_model(schema)
.. code-block:: python
# Example 5: Serializing and deserializing with Pydantic
from uuid import uuid4
from pydantic import BaseModel, EmailStr, Field
from typing import Optional, List, Dict, Any
from autogen_core.utils import schema_to_pydantic_model
class Address(BaseModel):
street: str
city: str
zipcode: str
class User(BaseModel):
id: str
name: str
email: EmailStr
age: int = Field(..., ge=18)
address: Address
class Employee(BaseModel):
id: str
name: str
manager: Optional["Employee"] = None
class Department(BaseModel):
name: str
employees: List[Employee]
class ComplexModel(BaseModel):
user: User
extra_info: Optional[Dict[str, Any]] = None
sub_items: List[Employee]
# Convert ComplexModel to JSON schema
complex_schema = ComplexModel.model_json_schema()
# Rebuild a new Pydantic model from JSON schema
ReconstructedModel = schema_to_pydantic_model(complex_schema, "ComplexModel")
# Instantiate reconstructed model
reconstructed = ReconstructedModel(
user={
"id": str(uuid4()),
"name": "Alice",
"email": "alice@example.com",
"age": 30,
"address": {"street": "123 Main St", "city": "Wonderland", "zipcode": "12345"},
},
sub_items=[{"id": str(uuid4()), "name": "Bob", "manager": {"id": str(uuid4()), "name": "Eve"}}],
)
print(reconstructed.model_dump())
Args:
schema (Dict[str, Any]): A valid JSON Schema dictionary.
model_name (str, optional): The name of the root model. Defaults to "GeneratedModel".
Returns:
Type[BaseModel]: A dynamically generated Pydantic model class.
Raises:
ReferenceNotFoundError: If a `$ref` key references a missing entry.
FormatNotSupportedError: If a `format` keyword is unknown or unsupported.
UnsupportedKeywordError: If the schema contains an unsupported `type`.
See Also:
- :class:`pydantic.BaseModel`
- :func:`pydantic.create_model`
- https://json-schema.org/
"""
...
return _JSONSchemaToPydantic().json_schema_to_pydantic(schema, model_name)

View File

@ -0,0 +1,757 @@
import types
from typing import Any, Dict, List, Literal, Optional, Type, get_args, get_origin
from uuid import UUID, uuid4
import pytest
from autogen_core.utils._json_to_pydantic import (
FORMAT_MAPPING,
TYPE_MAPPING,
FormatNotSupportedError,
ReferenceNotFoundError,
UnsupportedKeywordError,
_JSONSchemaToPydantic, # pyright: ignore[reportPrivateUsage]
)
from pydantic import BaseModel, EmailStr, Field, ValidationError
# ✅ Define Pydantic models for testing
class Address(BaseModel):
street: str
city: str
zipcode: str
class User(BaseModel):
id: UUID
name: str
email: EmailStr
age: int = Field(..., ge=18) # Minimum age = 18
address: Address
class Employee(BaseModel):
id: UUID
name: str
manager: Optional["Employee"] = None # Recursive self-reference
class Department(BaseModel):
name: str
employees: List[Employee] # Array of objects
class ComplexModel(BaseModel):
user: User
extra_info: Optional[Dict[str, Any]] = None # Optional dictionary
sub_items: List[Employee] # List of Employees
@pytest.fixture
def converter() -> _JSONSchemaToPydantic:
"""Fixture to create a fresh instance of JSONSchemaToPydantic for every test."""
return _JSONSchemaToPydantic()
@pytest.fixture
def sample_json_schema() -> Dict[str, Any]:
"""Fixture that returns a JSON schema dynamically using model_json_schema()."""
return User.model_json_schema()
@pytest.fixture
def sample_json_schema_recursive() -> Dict[str, Any]:
"""Fixture that returns a self-referencing JSON schema."""
return Employee.model_json_schema()
@pytest.fixture
def sample_json_schema_nested() -> Dict[str, Any]:
"""Fixture that returns a nested schema with arrays of objects."""
return Department.model_json_schema()
@pytest.fixture
def sample_json_schema_complex() -> Dict[str, Any]:
"""Fixture that returns a complex schema with multiple structures."""
return ComplexModel.model_json_schema()
@pytest.mark.parametrize(
"schema_fixture, model_name, expected_fields",
[
(sample_json_schema, "User", ["id", "name", "email", "age", "address"]),
(sample_json_schema_recursive, "Employee", ["id", "name", "manager"]),
(sample_json_schema_nested, "Department", ["name", "employees"]),
(sample_json_schema_complex, "ComplexModel", ["user", "extra_info", "sub_items"]),
],
)
def test_json_schema_to_pydantic(
converter: _JSONSchemaToPydantic,
schema_fixture: Any,
model_name: str,
expected_fields: List[str],
request: Any,
) -> None:
"""Test conversion of JSON Schema to Pydantic model using the class instance."""
schema = request.getfixturevalue(schema_fixture.__name__)
Model = converter.json_schema_to_pydantic(schema, model_name)
for field in expected_fields:
assert field in Model.__annotations__, f"Expected '{field}' missing in {model_name}Model"
# ✅ **Valid Data Tests**
@pytest.mark.parametrize(
"schema_fixture, model_name, valid_data",
[
(
sample_json_schema,
"User",
{
"id": str(uuid4()),
"name": "Alice",
"email": "alice@example.com",
"age": 25,
"address": {"street": "123 Main St", "city": "Metropolis", "zipcode": "12345"},
},
),
(
sample_json_schema_recursive,
"Employee",
{
"id": str(uuid4()),
"name": "Alice",
"manager": {
"id": str(uuid4()),
"name": "Bob",
},
},
),
(
sample_json_schema_nested,
"Department",
{
"name": "Engineering",
"employees": [
{
"id": str(uuid4()),
"name": "Alice",
"manager": {
"id": str(uuid4()),
"name": "Bob",
},
}
],
},
),
(
sample_json_schema_complex,
"ComplexModel",
{
"user": {
"id": str(uuid4()),
"name": "Charlie",
"email": "charlie@example.com",
"age": 30,
"address": {"street": "456 Side St", "city": "Gotham", "zipcode": "67890"},
},
"extra_info": {"hobby": "Chess", "level": "Advanced"},
"sub_items": [
{"id": str(uuid4()), "name": "Eve"},
{"id": str(uuid4()), "name": "David", "manager": {"id": str(uuid4()), "name": "Frank"}},
],
},
),
],
)
def test_valid_data_model(
converter: _JSONSchemaToPydantic,
schema_fixture: Any,
model_name: str,
valid_data: Dict[str, Any],
request: Any,
) -> None:
"""Test that valid data is accepted by the generated model."""
schema = request.getfixturevalue(schema_fixture.__name__)
Model = converter.json_schema_to_pydantic(schema, model_name)
instance = Model(**valid_data)
assert instance
dumped = instance.model_dump(mode="json", exclude_none=True)
assert dumped == valid_data, f"Model output mismatch.\nExpected: {valid_data}\nGot: {dumped}"
# ✅ **Invalid Data Tests**
@pytest.mark.parametrize(
"schema_fixture, model_name, invalid_data",
[
(
sample_json_schema,
"User",
{
"id": "not-a-uuid", # Invalid UUID
"name": "Alice",
"email": "not-an-email", # Invalid email
"age": 17, # Below minimum
"address": {"street": "123 Main St", "city": "Metropolis"},
},
),
(
sample_json_schema_recursive,
"Employee",
{
"id": str(uuid4()),
"name": "Alice",
"manager": {
"id": "not-a-uuid", # Invalid UUID
"name": "Bob",
},
},
),
(
sample_json_schema_nested,
"Department",
{
"name": "Engineering",
"employees": [
{
"id": "not-a-uuid", # Invalid UUID
"name": "Alice",
"manager": {
"id": str(uuid4()),
"name": "Bob",
},
}
],
},
),
(
sample_json_schema_complex,
"ComplexModel",
{
"user": {
"id": str(uuid4()),
"name": "Charlie",
"email": "charlie@example.com",
"age": "thirty", # Invalid: Should be an int
"address": {"street": "456 Side St", "city": "Gotham", "zipcode": "67890"},
},
"extra_info": "should-be-dictionary", # Invalid type
"sub_items": [
{"id": "invalid-uuid", "name": "Eve"}, # Invalid UUID
{"id": str(uuid4()), "name": 123}, # Invalid name type
],
},
),
],
)
def test_invalid_data_model(
converter: _JSONSchemaToPydantic,
schema_fixture: Any,
model_name: str,
invalid_data: Dict[str, Any],
request: Any,
) -> None:
"""Test that invalid data raises ValidationError."""
schema = request.getfixturevalue(schema_fixture.__name__)
Model = converter.json_schema_to_pydantic(schema, model_name)
with pytest.raises(ValidationError):
Model(**invalid_data)
class ListDictModel(BaseModel):
"""Example for `List[Dict[str, Any]]`"""
data: List[Dict[str, Any]]
class DictListModel(BaseModel):
"""Example for `Dict[str, List[Any]]`"""
mapping: Dict[str, List[Any]]
class NestedListModel(BaseModel):
"""Example for `List[List[str]]`"""
matrix: List[List[str]]
@pytest.fixture
def sample_json_schema_list_dict() -> Dict[str, Any]:
"""Fixture for `List[Dict[str, Any]]`"""
return ListDictModel.model_json_schema()
@pytest.fixture
def sample_json_schema_dict_list() -> Dict[str, Any]:
"""Fixture for `Dict[str, List[Any]]`"""
return DictListModel.model_json_schema()
@pytest.fixture
def sample_json_schema_nested_list() -> Dict[str, Any]:
"""Fixture for `List[List[str]]`"""
return NestedListModel.model_json_schema()
@pytest.mark.parametrize(
"schema_fixture, model_name, expected_fields",
[
(sample_json_schema_list_dict, "ListDictModel", ["data"]),
(sample_json_schema_dict_list, "DictListModel", ["mapping"]),
(sample_json_schema_nested_list, "NestedListModel", ["matrix"]),
],
)
def test_json_schema_to_pydantic_nested(
converter: _JSONSchemaToPydantic,
schema_fixture: Any,
model_name: str,
expected_fields: list[str],
request: Any,
) -> None:
"""Test conversion of JSON Schema to Pydantic model using the class instance."""
schema = request.getfixturevalue(schema_fixture.__name__)
Model = converter.json_schema_to_pydantic(schema, model_name)
for field in expected_fields:
assert field in Model.__annotations__, f"Expected '{field}' missing in {model_name}Model"
# ✅ **Valid Data Tests**
@pytest.mark.parametrize(
"schema_fixture, model_name, valid_data",
[
(
sample_json_schema_list_dict,
"ListDictModel",
{
"data": [
{"key1": "value1", "key2": 10},
{"another_key": False, "nested": {"subkey": "data"}},
]
},
),
(
sample_json_schema_dict_list,
"DictListModel",
{
"mapping": {
"first": ["a", "b", "c"],
"second": [1, 2, 3, 4],
"third": [True, False, True],
}
},
),
(
sample_json_schema_nested_list,
"NestedListModel",
{"matrix": [["A", "B"], ["C", "D"], ["E", "F"]]},
),
],
)
def test_valid_data_model_nested(
converter: _JSONSchemaToPydantic,
schema_fixture: Any,
model_name: str,
valid_data: Dict[str, Any],
request: Any,
) -> None:
"""Test that valid data is accepted by the generated model."""
schema = request.getfixturevalue(schema_fixture.__name__)
Model = converter.json_schema_to_pydantic(schema, model_name)
instance = Model(**valid_data)
assert instance
for field, value in valid_data.items():
assert (
getattr(instance, field) == value
), f"Mismatch in field `{field}`: expected `{value}`, got `{getattr(instance, field)}`"
# ✅ **Invalid Data Tests**
@pytest.mark.parametrize(
"schema_fixture, model_name, invalid_data",
[
(
sample_json_schema_list_dict,
"ListDictModel",
{
"data": "should-be-a-list", # ❌ Should be a list of dicts
},
),
(
sample_json_schema_dict_list,
"DictListModel",
{
"mapping": [
"should-be-a-dictionary", # ❌ Should be a dict of lists
]
},
),
(
sample_json_schema_nested_list,
"NestedListModel",
{"matrix": [["A", "B"], "C", ["D", "E"]]}, # ❌ "C" is not a list
),
],
)
def test_invalid_data_model_nested(
converter: _JSONSchemaToPydantic,
schema_fixture: Any,
model_name: str,
invalid_data: Dict[str, Any],
request: Any,
) -> None:
"""Test that invalid data raises ValidationError."""
schema = request.getfixturevalue(schema_fixture.__name__)
Model = converter.json_schema_to_pydantic(schema, model_name)
with pytest.raises(ValidationError):
Model(**invalid_data)
def test_reference_not_found(converter: _JSONSchemaToPydantic) -> None:
schema = {"type": "object", "properties": {"manager": {"$ref": "#/$defs/MissingRef"}}}
with pytest.raises(ReferenceNotFoundError):
converter.json_schema_to_pydantic(schema, "MissingRefModel")
def test_format_not_supported(converter: _JSONSchemaToPydantic) -> None:
schema = {"type": "object", "properties": {"custom_field": {"type": "string", "format": "unsupported-format"}}}
with pytest.raises(FormatNotSupportedError):
converter.json_schema_to_pydantic(schema, "UnsupportedFormatModel")
def test_unsupported_keyword(converter: _JSONSchemaToPydantic) -> None:
schema = {"type": "object", "properties": {"broken_field": {"title": "Missing type"}}}
with pytest.raises(UnsupportedKeywordError):
converter.json_schema_to_pydantic(schema, "MissingTypeModel")
def test_enum_field_schema() -> None:
schema = {
"type": "object",
"properties": {
"status": {"type": "string", "enum": ["pending", "approved", "rejected"]},
"priority": {"type": "integer", "enum": [1, 2, 3]},
},
"required": ["status"],
}
converter: _JSONSchemaToPydantic = _JSONSchemaToPydantic()
Model = converter.json_schema_to_pydantic(schema, "Task")
status_ann = Model.model_fields["status"].annotation
assert get_origin(status_ann) is Literal
assert set(get_args(status_ann)) == {"pending", "approved", "rejected"}
priority_ann = Model.model_fields["priority"].annotation
args = get_args(priority_ann)
assert type(None) in args
assert Literal[1, 2, 3] in args
instance = Model(status="approved", priority=2)
assert instance.status == "approved" # type: ignore[attr-defined]
assert instance.priority == 2 # type: ignore[attr-defined]
def test_metadata_title_description(converter: _JSONSchemaToPydantic) -> None:
schema = {
"title": "CustomerProfile",
"description": "A profile containing personal and contact info",
"type": "object",
"properties": {
"first_name": {"type": "string", "title": "First Name", "description": "Given name of the user"},
"age": {"type": "integer", "title": "Age", "description": "Age in years"},
"contact": {
"type": "object",
"title": "Contact Information",
"description": "How to reach the user",
"properties": {
"email": {
"type": "string",
"format": "email",
"title": "Email Address",
"description": "Primary email",
}
},
},
},
"required": ["first_name"],
}
Model: Type[BaseModel] = converter.json_schema_to_pydantic(schema, "CustomerProfile")
generated_schema = Model.model_json_schema()
assert generated_schema["title"] == "CustomerProfile"
props = generated_schema["properties"]
assert props["first_name"]["title"] == "First Name"
assert props["first_name"]["description"] == "Given name of the user"
assert props["age"]["title"] == "Age"
assert props["age"]["description"] == "Age in years"
contact = props["contact"]
assert contact["title"] == "Contact Information"
assert contact["description"] == "How to reach the user"
# Follow the $ref
ref_key = contact["anyOf"][0]["$ref"].split("/")[-1]
contact_def = generated_schema["$defs"][ref_key]
email = contact_def["properties"]["email"]
assert email["title"] == "Email Address"
assert email["description"] == "Primary email"
def test_oneof_with_discriminator(converter: _JSONSchemaToPydantic) -> None:
schema = {
"title": "PetWrapper",
"type": "object",
"properties": {
"pet": {
"oneOf": [{"$ref": "#/$defs/Cat"}, {"$ref": "#/$defs/Dog"}],
"discriminator": {"propertyName": "pet_type"},
}
},
"required": ["pet"],
"$defs": {
"Cat": {
"type": "object",
"properties": {"pet_type": {"type": "string", "enum": ["cat"]}, "hunting_skill": {"type": "string"}},
"required": ["pet_type", "hunting_skill"],
"title": "Cat",
},
"Dog": {
"type": "object",
"properties": {"pet_type": {"type": "string", "enum": ["dog"]}, "pack_size": {"type": "integer"}},
"required": ["pet_type", "pack_size"],
"title": "Dog",
},
},
}
Model = converter.json_schema_to_pydantic(schema, "PetWrapper")
# Instantiate with a Cat
cat = Model(pet={"pet_type": "cat", "hunting_skill": "expert"})
assert cat.pet.pet_type == "cat" # type: ignore[attr-defined]
# Instantiate with a Dog
dog = Model(pet={"pet_type": "dog", "pack_size": 4})
assert dog.pet.pet_type == "dog" # type: ignore[attr-defined]
# Check round-trip schema includes discriminator
model_schema = Model.model_json_schema()
assert "discriminator" in model_schema["properties"]["pet"]
assert model_schema["properties"]["pet"]["discriminator"]["propertyName"] == "pet_type"
def test_allof_merging_with_refs(converter: _JSONSchemaToPydantic) -> None:
schema = {
"title": "EmployeeWithDepartment",
"allOf": [{"$ref": "#/$defs/Employee"}, {"$ref": "#/$defs/Department"}],
"$defs": {
"Employee": {
"type": "object",
"properties": {"id": {"type": "string"}, "name": {"type": "string"}},
"required": ["id", "name"],
"title": "Employee",
},
"Department": {
"type": "object",
"properties": {"department": {"type": "string"}},
"required": ["department"],
"title": "Department",
},
},
}
Model = converter.json_schema_to_pydantic(schema, "EmployeeWithDepartment")
instance = Model(id="123", name="Alice", department="Engineering")
assert instance.id == "123" # type: ignore[attr-defined]
assert instance.name == "Alice" # type: ignore[attr-defined]
assert instance.department == "Engineering" # type: ignore[attr-defined]
dumped = instance.model_dump()
assert dumped == {"id": "123", "name": "Alice", "department": "Engineering"}
def test_nested_allof_merging(converter: _JSONSchemaToPydantic) -> None:
schema = {
"title": "ContainerModel",
"type": "object",
"properties": {
"nested": {
"type": "object",
"properties": {
"data": {
"allOf": [
{"$ref": "#/$defs/Base"},
{"type": "object", "properties": {"extra": {"type": "string"}}, "required": ["extra"]},
]
}
},
"required": ["data"],
}
},
"required": ["nested"],
"$defs": {
"Base": {
"type": "object",
"properties": {"base_field": {"type": "string"}},
"required": ["base_field"],
"title": "Base",
}
},
}
Model = converter.json_schema_to_pydantic(schema, "ContainerModel")
instance = Model(nested={"data": {"base_field": "abc", "extra": "xyz"}})
assert instance.nested.data.base_field == "abc" # type: ignore[attr-defined]
assert instance.nested.data.extra == "xyz" # type: ignore[attr-defined]
@pytest.mark.parametrize(
"schema, field_name, valid_values, invalid_values",
[
# String constraints
(
{
"type": "object",
"properties": {
"username": {"type": "string", "minLength": 3, "maxLength": 10, "pattern": "^[a-zA-Z0-9_]+$"}
},
"required": ["username"],
},
"username",
["user_123", "abc", "Name2023"],
["", "ab", "toolongusername123", "invalid!char"],
),
# Integer constraints
(
{
"type": "object",
"properties": {"age": {"type": "integer", "minimum": 18, "maximum": 99}},
"required": ["age"],
},
"age",
[18, 25, 99],
[17, 100, -1],
),
# Float constraints
(
{
"type": "object",
"properties": {"score": {"type": "number", "minimum": 0.0, "exclusiveMaximum": 1.0}},
"required": ["score"],
},
"score",
[0.0, 0.5, 0.999],
[-0.1, 1.0, 2.5],
),
# Array constraints
(
{
"type": "object",
"properties": {"tags": {"type": "array", "items": {"type": "string"}, "minItems": 1, "maxItems": 3}},
"required": ["tags"],
},
"tags",
[["a"], ["a", "b"], ["x", "y", "z"]],
[[], ["one", "two", "three", "four"]],
),
],
)
def test_field_constraints(
schema: Dict[str, Any],
field_name: str,
valid_values: List[Any],
invalid_values: List[Any],
) -> None:
converter = _JSONSchemaToPydantic()
Model = converter.json_schema_to_pydantic(schema, "ConstraintModel")
for value in valid_values:
instance = Model(**{field_name: value})
assert getattr(instance, field_name) == value
for value in invalid_values:
with pytest.raises(ValidationError):
Model(**{field_name: value})
@pytest.mark.parametrize(
"schema",
[
# Top-level field
{"type": "object", "properties": {"weird": {"type": "abc"}}, "required": ["weird"]},
# Inside array items
{"type": "object", "properties": {"items": {"type": "array", "items": {"type": "abc"}}}, "required": ["items"]},
# Inside anyOf
{
"type": "object",
"properties": {"choice": {"anyOf": [{"type": "string"}, {"type": "abc"}]}},
"required": ["choice"],
},
],
)
def test_unknown_type_raises(schema: Dict[str, Any]) -> None:
converter = _JSONSchemaToPydantic()
with pytest.raises(UnsupportedKeywordError):
converter.json_schema_to_pydantic(schema, "UnknownTypeModel")
@pytest.mark.parametrize("json_type, expected_type", list(TYPE_MAPPING.items()))
def test_basic_type_mapping(json_type: str, expected_type: type) -> None:
schema = {
"type": "object",
"properties": {"field": {"type": json_type}},
"required": ["field"],
}
converter = _JSONSchemaToPydantic()
Model = converter.json_schema_to_pydantic(schema, f"{json_type.capitalize()}Model")
assert "field" in Model.__annotations__
field_type = Model.__annotations__["field"]
# For array/object/null we check the outer type only
if json_type == "null":
assert field_type is type(None)
elif json_type == "array":
assert getattr(field_type, "__origin__", None) is list
elif json_type == "object":
assert field_type in (dict, Dict) or getattr(field_type, "__origin__", None) in (dict, Dict)
else:
assert field_type == expected_type
@pytest.mark.parametrize("format_name, expected_type", list(FORMAT_MAPPING.items()))
def test_format_mapping(format_name: str, expected_type: Any) -> None:
schema = {
"type": "object",
"properties": {"field": {"type": "string", "format": format_name}},
"required": ["field"],
}
converter = _JSONSchemaToPydantic()
Model = converter.json_schema_to_pydantic(schema, f"{format_name.capitalize()}Model")
assert "field" in Model.__annotations__
field_type = Model.__annotations__["field"]
if isinstance(expected_type, types.FunctionType): # if it's a constrained constructor (e.g., conint)
assert callable(field_type)
else:
assert field_type == expected_type
def test_unknown_format_raises() -> None:
schema = {
"type": "object",
"properties": {"bad_field": {"type": "string", "format": "definitely-not-a-format"}},
}
converter = _JSONSchemaToPydantic()
with pytest.raises(FormatNotSupportedError):
converter.json_schema_to_pydantic(schema, "UnknownFormatModel")