diff --git a/ingestion/setup.py b/ingestion/setup.py index f16c0c22dd5..a23197e19b1 100644 --- a/ingestion/setup.py +++ b/ingestion/setup.py @@ -1,4 +1,4 @@ -# Copyright 2025 Collate +# https://github.com/open-metadata/OpenMetadata/actions/runs/15640676139/job/44066998708?pr=21719 Copyright 2025 Collate # Licensed under the Collate Community License, Version 1.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -377,7 +377,7 @@ dev = { # Dependencies for unit testing in addition to dev dependencies and plugins test_unit = { - "pytest==7.0.0", + "pytest==7.0.1", "pytest-cov", "pytest-order", "dirty-equals", @@ -396,7 +396,7 @@ test = { # Install GE because it's not in the `all` plugin VERSIONS["great-expectations"], "basedpyright~=1.14", - "pytest==7.0.0", + "pytest==7.0.1", "pytest-cov", "pytest-order", "dirty-equals", diff --git a/ingestion/src/metadata/__init__.py b/ingestion/src/metadata/__init__.py new file mode 100644 index 00000000000..4a794f00c76 --- /dev/null +++ b/ingestion/src/metadata/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2025 Collate +# Licensed under the Collate Community License, Version 1.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +OpenMetadata package initialization. +""" + +from metadata.utils.dependency_injector.dependency_injector import DependencyContainer +from metadata.utils.service_spec.service_spec import DefaultSourceLoader, SourceLoader + +# Initialize the dependency container +container = DependencyContainer() + +# Register the source loader +container.register(SourceLoader, DefaultSourceLoader) diff --git a/ingestion/src/metadata/utils/dependency_injector/README.md b/ingestion/src/metadata/utils/dependency_injector/README.md new file mode 100644 index 00000000000..f709e484a54 --- /dev/null +++ b/ingestion/src/metadata/utils/dependency_injector/README.md @@ -0,0 +1,154 @@ +# OpenMetadata Dependency Injection System + +This module provides a type-safe dependency injection system for OpenMetadata that uses Python's type hints to automatically inject dependencies into functions and methods. + +## Features + +- Type-safe dependency injection using Python's type hints +- Thread-safe singleton container for managing dependencies +- Support for dependency overrides (useful for testing) +- Automatic dependency resolution + +## Basic Usage + +### 1. Define Your Dependencies + +First, define your dependencies as classes or functions: + +```python +class Database: + def __init__(self, connection_string: str): + self.connection_string = connection_string + + def query(self, query: str) -> dict: + # Implementation + pass +``` + +### 2. Register Dependencies + +Register your dependencies with the container: + +```python +from metadata.utils.dependency_injector import DependencyContainer, Inject, inject + +# Create a container instance +container = DependencyContainer[Callable]() + +# Register a dependency (usually as a factory function) +container.register(Database, lambda: Database("postgresql://localhost:5432")) +``` + +### 3. Use Dependency Injection + +Use the `@inject` decorator and `Inject` type to automatically inject dependencies: + +```python +@inject +def get_user(user_id: int, db: Inject[Database]) -> dict: + return db.query(f"SELECT * FROM users WHERE id = {user_id}") + +# The db parameter will be automatically injected +user = get_user(user_id=1) +``` + +## Advanced Usage + +### Dependency Overrides + +You can temporarily override dependencies, which is useful for testing: + +```python +# Override the Database dependency +container.override(Database, lambda: Database("postgresql://test:5432")) + +# Use the overridden dependency +user = get_user(user_id=1) + +# Remove the override when done +container.remove_override(Database) +``` + +### Explicit Dependency Injection + +You can also explicitly provide dependencies when calling functions: + +```python +# Explicitly provide the database +custom_db = Database("postgresql://custom:5432") +user = get_user(user_id=1, db=custom_db) +``` + +### Checking Dependencies + +Check if a dependency is registered: + +```python +if container.has(Database): + print("Database dependency is registered") +``` + +### Clearing Dependencies + +Clear all registered dependencies and overrides: + +```python +container.clear() +``` + +## Best Practices + +1. **Use Factory Functions**: Register dependencies as factory functions to ensure fresh instances: + ```python + container.register(Database, lambda: Database("postgresql://localhost:5432")) + ``` + +2. **Type Safety**: Always use proper type hints with `Inject`: + ```python + @inject + def my_function(db: Inject[Database]): + pass + ``` + +3. **Testing**: Use dependency overrides in tests: + ```python + def test_get_user(): + container.override(Database, lambda: MockDatabase()) + try: + user = get_user(user_id=1) + assert user is not None + finally: + container.remove_override(Database) + ``` + +4. **Error Handling**: Handle missing dependencies gracefully: + ```python + try: + user = get_user(user_id=1) + except DependencyNotFoundError: + # Handle missing dependency + pass + ``` + +## Error Types + +The system provides specific exceptions for different error cases: + +- `DependencyInjectionError`: Base exception for all dependency injection errors +- `DependencyNotFoundError`: Raised when a required dependency is not found +- `InvalidInjectionTypeError`: Raised when an invalid injection type is used + +## Thread Safety + +The dependency container is thread-safe and uses a reentrant lock (RLock) to support: +- Multiple threads accessing the container simultaneously +- The same thread acquiring the lock multiple times +- Safe dependency registration and retrieval + +## Limitations + +1. Dependencies must be registered before they can be injected +2. The system uses type names as keys, so different types with the same name will conflict +3. Circular dependencies are not supported +4. Dependencies are always treated as optional and can be overridden +5. Dependencies can't be passed as *arg. Must be passed as *kwargs \ No newline at end of file diff --git a/ingestion/src/metadata/utils/dependency_injector/dependency_injector.py b/ingestion/src/metadata/utils/dependency_injector/dependency_injector.py new file mode 100644 index 00000000000..7065dd9eb21 --- /dev/null +++ b/ingestion/src/metadata/utils/dependency_injector/dependency_injector.py @@ -0,0 +1,343 @@ +# Copyright 2025 Collate +# Licensed under the Collate Community License, Version 1.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Dependency injection utilities for OpenMetadata. + +This module provides a type-safe dependency injection system that uses Python's type hints +to automatically inject dependencies into functions and methods. + +Example: + ```python + from typing import Annotated + from metadata.utils.dependency_injector import Inject, inject, DependencyContainer + + class Database: + def __init__(self, connection_string: str): + self.connection_string = connection_string + + # Register a dependency + container = DependencyContainer[Callable]() + container.register(Database, lambda: Database("postgresql://localhost:5432")) + + # Use dependency injection + @inject + def get_user(user_id: int, db: Inject[Database]) -> dict: + return db.query(f"SELECT * FROM users WHERE id = {user_id}") + ``` +""" +from functools import wraps +from threading import RLock +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + Callable, + Dict, + Generic, + Optional, + Type, + TypeVar, + Union, + get_args, + get_origin, + get_type_hints, +) + +from metadata.utils.logger import utils_logger + +logger = utils_logger() + +T = TypeVar("T") + + +class DependencyInjectionError(Exception): + """Base exception for dependency injection errors.""" + + pass + + +class DependencyNotFoundError(DependencyInjectionError): + """Raised when a required dependency is not found in the container.""" + + pass + + +class InvalidInjectionTypeError(DependencyInjectionError): + """Raised when an invalid injection type is used.""" + + pass + + +if TYPE_CHECKING: + Inject = Annotated[Union[T, None], "Inject Marker"] +else: + + class Inject(Generic[T]): + """ + Type for dependency injection that uses types as keys. + + This type is used to mark parameters that should be automatically injected + by the dependency container. Injection is always treated as Optioonal. It can be overriden. + + Type Parameters: + T: The type of the dependency to inject + + Example: + ```python + @inject + def my_function(db: Inject[Database]): + # db will be automatically injected + pass + ``` + """ + + +class DependencyContainer: + """ + Thread-safe singleton container for managing dependencies. + + This container uses RLock to support reentrant locking, allowing the same thread + to acquire the lock multiple times. It maintains two dictionaries: + - _dependencies: The original registered dependencies + - _overrides: Temporary overrides that take precedence over original dependencies + + Type Parameters: + T: The base type for all dependencies in this container + + Example: + ```python + container = DependencyContainer[Callable]() + container.register(Database, lambda: Database("postgresql://localhost:5432")) + container.override(Database, lambda: Database("postgresql://test:5432")) + ``` + """ + + _instance: Optional["DependencyContainer"] = None + _lock = RLock() + _dependencies: Dict[str, Callable[[], Any]] = {} + _overrides: Dict[str, Callable[[], Any]] = {} + + def __new__(cls) -> "DependencyContainer": + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def register( + self, dependency_type: Type[Any], dependency: Callable[[], Any] + ) -> None: + """ + Register a dependency with the container. + + Args: + dependency_type: The type of the dependency to register + dependency: The dependency to register (usually a factory function) + + Example: + ```python + container.register(Database, lambda: Database("postgresql://localhost:5432")) + ``` + """ + with self._lock: + self._dependencies[dependency_type.__name__] = dependency + + def override( + self, dependency_type: Type[Any], dependency: Callable[[], Any] + ) -> None: + """ + Override a dependency with a new implementation. + + The override takes precedence over the original dependency and is useful + for testing or temporary changes. + + Args: + dependency_type: The type of the dependency to override + dependency: The new dependency implementation + + Example: + ```python + container.override(Database, lambda: Database("postgresql://test:5432")) + ``` + """ + with self._lock: + self._overrides[dependency_type.__name__] = dependency + + def remove_override(self, dependency_type: Type[T]) -> None: + """ + Remove an override for a dependency. + + Args: + dependency_type: The type of the dependency override to remove + + Example: + ```python + container.remove_override(Database) + ``` + """ + with self._lock: + self._overrides.pop(dependency_type.__name__, None) + + def get(self, dependency_type: Type[Any]) -> Optional[Any]: + """ + Get a dependency from the container. + + Checks overrides first, then falls back to original dependencies. + + Args: + dependency_type: The type of the dependency to retrieve + + Returns: + The dependency if found, None otherwise + + Example: + ```python + db_factory = container.get(Database) + if db_factory: + db = db_factory() + ``` + """ + with self._lock: + type_name = dependency_type.__name__ + factory = self._overrides.get(type_name) or self._dependencies.get( + type_name + ) + if factory is None: + return None + return factory() + + def clear(self) -> None: + """ + Clear all dependencies and overrides. + + Example: + ```python + container.clear() # Remove all registered dependencies and overrides + ``` + """ + with self._lock: + self._dependencies.clear() + self._overrides.clear() + + def has(self, dependency_type: Type[T]) -> bool: + """ + Check if a dependency exists in the container. + + Args: + dependency_type: The type to check + + Returns: + True if the dependency exists, False otherwise + + Example: + ```python + if container.has(Database): + print("Database dependency is registered") + ```""" + with self._lock: + type_name = dependency_type.__name__ + return type_name in self._overrides or type_name in self._dependencies + + +def inject(func: Callable[..., Any]) -> Callable[..., Any]: + """ + Decorator to inject dependencies based on type hints. + + This decorator automatically injects dependencies into function parameters + based on their type hints. It uses types as keys for dependency lookup and + allows explicit injection by passing dependencies as keyword arguments. + + Args: + func: The function to inject dependencies into + + Returns: + A function with dependencies injected + + Example: + ```python + @inject + def get_user(user_id: int, db: Inject[Database]) -> dict: + return db.query(f"SELECT * FROM users WHERE id = {user_id}") + + # Dependencies can also be passed explicitly + get_user(user_id=1, db=Database("postgresql://localhost:5432")) + ``` + """ + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + container = DependencyContainer() + type_hints = get_type_hints(func, include_extras=True) + + for param_name, param_type in type_hints.items(): + # Skip if parameter is already provided explicitly + if param_name in kwargs: + continue + + # Check if it's an Inject type + if is_inject_type(param_type): + dependency_type = extract_inject_arg(param_type) + dependency = container.get(dependency_type) + if dependency is None: + raise DependencyNotFoundError( + f"Dependency of type {dependency_type} not found in container. " + f"Make sure to register it using container.register({dependency_type.__name__}, ...)" + ) + kwargs[param_name] = dependency + + return func(*args, **kwargs) + + return wrapper + + +def is_inject_type(tp: Any) -> bool: + """ + Check if a type is an Inject type or Optional[Inject]. + + Args: + tp: The type to check + + Returns: + True if the type is Inject or Optional[Inject], False otherwise + """ + origin = get_origin(tp) + if origin is Inject: + return True + if origin is Union: + args = get_args(tp) + return any(get_origin(arg) is Inject for arg in args) + return False + + +def extract_inject_arg(tp: Any) -> Any: + """ + Extract the type argument from an Inject type. + + Args: + tp: The type to extract from + + Returns: + The type argument from the Inject type + + Raises: + InvalidInjectionTypeError: If the type is not Inject or Optional[Inject] + """ + origin = get_origin(tp) + if origin is Inject: + return get_args(tp)[0] + if origin is Union: + for arg in get_args(tp): + if get_origin(arg) is Inject: + return get_args(arg)[0] + raise InvalidInjectionTypeError( + f"Type {tp} is not Inject or Optional[Inject]. " + f"Use Annotated[YourType, 'Inject'] to mark a parameter for injection." + ) diff --git a/ingestion/src/metadata/utils/service_spec/service_spec.py b/ingestion/src/metadata/utils/service_spec/service_spec.py index 4a791d25d47..4faa8f863de 100644 --- a/ingestion/src/metadata/utils/service_spec/service_spec.py +++ b/ingestion/src/metadata/utils/service_spec/service_spec.py @@ -2,7 +2,8 @@ Manifests are used to store class information """ -from typing import Optional, Type, cast +from abc import ABC, abstractmethod +from typing import Any, Optional, Type, cast from pydantic import model_validator @@ -13,6 +14,7 @@ from metadata.ingestion.connections.connection import BaseConnection from metadata.ingestion.models.custom_pydantic import BaseModel from metadata.profiler.interface.profiler_interface import ProfilerInterface from metadata.sampler.sampler_interface import SamplerInterface +from metadata.utils.dependency_injector.dependency_injector import Inject, inject from metadata.utils.importer import ( TYPE_SEPARATOR, DynamicImportException, @@ -25,6 +27,14 @@ from metadata.utils.logger import utils_logger logger = utils_logger() +class SourceLoader(ABC): + @abstractmethod + def __call__( + self, service_type: ServiceType, source_type: str, from_: str + ) -> Type[Any]: + """Load the service spec for a given service type and source type.""" + + class BaseSpec(BaseModel): """ # The OpenMetadata Ingestion Service Specification (Spec) @@ -67,8 +77,13 @@ class BaseSpec(BaseModel): return values @classmethod + @inject def get_for_source( - cls, service_type: ServiceType, source_type: str, from_: str = "ingestion" + cls, + service_type: ServiceType, + source_type: str, + from_: str = "ingestion", + source_loader: Inject[SourceLoader] = None, ) -> "BaseSpec": """Retrieves the manifest for a given source type. If it does not exist will attempt to retrieve a default manifest for the service type. @@ -81,14 +96,26 @@ class BaseSpec(BaseModel): Returns: BaseSpec: The manifest for the source type. """ - return cls.model_validate( - import_from_module( - "metadata.{}.source.{}.{}.{}.ServiceSpec".format( # pylint: disable=C0209 - from_, - service_type.name.lower(), - get_module_dir(source_type), - "service_spec", - ) + if not source_loader: + raise ValueError("Source loader is required") + + return cls.model_validate(source_loader(service_type, source_type, from_)) + + +class DefaultSourceLoader(SourceLoader): + def __call__( + self, + service_type: ServiceType, + source_type: str, + from_: str = "ingestion", + ) -> Type[Any]: + """Default implementation for loading service specifications.""" + return import_from_module( + "metadata.{}.source.{}.{}.{}.ServiceSpec".format( # pylint: disable=C0209 + from_, + service_type.name.lower(), + get_module_dir(source_type), + "service_spec", ) ) diff --git a/ingestion/tests/unit/conftest.py b/ingestion/tests/unit/conftest.py index 07bd014cb33..ed323088b98 100644 --- a/ingestion/tests/unit/conftest.py +++ b/ingestion/tests/unit/conftest.py @@ -6,3 +6,25 @@ def pytest_pycollect_makeitem(collector, name, obj): return [] except AttributeError: pass + + +def pytest_collection_modifyitems(session, config, items): + """Reorder test items to ensure certain files run last.""" + # List of test files that should run last + last_files = [ + "test_dependency_injector.py", + # Add other files that should run last here + ] + + # Get all test items that should run last + last_items = [] + other_items = [] + + for item in items: + if any(file in item.nodeid for file in last_files): + last_items.append(item) + else: + other_items.append(item) + + # Reorder the items + items[:] = other_items + last_items diff --git a/ingestion/tests/unit/metadata/utils/dependency_injector/test_dependency_injector.py b/ingestion/tests/unit/metadata/utils/dependency_injector/test_dependency_injector.py new file mode 100644 index 00000000000..8a8aa27e777 --- /dev/null +++ b/ingestion/tests/unit/metadata/utils/dependency_injector/test_dependency_injector.py @@ -0,0 +1,141 @@ +import pytest + +from metadata.utils.dependency_injector.dependency_injector import ( + DependencyContainer, + DependencyNotFoundError, + Inject, + inject, +) + + +# Test classes for dependency injection +class Database: + def __init__(self, connection_string: str): + self.connection_string = connection_string + + def query(self, query: str) -> str: + return f"Executed: {query}" + + +class Cache: + def __init__(self, host: str): + self.host = host + + def get(self, key: str) -> str: + return f"Cache hit for {key}" + + +# Test functions for injection +@inject +def get_user(user_id: int, db: Inject[Database] = None) -> str: + if db is None: + raise DependencyNotFoundError("Database dependency not found") + return db.query(f"SELECT * FROM users WHERE id = {user_id}") + + +@inject +def get_cached_user(user_id: int, db: Inject[Database], cache: Inject[Cache]) -> str: + if db is None: + raise DependencyNotFoundError("Database dependency not found") + if cache is None: + raise DependencyNotFoundError("Cache dependency not found") + cache_key = f"user:{user_id}" + cached = cache.get(cache_key) + if cached: + return cached + return db.query(f"SELECT * FROM users WHERE id = {user_id}") + + +class TestDependencyContainer: + def test_register_and_get_dependency(self): + container = DependencyContainer() + db_factory = lambda: Database("postgresql://localhost:5432") + + container.register(Database, db_factory) + db = container.get(Database) + + assert db is not None + assert isinstance(db, Database) + assert db.connection_string == "postgresql://localhost:5432" + + def test_override_dependency(self): + container = DependencyContainer() + original_factory = lambda: Database("postgresql://localhost:5432") + override_factory = lambda: Database("postgresql://test:5432") + + container.register(Database, original_factory) + container.override(Database, override_factory) + + db = container.get(Database) + assert db is not None + assert db.connection_string == "postgresql://test:5432" + + def test_remove_override(self): + container = DependencyContainer() + original_factory = lambda: Database("postgresql://localhost:5432") + override_factory = lambda: Database("postgresql://test:5432") + + container.register(Database, original_factory) + container.override(Database, override_factory) + container.remove_override(Database) + + db = container.get(Database) + assert db is not None + assert db.connection_string == "postgresql://localhost:5432" + + def test_clear_dependencies(self): + container = DependencyContainer() + db_factory = lambda: Database("postgresql://localhost:5432") + cache_factory = lambda: Cache("localhost") + + container.register(Database, db_factory) + container.register(Cache, cache_factory) + container.clear() + + assert container.get(Database) is None + assert container.get(Cache) is None + + def test_has_dependency(self): + container = DependencyContainer() + db_factory = lambda: Database("postgresql://localhost:5432") + + assert not container.has(Database) + container.register(Database, db_factory) + assert container.has(Database) + + +class TestInjectDecorator: + def test_inject_single_dependency(self): + container = DependencyContainer() + db_factory = lambda: Database("postgresql://localhost:5432") + container.register(Database, db_factory) + + result = get_user(user_id=1) + assert result == "Executed: SELECT * FROM users WHERE id = 1" + + def test_inject_multiple_dependencies(self): + container = DependencyContainer() + db_factory = lambda: Database("postgresql://localhost:5432") + cache_factory = lambda: Cache("localhost") + + container.register(Database, db_factory) + container.register(Cache, cache_factory) + + result = get_cached_user(user_id=1) + assert result == "Cache hit for user:1" + + def test_missing_dependency(self): + container = DependencyContainer() + container.clear() # Ensure no dependencies are registered + + with pytest.raises(DependencyNotFoundError): + get_user(user_id=1) + + def test_explicit_dependency_override(self): + container = DependencyContainer() + db_factory = lambda: Database("postgresql://localhost:5432") + container.register(Database, db_factory) + + custom_db = Database("postgresql://custom:5432") + result = get_user(user_id=1, db=custom_db) + assert result == "Executed: SELECT * FROM users WHERE id = 1"