MINOR: Implement dependency injection on ingestion (#21719)

* Initial implementation for our Connection Class

* Implement the Initial Connection class

* Add Unit Tests

* Implement Dependency Injection for the Ingestion Framework

* Fix Test

* Fix Profile Test Connection

* Fix test, making the injection test run last

* Update connections.py

* Changed NewType to an AbstractClass to avoid linting issues

* remove comment

* Fix bug in service spec

* Update PyTest version to avoid importlib.reader wrong import
This commit is contained in:
IceS2 2025-06-16 08:03:38 +02:00 committed by GitHub
parent 092932b4a3
commit 49df5fc9de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 722 additions and 13 deletions

View File

@ -1,4 +1,4 @@
# Copyright 2025 Collate
# https://github.com/open-metadata/OpenMetadata/actions/runs/15640676139/job/44066998708?pr=21719 Copyright 2025 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
@ -377,7 +377,7 @@ dev = {
# Dependencies for unit testing in addition to dev dependencies and plugins
test_unit = {
"pytest==7.0.0",
"pytest==7.0.1",
"pytest-cov",
"pytest-order",
"dirty-equals",
@ -396,7 +396,7 @@ test = {
# Install GE because it's not in the `all` plugin
VERSIONS["great-expectations"],
"basedpyright~=1.14",
"pytest==7.0.0",
"pytest==7.0.1",
"pytest-cov",
"pytest-order",
"dirty-equals",

View File

@ -0,0 +1,22 @@
# Copyright 2025 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
OpenMetadata package initialization.
"""
from metadata.utils.dependency_injector.dependency_injector import DependencyContainer
from metadata.utils.service_spec.service_spec import DefaultSourceLoader, SourceLoader
# Initialize the dependency container
container = DependencyContainer()
# Register the source loader
container.register(SourceLoader, DefaultSourceLoader)

View File

@ -0,0 +1,154 @@
# OpenMetadata Dependency Injection System
This module provides a type-safe dependency injection system for OpenMetadata that uses Python's type hints to automatically inject dependencies into functions and methods.
## Features
- Type-safe dependency injection using Python's type hints
- Thread-safe singleton container for managing dependencies
- Support for dependency overrides (useful for testing)
- Automatic dependency resolution
## Basic Usage
### 1. Define Your Dependencies
First, define your dependencies as classes or functions:
```python
class Database:
def __init__(self, connection_string: str):
self.connection_string = connection_string
def query(self, query: str) -> dict:
# Implementation
pass
```
### 2. Register Dependencies
Register your dependencies with the container:
```python
from metadata.utils.dependency_injector import DependencyContainer, Inject, inject
# Create a container instance
container = DependencyContainer[Callable]()
# Register a dependency (usually as a factory function)
container.register(Database, lambda: Database("postgresql://localhost:5432"))
```
### 3. Use Dependency Injection
Use the `@inject` decorator and `Inject` type to automatically inject dependencies:
```python
@inject
def get_user(user_id: int, db: Inject[Database]) -> dict:
return db.query(f"SELECT * FROM users WHERE id = {user_id}")
# The db parameter will be automatically injected
user = get_user(user_id=1)
```
## Advanced Usage
### Dependency Overrides
You can temporarily override dependencies, which is useful for testing:
```python
# Override the Database dependency
container.override(Database, lambda: Database("postgresql://test:5432"))
# Use the overridden dependency
user = get_user(user_id=1)
# Remove the override when done
container.remove_override(Database)
```
### Explicit Dependency Injection
You can also explicitly provide dependencies when calling functions:
```python
# Explicitly provide the database
custom_db = Database("postgresql://custom:5432")
user = get_user(user_id=1, db=custom_db)
```
### Checking Dependencies
Check if a dependency is registered:
```python
if container.has(Database):
print("Database dependency is registered")
```
### Clearing Dependencies
Clear all registered dependencies and overrides:
```python
container.clear()
```
## Best Practices
1. **Use Factory Functions**: Register dependencies as factory functions to ensure fresh instances:
```python
container.register(Database, lambda: Database("postgresql://localhost:5432"))
```
2. **Type Safety**: Always use proper type hints with `Inject`:
```python
@inject
def my_function(db: Inject[Database]):
pass
```
3. **Testing**: Use dependency overrides in tests:
```python
def test_get_user():
container.override(Database, lambda: MockDatabase())
try:
user = get_user(user_id=1)
assert user is not None
finally:
container.remove_override(Database)
```
4. **Error Handling**: Handle missing dependencies gracefully:
```python
try:
user = get_user(user_id=1)
except DependencyNotFoundError:
# Handle missing dependency
pass
```
## Error Types
The system provides specific exceptions for different error cases:
- `DependencyInjectionError`: Base exception for all dependency injection errors
- `DependencyNotFoundError`: Raised when a required dependency is not found
- `InvalidInjectionTypeError`: Raised when an invalid injection type is used
## Thread Safety
The dependency container is thread-safe and uses a reentrant lock (RLock) to support:
- Multiple threads accessing the container simultaneously
- The same thread acquiring the lock multiple times
- Safe dependency registration and retrieval
## Limitations
1. Dependencies must be registered before they can be injected
2. The system uses type names as keys, so different types with the same name will conflict
3. Circular dependencies are not supported
4. Dependencies are always treated as optional and can be overridden
5. Dependencies can't be passed as *arg. Must be passed as *kwargs

View File

@ -0,0 +1,343 @@
# Copyright 2025 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Dependency injection utilities for OpenMetadata.
This module provides a type-safe dependency injection system that uses Python's type hints
to automatically inject dependencies into functions and methods.
Example:
```python
from typing import Annotated
from metadata.utils.dependency_injector import Inject, inject, DependencyContainer
class Database:
def __init__(self, connection_string: str):
self.connection_string = connection_string
# Register a dependency
container = DependencyContainer[Callable]()
container.register(Database, lambda: Database("postgresql://localhost:5432"))
# Use dependency injection
@inject
def get_user(user_id: int, db: Inject[Database]) -> dict:
return db.query(f"SELECT * FROM users WHERE id = {user_id}")
```
"""
from functools import wraps
from threading import RLock
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Callable,
Dict,
Generic,
Optional,
Type,
TypeVar,
Union,
get_args,
get_origin,
get_type_hints,
)
from metadata.utils.logger import utils_logger
logger = utils_logger()
T = TypeVar("T")
class DependencyInjectionError(Exception):
"""Base exception for dependency injection errors."""
pass
class DependencyNotFoundError(DependencyInjectionError):
"""Raised when a required dependency is not found in the container."""
pass
class InvalidInjectionTypeError(DependencyInjectionError):
"""Raised when an invalid injection type is used."""
pass
if TYPE_CHECKING:
Inject = Annotated[Union[T, None], "Inject Marker"]
else:
class Inject(Generic[T]):
"""
Type for dependency injection that uses types as keys.
This type is used to mark parameters that should be automatically injected
by the dependency container. Injection is always treated as Optioonal. It can be overriden.
Type Parameters:
T: The type of the dependency to inject
Example:
```python
@inject
def my_function(db: Inject[Database]):
# db will be automatically injected
pass
```
"""
class DependencyContainer:
"""
Thread-safe singleton container for managing dependencies.
This container uses RLock to support reentrant locking, allowing the same thread
to acquire the lock multiple times. It maintains two dictionaries:
- _dependencies: The original registered dependencies
- _overrides: Temporary overrides that take precedence over original dependencies
Type Parameters:
T: The base type for all dependencies in this container
Example:
```python
container = DependencyContainer[Callable]()
container.register(Database, lambda: Database("postgresql://localhost:5432"))
container.override(Database, lambda: Database("postgresql://test:5432"))
```
"""
_instance: Optional["DependencyContainer"] = None
_lock = RLock()
_dependencies: Dict[str, Callable[[], Any]] = {}
_overrides: Dict[str, Callable[[], Any]] = {}
def __new__(cls) -> "DependencyContainer":
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def register(
self, dependency_type: Type[Any], dependency: Callable[[], Any]
) -> None:
"""
Register a dependency with the container.
Args:
dependency_type: The type of the dependency to register
dependency: The dependency to register (usually a factory function)
Example:
```python
container.register(Database, lambda: Database("postgresql://localhost:5432"))
```
"""
with self._lock:
self._dependencies[dependency_type.__name__] = dependency
def override(
self, dependency_type: Type[Any], dependency: Callable[[], Any]
) -> None:
"""
Override a dependency with a new implementation.
The override takes precedence over the original dependency and is useful
for testing or temporary changes.
Args:
dependency_type: The type of the dependency to override
dependency: The new dependency implementation
Example:
```python
container.override(Database, lambda: Database("postgresql://test:5432"))
```
"""
with self._lock:
self._overrides[dependency_type.__name__] = dependency
def remove_override(self, dependency_type: Type[T]) -> None:
"""
Remove an override for a dependency.
Args:
dependency_type: The type of the dependency override to remove
Example:
```python
container.remove_override(Database)
```
"""
with self._lock:
self._overrides.pop(dependency_type.__name__, None)
def get(self, dependency_type: Type[Any]) -> Optional[Any]:
"""
Get a dependency from the container.
Checks overrides first, then falls back to original dependencies.
Args:
dependency_type: The type of the dependency to retrieve
Returns:
The dependency if found, None otherwise
Example:
```python
db_factory = container.get(Database)
if db_factory:
db = db_factory()
```
"""
with self._lock:
type_name = dependency_type.__name__
factory = self._overrides.get(type_name) or self._dependencies.get(
type_name
)
if factory is None:
return None
return factory()
def clear(self) -> None:
"""
Clear all dependencies and overrides.
Example:
```python
container.clear() # Remove all registered dependencies and overrides
```
"""
with self._lock:
self._dependencies.clear()
self._overrides.clear()
def has(self, dependency_type: Type[T]) -> bool:
"""
Check if a dependency exists in the container.
Args:
dependency_type: The type to check
Returns:
True if the dependency exists, False otherwise
Example:
```python
if container.has(Database):
print("Database dependency is registered")
```"""
with self._lock:
type_name = dependency_type.__name__
return type_name in self._overrides or type_name in self._dependencies
def inject(func: Callable[..., Any]) -> Callable[..., Any]:
"""
Decorator to inject dependencies based on type hints.
This decorator automatically injects dependencies into function parameters
based on their type hints. It uses types as keys for dependency lookup and
allows explicit injection by passing dependencies as keyword arguments.
Args:
func: The function to inject dependencies into
Returns:
A function with dependencies injected
Example:
```python
@inject
def get_user(user_id: int, db: Inject[Database]) -> dict:
return db.query(f"SELECT * FROM users WHERE id = {user_id}")
# Dependencies can also be passed explicitly
get_user(user_id=1, db=Database("postgresql://localhost:5432"))
```
"""
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
container = DependencyContainer()
type_hints = get_type_hints(func, include_extras=True)
for param_name, param_type in type_hints.items():
# Skip if parameter is already provided explicitly
if param_name in kwargs:
continue
# Check if it's an Inject type
if is_inject_type(param_type):
dependency_type = extract_inject_arg(param_type)
dependency = container.get(dependency_type)
if dependency is None:
raise DependencyNotFoundError(
f"Dependency of type {dependency_type} not found in container. "
f"Make sure to register it using container.register({dependency_type.__name__}, ...)"
)
kwargs[param_name] = dependency
return func(*args, **kwargs)
return wrapper
def is_inject_type(tp: Any) -> bool:
"""
Check if a type is an Inject type or Optional[Inject].
Args:
tp: The type to check
Returns:
True if the type is Inject or Optional[Inject], False otherwise
"""
origin = get_origin(tp)
if origin is Inject:
return True
if origin is Union:
args = get_args(tp)
return any(get_origin(arg) is Inject for arg in args)
return False
def extract_inject_arg(tp: Any) -> Any:
"""
Extract the type argument from an Inject type.
Args:
tp: The type to extract from
Returns:
The type argument from the Inject type
Raises:
InvalidInjectionTypeError: If the type is not Inject or Optional[Inject]
"""
origin = get_origin(tp)
if origin is Inject:
return get_args(tp)[0]
if origin is Union:
for arg in get_args(tp):
if get_origin(arg) is Inject:
return get_args(arg)[0]
raise InvalidInjectionTypeError(
f"Type {tp} is not Inject or Optional[Inject]. "
f"Use Annotated[YourType, 'Inject'] to mark a parameter for injection."
)

View File

@ -2,7 +2,8 @@
Manifests are used to store class information
"""
from typing import Optional, Type, cast
from abc import ABC, abstractmethod
from typing import Any, Optional, Type, cast
from pydantic import model_validator
@ -13,6 +14,7 @@ from metadata.ingestion.connections.connection import BaseConnection
from metadata.ingestion.models.custom_pydantic import BaseModel
from metadata.profiler.interface.profiler_interface import ProfilerInterface
from metadata.sampler.sampler_interface import SamplerInterface
from metadata.utils.dependency_injector.dependency_injector import Inject, inject
from metadata.utils.importer import (
TYPE_SEPARATOR,
DynamicImportException,
@ -25,6 +27,14 @@ from metadata.utils.logger import utils_logger
logger = utils_logger()
class SourceLoader(ABC):
@abstractmethod
def __call__(
self, service_type: ServiceType, source_type: str, from_: str
) -> Type[Any]:
"""Load the service spec for a given service type and source type."""
class BaseSpec(BaseModel):
"""
# The OpenMetadata Ingestion Service Specification (Spec)
@ -67,8 +77,13 @@ class BaseSpec(BaseModel):
return values
@classmethod
@inject
def get_for_source(
cls, service_type: ServiceType, source_type: str, from_: str = "ingestion"
cls,
service_type: ServiceType,
source_type: str,
from_: str = "ingestion",
source_loader: Inject[SourceLoader] = None,
) -> "BaseSpec":
"""Retrieves the manifest for a given source type. If it does not exist will attempt to retrieve
a default manifest for the service type.
@ -81,14 +96,26 @@ class BaseSpec(BaseModel):
Returns:
BaseSpec: The manifest for the source type.
"""
return cls.model_validate(
import_from_module(
"metadata.{}.source.{}.{}.{}.ServiceSpec".format( # pylint: disable=C0209
from_,
service_type.name.lower(),
get_module_dir(source_type),
"service_spec",
)
if not source_loader:
raise ValueError("Source loader is required")
return cls.model_validate(source_loader(service_type, source_type, from_))
class DefaultSourceLoader(SourceLoader):
def __call__(
self,
service_type: ServiceType,
source_type: str,
from_: str = "ingestion",
) -> Type[Any]:
"""Default implementation for loading service specifications."""
return import_from_module(
"metadata.{}.source.{}.{}.{}.ServiceSpec".format( # pylint: disable=C0209
from_,
service_type.name.lower(),
get_module_dir(source_type),
"service_spec",
)
)

View File

@ -6,3 +6,25 @@ def pytest_pycollect_makeitem(collector, name, obj):
return []
except AttributeError:
pass
def pytest_collection_modifyitems(session, config, items):
"""Reorder test items to ensure certain files run last."""
# List of test files that should run last
last_files = [
"test_dependency_injector.py",
# Add other files that should run last here
]
# Get all test items that should run last
last_items = []
other_items = []
for item in items:
if any(file in item.nodeid for file in last_files):
last_items.append(item)
else:
other_items.append(item)
# Reorder the items
items[:] = other_items + last_items

View File

@ -0,0 +1,141 @@
import pytest
from metadata.utils.dependency_injector.dependency_injector import (
DependencyContainer,
DependencyNotFoundError,
Inject,
inject,
)
# Test classes for dependency injection
class Database:
def __init__(self, connection_string: str):
self.connection_string = connection_string
def query(self, query: str) -> str:
return f"Executed: {query}"
class Cache:
def __init__(self, host: str):
self.host = host
def get(self, key: str) -> str:
return f"Cache hit for {key}"
# Test functions for injection
@inject
def get_user(user_id: int, db: Inject[Database] = None) -> str:
if db is None:
raise DependencyNotFoundError("Database dependency not found")
return db.query(f"SELECT * FROM users WHERE id = {user_id}")
@inject
def get_cached_user(user_id: int, db: Inject[Database], cache: Inject[Cache]) -> str:
if db is None:
raise DependencyNotFoundError("Database dependency not found")
if cache is None:
raise DependencyNotFoundError("Cache dependency not found")
cache_key = f"user:{user_id}"
cached = cache.get(cache_key)
if cached:
return cached
return db.query(f"SELECT * FROM users WHERE id = {user_id}")
class TestDependencyContainer:
def test_register_and_get_dependency(self):
container = DependencyContainer()
db_factory = lambda: Database("postgresql://localhost:5432")
container.register(Database, db_factory)
db = container.get(Database)
assert db is not None
assert isinstance(db, Database)
assert db.connection_string == "postgresql://localhost:5432"
def test_override_dependency(self):
container = DependencyContainer()
original_factory = lambda: Database("postgresql://localhost:5432")
override_factory = lambda: Database("postgresql://test:5432")
container.register(Database, original_factory)
container.override(Database, override_factory)
db = container.get(Database)
assert db is not None
assert db.connection_string == "postgresql://test:5432"
def test_remove_override(self):
container = DependencyContainer()
original_factory = lambda: Database("postgresql://localhost:5432")
override_factory = lambda: Database("postgresql://test:5432")
container.register(Database, original_factory)
container.override(Database, override_factory)
container.remove_override(Database)
db = container.get(Database)
assert db is not None
assert db.connection_string == "postgresql://localhost:5432"
def test_clear_dependencies(self):
container = DependencyContainer()
db_factory = lambda: Database("postgresql://localhost:5432")
cache_factory = lambda: Cache("localhost")
container.register(Database, db_factory)
container.register(Cache, cache_factory)
container.clear()
assert container.get(Database) is None
assert container.get(Cache) is None
def test_has_dependency(self):
container = DependencyContainer()
db_factory = lambda: Database("postgresql://localhost:5432")
assert not container.has(Database)
container.register(Database, db_factory)
assert container.has(Database)
class TestInjectDecorator:
def test_inject_single_dependency(self):
container = DependencyContainer()
db_factory = lambda: Database("postgresql://localhost:5432")
container.register(Database, db_factory)
result = get_user(user_id=1)
assert result == "Executed: SELECT * FROM users WHERE id = 1"
def test_inject_multiple_dependencies(self):
container = DependencyContainer()
db_factory = lambda: Database("postgresql://localhost:5432")
cache_factory = lambda: Cache("localhost")
container.register(Database, db_factory)
container.register(Cache, cache_factory)
result = get_cached_user(user_id=1)
assert result == "Cache hit for user:1"
def test_missing_dependency(self):
container = DependencyContainer()
container.clear() # Ensure no dependencies are registered
with pytest.raises(DependencyNotFoundError):
get_user(user_id=1)
def test_explicit_dependency_override(self):
container = DependencyContainer()
db_factory = lambda: Database("postgresql://localhost:5432")
container.register(Database, db_factory)
custom_db = Database("postgresql://custom:5432")
result = get_user(user_id=1, db=custom_db)
assert result == "Executed: SELECT * FROM users WHERE id = 1"