476 lines
17 KiB
Python
Raw Normal View History

from typing import List, Optional
from unittest.mock import Mock, patch
import cachetools
import pytest
# Import the classes from your module
from datahub.api.entities.external.external_entities import (
CaseSensitivity,
ExternalEntityId,
LinkedResourceSet,
MissingExternalEntity,
PlatformResourceRepository,
)
from datahub.api.entities.platformresource.platform_resource import (
PlatformResource,
PlatformResourceKey,
)
from datahub.ingestion.graph.client import DataHubGraph
from datahub.metadata.urns import PlatformResourceUrn, Urn
from datahub.utilities.search_utils import ElasticDocumentQuery
from datahub.utilities.urns.error import InvalidUrnError
class TestPlatformResourceRepository:
"""Tests for PlatformResourceRepository class."""
@pytest.fixture
def mock_graph(self) -> Mock:
"""Create a mock DataHubGraph."""
return Mock(spec=DataHubGraph)
@pytest.fixture
def repository(self, mock_graph: Mock) -> PlatformResourceRepository:
"""Create a PlatformResourceRepository instance."""
return PlatformResourceRepository(mock_graph)
@pytest.fixture
def mock_platform_resource(self) -> Mock:
"""Create a mock PlatformResource."""
resource: Mock = Mock(spec=PlatformResource)
resource.id = "test-resource-id"
return resource
def test_init(self, mock_graph: Mock) -> None:
"""Test repository initialization."""
repo: PlatformResourceRepository = PlatformResourceRepository(mock_graph)
assert repo.graph == mock_graph
assert isinstance(repo.cache, cachetools.TTLCache)
assert repo.cache.maxsize == 1000
assert repo.cache.ttl == 300 # 60 * 5
@patch(
"datahub.api.entities.platformresource.platform_resource.PlatformResource.search_by_filters"
)
def test_search_by_filter_with_cache(
self,
mock_search: Mock,
repository: PlatformResourceRepository,
mock_platform_resource: Mock,
) -> None:
"""Test search_by_filter with caching enabled."""
mock_filter = Mock(spec=ElasticDocumentQuery)
mock_search.return_value = [mock_platform_resource]
results = list(repository.search_by_filter(mock_filter, add_to_cache=True))
assert len(results) == 1
assert results[0] == mock_platform_resource
assert repository.cache[mock_platform_resource.id] == mock_platform_resource
mock_search.assert_called_once_with(repository.graph, mock_filter)
@patch(
"datahub.api.entities.platformresource.platform_resource.PlatformResource.search_by_filters"
)
def test_search_by_filter_without_cache(
self,
mock_search: Mock,
repository: PlatformResourceRepository,
mock_platform_resource: Mock,
) -> None:
"""Test search_by_filter with caching disabled."""
mock_filter = Mock(spec=ElasticDocumentQuery)
mock_search.return_value = [mock_platform_resource]
results = list(repository.search_by_filter(mock_filter, add_to_cache=False))
assert len(results) == 1
assert results[0] == mock_platform_resource
assert mock_platform_resource.id not in repository.cache
mock_search.assert_called_once_with(repository.graph, mock_filter)
def test_create(
self, repository: PlatformResourceRepository, mock_platform_resource: Mock
) -> None:
"""Test create method."""
repository.create(mock_platform_resource)
mock_platform_resource.to_datahub.assert_called_once_with(repository.graph)
assert repository.cache[mock_platform_resource.id] == mock_platform_resource
def test_get_existing(self, repository: PlatformResourceRepository) -> None:
"""Test get method for existing resource."""
mock_platform_resource_key: PlatformResourceKey = PlatformResourceKey(
platform="test-platform",
resource_type="test-resource-type",
platform_instance="test-instance",
primary_key="test-primary-key",
)
mock_platform_resource = PlatformResource.create(
key=mock_platform_resource_key, secondary_keys=[], value={}
)
repository.cache[mock_platform_resource_key.id] = mock_platform_resource
result: Optional[PlatformResource] = repository.get(mock_platform_resource_key)
assert result == mock_platform_resource
def test_get_non_existing(self, repository: PlatformResourceRepository) -> None:
"""Test get method for non-existing resource."""
mock_platform_resource_key: PlatformResourceKey = PlatformResourceKey(
platform="test-platform",
resource_type="test-resource-type",
platform_instance="test-instance",
primary_key="test-primary-key",
)
result: Optional[PlatformResource] = repository.get(mock_platform_resource_key)
assert result is None
def test_delete(self, repository: PlatformResourceRepository) -> None:
"""Test delete method."""
mock_platform_resource_key: PlatformResourceKey = PlatformResourceKey(
platform="test-platform",
resource_type="test-resource-type",
platform_instance="test-instance",
primary_key="test-primary-key",
)
mock_platform_resource: PlatformResource = PlatformResource.create(
key=mock_platform_resource_key, secondary_keys=[], value={}
)
# Add item to cache first
repository.cache[mock_platform_resource_key.id] = mock_platform_resource
repository.delete(mock_platform_resource_key)
repository.graph.delete_entity.assert_called_once_with( # type: ignore[attr-defined]
urn=PlatformResourceUrn(mock_platform_resource.id).urn(), hard=True
)
assert mock_platform_resource_key.id not in repository.cache
class TestCaseSensitivity:
"""Tests for CaseSensitivity enum and methods."""
def test_detect_case_sensitivity_upper(self) -> None:
"""Test detecting uppercase strings."""
result: CaseSensitivity = CaseSensitivity.detect_case_sensitivity("HELLO")
assert result == CaseSensitivity.UPPER
def test_detect_case_sensitivity_lower(self) -> None:
"""Test detecting lowercase strings."""
result: CaseSensitivity = CaseSensitivity.detect_case_sensitivity("hello")
assert result == CaseSensitivity.LOWER
def test_detect_case_sensitivity_mixed(self) -> None:
"""Test detecting mixed case strings."""
result: CaseSensitivity = CaseSensitivity.detect_case_sensitivity("Hello")
assert result == CaseSensitivity.MIXED
def test_detect_for_many_empty_list(self) -> None:
"""Test detect_for_many with empty list."""
result: CaseSensitivity = CaseSensitivity.detect_for_many([])
assert result == CaseSensitivity.MIXED
def test_detect_for_many_all_upper(self) -> None:
"""Test detect_for_many with all uppercase strings."""
result: CaseSensitivity = CaseSensitivity.detect_for_many(
["HELLO", "WORLD", "TEST"]
)
assert result == CaseSensitivity.UPPER
def test_detect_for_many_all_lower(self) -> None:
"""Test detect_for_many with all lowercase strings."""
result: CaseSensitivity = CaseSensitivity.detect_for_many(
["hello", "world", "test"]
)
assert result == CaseSensitivity.LOWER
def test_detect_for_many_mixed_cases(self) -> None:
"""Test detect_for_many with mixed case strings."""
result: CaseSensitivity = CaseSensitivity.detect_for_many(
["HELLO", "world", "Test"]
)
assert result == CaseSensitivity.MIXED
class TestLinkedResourceSet:
"""Tests for LinkedResourceSet class."""
@pytest.fixture
def empty_resource_set(self) -> LinkedResourceSet:
"""Create an empty LinkedResourceSet."""
return LinkedResourceSet(urns=[])
@pytest.fixture
def mock_urn(self) -> Mock:
"""Create a mock URN."""
urn: Mock = Mock(spec=Urn)
urn.urn.return_value = (
"urn:li:dataset:(urn:li:dataPlatform:mysql,test.table,PROD)"
)
urn.entity_type = "dataset"
urn.get_entity_id_as_string.return_value = "test.table"
return urn
@pytest.fixture
def populated_resource_set(self) -> LinkedResourceSet:
"""Create a LinkedResourceSet with some URNs."""
return LinkedResourceSet(
urns=[
"urn:li:dataset:(urn:li:dataPlatform:mysql,test.table1,PROD)",
"urn:li:dataset:(urn:li:dataPlatform:mysql,test.table2,PROD)",
]
)
def test_has_conflict_duplicate_urn(
self, populated_resource_set: LinkedResourceSet
) -> None:
"""Test _has_conflict with duplicate URN."""
urn: Urn = Urn.from_string(
"urn:li:dataset:(urn:li:dataPlatform:mysql,test.table1,PROD)"
)
result: bool = populated_resource_set._has_conflict(urn)
assert result is False # No conflict for duplicate URNs
def test_has_conflict_different_entity_types(
self, empty_resource_set: LinkedResourceSet
) -> None:
"""Test _has_conflict with different entity types."""
# Setup existing URN
empty_resource_set.urns = ["urn:li:tag2:test"]
# New URN with different entity type
new_urn: Urn = Urn.from_string("urn:li:tag:test")
result: bool = empty_resource_set._has_conflict(new_urn)
assert result is True
def test_has_conflict_invalid_existing_urn(
self, empty_resource_set: LinkedResourceSet
) -> None:
"""Test _has_conflict with invalid existing URN."""
empty_resource_set.urns = ["invalid-urn"]
new_urn: Urn = Urn.from_string(
"urn:li:dataset:(urn:li:dataPlatform:mysql,test.table,PROD)"
)
with pytest.raises(InvalidUrnError):
empty_resource_set._has_conflict(new_urn)
def test_add_new_urn_string(self, empty_resource_set: LinkedResourceSet) -> None:
"""Test adding a new URN as string."""
empty_resource_set.urns = [
"urn:li:dataset:(urn:li:dataPlatform:mysql,test.table,PROD)"
]
result: bool = empty_resource_set.add(
"urn:li:dataset:(urn:li:dataPlatform:mysql,test.table2,PROD)"
)
assert result is True
assert (
empty_resource_set.urns.count(
"urn:li:dataset:(urn:li:dataPlatform:mysql,test.table2,PROD)"
)
== 1
)
assert (
empty_resource_set.urns.count(
"urn:li:dataset:(urn:li:dataPlatform:mysql,test.table,PROD)"
)
== 1
)
def test_add_new_urn_object(
self, empty_resource_set: LinkedResourceSet, mock_urn: Mock
) -> None:
"""Test adding a new URN as Urn object."""
empty_resource_set.urns = [
"urn:li:dataset:(urn:li:dataPlatform:mysql,test.table,PROD)"
]
result: bool = empty_resource_set.add(
Urn.from_string(
"urn:li:dataset:(urn:li:dataPlatform:mysql,test.table2,PROD)"
)
)
assert result is True
assert (
empty_resource_set.urns.count(
"urn:li:dataset:(urn:li:dataPlatform:mysql,test.table2,PROD)"
)
== 1
)
assert (
empty_resource_set.urns.count(
"urn:li:dataset:(urn:li:dataPlatform:mysql,test.table,PROD)"
)
== 1
)
def test_add_duplicate_urn(
self, empty_resource_set: LinkedResourceSet, mock_urn: Mock
) -> None:
"""Test adding a duplicate URN."""
empty_resource_set.urns = [
"urn:li:dataset:(urn:li:dataPlatform:mysql,test.table,PROD)"
]
result: bool = empty_resource_set.add(
Urn.from_string(
"urn:li:dataset:(urn:li:dataPlatform:mysql,test.table,PROD)"
)
)
assert result is False # Already exists
assert (
empty_resource_set.urns.count(
"urn:li:dataset:(urn:li:dataPlatform:mysql,test.table,PROD)"
)
== 1
)
def test_add_conflicting_urn(
self, empty_resource_set: LinkedResourceSet, mock_urn: Mock
) -> None:
"""Test adding a conflicting URN raises ValueError."""
empty_resource_set.urns = [
"urn:li:dataset:(urn:li:dataPlatform:mysql,test.table,PROD)"
]
result: bool = empty_resource_set._has_conflict(
Urn.from_string("urn:li:tag:myTag")
)
assert result is True # Conflict detected
def test_add_deduplicates_urns(self, empty_resource_set: LinkedResourceSet) -> None:
"""Test that add method deduplicates existing URNs."""
empty_resource_set.urns = [
"urn:li:dataset:(urn:li:dataPlatform:mysql,test.table,PROD)",
"urn:li:dataset:(urn:li:dataPlatform:mysql,test.table,PROD)",
]
result: bool = empty_resource_set.add(
Urn.from_string(
"urn:li:dataset:(urn:li:dataPlatform:mysql,test.table,PROD)"
)
)
assert result is False # Already exists after deduplication
assert (
empty_resource_set.urns.count(
"urn:li:dataset:(urn:li:dataPlatform:mysql,test.table,PROD)"
)
== 1
)
class TestMissingExternalEntity:
"""Tests for MissingExternalEntity class."""
@pytest.fixture
def mock_external_entity_id(self) -> Mock:
"""Create a mock ExternalEntityId."""
return Mock(spec=ExternalEntityId)
@pytest.fixture
def missing_entity(self, mock_external_entity_id: Mock) -> MissingExternalEntity:
"""Create a MissingExternalEntity instance."""
return MissingExternalEntity(id=mock_external_entity_id)
def test_is_managed_by_datahub(self, missing_entity: MissingExternalEntity) -> None:
"""Test is_managed_by_datahub returns False."""
result: bool = missing_entity.is_managed_by_datahub()
assert result is False
def test_datahub_linked_resources(
self, missing_entity: MissingExternalEntity
) -> None:
"""Test datahub_linked_resources returns empty LinkedResourceSet."""
result: LinkedResourceSet = missing_entity.datahub_linked_resources()
assert isinstance(result, LinkedResourceSet)
assert result.urns == []
def test_as_platform_resource(self, missing_entity: MissingExternalEntity) -> None:
"""Test as_platform_resource returns None."""
result: Optional[PlatformResource] = missing_entity.as_platform_resource()
assert result is None
def test_get_id(
self, missing_entity: MissingExternalEntity, mock_external_entity_id: Mock
) -> None:
"""Test get_id returns the correct id."""
result: ExternalEntityId = missing_entity.get_id()
assert result == mock_external_entity_id
class TestIntegration:
"""Integration tests combining multiple classes."""
@pytest.fixture
def mock_graph(self) -> Mock:
return Mock(spec=DataHubGraph)
@pytest.fixture
def repository(self, mock_graph: Mock) -> PlatformResourceRepository:
return PlatformResourceRepository(mock_graph)
def test_repository_with_linked_resource_set(
self, repository: PlatformResourceRepository
) -> None:
"""Test repository operations with LinkedResourceSet."""
# Create a mock platform resource with linked URNs
mock_resource: Mock = Mock(spec=PlatformResource)
mock_resource.id = "test-resource"
# Test create and cache
repository.create(mock_resource)
assert repository.cache["test-resource"] == mock_resource
def test_case_sensitivity_with_linked_resources(self) -> None:
"""Test case sensitivity detection with LinkedResourceSet."""
# Test with mixed case detection
test_values: List[str] = ["TABLE1", "table2", "Table3"]
result: CaseSensitivity = CaseSensitivity.detect_for_many(test_values)
assert result == CaseSensitivity.MIXED
# Additional edge case tests
class TestEdgeCases:
"""Tests for edge cases and error conditions."""
def test_cache_ttl_expiration(self) -> None:
"""Test that cache TTL works correctly."""
mock_graph: Mock = Mock(spec=DataHubGraph)
repo: PlatformResourceRepository = PlatformResourceRepository(mock_graph)
# Add item to cache
repo.cache["test-key"] = "test-value"
assert "test-key" in repo.cache
# Cache should still contain the item within TTL
cached_value: Optional[str] = repo.cache.get("test-key")
assert cached_value == "test-value"
def test_linked_resource_set_with_empty_urns(self) -> None:
"""Test LinkedResourceSet behavior with empty URN list."""
test_urn: str = "urn:li:dataset:(urn:li:dataPlatform:mysql,test.table,PROD)"
resource_set: LinkedResourceSet = LinkedResourceSet(urns=[])
result: bool = resource_set.add(test_urn)
assert result is True
assert test_urn in resource_set.urns
def test_case_sensitivity_enum_values(self) -> None:
"""Test CaseSensitivity enum values."""
assert CaseSensitivity.UPPER.value == "upper"
assert CaseSensitivity.LOWER.value == "lower"
assert CaseSensitivity.MIXED.value == "mixed"