mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-08-14 12:06:54 +00:00
MINOR: Implement dependency injection on ingestion (#21719)
* Initial implementation for our Connection Class * Implement the Initial Connection class * Add Unit Tests * Implement Dependency Injection for the Ingestion Framework * Fix Test * Fix Profile Test Connection * Fix test, making the injection test run last * Update connections.py * Changed NewType to an AbstractClass to avoid linting issues * remove comment * Fix bug in service spec * Update PyTest version to avoid importlib.reader wrong import
This commit is contained in:
parent
092932b4a3
commit
49df5fc9de
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 Collate
|
||||
# https://github.com/open-metadata/OpenMetadata/actions/runs/15640676139/job/44066998708?pr=21719 Copyright 2025 Collate
|
||||
# Licensed under the Collate Community License, Version 1.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
@ -377,7 +377,7 @@ dev = {
|
||||
|
||||
# Dependencies for unit testing in addition to dev dependencies and plugins
|
||||
test_unit = {
|
||||
"pytest==7.0.0",
|
||||
"pytest==7.0.1",
|
||||
"pytest-cov",
|
||||
"pytest-order",
|
||||
"dirty-equals",
|
||||
@ -396,7 +396,7 @@ test = {
|
||||
# Install GE because it's not in the `all` plugin
|
||||
VERSIONS["great-expectations"],
|
||||
"basedpyright~=1.14",
|
||||
"pytest==7.0.0",
|
||||
"pytest==7.0.1",
|
||||
"pytest-cov",
|
||||
"pytest-order",
|
||||
"dirty-equals",
|
||||
|
22
ingestion/src/metadata/__init__.py
Normal file
22
ingestion/src/metadata/__init__.py
Normal file
@ -0,0 +1,22 @@
|
||||
# Copyright 2025 Collate
|
||||
# Licensed under the Collate Community License, Version 1.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
OpenMetadata package initialization.
|
||||
"""
|
||||
|
||||
from metadata.utils.dependency_injector.dependency_injector import DependencyContainer
|
||||
from metadata.utils.service_spec.service_spec import DefaultSourceLoader, SourceLoader
|
||||
|
||||
# Initialize the dependency container
|
||||
container = DependencyContainer()
|
||||
|
||||
# Register the source loader
|
||||
container.register(SourceLoader, DefaultSourceLoader)
|
154
ingestion/src/metadata/utils/dependency_injector/README.md
Normal file
154
ingestion/src/metadata/utils/dependency_injector/README.md
Normal file
@ -0,0 +1,154 @@
|
||||
# OpenMetadata Dependency Injection System
|
||||
|
||||
This module provides a type-safe dependency injection system for OpenMetadata that uses Python's type hints to automatically inject dependencies into functions and methods.
|
||||
|
||||
## Features
|
||||
|
||||
- Type-safe dependency injection using Python's type hints
|
||||
- Thread-safe singleton container for managing dependencies
|
||||
- Support for dependency overrides (useful for testing)
|
||||
- Automatic dependency resolution
|
||||
|
||||
## Basic Usage
|
||||
|
||||
### 1. Define Your Dependencies
|
||||
|
||||
First, define your dependencies as classes or functions:
|
||||
|
||||
```python
|
||||
class Database:
|
||||
def __init__(self, connection_string: str):
|
||||
self.connection_string = connection_string
|
||||
|
||||
def query(self, query: str) -> dict:
|
||||
# Implementation
|
||||
pass
|
||||
```
|
||||
|
||||
### 2. Register Dependencies
|
||||
|
||||
Register your dependencies with the container:
|
||||
|
||||
```python
|
||||
from metadata.utils.dependency_injector import DependencyContainer, Inject, inject
|
||||
|
||||
# Create a container instance
|
||||
container = DependencyContainer[Callable]()
|
||||
|
||||
# Register a dependency (usually as a factory function)
|
||||
container.register(Database, lambda: Database("postgresql://localhost:5432"))
|
||||
```
|
||||
|
||||
### 3. Use Dependency Injection
|
||||
|
||||
Use the `@inject` decorator and `Inject` type to automatically inject dependencies:
|
||||
|
||||
```python
|
||||
@inject
|
||||
def get_user(user_id: int, db: Inject[Database]) -> dict:
|
||||
return db.query(f"SELECT * FROM users WHERE id = {user_id}")
|
||||
|
||||
# The db parameter will be automatically injected
|
||||
user = get_user(user_id=1)
|
||||
```
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Dependency Overrides
|
||||
|
||||
You can temporarily override dependencies, which is useful for testing:
|
||||
|
||||
```python
|
||||
# Override the Database dependency
|
||||
container.override(Database, lambda: Database("postgresql://test:5432"))
|
||||
|
||||
# Use the overridden dependency
|
||||
user = get_user(user_id=1)
|
||||
|
||||
# Remove the override when done
|
||||
container.remove_override(Database)
|
||||
```
|
||||
|
||||
### Explicit Dependency Injection
|
||||
|
||||
You can also explicitly provide dependencies when calling functions:
|
||||
|
||||
```python
|
||||
# Explicitly provide the database
|
||||
custom_db = Database("postgresql://custom:5432")
|
||||
user = get_user(user_id=1, db=custom_db)
|
||||
```
|
||||
|
||||
### Checking Dependencies
|
||||
|
||||
Check if a dependency is registered:
|
||||
|
||||
```python
|
||||
if container.has(Database):
|
||||
print("Database dependency is registered")
|
||||
```
|
||||
|
||||
### Clearing Dependencies
|
||||
|
||||
Clear all registered dependencies and overrides:
|
||||
|
||||
```python
|
||||
container.clear()
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use Factory Functions**: Register dependencies as factory functions to ensure fresh instances:
|
||||
```python
|
||||
container.register(Database, lambda: Database("postgresql://localhost:5432"))
|
||||
```
|
||||
|
||||
2. **Type Safety**: Always use proper type hints with `Inject`:
|
||||
```python
|
||||
@inject
|
||||
def my_function(db: Inject[Database]):
|
||||
pass
|
||||
```
|
||||
|
||||
3. **Testing**: Use dependency overrides in tests:
|
||||
```python
|
||||
def test_get_user():
|
||||
container.override(Database, lambda: MockDatabase())
|
||||
try:
|
||||
user = get_user(user_id=1)
|
||||
assert user is not None
|
||||
finally:
|
||||
container.remove_override(Database)
|
||||
```
|
||||
|
||||
4. **Error Handling**: Handle missing dependencies gracefully:
|
||||
```python
|
||||
try:
|
||||
user = get_user(user_id=1)
|
||||
except DependencyNotFoundError:
|
||||
# Handle missing dependency
|
||||
pass
|
||||
```
|
||||
|
||||
## Error Types
|
||||
|
||||
The system provides specific exceptions for different error cases:
|
||||
|
||||
- `DependencyInjectionError`: Base exception for all dependency injection errors
|
||||
- `DependencyNotFoundError`: Raised when a required dependency is not found
|
||||
- `InvalidInjectionTypeError`: Raised when an invalid injection type is used
|
||||
|
||||
## Thread Safety
|
||||
|
||||
The dependency container is thread-safe and uses a reentrant lock (RLock) to support:
|
||||
- Multiple threads accessing the container simultaneously
|
||||
- The same thread acquiring the lock multiple times
|
||||
- Safe dependency registration and retrieval
|
||||
|
||||
## Limitations
|
||||
|
||||
1. Dependencies must be registered before they can be injected
|
||||
2. The system uses type names as keys, so different types with the same name will conflict
|
||||
3. Circular dependencies are not supported
|
||||
4. Dependencies are always treated as optional and can be overridden
|
||||
5. Dependencies can't be passed as *arg. Must be passed as *kwargs
|
@ -0,0 +1,343 @@
|
||||
# Copyright 2025 Collate
|
||||
# Licensed under the Collate Community License, Version 1.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Dependency injection utilities for OpenMetadata.
|
||||
|
||||
This module provides a type-safe dependency injection system that uses Python's type hints
|
||||
to automatically inject dependencies into functions and methods.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from typing import Annotated
|
||||
from metadata.utils.dependency_injector import Inject, inject, DependencyContainer
|
||||
|
||||
class Database:
|
||||
def __init__(self, connection_string: str):
|
||||
self.connection_string = connection_string
|
||||
|
||||
# Register a dependency
|
||||
container = DependencyContainer[Callable]()
|
||||
container.register(Database, lambda: Database("postgresql://localhost:5432"))
|
||||
|
||||
# Use dependency injection
|
||||
@inject
|
||||
def get_user(user_id: int, db: Inject[Database]) -> dict:
|
||||
return db.query(f"SELECT * FROM users WHERE id = {user_id}")
|
||||
```
|
||||
"""
|
||||
from functools import wraps
|
||||
from threading import RLock
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from metadata.utils.logger import utils_logger
|
||||
|
||||
logger = utils_logger()
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class DependencyInjectionError(Exception):
|
||||
"""Base exception for dependency injection errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DependencyNotFoundError(DependencyInjectionError):
|
||||
"""Raised when a required dependency is not found in the container."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidInjectionTypeError(DependencyInjectionError):
|
||||
"""Raised when an invalid injection type is used."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
Inject = Annotated[Union[T, None], "Inject Marker"]
|
||||
else:
|
||||
|
||||
class Inject(Generic[T]):
|
||||
"""
|
||||
Type for dependency injection that uses types as keys.
|
||||
|
||||
This type is used to mark parameters that should be automatically injected
|
||||
by the dependency container. Injection is always treated as Optioonal. It can be overriden.
|
||||
|
||||
Type Parameters:
|
||||
T: The type of the dependency to inject
|
||||
|
||||
Example:
|
||||
```python
|
||||
@inject
|
||||
def my_function(db: Inject[Database]):
|
||||
# db will be automatically injected
|
||||
pass
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class DependencyContainer:
|
||||
"""
|
||||
Thread-safe singleton container for managing dependencies.
|
||||
|
||||
This container uses RLock to support reentrant locking, allowing the same thread
|
||||
to acquire the lock multiple times. It maintains two dictionaries:
|
||||
- _dependencies: The original registered dependencies
|
||||
- _overrides: Temporary overrides that take precedence over original dependencies
|
||||
|
||||
Type Parameters:
|
||||
T: The base type for all dependencies in this container
|
||||
|
||||
Example:
|
||||
```python
|
||||
container = DependencyContainer[Callable]()
|
||||
container.register(Database, lambda: Database("postgresql://localhost:5432"))
|
||||
container.override(Database, lambda: Database("postgresql://test:5432"))
|
||||
```
|
||||
"""
|
||||
|
||||
_instance: Optional["DependencyContainer"] = None
|
||||
_lock = RLock()
|
||||
_dependencies: Dict[str, Callable[[], Any]] = {}
|
||||
_overrides: Dict[str, Callable[[], Any]] = {}
|
||||
|
||||
def __new__(cls) -> "DependencyContainer":
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def register(
|
||||
self, dependency_type: Type[Any], dependency: Callable[[], Any]
|
||||
) -> None:
|
||||
"""
|
||||
Register a dependency with the container.
|
||||
|
||||
Args:
|
||||
dependency_type: The type of the dependency to register
|
||||
dependency: The dependency to register (usually a factory function)
|
||||
|
||||
Example:
|
||||
```python
|
||||
container.register(Database, lambda: Database("postgresql://localhost:5432"))
|
||||
```
|
||||
"""
|
||||
with self._lock:
|
||||
self._dependencies[dependency_type.__name__] = dependency
|
||||
|
||||
def override(
|
||||
self, dependency_type: Type[Any], dependency: Callable[[], Any]
|
||||
) -> None:
|
||||
"""
|
||||
Override a dependency with a new implementation.
|
||||
|
||||
The override takes precedence over the original dependency and is useful
|
||||
for testing or temporary changes.
|
||||
|
||||
Args:
|
||||
dependency_type: The type of the dependency to override
|
||||
dependency: The new dependency implementation
|
||||
|
||||
Example:
|
||||
```python
|
||||
container.override(Database, lambda: Database("postgresql://test:5432"))
|
||||
```
|
||||
"""
|
||||
with self._lock:
|
||||
self._overrides[dependency_type.__name__] = dependency
|
||||
|
||||
def remove_override(self, dependency_type: Type[T]) -> None:
|
||||
"""
|
||||
Remove an override for a dependency.
|
||||
|
||||
Args:
|
||||
dependency_type: The type of the dependency override to remove
|
||||
|
||||
Example:
|
||||
```python
|
||||
container.remove_override(Database)
|
||||
```
|
||||
"""
|
||||
with self._lock:
|
||||
self._overrides.pop(dependency_type.__name__, None)
|
||||
|
||||
def get(self, dependency_type: Type[Any]) -> Optional[Any]:
|
||||
"""
|
||||
Get a dependency from the container.
|
||||
|
||||
Checks overrides first, then falls back to original dependencies.
|
||||
|
||||
Args:
|
||||
dependency_type: The type of the dependency to retrieve
|
||||
|
||||
Returns:
|
||||
The dependency if found, None otherwise
|
||||
|
||||
Example:
|
||||
```python
|
||||
db_factory = container.get(Database)
|
||||
if db_factory:
|
||||
db = db_factory()
|
||||
```
|
||||
"""
|
||||
with self._lock:
|
||||
type_name = dependency_type.__name__
|
||||
factory = self._overrides.get(type_name) or self._dependencies.get(
|
||||
type_name
|
||||
)
|
||||
if factory is None:
|
||||
return None
|
||||
return factory()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
Clear all dependencies and overrides.
|
||||
|
||||
Example:
|
||||
```python
|
||||
container.clear() # Remove all registered dependencies and overrides
|
||||
```
|
||||
"""
|
||||
with self._lock:
|
||||
self._dependencies.clear()
|
||||
self._overrides.clear()
|
||||
|
||||
def has(self, dependency_type: Type[T]) -> bool:
|
||||
"""
|
||||
Check if a dependency exists in the container.
|
||||
|
||||
Args:
|
||||
dependency_type: The type to check
|
||||
|
||||
Returns:
|
||||
True if the dependency exists, False otherwise
|
||||
|
||||
Example:
|
||||
```python
|
||||
if container.has(Database):
|
||||
print("Database dependency is registered")
|
||||
```"""
|
||||
with self._lock:
|
||||
type_name = dependency_type.__name__
|
||||
return type_name in self._overrides or type_name in self._dependencies
|
||||
|
||||
|
||||
def inject(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
"""
|
||||
Decorator to inject dependencies based on type hints.
|
||||
|
||||
This decorator automatically injects dependencies into function parameters
|
||||
based on their type hints. It uses types as keys for dependency lookup and
|
||||
allows explicit injection by passing dependencies as keyword arguments.
|
||||
|
||||
Args:
|
||||
func: The function to inject dependencies into
|
||||
|
||||
Returns:
|
||||
A function with dependencies injected
|
||||
|
||||
Example:
|
||||
```python
|
||||
@inject
|
||||
def get_user(user_id: int, db: Inject[Database]) -> dict:
|
||||
return db.query(f"SELECT * FROM users WHERE id = {user_id}")
|
||||
|
||||
# Dependencies can also be passed explicitly
|
||||
get_user(user_id=1, db=Database("postgresql://localhost:5432"))
|
||||
```
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
container = DependencyContainer()
|
||||
type_hints = get_type_hints(func, include_extras=True)
|
||||
|
||||
for param_name, param_type in type_hints.items():
|
||||
# Skip if parameter is already provided explicitly
|
||||
if param_name in kwargs:
|
||||
continue
|
||||
|
||||
# Check if it's an Inject type
|
||||
if is_inject_type(param_type):
|
||||
dependency_type = extract_inject_arg(param_type)
|
||||
dependency = container.get(dependency_type)
|
||||
if dependency is None:
|
||||
raise DependencyNotFoundError(
|
||||
f"Dependency of type {dependency_type} not found in container. "
|
||||
f"Make sure to register it using container.register({dependency_type.__name__}, ...)"
|
||||
)
|
||||
kwargs[param_name] = dependency
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def is_inject_type(tp: Any) -> bool:
|
||||
"""
|
||||
Check if a type is an Inject type or Optional[Inject].
|
||||
|
||||
Args:
|
||||
tp: The type to check
|
||||
|
||||
Returns:
|
||||
True if the type is Inject or Optional[Inject], False otherwise
|
||||
"""
|
||||
origin = get_origin(tp)
|
||||
if origin is Inject:
|
||||
return True
|
||||
if origin is Union:
|
||||
args = get_args(tp)
|
||||
return any(get_origin(arg) is Inject for arg in args)
|
||||
return False
|
||||
|
||||
|
||||
def extract_inject_arg(tp: Any) -> Any:
|
||||
"""
|
||||
Extract the type argument from an Inject type.
|
||||
|
||||
Args:
|
||||
tp: The type to extract from
|
||||
|
||||
Returns:
|
||||
The type argument from the Inject type
|
||||
|
||||
Raises:
|
||||
InvalidInjectionTypeError: If the type is not Inject or Optional[Inject]
|
||||
"""
|
||||
origin = get_origin(tp)
|
||||
if origin is Inject:
|
||||
return get_args(tp)[0]
|
||||
if origin is Union:
|
||||
for arg in get_args(tp):
|
||||
if get_origin(arg) is Inject:
|
||||
return get_args(arg)[0]
|
||||
raise InvalidInjectionTypeError(
|
||||
f"Type {tp} is not Inject or Optional[Inject]. "
|
||||
f"Use Annotated[YourType, 'Inject'] to mark a parameter for injection."
|
||||
)
|
@ -2,7 +2,8 @@
|
||||
Manifests are used to store class information
|
||||
"""
|
||||
|
||||
from typing import Optional, Type, cast
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional, Type, cast
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
@ -13,6 +14,7 @@ from metadata.ingestion.connections.connection import BaseConnection
|
||||
from metadata.ingestion.models.custom_pydantic import BaseModel
|
||||
from metadata.profiler.interface.profiler_interface import ProfilerInterface
|
||||
from metadata.sampler.sampler_interface import SamplerInterface
|
||||
from metadata.utils.dependency_injector.dependency_injector import Inject, inject
|
||||
from metadata.utils.importer import (
|
||||
TYPE_SEPARATOR,
|
||||
DynamicImportException,
|
||||
@ -25,6 +27,14 @@ from metadata.utils.logger import utils_logger
|
||||
logger = utils_logger()
|
||||
|
||||
|
||||
class SourceLoader(ABC):
|
||||
@abstractmethod
|
||||
def __call__(
|
||||
self, service_type: ServiceType, source_type: str, from_: str
|
||||
) -> Type[Any]:
|
||||
"""Load the service spec for a given service type and source type."""
|
||||
|
||||
|
||||
class BaseSpec(BaseModel):
|
||||
"""
|
||||
# The OpenMetadata Ingestion Service Specification (Spec)
|
||||
@ -67,8 +77,13 @@ class BaseSpec(BaseModel):
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
@inject
|
||||
def get_for_source(
|
||||
cls, service_type: ServiceType, source_type: str, from_: str = "ingestion"
|
||||
cls,
|
||||
service_type: ServiceType,
|
||||
source_type: str,
|
||||
from_: str = "ingestion",
|
||||
source_loader: Inject[SourceLoader] = None,
|
||||
) -> "BaseSpec":
|
||||
"""Retrieves the manifest for a given source type. If it does not exist will attempt to retrieve
|
||||
a default manifest for the service type.
|
||||
@ -81,14 +96,26 @@ class BaseSpec(BaseModel):
|
||||
Returns:
|
||||
BaseSpec: The manifest for the source type.
|
||||
"""
|
||||
return cls.model_validate(
|
||||
import_from_module(
|
||||
"metadata.{}.source.{}.{}.{}.ServiceSpec".format( # pylint: disable=C0209
|
||||
from_,
|
||||
service_type.name.lower(),
|
||||
get_module_dir(source_type),
|
||||
"service_spec",
|
||||
)
|
||||
if not source_loader:
|
||||
raise ValueError("Source loader is required")
|
||||
|
||||
return cls.model_validate(source_loader(service_type, source_type, from_))
|
||||
|
||||
|
||||
class DefaultSourceLoader(SourceLoader):
|
||||
def __call__(
|
||||
self,
|
||||
service_type: ServiceType,
|
||||
source_type: str,
|
||||
from_: str = "ingestion",
|
||||
) -> Type[Any]:
|
||||
"""Default implementation for loading service specifications."""
|
||||
return import_from_module(
|
||||
"metadata.{}.source.{}.{}.{}.ServiceSpec".format( # pylint: disable=C0209
|
||||
from_,
|
||||
service_type.name.lower(),
|
||||
get_module_dir(source_type),
|
||||
"service_spec",
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -6,3 +6,25 @@ def pytest_pycollect_makeitem(collector, name, obj):
|
||||
return []
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(session, config, items):
|
||||
"""Reorder test items to ensure certain files run last."""
|
||||
# List of test files that should run last
|
||||
last_files = [
|
||||
"test_dependency_injector.py",
|
||||
# Add other files that should run last here
|
||||
]
|
||||
|
||||
# Get all test items that should run last
|
||||
last_items = []
|
||||
other_items = []
|
||||
|
||||
for item in items:
|
||||
if any(file in item.nodeid for file in last_files):
|
||||
last_items.append(item)
|
||||
else:
|
||||
other_items.append(item)
|
||||
|
||||
# Reorder the items
|
||||
items[:] = other_items + last_items
|
||||
|
@ -0,0 +1,141 @@
|
||||
import pytest
|
||||
|
||||
from metadata.utils.dependency_injector.dependency_injector import (
|
||||
DependencyContainer,
|
||||
DependencyNotFoundError,
|
||||
Inject,
|
||||
inject,
|
||||
)
|
||||
|
||||
|
||||
# Test classes for dependency injection
|
||||
class Database:
|
||||
def __init__(self, connection_string: str):
|
||||
self.connection_string = connection_string
|
||||
|
||||
def query(self, query: str) -> str:
|
||||
return f"Executed: {query}"
|
||||
|
||||
|
||||
class Cache:
|
||||
def __init__(self, host: str):
|
||||
self.host = host
|
||||
|
||||
def get(self, key: str) -> str:
|
||||
return f"Cache hit for {key}"
|
||||
|
||||
|
||||
# Test functions for injection
|
||||
@inject
|
||||
def get_user(user_id: int, db: Inject[Database] = None) -> str:
|
||||
if db is None:
|
||||
raise DependencyNotFoundError("Database dependency not found")
|
||||
return db.query(f"SELECT * FROM users WHERE id = {user_id}")
|
||||
|
||||
|
||||
@inject
|
||||
def get_cached_user(user_id: int, db: Inject[Database], cache: Inject[Cache]) -> str:
|
||||
if db is None:
|
||||
raise DependencyNotFoundError("Database dependency not found")
|
||||
if cache is None:
|
||||
raise DependencyNotFoundError("Cache dependency not found")
|
||||
cache_key = f"user:{user_id}"
|
||||
cached = cache.get(cache_key)
|
||||
if cached:
|
||||
return cached
|
||||
return db.query(f"SELECT * FROM users WHERE id = {user_id}")
|
||||
|
||||
|
||||
class TestDependencyContainer:
|
||||
def test_register_and_get_dependency(self):
|
||||
container = DependencyContainer()
|
||||
db_factory = lambda: Database("postgresql://localhost:5432")
|
||||
|
||||
container.register(Database, db_factory)
|
||||
db = container.get(Database)
|
||||
|
||||
assert db is not None
|
||||
assert isinstance(db, Database)
|
||||
assert db.connection_string == "postgresql://localhost:5432"
|
||||
|
||||
def test_override_dependency(self):
|
||||
container = DependencyContainer()
|
||||
original_factory = lambda: Database("postgresql://localhost:5432")
|
||||
override_factory = lambda: Database("postgresql://test:5432")
|
||||
|
||||
container.register(Database, original_factory)
|
||||
container.override(Database, override_factory)
|
||||
|
||||
db = container.get(Database)
|
||||
assert db is not None
|
||||
assert db.connection_string == "postgresql://test:5432"
|
||||
|
||||
def test_remove_override(self):
|
||||
container = DependencyContainer()
|
||||
original_factory = lambda: Database("postgresql://localhost:5432")
|
||||
override_factory = lambda: Database("postgresql://test:5432")
|
||||
|
||||
container.register(Database, original_factory)
|
||||
container.override(Database, override_factory)
|
||||
container.remove_override(Database)
|
||||
|
||||
db = container.get(Database)
|
||||
assert db is not None
|
||||
assert db.connection_string == "postgresql://localhost:5432"
|
||||
|
||||
def test_clear_dependencies(self):
|
||||
container = DependencyContainer()
|
||||
db_factory = lambda: Database("postgresql://localhost:5432")
|
||||
cache_factory = lambda: Cache("localhost")
|
||||
|
||||
container.register(Database, db_factory)
|
||||
container.register(Cache, cache_factory)
|
||||
container.clear()
|
||||
|
||||
assert container.get(Database) is None
|
||||
assert container.get(Cache) is None
|
||||
|
||||
def test_has_dependency(self):
|
||||
container = DependencyContainer()
|
||||
db_factory = lambda: Database("postgresql://localhost:5432")
|
||||
|
||||
assert not container.has(Database)
|
||||
container.register(Database, db_factory)
|
||||
assert container.has(Database)
|
||||
|
||||
|
||||
class TestInjectDecorator:
|
||||
def test_inject_single_dependency(self):
|
||||
container = DependencyContainer()
|
||||
db_factory = lambda: Database("postgresql://localhost:5432")
|
||||
container.register(Database, db_factory)
|
||||
|
||||
result = get_user(user_id=1)
|
||||
assert result == "Executed: SELECT * FROM users WHERE id = 1"
|
||||
|
||||
def test_inject_multiple_dependencies(self):
|
||||
container = DependencyContainer()
|
||||
db_factory = lambda: Database("postgresql://localhost:5432")
|
||||
cache_factory = lambda: Cache("localhost")
|
||||
|
||||
container.register(Database, db_factory)
|
||||
container.register(Cache, cache_factory)
|
||||
|
||||
result = get_cached_user(user_id=1)
|
||||
assert result == "Cache hit for user:1"
|
||||
|
||||
def test_missing_dependency(self):
|
||||
container = DependencyContainer()
|
||||
container.clear() # Ensure no dependencies are registered
|
||||
|
||||
with pytest.raises(DependencyNotFoundError):
|
||||
get_user(user_id=1)
|
||||
|
||||
def test_explicit_dependency_override(self):
|
||||
container = DependencyContainer()
|
||||
db_factory = lambda: Database("postgresql://localhost:5432")
|
||||
container.register(Database, db_factory)
|
||||
|
||||
custom_db = Database("postgresql://custom:5432")
|
||||
result = get_user(user_id=1, db=custom_db)
|
||||
assert result == "Executed: SELECT * FROM users WHERE id = 1"
|
Loading…
x
Reference in New Issue
Block a user