dify/api/core/schemas/resolver.py

398 lines
12 KiB
Python
Raw Normal View History

import logging
import re
import threading
from collections import deque
from dataclasses import dataclass
from typing import Any, Union
from core.schemas.registry import SchemaRegistry
logger = logging.getLogger(__name__)
# Type aliases for better clarity
SchemaType = Union[dict[str, Any], list[Any], str, int, float, bool, None]
SchemaDict = dict[str, Any]
# Pre-compiled pattern for better performance
_DIFY_SCHEMA_PATTERN = re.compile(r"^https://dify\.ai/schemas/(v\d+)/(.+)\.json$")
class SchemaResolutionError(Exception):
"""Base exception for schema resolution errors"""
pass
class CircularReferenceError(SchemaResolutionError):
"""Raised when a circular reference is detected"""
def __init__(self, ref_uri: str, ref_path: list[str]):
self.ref_uri = ref_uri
self.ref_path = ref_path
super().__init__(f"Circular reference detected: {ref_uri} in path {' -> '.join(ref_path)}")
class MaxDepthExceededError(SchemaResolutionError):
"""Raised when maximum resolution depth is exceeded"""
def __init__(self, max_depth: int):
self.max_depth = max_depth
super().__init__(f"Maximum resolution depth ({max_depth}) exceeded")
class SchemaNotFoundError(SchemaResolutionError):
"""Raised when a referenced schema cannot be found"""
def __init__(self, ref_uri: str):
self.ref_uri = ref_uri
super().__init__(f"Schema not found: {ref_uri}")
@dataclass
class QueueItem:
"""Represents an item in the BFS queue"""
current: Any
parent: Any | None
key: Union[str, int] | None
depth: int
ref_path: set[str]
class SchemaResolver:
"""Resolver for Dify schema references with caching and optimizations"""
_cache: dict[str, SchemaDict] = {}
_cache_lock = threading.Lock()
def __init__(self, registry: SchemaRegistry | None = None, max_depth: int = 10):
"""
Initialize the schema resolver
Args:
registry: Schema registry to use (defaults to default registry)
max_depth: Maximum depth for reference resolution
"""
self.registry = registry or SchemaRegistry.default_registry()
self.max_depth = max_depth
@classmethod
def clear_cache(cls) -> None:
"""Clear the global schema cache"""
with cls._cache_lock:
cls._cache.clear()
def resolve(self, schema: SchemaType) -> SchemaType:
"""
Resolve all $ref references in the schema
Performance optimization: quickly checks for $ref presence before processing.
Args:
schema: Schema to resolve
Returns:
Resolved schema with all references expanded
Raises:
CircularReferenceError: If circular reference detected
MaxDepthExceededError: If max depth exceeded
SchemaNotFoundError: If referenced schema not found
"""
if not isinstance(schema, (dict, list)):
return schema
# Fast path: if no Dify refs found, return original schema unchanged
# This avoids expensive deepcopy and BFS traversal for schemas without refs
if not _has_dify_refs(schema):
return schema
# Slow path: schema contains refs, perform full resolution
import copy
result = copy.deepcopy(schema)
# Initialize BFS queue
queue = deque([QueueItem(current=result, parent=None, key=None, depth=0, ref_path=set())])
while queue:
item = queue.popleft()
# Process the current item
self._process_queue_item(queue, item)
return result
def _process_queue_item(self, queue: deque, item: QueueItem) -> None:
"""Process a single queue item"""
if isinstance(item.current, dict):
self._process_dict(queue, item)
elif isinstance(item.current, list):
self._process_list(queue, item)
def _process_dict(self, queue: deque, item: QueueItem) -> None:
"""Process a dictionary item"""
ref_uri = item.current.get("$ref")
if ref_uri and _is_dify_schema_ref(ref_uri):
# Handle $ref resolution
self._resolve_ref(queue, item, ref_uri)
else:
# Process nested items
for key, value in item.current.items():
if isinstance(value, (dict, list)):
next_depth = item.depth + 1
if next_depth >= self.max_depth:
raise MaxDepthExceededError(self.max_depth)
queue.append(
QueueItem(current=value, parent=item.current, key=key, depth=next_depth, ref_path=item.ref_path)
)
def _process_list(self, queue: deque, item: QueueItem) -> None:
"""Process a list item"""
for idx, value in enumerate(item.current):
if isinstance(value, (dict, list)):
next_depth = item.depth + 1
if next_depth >= self.max_depth:
raise MaxDepthExceededError(self.max_depth)
queue.append(
QueueItem(current=value, parent=item.current, key=idx, depth=next_depth, ref_path=item.ref_path)
)
def _resolve_ref(self, queue: deque, item: QueueItem, ref_uri: str) -> None:
"""Resolve a $ref reference"""
# Check for circular reference
if ref_uri in item.ref_path:
# Mark as circular and skip
item.current["$circular_ref"] = True
logger.warning("Circular reference detected: %s", ref_uri)
return
# Get resolved schema (from cache or registry)
resolved_schema = self._get_resolved_schema(ref_uri)
if not resolved_schema:
logger.warning("Schema not found: %s", ref_uri)
return
# Update ref path
new_ref_path = item.ref_path | {ref_uri}
# Replace the reference with resolved schema
next_depth = item.depth + 1
if next_depth >= self.max_depth:
raise MaxDepthExceededError(self.max_depth)
if item.parent is None:
# Root level replacement
item.current.clear()
item.current.update(resolved_schema)
queue.append(
QueueItem(current=item.current, parent=None, key=None, depth=next_depth, ref_path=new_ref_path)
)
else:
# Update parent container
item.parent[item.key] = resolved_schema.copy()
queue.append(
QueueItem(
current=item.parent[item.key],
parent=item.parent,
key=item.key,
depth=next_depth,
ref_path=new_ref_path,
)
)
def _get_resolved_schema(self, ref_uri: str) -> SchemaDict | None:
"""Get resolved schema from cache or registry"""
# Check cache first
with self._cache_lock:
if ref_uri in self._cache:
return self._cache[ref_uri].copy()
# Fetch from registry
schema = self.registry.get_schema(ref_uri)
if not schema:
return None
# Clean and cache
cleaned = _remove_metadata_fields(schema)
with self._cache_lock:
self._cache[ref_uri] = cleaned
return cleaned.copy()
def resolve_dify_schema_refs(
schema: SchemaType, registry: SchemaRegistry | None = None, max_depth: int = 30
) -> SchemaType:
"""
Resolve $ref references in Dify schema to actual schema content
This is a convenience function that creates a resolver and resolves the schema.
Performance optimization: quickly checks for $ref presence before processing.
Args:
schema: Schema object that may contain $ref references
registry: Optional schema registry, defaults to default registry
max_depth: Maximum depth to prevent infinite loops (default: 30)
Returns:
Schema with all $ref references resolved to actual content
Raises:
CircularReferenceError: If circular reference detected
MaxDepthExceededError: If maximum depth exceeded
SchemaNotFoundError: If referenced schema not found
"""
# Fast path: if no Dify refs found, return original schema unchanged
# This avoids expensive deepcopy and BFS traversal for schemas without refs
if not _has_dify_refs(schema):
return schema
# Slow path: schema contains refs, perform full resolution
resolver = SchemaResolver(registry, max_depth)
return resolver.resolve(schema)
def _remove_metadata_fields(schema: dict) -> dict:
"""
Remove metadata fields from schema that shouldn't be included in resolved output
Args:
schema: Schema dictionary
Returns:
Cleaned schema without metadata fields
"""
# Create a copy and remove metadata fields
cleaned = schema.copy()
metadata_fields = ["$id", "$schema", "version"]
for field in metadata_fields:
cleaned.pop(field, None)
return cleaned
def _is_dify_schema_ref(ref_uri: Any) -> bool:
"""
Check if the reference URI is a Dify schema reference
Args:
ref_uri: URI to check
Returns:
True if it's a Dify schema reference
"""
if not isinstance(ref_uri, str):
return False
# Use pre-compiled pattern for better performance
return bool(_DIFY_SCHEMA_PATTERN.match(ref_uri))
def _has_dify_refs_recursive(schema: SchemaType) -> bool:
"""
Recursively check if a schema contains any Dify $ref references
This is the fallback method when string-based detection is not possible.
Args:
schema: Schema to check for references
Returns:
True if any Dify $ref is found, False otherwise
"""
if isinstance(schema, dict):
# Check if this dict has a $ref field
ref_uri = schema.get("$ref")
if ref_uri and _is_dify_schema_ref(ref_uri):
return True
# Check nested values
for value in schema.values():
if _has_dify_refs_recursive(value):
return True
elif isinstance(schema, list):
# Check each item in the list
for item in schema:
if _has_dify_refs_recursive(item):
return True
# Primitive types don't contain refs
return False
def _has_dify_refs_hybrid(schema: SchemaType) -> bool:
"""
Hybrid detection: fast string scan followed by precise recursive check
Performance optimization using two-phase detection:
1. Fast string scan to quickly eliminate schemas without $ref
2. Precise recursive validation only for potential candidates
Args:
schema: Schema to check for references
Returns:
True if any Dify $ref is found, False otherwise
"""
# Phase 1: Fast string-based pre-filtering
try:
import json
schema_str = json.dumps(schema, separators=(",", ":"))
# Quick elimination: no $ref at all
if '"$ref"' not in schema_str:
return False
# Quick elimination: no Dify schema URLs
if "https://dify.ai/schemas/" not in schema_str:
return False
except (TypeError, ValueError, OverflowError):
# JSON serialization failed (e.g., circular references, non-serializable objects)
# Fall back to recursive detection
logger.debug("JSON serialization failed for schema, using recursive detection")
return _has_dify_refs_recursive(schema)
# Phase 2: Precise recursive validation
# Only executed for schemas that passed string pre-filtering
return _has_dify_refs_recursive(schema)
def _has_dify_refs(schema: SchemaType) -> bool:
"""
Check if a schema contains any Dify $ref references
Uses hybrid detection for optimal performance:
- Fast string scan for quick elimination
- Precise recursive check for validation
Args:
schema: Schema to check for references
Returns:
True if any Dify $ref is found, False otherwise
"""
return _has_dify_refs_hybrid(schema)
def parse_dify_schema_uri(uri: str) -> tuple[str, str]:
"""
Parse a Dify schema URI to extract version and schema name
Args:
uri: Schema URI to parse
Returns:
Tuple of (version, schema_name) or ("", "") if invalid
"""
match = _DIFY_SCHEMA_PATTERN.match(uri)
if not match:
return "", ""
return match.group(1), match.group(2)