diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 4c94b497d..3e4996fe8 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -46,6 +46,7 @@ from ..messages import ( MemoryQueryEvent, ModelClientStreamingChunkEvent, StructuredMessage, + StructuredMessageFactory, TextMessage, ThoughtEvent, ToolCallExecutionEvent, @@ -74,6 +75,7 @@ class AssistantAgentConfig(BaseModel): reflect_on_tool_use: bool tool_call_summary_format: str metadata: Dict[str, str] | None = None + structured_message_factory: ComponentModel | None = None class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): @@ -183,6 +185,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): This will be used with the model client to generate structured output. If this is set, the agent will respond with a :class:`~autogen_agentchat.messages.StructuredMessage` instead of a :class:`~autogen_agentchat.messages.TextMessage` in the final response, unless `reflect_on_tool_use` is `False` and a tool call is made. + output_content_type_format (str | None, optional): (Experimental) The format string used for the content of a :class:`~autogen_agentchat.messages.StructuredMessage` response. tool_call_summary_format (str, optional): The format string used to create the content for a :class:`~autogen_agentchat.messages.ToolCallSummaryMessage` response. The format string is used to format the tool call summary for every tool call result. Defaults to "{result}". @@ -635,6 +638,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): reflect_on_tool_use: bool | None = None, tool_call_summary_format: str = "{result}", output_content_type: type[BaseModel] | None = None, + output_content_type_format: str | None = None, memory: Sequence[Memory] | None = None, metadata: Dict[str, str] | None = None, ): @@ -643,6 +647,13 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): self._model_client = model_client self._model_client_stream = model_client_stream self._output_content_type: type[BaseModel] | None = output_content_type + self._output_content_type_format = output_content_type_format + self._structured_message_factory: StructuredMessageFactory | None = None + if output_content_type is not None: + self._structured_message_factory = StructuredMessageFactory( + input_model=output_content_type, format_string=output_content_type_format + ) + self._memory = None if memory is not None: if isinstance(memory, list): @@ -771,6 +782,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): reflect_on_tool_use = self._reflect_on_tool_use tool_call_summary_format = self._tool_call_summary_format output_content_type = self._output_content_type + format_string = self._output_content_type_format # STEP 1: Add new user/handoff messages to the model context await self._add_messages_to_context( @@ -840,6 +852,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): reflect_on_tool_use=reflect_on_tool_use, tool_call_summary_format=tool_call_summary_format, output_content_type=output_content_type, + format_string=format_string, ): yield output_event @@ -942,6 +955,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): reflect_on_tool_use: bool, tool_call_summary_format: str, output_content_type: type[BaseModel] | None, + format_string: str | None = None, ) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]: """ Handle final or partial responses from model_result, including tool calls, handoffs, @@ -957,6 +971,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): content=content, source=agent_name, models_usage=model_result.usage, + format_string=format_string, ), inner_messages=inner_messages, ) @@ -1277,9 +1292,6 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): def _to_config(self) -> AssistantAgentConfig: """Convert the assistant agent to a declarative config.""" - if self._output_content_type: - raise ValueError("AssistantAgent with output_content_type does not support declarative config.") - return AssistantAgentConfig( name=self.name, model_client=self._model_client.dump_component(), @@ -1294,12 +1306,23 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): model_client_stream=self._model_client_stream, reflect_on_tool_use=self._reflect_on_tool_use, tool_call_summary_format=self._tool_call_summary_format, + structured_message_factory=self._structured_message_factory.dump_component() + if self._structured_message_factory + else None, metadata=self._metadata, ) @classmethod def _from_config(cls, config: AssistantAgentConfig) -> Self: """Create an assistant agent from a declarative config.""" + if config.structured_message_factory: + structured_message_factory = StructuredMessageFactory.load_component(config.structured_message_factory) + format_string = structured_message_factory.format_string + output_content_type = structured_message_factory.ContentModel + + else: + format_string = None + output_content_type = None return cls( name=config.name, @@ -1313,5 +1336,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): model_client_stream=config.model_client_stream, reflect_on_tool_use=config.reflect_on_tool_use, tool_call_summary_format=config.tool_call_summary_format, + output_content_type=output_content_type, + output_content_type_format=format_string, metadata=config.metadata, ) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py index e24b6993c..da67e07fd 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py @@ -5,9 +5,9 @@ class and includes specific fields relevant to the type of message being sent. """ from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, List, Literal, Mapping, TypeVar +from typing import Any, Dict, Generic, List, Literal, Mapping, Optional, Type, TypeVar -from autogen_core import FunctionCall, Image +from autogen_core import Component, ComponentBase, FunctionCall, Image from autogen_core.code_executor import CodeBlock, CodeResult from autogen_core.memory import MemoryContent from autogen_core.models import ( @@ -16,6 +16,7 @@ from autogen_core.models import ( RequestUsage, UserMessage, ) +from autogen_core.utils import schema_to_pydantic_model from pydantic import BaseModel, Field, computed_field from typing_extensions import Annotated, Self @@ -182,21 +183,56 @@ class StructuredMessage(BaseChatMessage, Generic[StructuredContentType]): print(message.to_text()) # {"text": "Hello", "number": 42} + .. code-block:: python + + from pydantic import BaseModel + from autogen_agentchat.messages import StructuredMessage + + + class MyMessageContent(BaseModel): + text: str + number: int + + + message = StructuredMessage[MyMessageContent]( + content=MyMessageContent(text="Hello", number=42), + source="agent", + format_string="Hello, {text} {number}!", + ) + + print(message.to_text()) # Hello, agent 42! + """ content: StructuredContentType """The content of the message. Must be a subclass of `Pydantic BaseModel `_.""" + format_string: Optional[str] = None + """(Experimental) An optional format string to render the content into a human-readable format. + The format string can use the fields of the content model as placeholders. + For example, if the content model has a field `name`, you can use + `{name}` in the format string to include the value of that field. + The format string is used in the :meth:`to_text` method to create a + human-readable representation of the message. + This setting is experimental and will change in the future. + """ + @computed_field def type(self) -> str: return self.__class__.__name__ def to_text(self) -> str: - return self.content.model_dump_json(indent=2) + if self.format_string is not None: + return self.format_string.format(**self.content.model_dump()) + else: + return self.content.model_dump_json() def to_model_text(self) -> str: - return self.content.model_dump_json() + if self.format_string is not None: + return self.format_string.format(**self.content.model_dump()) + else: + return self.content.model_dump_json() def to_model_message(self) -> UserMessage: return UserMessage( @@ -205,6 +241,113 @@ class StructuredMessage(BaseChatMessage, Generic[StructuredContentType]): ) +class StructureMessageConfig(BaseModel): + """The declarative configuration for the structured output.""" + + json_schema: Dict[str, Any] + format_string: Optional[str] = None + content_model_name: str + + +class StructuredMessageFactory(ComponentBase[StructureMessageConfig], Component[StructureMessageConfig]): + """:meta private: + + A component that creates structured chat messages from Pydantic models or JSON schemas. + + This component helps you generate strongly-typed chat messages with content defined using a Pydantic model. + It can be used in declarative workflows where message structure must be validated, formatted, and serialized. + + You can initialize the component directly using a `BaseModel` subclass, or dynamically from a configuration + object (e.g., loaded from disk or a database). + + ### Example 1: Create from a Pydantic Model + + .. code-block:: python + + from pydantic import BaseModel + from autogen_agentchat.messages import StructuredMessageFactory + + + class TestContent(BaseModel): + field1: str + field2: int + + + format_string = "This is a string {field1} and this is an int {field2}" + sm_component = StructuredMessageFactory(input_model=TestContent, format_string=format_string) + + message = sm_component.StructuredMessage( + source="test_agent", content=TestContent(field1="Hello", field2=42), format_string=format_string + ) + + print(message.to_model_text()) # Output: This is a string Hello and this is an int 42 + + config = sm_component.dump_component() + + s_m_dyn = StructuredMessageFactory.load_component(config) + message = s_m_dyn.StructuredMessage( + source="test_agent", + content=s_m_dyn.ContentModel(field1="dyn agent", field2=43), + format_string=s_m_dyn.format_string, + ) + print(type(message)) # StructuredMessage[GeneratedModel] + print(message.to_model_text()) # Output: This is a string dyn agent and this is an int 43 + + Attributes: + component_config_schema (StructureMessageConfig): Defines the configuration structure for this component. + component_provider_override (str): Path used to reference this component in external tooling. + component_type (str): Identifier used for categorization (e.g., "structured_message"). + + Raises: + ValueError: If neither `json_schema` nor `input_model` is provided. + + Args: + json_schema (Optional[str]): JSON schema to dynamically create a Pydantic model. + input_model (Optional[Type[BaseModel]]): A subclass of `BaseModel` that defines the expected message structure. + format_string (Optional[str]): Optional string to render content into a human-readable format. + content_model_name (Optional[str]): Optional name for the generated Pydantic model. + """ + + component_config_schema = StructureMessageConfig + component_provider_override = "autogen_agentchat.messages.StructuredMessageFactory" + component_type = "structured_message" + + def __init__( + self, + json_schema: Optional[Dict[str, Any]] = None, + input_model: Optional[Type[BaseModel]] = None, + format_string: Optional[str] = None, + content_model_name: Optional[str] = None, + ) -> None: + self.format_string = format_string + + if json_schema: + self.ContentModel = schema_to_pydantic_model( + json_schema, model_name=content_model_name or "GeneratedContentModel" + ) + elif input_model: + self.ContentModel = input_model + else: + raise ValueError("Either `json_schema` or `input_model` must be provided.") + + self.StructuredMessage = StructuredMessage[self.ContentModel] # type: ignore[name-defined] + + def _to_config(self) -> StructureMessageConfig: + return StructureMessageConfig( + json_schema=self.ContentModel.model_json_schema(), + format_string=self.format_string, + content_model_name=self.ContentModel.__name__, + ) + + @classmethod + def _from_config(cls, config: StructureMessageConfig) -> "StructuredMessageFactory": + return cls( + json_schema=config.json_schema, + format_string=config.format_string, + content_model_name=config.content_model_name, + ) + + class TextMessage(BaseTextChatMessage): """A text message with string-only content.""" @@ -468,6 +611,7 @@ __all__ = [ "BaseTextChatMessage", "StructuredContentType", "StructuredMessage", + "StructuredMessageFactory", "HandoffMessage", "MultiModalMessage", "StopMessage", diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index 242bf5cde..23efbf459 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -21,6 +21,7 @@ from ...messages import ( MessageFactory, ModelClientStreamingChunkEvent, StopMessage, + StructuredMessage, TextMessage, ) from ...state import TeamState @@ -68,6 +69,15 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]): for message_type in custom_message_types: self._message_factory.register(message_type) + for agent in participants: + for message_type in agent.produced_message_types: + try: + if issubclass(message_type, StructuredMessage): + self._message_factory.register(message_type) # type: ignore[reportUnknownArgumentType] + except TypeError: + # Not a class or not a valid subclassable type (skip) + pass + # The team ID is a UUID that is used to identify the team and its participants # in the agent runtime. It is used to create unique topic types for each participant. # Currently, team ID is binded to an object instance of the group chat class. diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index 381583919..192ad2fe9 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -36,7 +36,7 @@ from autogen_core.models._model_client import ModelFamily from autogen_core.tools import BaseTool, FunctionTool from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_ext.models.replay import ReplayChatCompletionClient -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from utils import FileLogHandler logger = logging.getLogger(EVENT_LOGGER_NAME) @@ -1104,3 +1104,103 @@ async def test_model_client_stream_with_tool_calls() -> None: elif isinstance(message, ModelClientStreamingChunkEvent): chunks.append(message.content) assert "".join(chunks) == "Example response 2 to task" + + +@pytest.mark.asyncio +async def test_invalid_structured_output_format() -> None: + class AgentResponse(BaseModel): + response: str + status: str + + model_client = ReplayChatCompletionClient( + [ + CreateResult( + finish_reason="stop", + content='{"response": "Hello"}', + usage=RequestUsage(prompt_tokens=10, completion_tokens=5), + cached=False, + ), + ] + ) + + agent = AssistantAgent( + name="assistant", + model_client=model_client, + output_content_type=AgentResponse, + ) + + with pytest.raises(ValidationError): + await agent.run() + + +@pytest.mark.asyncio +async def test_structured_message_factory_serialization() -> None: + class AgentResponse(BaseModel): + result: str + status: str + + model_client = ReplayChatCompletionClient( + [ + CreateResult( + finish_reason="stop", + content=AgentResponse(result="All good", status="ok").model_dump_json(), + usage=RequestUsage(prompt_tokens=10, completion_tokens=5), + cached=False, + ) + ] + ) + + agent = AssistantAgent( + name="structured_agent", + model_client=model_client, + output_content_type=AgentResponse, + output_content_type_format="{result} - {status}", + ) + + dumped = agent.dump_component() + restored_agent = AssistantAgent.load_component(dumped) + result = await restored_agent.run() + + assert isinstance(result.messages[0], StructuredMessage) + assert result.messages[0].content.result == "All good" # type: ignore[reportUnknownMemberType] + assert result.messages[0].content.status == "ok" # type: ignore[reportUnknownMemberType] + + +@pytest.mark.asyncio +async def test_structured_message_format_string() -> None: + class AgentResponse(BaseModel): + field1: str + field2: str + + expected = AgentResponse(field1="foo", field2="bar") + + model_client = ReplayChatCompletionClient( + [ + CreateResult( + finish_reason="stop", + content=expected.model_dump_json(), + usage=RequestUsage(prompt_tokens=10, completion_tokens=5), + cached=False, + ) + ] + ) + + agent = AssistantAgent( + name="formatted_agent", + model_client=model_client, + output_content_type=AgentResponse, + output_content_type_format="{field1} - {field2}", + ) + + result = await agent.run() + + assert len(result.messages) == 1 + message = result.messages[0] + + # Check that it's a StructuredMessage with the correct content model + assert isinstance(message, StructuredMessage) + assert isinstance(message.content, AgentResponse) # type: ignore[reportUnknownMemberType] + assert message.content == expected + + # Check that the format_string was applied correctly + assert message.to_model_text() == "foo - bar" diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index 6e04c2b8e..e9aee1ba7 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -1441,3 +1441,87 @@ async def test_declarative_groupchats_with_config(runtime: AgentRuntime | None) assert selector.dump_component().provider == "autogen_agentchat.teams.SelectorGroupChat" assert swarm.dump_component().provider == "autogen_agentchat.teams.Swarm" assert magentic.dump_component().provider == "autogen_agentchat.teams.MagenticOneGroupChat" + + +class _StructuredContent(BaseModel): + message: str + + +class _StructuredAgent(BaseChatAgent): + def __init__(self, name: str, description: str) -> None: + super().__init__(name, description) + self._message = _StructuredContent(message="Structured hello") + + @property + def produced_message_types(self) -> Sequence[type[BaseChatMessage]]: + return (StructuredMessage[_StructuredContent],) + + async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response: + return Response( + chat_message=StructuredMessage[_StructuredContent]( + source=self.name, + content=self._message, + format_string="Structured says: {message}", + ) + ) + + async def on_reset(self, cancellation_token: CancellationToken) -> None: + pass + + +@pytest.mark.asyncio +async def test_message_type_auto_registration(runtime: AgentRuntime | None) -> None: + agent1 = _StructuredAgent("structured", description="emits structured messages") + agent2 = _EchoAgent("echo", description="echoes input") + + team = RoundRobinGroupChat(participants=[agent1, agent2], max_turns=2, runtime=runtime) + + result = await team.run(task="Say something structured") + + assert len(result.messages) == 3 + assert isinstance(result.messages[0], TextMessage) + assert isinstance(result.messages[1], StructuredMessage) + assert isinstance(result.messages[2], TextMessage) + assert result.messages[1].to_text() == "Structured says: Structured hello" + + +@pytest.mark.asyncio +async def test_structured_message_state_roundtrip(runtime: AgentRuntime | None) -> None: + agent1 = _StructuredAgent("structured", description="sends structured") + agent2 = _EchoAgent("echo", description="echoes") + + team1 = RoundRobinGroupChat( + participants=[agent1, agent2], + termination_condition=MaxMessageTermination(2), + runtime=runtime, + ) + + await team1.run(task="Say something structured") + state1 = await team1.save_state() + + # Recreate team without needing custom_message_types + agent3 = _StructuredAgent("structured", description="sends structured") + agent4 = _EchoAgent("echo", description="echoes") + team2 = RoundRobinGroupChat( + participants=[agent3, agent4], + termination_condition=MaxMessageTermination(2), + runtime=runtime, + ) + + await team2.load_state(state1) + state2 = await team2.save_state() + + # Assert full state equality + assert state1 == state2 + + # Assert message thread content match + manager1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore + AgentId(f"{team1._group_chat_manager_name}_{team1._team_id}", team1._team_id), # pyright: ignore + RoundRobinGroupChatManager, + ) + manager2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore + AgentId(f"{team2._group_chat_manager_name}_{team2._team_id}", team2._team_id), # pyright: ignore + RoundRobinGroupChatManager, + ) + + assert manager1._message_thread == manager2._message_thread # pyright: ignore diff --git a/python/packages/autogen-agentchat/tests/test_messages.py b/python/packages/autogen-agentchat/tests/test_messages.py index 2129cd66c..17f53f27e 100644 --- a/python/packages/autogen-agentchat/tests/test_messages.py +++ b/python/packages/autogen-agentchat/tests/test_messages.py @@ -11,6 +11,7 @@ from autogen_agentchat.messages import ( MultiModalMessage, StopMessage, StructuredMessage, + StructuredMessageFactory, TextMessage, ToolCallExecutionEvent, ToolCallRequestEvent, @@ -52,6 +53,28 @@ def test_structured_message() -> None: assert dumped_message["type"] == "StructuredMessage[TestContent]" +def test_structured_message_component() -> None: + # Create a structured message with the test contentformat_string="this is a string {field1} and this is an int {field2}" + format_string = "this is a string {field1} and this is an int {field2}" + s_m = StructuredMessageFactory(input_model=TestContent, format_string=format_string) + config = s_m.dump_component() + s_m_dyn = StructuredMessageFactory.load_component(config) + message = s_m_dyn.StructuredMessage( + source="test_agent", content=s_m_dyn.ContentModel(field1="test", field2=42), format_string=s_m_dyn.format_string + ) + + assert isinstance(message.content, s_m_dyn.ContentModel) + assert not isinstance(message.content, TestContent) + assert message.content.field1 == "test" # type: ignore[attr-defined] + assert message.content.field2 == 42 # type: ignore[attr-defined] + + dumped_message = message.model_dump() + assert dumped_message["source"] == "test_agent" + assert dumped_message["content"]["field1"] == "test" + assert dumped_message["content"]["field2"] == 42 + assert message.to_model_text() == format_string.format(field1="test", field2=42) + + def test_message_factory() -> None: factory = MessageFactory() @@ -109,6 +132,22 @@ def test_message_factory() -> None: assert structured_message.content.field2 == 42 assert structured_message.type == "StructuredMessage[TestContent]" # type: ignore[comparison-overlap] + sm_factory = StructuredMessageFactory(input_model=TestContent, format_string=None, content_model_name="TestContent") + config = sm_factory.dump_component() + config.config["content_model_name"] = "DynamicTestContent" + sm_factory_dynamic = StructuredMessageFactory.load_component(config) + + factory.register(sm_factory_dynamic.StructuredMessage) + msg = sm_factory_dynamic.StructuredMessage( + content=sm_factory_dynamic.ContentModel(field1="static", field2=123), source="static_agent" + ) + restored = factory.create(msg.dump()) + assert isinstance(restored, StructuredMessage) + assert isinstance(restored.content, sm_factory_dynamic.ContentModel) # type: ignore[reportUnkownMemberType] + assert restored.source == "static_agent" + assert restored.content.field1 == "static" # type: ignore[attr-defined] + assert restored.content.field2 == 123 # type: ignore[attr-defined] + class TestContainer(BaseModel): chat_messages: List[ChatMessage] diff --git a/python/packages/autogen-core/src/autogen_core/utils/__init__.py b/python/packages/autogen-core/src/autogen_core/utils/__init__.py new file mode 100644 index 000000000..c5c0cfde5 --- /dev/null +++ b/python/packages/autogen-core/src/autogen_core/utils/__init__.py @@ -0,0 +1,3 @@ +from ._json_to_pydantic import schema_to_pydantic_model + +__all__ = ["schema_to_pydantic_model"] diff --git a/python/packages/autogen-core/src/autogen_core/utils/_json_to_pydantic.py b/python/packages/autogen-core/src/autogen_core/utils/_json_to_pydantic.py new file mode 100644 index 000000000..892a22a90 --- /dev/null +++ b/python/packages/autogen-core/src/autogen_core/utils/_json_to_pydantic.py @@ -0,0 +1,531 @@ +import datetime +from ipaddress import IPv4Address, IPv6Address +from typing import Annotated, Any, Dict, ForwardRef, List, Literal, Optional, Type, Union, cast + +from pydantic import ( + UUID1, + UUID3, + UUID4, + UUID5, + AnyUrl, + BaseModel, + EmailStr, + Field, + conbytes, + confloat, + conint, + conlist, + constr, + create_model, +) +from pydantic.fields import FieldInfo + + +class SchemaConversionError(Exception): + """Base class for schema conversion exceptions.""" + + pass + + +class ReferenceNotFoundError(SchemaConversionError): + """Raised when a $ref cannot be resolved.""" + + pass + + +class FormatNotSupportedError(SchemaConversionError): + """Raised when a format is not supported.""" + + pass + + +class UnsupportedKeywordError(SchemaConversionError): + """Raised when an unsupported JSON Schema keyword is encountered.""" + + pass + + +TYPE_MAPPING: Dict[str, Type[Any]] = { + "string": str, + "integer": int, + "boolean": bool, + "number": float, + "array": List, + "object": dict, + "null": type(None), +} + +FORMAT_MAPPING: Dict[str, Any] = { + "uuid": UUID4, + "uuid1": UUID1, + "uuid2": UUID4, + "uuid3": UUID3, + "uuid4": UUID4, + "uuid5": UUID5, + "email": EmailStr, + "uri": AnyUrl, + "hostname": constr(strict=True), + "ipv4": IPv4Address, + "ipv6": IPv6Address, + "ipv4-network": IPv4Address, + "ipv6-network": IPv6Address, + "date-time": datetime.datetime, + "date": datetime.date, + "time": datetime.time, + "duration": datetime.timedelta, + "int32": conint(strict=True, ge=-(2**31), le=2**31 - 1), + "int64": conint(strict=True, ge=-(2**63), le=2**63 - 1), + "float": confloat(strict=True), + "double": float, + "decimal": float, + "byte": conbytes(strict=True), + "binary": conbytes(strict=True), + "password": str, + "path": str, +} + + +def _make_field( + default: Any, + *, + title: Optional[str] = None, + description: Optional[str] = None, +) -> Any: + """Construct a Pydantic Field with proper typing.""" + field_kwargs: Dict[str, Any] = {} + if title is not None: + field_kwargs["title"] = title + if description is not None: + field_kwargs["description"] = description + return Field(default, **field_kwargs) + + +class _JSONSchemaToPydantic: + def __init__(self) -> None: + self._model_cache: Dict[str, Optional[Union[Type[BaseModel], ForwardRef]]] = {} + + def _resolve_ref(self, ref: str, schema: Dict[str, Any]) -> Dict[str, Any]: + ref_key = ref.split("/")[-1] + definitions = cast(dict[str, dict[str, Any]], schema.get("$defs", {})) + + if ref_key not in definitions: + raise ReferenceNotFoundError( + f"Reference `{ref}` not found in `$defs`. Available keys: {list(definitions.keys())}" + ) + + return definitions[ref_key] + + def get_ref(self, ref_name: str) -> Any: + if ref_name not in self._model_cache: + raise ReferenceNotFoundError( + f"Reference `{ref_name}` not found in cache. Available: {list(self._model_cache.keys())}" + ) + + if self._model_cache[ref_name] is None: + return ForwardRef(ref_name) + + return self._model_cache[ref_name] + + def _process_definitions(self, root_schema: Dict[str, Any]) -> None: + if "$defs" in root_schema: + for model_name in root_schema["$defs"]: + if model_name not in self._model_cache: + self._model_cache[model_name] = None + + for model_name, model_schema in root_schema["$defs"].items(): + if self._model_cache[model_name] is None: + self._model_cache[model_name] = self.json_schema_to_pydantic(model_schema, model_name, root_schema) + + def json_schema_to_pydantic( + self, schema: Dict[str, Any], model_name: str = "GeneratedModel", root_schema: Optional[Dict[str, Any]] = None + ) -> Type[BaseModel]: + if root_schema is None: + root_schema = schema + self._process_definitions(root_schema) + + if "$ref" in schema: + resolved = self._resolve_ref(schema["$ref"], root_schema) + schema = {**resolved, **{k: v for k, v in schema.items() if k != "$ref"}} + + if "allOf" in schema: + merged: Dict[str, Any] = {"type": "object", "properties": {}, "required": []} + for s in schema["allOf"]: + part = self._resolve_ref(s["$ref"], root_schema) if "$ref" in s else s + merged["properties"].update(part.get("properties", {})) + merged["required"].extend(part.get("required", [])) + for k, v in schema.items(): + if k not in {"allOf", "properties", "required"}: + merged[k] = v + merged["required"] = list(set(merged["required"])) + schema = merged + + return self._json_schema_to_model(schema, model_name, root_schema) + + def _resolve_union_types(self, schemas: List[Dict[str, Any]]) -> List[Any]: + types: List[Any] = [] + for s in schemas: + if "$ref" in s: + types.append(self.get_ref(s["$ref"].split("/")[-1])) + elif "enum" in s: + types.append(Literal[tuple(s["enum"])] if len(s["enum"]) > 0 else Any) + else: + json_type = s.get("type") + if json_type not in TYPE_MAPPING: + raise UnsupportedKeywordError(f"Unsupported or missing type `{json_type}` in union") + types.append(TYPE_MAPPING[json_type]) + return types + + def _extract_field_type(self, key: str, value: Dict[str, Any], model_name: str, root_schema: Dict[str, Any]) -> Any: + json_type = value.get("type") + if json_type not in TYPE_MAPPING: + raise UnsupportedKeywordError( + f"Unsupported or missing type `{json_type}` for field `{key}` in `{model_name}`" + ) + + base_type = TYPE_MAPPING[json_type] + constraints: Dict[str, Any] = {} + + if json_type == "string": + if "minLength" in value: + constraints["min_length"] = value["minLength"] + if "maxLength" in value: + constraints["max_length"] = value["maxLength"] + if "pattern" in value: + constraints["pattern"] = value["pattern"] + if constraints: + base_type = constr(**constraints) + + elif json_type == "integer": + if "minimum" in value: + constraints["ge"] = value["minimum"] + if "maximum" in value: + constraints["le"] = value["maximum"] + if "exclusiveMinimum" in value: + constraints["gt"] = value["exclusiveMinimum"] + if "exclusiveMaximum" in value: + constraints["lt"] = value["exclusiveMaximum"] + if constraints: + base_type = conint(**constraints) + + elif json_type == "number": + if "minimum" in value: + constraints["ge"] = value["minimum"] + if "maximum" in value: + constraints["le"] = value["maximum"] + if "exclusiveMinimum" in value: + constraints["gt"] = value["exclusiveMinimum"] + if "exclusiveMaximum" in value: + constraints["lt"] = value["exclusiveMaximum"] + if constraints: + base_type = confloat(**constraints) + + elif json_type == "array": + if "minItems" in value: + constraints["min_length"] = value["minItems"] + if "maxItems" in value: + constraints["max_length"] = value["maxItems"] + item_schema = value.get("items", {"type": "string"}) + if "$ref" in item_schema: + item_type = self.get_ref(item_schema["$ref"].split("/")[-1]) + else: + item_type_name = item_schema.get("type") + if item_type_name not in TYPE_MAPPING: + raise UnsupportedKeywordError( + f"Unsupported or missing item type `{item_type_name}` for array field `{key}` in `{model_name}`" + ) + item_type = TYPE_MAPPING[item_type_name] + + base_type = conlist(item_type, **constraints) if constraints else List[item_type] # type: ignore[valid-type] + + if "format" in value: + format_type = FORMAT_MAPPING.get(value["format"]) + if format_type is None: + raise FormatNotSupportedError(f"Unknown format `{value['format']}` for `{key}` in `{model_name}`") + if not isinstance(format_type, type): + return format_type + if not issubclass(format_type, str): + return format_type + return format_type + + return base_type + + def _json_schema_to_model( + self, schema: Dict[str, Any], model_name: str, root_schema: Dict[str, Any] + ) -> Type[BaseModel]: + if "allOf" in schema: + merged: Dict[str, Any] = {"type": "object", "properties": {}, "required": []} + for s in schema["allOf"]: + part = self._resolve_ref(s["$ref"], root_schema) if "$ref" in s else s + merged["properties"].update(part.get("properties", {})) + merged["required"].extend(part.get("required", [])) + for k, v in schema.items(): + if k not in {"allOf", "properties", "required"}: + merged[k] = v + merged["required"] = list(set(merged["required"])) + schema = merged + + fields: Dict[str, tuple[Any, FieldInfo]] = {} + required_fields = set(schema.get("required", [])) + + for key, value in schema.get("properties", {}).items(): + if "$ref" in value: + ref_name = value["$ref"].split("/")[-1] + field_type = self.get_ref(ref_name) + elif "anyOf" in value: + sub_models = self._resolve_union_types(value["anyOf"]) + field_type = Union[tuple(sub_models)] + elif "oneOf" in value: + sub_models = self._resolve_union_types(value["oneOf"]) + field_type = Union[tuple(sub_models)] + if "discriminator" in value: + discriminator = value["discriminator"]["propertyName"] + field_type = Annotated[field_type, Field(discriminator=discriminator)] + elif "enum" in value: + field_type = Literal[tuple(value["enum"])] + elif "allOf" in value: + merged = {"type": "object", "properties": {}, "required": []} + for s in value["allOf"]: + part = self._resolve_ref(s["$ref"], root_schema) if "$ref" in s else s + merged["properties"].update(part.get("properties", {})) + merged["required"].extend(part.get("required", [])) + for k, v in value.items(): + if k not in {"allOf", "properties", "required"}: + merged[k] = v + merged["required"] = list(set(merged["required"])) + field_type = self._json_schema_to_model(merged, f"{model_name}_{key}", root_schema) + elif value.get("type") == "object" and "properties" in value: + field_type = self._json_schema_to_model(value, f"{model_name}_{key}", root_schema) + else: + field_type = self._extract_field_type(key, value, model_name, root_schema) + + if field_type is None: + raise UnsupportedKeywordError(f"Unsupported or missing type for field `{key}` in `{model_name}`") + + default_value = value.get("default") + is_required = key in required_fields + + if not is_required and default_value is None: + field_type = Optional[field_type] + + field_args = { + "default": default_value if not is_required else ..., + } + if "title" in value: + field_args["title"] = value["title"] + if "description" in value: + field_args["description"] = value["description"] + + fields[key] = ( + field_type, + _make_field( + default_value if not is_required else ..., + title=value.get("title"), + description=value.get("description"), + ), + ) + + model: Type[BaseModel] = create_model(model_name, **cast(dict[str, Any], fields)) + model.model_rebuild() + return model + + +def schema_to_pydantic_model(schema: Dict[str, Any], model_name: str = "GeneratedModel") -> Type[BaseModel]: + """ + Convert a JSON Schema dictionary to a fully-typed Pydantic model. + + This function handles schema translation and validation logic to produce + a Pydantic model. + + **Supported JSON Schema Features** + + - **Primitive types**: `string`, `integer`, `number`, `boolean`, `object`, `array`, `null` + - **String formats**: + - `email`, `uri`, `uuid`, `uuid1`, `uuid3`, `uuid4`, `uuid5` + - `hostname`, `ipv4`, `ipv6`, `ipv4-network`, `ipv6-network` + - `date`, `time`, `date-time`, `duration` + - `byte`, `binary`, `password`, `path` + - **String constraints**: + - `minLength`, `maxLength`, `pattern` + - **Numeric constraints**: + - `minimum`, `maximum`, `exclusiveMinimum`, `exclusiveMaximum` + - **Array constraints**: + - `minItems`, `maxItems`, `items` + - **Object schema support**: + - `properties`, `required`, `title`, `description`, `default` + - **Enums**: + - Converted to Python `Literal` type + - **Union types**: + - `anyOf`, `oneOf` supported with optional `discriminator` + - **Inheritance and composition**: + - `allOf` merges multiple schemas into one model + - **$ref and $defs resolution**: + - Supports references to sibling definitions and self-referencing schemas + + .. code-block:: python + + from json_schema_to_pydantic import schema_to_pydantic_model + + # Example 1: Simple user model + schema = { + "title": "User", + "type": "object", + "properties": { + "name": {"type": "string"}, + "email": {"type": "string", "format": "email"}, + "age": {"type": "integer", "minimum": 0}, + }, + "required": ["name", "email"], + } + + UserModel = schema_to_pydantic_model(schema) + user = UserModel(name="Alice", email="alice@example.com", age=30) + + .. code-block:: python + + # Example 2: Nested model + schema = { + "title": "BlogPost", + "type": "object", + "properties": { + "title": {"type": "string"}, + "tags": {"type": "array", "items": {"type": "string"}}, + "author": { + "type": "object", + "properties": {"name": {"type": "string"}, "email": {"type": "string", "format": "email"}}, + "required": ["name"], + }, + }, + "required": ["title", "author"], + } + + BlogPost = schema_to_pydantic_model(schema) + + .. code-block:: python + + # Example 3: allOf merging with $refs + schema = { + "title": "EmployeeWithDepartment", + "allOf": [{"$ref": "#/$defs/Employee"}, {"$ref": "#/$defs/Department"}], + "$defs": { + "Employee": { + "type": "object", + "properties": {"id": {"type": "string"}, "name": {"type": "string"}}, + "required": ["id", "name"], + }, + "Department": { + "type": "object", + "properties": {"department": {"type": "string"}}, + "required": ["department"], + }, + }, + } + + Model = schema_to_pydantic_model(schema) + + .. code-block:: python + + # Example 4: Self-referencing (recursive) model + schema = { + "title": "Category", + "type": "object", + "properties": { + "name": {"type": "string"}, + "subcategories": {"type": "array", "items": {"$ref": "#/$defs/Category"}}, + }, + "required": ["name"], + "$defs": { + "Category": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "subcategories": {"type": "array", "items": {"$ref": "#/$defs/Category"}}, + }, + "required": ["name"], + } + }, + } + + Category = schema_to_pydantic_model(schema) + + .. code-block:: python + + # Example 5: Serializing and deserializing with Pydantic + + from uuid import uuid4 + from pydantic import BaseModel, EmailStr, Field + from typing import Optional, List, Dict, Any + from autogen_core.utils import schema_to_pydantic_model + + + class Address(BaseModel): + street: str + city: str + zipcode: str + + + class User(BaseModel): + id: str + name: str + email: EmailStr + age: int = Field(..., ge=18) + address: Address + + + class Employee(BaseModel): + id: str + name: str + manager: Optional["Employee"] = None + + + class Department(BaseModel): + name: str + employees: List[Employee] + + + class ComplexModel(BaseModel): + user: User + extra_info: Optional[Dict[str, Any]] = None + sub_items: List[Employee] + + + # Convert ComplexModel to JSON schema + complex_schema = ComplexModel.model_json_schema() + + # Rebuild a new Pydantic model from JSON schema + ReconstructedModel = schema_to_pydantic_model(complex_schema, "ComplexModel") + + # Instantiate reconstructed model + reconstructed = ReconstructedModel( + user={ + "id": str(uuid4()), + "name": "Alice", + "email": "alice@example.com", + "age": 30, + "address": {"street": "123 Main St", "city": "Wonderland", "zipcode": "12345"}, + }, + sub_items=[{"id": str(uuid4()), "name": "Bob", "manager": {"id": str(uuid4()), "name": "Eve"}}], + ) + + print(reconstructed.model_dump()) + + + Args: + schema (Dict[str, Any]): A valid JSON Schema dictionary. + model_name (str, optional): The name of the root model. Defaults to "GeneratedModel". + + Returns: + Type[BaseModel]: A dynamically generated Pydantic model class. + + Raises: + ReferenceNotFoundError: If a `$ref` key references a missing entry. + FormatNotSupportedError: If a `format` keyword is unknown or unsupported. + UnsupportedKeywordError: If the schema contains an unsupported `type`. + + See Also: + - :class:`pydantic.BaseModel` + - :func:`pydantic.create_model` + - https://json-schema.org/ + """ + ... + + return _JSONSchemaToPydantic().json_schema_to_pydantic(schema, model_name) diff --git a/python/packages/autogen-core/tests/test_json_to_pydantic.py b/python/packages/autogen-core/tests/test_json_to_pydantic.py new file mode 100644 index 000000000..990ec6339 --- /dev/null +++ b/python/packages/autogen-core/tests/test_json_to_pydantic.py @@ -0,0 +1,757 @@ +import types +from typing import Any, Dict, List, Literal, Optional, Type, get_args, get_origin +from uuid import UUID, uuid4 + +import pytest +from autogen_core.utils._json_to_pydantic import ( + FORMAT_MAPPING, + TYPE_MAPPING, + FormatNotSupportedError, + ReferenceNotFoundError, + UnsupportedKeywordError, + _JSONSchemaToPydantic, # pyright: ignore[reportPrivateUsage] +) +from pydantic import BaseModel, EmailStr, Field, ValidationError + + +# ✅ Define Pydantic models for testing +class Address(BaseModel): + street: str + city: str + zipcode: str + + +class User(BaseModel): + id: UUID + name: str + email: EmailStr + age: int = Field(..., ge=18) # Minimum age = 18 + address: Address + + +class Employee(BaseModel): + id: UUID + name: str + manager: Optional["Employee"] = None # Recursive self-reference + + +class Department(BaseModel): + name: str + employees: List[Employee] # Array of objects + + +class ComplexModel(BaseModel): + user: User + extra_info: Optional[Dict[str, Any]] = None # Optional dictionary + sub_items: List[Employee] # List of Employees + + +@pytest.fixture +def converter() -> _JSONSchemaToPydantic: + """Fixture to create a fresh instance of JSONSchemaToPydantic for every test.""" + return _JSONSchemaToPydantic() + + +@pytest.fixture +def sample_json_schema() -> Dict[str, Any]: + """Fixture that returns a JSON schema dynamically using model_json_schema().""" + return User.model_json_schema() + + +@pytest.fixture +def sample_json_schema_recursive() -> Dict[str, Any]: + """Fixture that returns a self-referencing JSON schema.""" + return Employee.model_json_schema() + + +@pytest.fixture +def sample_json_schema_nested() -> Dict[str, Any]: + """Fixture that returns a nested schema with arrays of objects.""" + return Department.model_json_schema() + + +@pytest.fixture +def sample_json_schema_complex() -> Dict[str, Any]: + """Fixture that returns a complex schema with multiple structures.""" + return ComplexModel.model_json_schema() + + +@pytest.mark.parametrize( + "schema_fixture, model_name, expected_fields", + [ + (sample_json_schema, "User", ["id", "name", "email", "age", "address"]), + (sample_json_schema_recursive, "Employee", ["id", "name", "manager"]), + (sample_json_schema_nested, "Department", ["name", "employees"]), + (sample_json_schema_complex, "ComplexModel", ["user", "extra_info", "sub_items"]), + ], +) +def test_json_schema_to_pydantic( + converter: _JSONSchemaToPydantic, + schema_fixture: Any, + model_name: str, + expected_fields: List[str], + request: Any, +) -> None: + """Test conversion of JSON Schema to Pydantic model using the class instance.""" + schema = request.getfixturevalue(schema_fixture.__name__) + Model = converter.json_schema_to_pydantic(schema, model_name) + + for field in expected_fields: + assert field in Model.__annotations__, f"Expected '{field}' missing in {model_name}Model" + + +# ✅ **Valid Data Tests** +@pytest.mark.parametrize( + "schema_fixture, model_name, valid_data", + [ + ( + sample_json_schema, + "User", + { + "id": str(uuid4()), + "name": "Alice", + "email": "alice@example.com", + "age": 25, + "address": {"street": "123 Main St", "city": "Metropolis", "zipcode": "12345"}, + }, + ), + ( + sample_json_schema_recursive, + "Employee", + { + "id": str(uuid4()), + "name": "Alice", + "manager": { + "id": str(uuid4()), + "name": "Bob", + }, + }, + ), + ( + sample_json_schema_nested, + "Department", + { + "name": "Engineering", + "employees": [ + { + "id": str(uuid4()), + "name": "Alice", + "manager": { + "id": str(uuid4()), + "name": "Bob", + }, + } + ], + }, + ), + ( + sample_json_schema_complex, + "ComplexModel", + { + "user": { + "id": str(uuid4()), + "name": "Charlie", + "email": "charlie@example.com", + "age": 30, + "address": {"street": "456 Side St", "city": "Gotham", "zipcode": "67890"}, + }, + "extra_info": {"hobby": "Chess", "level": "Advanced"}, + "sub_items": [ + {"id": str(uuid4()), "name": "Eve"}, + {"id": str(uuid4()), "name": "David", "manager": {"id": str(uuid4()), "name": "Frank"}}, + ], + }, + ), + ], +) +def test_valid_data_model( + converter: _JSONSchemaToPydantic, + schema_fixture: Any, + model_name: str, + valid_data: Dict[str, Any], + request: Any, +) -> None: + """Test that valid data is accepted by the generated model.""" + schema = request.getfixturevalue(schema_fixture.__name__) + Model = converter.json_schema_to_pydantic(schema, model_name) + + instance = Model(**valid_data) + assert instance + dumped = instance.model_dump(mode="json", exclude_none=True) + assert dumped == valid_data, f"Model output mismatch.\nExpected: {valid_data}\nGot: {dumped}" + + +# ✅ **Invalid Data Tests** +@pytest.mark.parametrize( + "schema_fixture, model_name, invalid_data", + [ + ( + sample_json_schema, + "User", + { + "id": "not-a-uuid", # Invalid UUID + "name": "Alice", + "email": "not-an-email", # Invalid email + "age": 17, # Below minimum + "address": {"street": "123 Main St", "city": "Metropolis"}, + }, + ), + ( + sample_json_schema_recursive, + "Employee", + { + "id": str(uuid4()), + "name": "Alice", + "manager": { + "id": "not-a-uuid", # Invalid UUID + "name": "Bob", + }, + }, + ), + ( + sample_json_schema_nested, + "Department", + { + "name": "Engineering", + "employees": [ + { + "id": "not-a-uuid", # Invalid UUID + "name": "Alice", + "manager": { + "id": str(uuid4()), + "name": "Bob", + }, + } + ], + }, + ), + ( + sample_json_schema_complex, + "ComplexModel", + { + "user": { + "id": str(uuid4()), + "name": "Charlie", + "email": "charlie@example.com", + "age": "thirty", # Invalid: Should be an int + "address": {"street": "456 Side St", "city": "Gotham", "zipcode": "67890"}, + }, + "extra_info": "should-be-dictionary", # Invalid type + "sub_items": [ + {"id": "invalid-uuid", "name": "Eve"}, # Invalid UUID + {"id": str(uuid4()), "name": 123}, # Invalid name type + ], + }, + ), + ], +) +def test_invalid_data_model( + converter: _JSONSchemaToPydantic, + schema_fixture: Any, + model_name: str, + invalid_data: Dict[str, Any], + request: Any, +) -> None: + """Test that invalid data raises ValidationError.""" + schema = request.getfixturevalue(schema_fixture.__name__) + Model = converter.json_schema_to_pydantic(schema, model_name) + + with pytest.raises(ValidationError): + Model(**invalid_data) + + +class ListDictModel(BaseModel): + """Example for `List[Dict[str, Any]]`""" + + data: List[Dict[str, Any]] + + +class DictListModel(BaseModel): + """Example for `Dict[str, List[Any]]`""" + + mapping: Dict[str, List[Any]] + + +class NestedListModel(BaseModel): + """Example for `List[List[str]]`""" + + matrix: List[List[str]] + + +@pytest.fixture +def sample_json_schema_list_dict() -> Dict[str, Any]: + """Fixture for `List[Dict[str, Any]]`""" + return ListDictModel.model_json_schema() + + +@pytest.fixture +def sample_json_schema_dict_list() -> Dict[str, Any]: + """Fixture for `Dict[str, List[Any]]`""" + return DictListModel.model_json_schema() + + +@pytest.fixture +def sample_json_schema_nested_list() -> Dict[str, Any]: + """Fixture for `List[List[str]]`""" + return NestedListModel.model_json_schema() + + +@pytest.mark.parametrize( + "schema_fixture, model_name, expected_fields", + [ + (sample_json_schema_list_dict, "ListDictModel", ["data"]), + (sample_json_schema_dict_list, "DictListModel", ["mapping"]), + (sample_json_schema_nested_list, "NestedListModel", ["matrix"]), + ], +) +def test_json_schema_to_pydantic_nested( + converter: _JSONSchemaToPydantic, + schema_fixture: Any, + model_name: str, + expected_fields: list[str], + request: Any, +) -> None: + """Test conversion of JSON Schema to Pydantic model using the class instance.""" + schema = request.getfixturevalue(schema_fixture.__name__) + Model = converter.json_schema_to_pydantic(schema, model_name) + + for field in expected_fields: + assert field in Model.__annotations__, f"Expected '{field}' missing in {model_name}Model" + + +# ✅ **Valid Data Tests** +@pytest.mark.parametrize( + "schema_fixture, model_name, valid_data", + [ + ( + sample_json_schema_list_dict, + "ListDictModel", + { + "data": [ + {"key1": "value1", "key2": 10}, + {"another_key": False, "nested": {"subkey": "data"}}, + ] + }, + ), + ( + sample_json_schema_dict_list, + "DictListModel", + { + "mapping": { + "first": ["a", "b", "c"], + "second": [1, 2, 3, 4], + "third": [True, False, True], + } + }, + ), + ( + sample_json_schema_nested_list, + "NestedListModel", + {"matrix": [["A", "B"], ["C", "D"], ["E", "F"]]}, + ), + ], +) +def test_valid_data_model_nested( + converter: _JSONSchemaToPydantic, + schema_fixture: Any, + model_name: str, + valid_data: Dict[str, Any], + request: Any, +) -> None: + """Test that valid data is accepted by the generated model.""" + schema = request.getfixturevalue(schema_fixture.__name__) + Model = converter.json_schema_to_pydantic(schema, model_name) + + instance = Model(**valid_data) + assert instance + for field, value in valid_data.items(): + assert ( + getattr(instance, field) == value + ), f"Mismatch in field `{field}`: expected `{value}`, got `{getattr(instance, field)}`" + + +# ✅ **Invalid Data Tests** +@pytest.mark.parametrize( + "schema_fixture, model_name, invalid_data", + [ + ( + sample_json_schema_list_dict, + "ListDictModel", + { + "data": "should-be-a-list", # ❌ Should be a list of dicts + }, + ), + ( + sample_json_schema_dict_list, + "DictListModel", + { + "mapping": [ + "should-be-a-dictionary", # ❌ Should be a dict of lists + ] + }, + ), + ( + sample_json_schema_nested_list, + "NestedListModel", + {"matrix": [["A", "B"], "C", ["D", "E"]]}, # ❌ "C" is not a list + ), + ], +) +def test_invalid_data_model_nested( + converter: _JSONSchemaToPydantic, + schema_fixture: Any, + model_name: str, + invalid_data: Dict[str, Any], + request: Any, +) -> None: + """Test that invalid data raises ValidationError.""" + schema = request.getfixturevalue(schema_fixture.__name__) + Model = converter.json_schema_to_pydantic(schema, model_name) + + with pytest.raises(ValidationError): + Model(**invalid_data) + + +def test_reference_not_found(converter: _JSONSchemaToPydantic) -> None: + schema = {"type": "object", "properties": {"manager": {"$ref": "#/$defs/MissingRef"}}} + with pytest.raises(ReferenceNotFoundError): + converter.json_schema_to_pydantic(schema, "MissingRefModel") + + +def test_format_not_supported(converter: _JSONSchemaToPydantic) -> None: + schema = {"type": "object", "properties": {"custom_field": {"type": "string", "format": "unsupported-format"}}} + with pytest.raises(FormatNotSupportedError): + converter.json_schema_to_pydantic(schema, "UnsupportedFormatModel") + + +def test_unsupported_keyword(converter: _JSONSchemaToPydantic) -> None: + schema = {"type": "object", "properties": {"broken_field": {"title": "Missing type"}}} + with pytest.raises(UnsupportedKeywordError): + converter.json_schema_to_pydantic(schema, "MissingTypeModel") + + +def test_enum_field_schema() -> None: + schema = { + "type": "object", + "properties": { + "status": {"type": "string", "enum": ["pending", "approved", "rejected"]}, + "priority": {"type": "integer", "enum": [1, 2, 3]}, + }, + "required": ["status"], + } + + converter: _JSONSchemaToPydantic = _JSONSchemaToPydantic() + Model = converter.json_schema_to_pydantic(schema, "Task") + + status_ann = Model.model_fields["status"].annotation + assert get_origin(status_ann) is Literal + assert set(get_args(status_ann)) == {"pending", "approved", "rejected"} + + priority_ann = Model.model_fields["priority"].annotation + args = get_args(priority_ann) + assert type(None) in args + assert Literal[1, 2, 3] in args + + instance = Model(status="approved", priority=2) + assert instance.status == "approved" # type: ignore[attr-defined] + assert instance.priority == 2 # type: ignore[attr-defined] + + +def test_metadata_title_description(converter: _JSONSchemaToPydantic) -> None: + schema = { + "title": "CustomerProfile", + "description": "A profile containing personal and contact info", + "type": "object", + "properties": { + "first_name": {"type": "string", "title": "First Name", "description": "Given name of the user"}, + "age": {"type": "integer", "title": "Age", "description": "Age in years"}, + "contact": { + "type": "object", + "title": "Contact Information", + "description": "How to reach the user", + "properties": { + "email": { + "type": "string", + "format": "email", + "title": "Email Address", + "description": "Primary email", + } + }, + }, + }, + "required": ["first_name"], + } + + Model: Type[BaseModel] = converter.json_schema_to_pydantic(schema, "CustomerProfile") + generated_schema = Model.model_json_schema() + + assert generated_schema["title"] == "CustomerProfile" + + props = generated_schema["properties"] + assert props["first_name"]["title"] == "First Name" + assert props["first_name"]["description"] == "Given name of the user" + assert props["age"]["title"] == "Age" + assert props["age"]["description"] == "Age in years" + + contact = props["contact"] + assert contact["title"] == "Contact Information" + assert contact["description"] == "How to reach the user" + + # Follow the $ref + ref_key = contact["anyOf"][0]["$ref"].split("/")[-1] + contact_def = generated_schema["$defs"][ref_key] + email = contact_def["properties"]["email"] + assert email["title"] == "Email Address" + assert email["description"] == "Primary email" + + +def test_oneof_with_discriminator(converter: _JSONSchemaToPydantic) -> None: + schema = { + "title": "PetWrapper", + "type": "object", + "properties": { + "pet": { + "oneOf": [{"$ref": "#/$defs/Cat"}, {"$ref": "#/$defs/Dog"}], + "discriminator": {"propertyName": "pet_type"}, + } + }, + "required": ["pet"], + "$defs": { + "Cat": { + "type": "object", + "properties": {"pet_type": {"type": "string", "enum": ["cat"]}, "hunting_skill": {"type": "string"}}, + "required": ["pet_type", "hunting_skill"], + "title": "Cat", + }, + "Dog": { + "type": "object", + "properties": {"pet_type": {"type": "string", "enum": ["dog"]}, "pack_size": {"type": "integer"}}, + "required": ["pet_type", "pack_size"], + "title": "Dog", + }, + }, + } + + Model = converter.json_schema_to_pydantic(schema, "PetWrapper") + + # Instantiate with a Cat + cat = Model(pet={"pet_type": "cat", "hunting_skill": "expert"}) + assert cat.pet.pet_type == "cat" # type: ignore[attr-defined] + + # Instantiate with a Dog + dog = Model(pet={"pet_type": "dog", "pack_size": 4}) + assert dog.pet.pet_type == "dog" # type: ignore[attr-defined] + + # Check round-trip schema includes discriminator + model_schema = Model.model_json_schema() + assert "discriminator" in model_schema["properties"]["pet"] + assert model_schema["properties"]["pet"]["discriminator"]["propertyName"] == "pet_type" + + +def test_allof_merging_with_refs(converter: _JSONSchemaToPydantic) -> None: + schema = { + "title": "EmployeeWithDepartment", + "allOf": [{"$ref": "#/$defs/Employee"}, {"$ref": "#/$defs/Department"}], + "$defs": { + "Employee": { + "type": "object", + "properties": {"id": {"type": "string"}, "name": {"type": "string"}}, + "required": ["id", "name"], + "title": "Employee", + }, + "Department": { + "type": "object", + "properties": {"department": {"type": "string"}}, + "required": ["department"], + "title": "Department", + }, + }, + } + + Model = converter.json_schema_to_pydantic(schema, "EmployeeWithDepartment") + instance = Model(id="123", name="Alice", department="Engineering") + assert instance.id == "123" # type: ignore[attr-defined] + assert instance.name == "Alice" # type: ignore[attr-defined] + assert instance.department == "Engineering" # type: ignore[attr-defined] + + dumped = instance.model_dump() + assert dumped == {"id": "123", "name": "Alice", "department": "Engineering"} + + +def test_nested_allof_merging(converter: _JSONSchemaToPydantic) -> None: + schema = { + "title": "ContainerModel", + "type": "object", + "properties": { + "nested": { + "type": "object", + "properties": { + "data": { + "allOf": [ + {"$ref": "#/$defs/Base"}, + {"type": "object", "properties": {"extra": {"type": "string"}}, "required": ["extra"]}, + ] + } + }, + "required": ["data"], + } + }, + "required": ["nested"], + "$defs": { + "Base": { + "type": "object", + "properties": {"base_field": {"type": "string"}}, + "required": ["base_field"], + "title": "Base", + } + }, + } + + Model = converter.json_schema_to_pydantic(schema, "ContainerModel") + instance = Model(nested={"data": {"base_field": "abc", "extra": "xyz"}}) + + assert instance.nested.data.base_field == "abc" # type: ignore[attr-defined] + assert instance.nested.data.extra == "xyz" # type: ignore[attr-defined] + + +@pytest.mark.parametrize( + "schema, field_name, valid_values, invalid_values", + [ + # String constraints + ( + { + "type": "object", + "properties": { + "username": {"type": "string", "minLength": 3, "maxLength": 10, "pattern": "^[a-zA-Z0-9_]+$"} + }, + "required": ["username"], + }, + "username", + ["user_123", "abc", "Name2023"], + ["", "ab", "toolongusername123", "invalid!char"], + ), + # Integer constraints + ( + { + "type": "object", + "properties": {"age": {"type": "integer", "minimum": 18, "maximum": 99}}, + "required": ["age"], + }, + "age", + [18, 25, 99], + [17, 100, -1], + ), + # Float constraints + ( + { + "type": "object", + "properties": {"score": {"type": "number", "minimum": 0.0, "exclusiveMaximum": 1.0}}, + "required": ["score"], + }, + "score", + [0.0, 0.5, 0.999], + [-0.1, 1.0, 2.5], + ), + # Array constraints + ( + { + "type": "object", + "properties": {"tags": {"type": "array", "items": {"type": "string"}, "minItems": 1, "maxItems": 3}}, + "required": ["tags"], + }, + "tags", + [["a"], ["a", "b"], ["x", "y", "z"]], + [[], ["one", "two", "three", "four"]], + ), + ], +) +def test_field_constraints( + schema: Dict[str, Any], + field_name: str, + valid_values: List[Any], + invalid_values: List[Any], +) -> None: + converter = _JSONSchemaToPydantic() + Model = converter.json_schema_to_pydantic(schema, "ConstraintModel") + + for value in valid_values: + instance = Model(**{field_name: value}) + assert getattr(instance, field_name) == value + + for value in invalid_values: + with pytest.raises(ValidationError): + Model(**{field_name: value}) + + +@pytest.mark.parametrize( + "schema", + [ + # Top-level field + {"type": "object", "properties": {"weird": {"type": "abc"}}, "required": ["weird"]}, + # Inside array items + {"type": "object", "properties": {"items": {"type": "array", "items": {"type": "abc"}}}, "required": ["items"]}, + # Inside anyOf + { + "type": "object", + "properties": {"choice": {"anyOf": [{"type": "string"}, {"type": "abc"}]}}, + "required": ["choice"], + }, + ], +) +def test_unknown_type_raises(schema: Dict[str, Any]) -> None: + converter = _JSONSchemaToPydantic() + with pytest.raises(UnsupportedKeywordError): + converter.json_schema_to_pydantic(schema, "UnknownTypeModel") + + +@pytest.mark.parametrize("json_type, expected_type", list(TYPE_MAPPING.items())) +def test_basic_type_mapping(json_type: str, expected_type: type) -> None: + schema = { + "type": "object", + "properties": {"field": {"type": json_type}}, + "required": ["field"], + } + converter = _JSONSchemaToPydantic() + Model = converter.json_schema_to_pydantic(schema, f"{json_type.capitalize()}Model") + + assert "field" in Model.__annotations__ + field_type = Model.__annotations__["field"] + + # For array/object/null we check the outer type only + if json_type == "null": + assert field_type is type(None) + elif json_type == "array": + assert getattr(field_type, "__origin__", None) is list + elif json_type == "object": + assert field_type in (dict, Dict) or getattr(field_type, "__origin__", None) in (dict, Dict) + + else: + assert field_type == expected_type + + +@pytest.mark.parametrize("format_name, expected_type", list(FORMAT_MAPPING.items())) +def test_format_mapping(format_name: str, expected_type: Any) -> None: + schema = { + "type": "object", + "properties": {"field": {"type": "string", "format": format_name}}, + "required": ["field"], + } + converter = _JSONSchemaToPydantic() + Model = converter.json_schema_to_pydantic(schema, f"{format_name.capitalize()}Model") + + assert "field" in Model.__annotations__ + field_type = Model.__annotations__["field"] + if isinstance(expected_type, types.FunctionType): # if it's a constrained constructor (e.g., conint) + assert callable(field_type) + else: + assert field_type == expected_type + + +def test_unknown_format_raises() -> None: + schema = { + "type": "object", + "properties": {"bad_field": {"type": "string", "format": "definitely-not-a-format"}}, + } + converter = _JSONSchemaToPydantic() + with pytest.raises(FormatNotSupportedError): + converter.json_schema_to_pydantic(schema, "UnknownFormatModel")