datahub/metadata-ingestion/tests/unit/test_mongodb_source.py

449 lines
14 KiB
Python
Raw Normal View History

from typing import Any, Dict, List
from unittest.mock import MagicMock, patch
import bson
import pytest
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.source.mongodb import MongoDBConfig, MongoDBSource
from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeProposal
from datahub.metadata.schema_classes import (
ContainerPropertiesClass,
DatasetPropertiesClass,
SchemaMetadataClass,
)
from datahub.utilities.urns.urn import guess_entity_type
@pytest.fixture
def mock_mongo_client():
with patch("datahub.ingestion.source.mongodb.MongoClient") as mock_client:
mock_instance = MagicMock()
mock_client.return_value = mock_instance
mock_instance.admin.command.return_value = {"ok": 1}
yield mock_instance
@pytest.fixture
def pipeline_context():
return PipelineContext(run_id="test-mongodb-run")
def test_mongodb_schema_inference_respects_max_schema_size(
mock_mongo_client, pipeline_context
):
"""
Test that maxSchemaSize limits the number of fields in the output schema.
This ensures the schema downsampling works correctly when there are too
many fields.
"""
mock_mongo_client.list_database_names.return_value = ["test_db"]
mock_database = MagicMock()
mock_mongo_client.__getitem__.return_value = mock_database
mock_database.list_collection_names.return_value = ["wide_table"]
mock_collection = MagicMock()
mock_database.__getitem__.return_value = mock_collection
mock_document: Dict["str", Any] = {
"_id": bson.ObjectId("507f1f77bcf86cd799439011"),
}
for i in range(50):
mock_document[f"field_{i}"] = f"value_{i}"
mock_collection.aggregate.return_value = [mock_document]
config = MongoDBConfig(
connect_uri="mongodb://localhost:27017",
enableSchemaInference=True,
maxSchemaSize=10,
)
source = MongoDBSource(ctx=pipeline_context, config=config)
workunits = list(source.get_workunits_internal())
schema_metadata_aspects = [
aspect
for wu in workunits
if (aspect := wu.get_aspect_of_type(SchemaMetadataClass))
]
assert len(schema_metadata_aspects) == 1
schema_metadata = schema_metadata_aspects[0]
assert len(schema_metadata.fields) <= 10
dataset_properties_aspects: List[DatasetPropertiesClass] = [
properties_aspect
for wu in workunits
if (properties_aspect := wu.get_aspect_of_type(DatasetPropertiesClass))
]
assert len(dataset_properties_aspects) == 1
custom_props = dataset_properties_aspects[0].customProperties
assert "schema.downsampled" in custom_props
assert custom_props["schema.downsampled"] == "True"
def test_mongodb_complex_schema_trimming(mock_mongo_client, pipeline_context):
"""
Test that maxSchemaSize limits the number of fields in the output schema.
This ensures the schema downsampling works correctly when there are too
many fields.
"""
mock_mongo_client.list_database_names.return_value = ["test_db"]
mock_database = MagicMock()
mock_mongo_client.__getitem__.return_value = mock_database
mock_database.list_collection_names.return_value = ["wide_table"]
mock_collection = MagicMock()
mock_database.__getitem__.return_value = mock_collection
mock_mongo_client.list_database_names.return_value = ["test_db"]
mock_database = MagicMock()
mock_mongo_client.__getitem__.return_value = mock_database
mock_database.list_collection_names.return_value = ["events"]
mock_collection = MagicMock()
mock_database.__getitem__.return_value = mock_collection
mock_documents = [
{
"_id": bson.ObjectId("507f1f77bcf86cd799439011"),
"metadata": {
"source": "web",
"properties": [
{"name": "user_agent", "value": "Mozilla/5.0"},
{"name": "ip_address2"},
{"name": "ip_address3"},
{"name": "ip_address4"},
],
"tags": {"foo": "bar"},
},
},
{
"_id": bson.ObjectId("507f1f77bcf86cd799439012"),
},
]
mock_collection.aggregate.return_value = mock_documents
config = MongoDBConfig(
connect_uri="mongodb://localhost:27017",
enableSchemaInference=True,
maxSchemaSize=3,
)
source = MongoDBSource(ctx=pipeline_context, config=config)
workunits = list(source.get_workunits_internal())
schema_metadata_aspects = [
aspect
for wu in workunits
if (aspect := wu.get_aspect_of_type(SchemaMetadataClass))
]
assert len(schema_metadata_aspects) == 1
schema_metadata = schema_metadata_aspects[0]
field_paths = [field.fieldPath for field in schema_metadata.fields]
assert field_paths == [
"metadata",
"metadata.properties",
"metadata.properties.name",
]
dataset_properties_aspects = [
properties_aspect
for wu in workunits
if (properties_aspect := wu.get_aspect_of_type(DatasetPropertiesClass))
]
assert len(dataset_properties_aspects) == 1
custom_props = dataset_properties_aspects[0].customProperties
assert "schema.downsampled" in custom_props
assert custom_props["schema.downsampled"] == "True"
def test_mongodb_schema_inference_with_deeply_nested_structures(
mock_mongo_client, pipeline_context
):
"""
Test schema inference with deeply nested objects and arrays.
Ensures parent counts bubble up correctly through multiple levels.
"""
mock_mongo_client.list_database_names.return_value = ["test_db"]
mock_database = MagicMock()
mock_mongo_client.__getitem__.return_value = mock_database
mock_database.list_collection_names.return_value = ["events"]
mock_collection = MagicMock()
mock_database.__getitem__.return_value = mock_collection
mock_documents = [
{
"_id": bson.ObjectId("507f1f77bcf86cd799439011"),
"metadata": {
"source": "web",
"properties": [
{"name": "user_agent", "value": "Mozilla/5.0"},
{"name": "ip_address", "value": "192.168.1.1"},
],
"tags": {"foo": "bar"},
},
},
{
"_id": bson.ObjectId("507f1f77bcf86cd799439012"),
"metadata": {
"source": "mobile",
"properties": [
{"name": "device_id", "value": "abc123"},
],
},
},
]
mock_collection.aggregate.return_value = mock_documents
config = MongoDBConfig(
connect_uri="mongodb://localhost:27017", enableSchemaInference=True
)
source = MongoDBSource(ctx=pipeline_context, config=config)
workunits = list(source.get_workunits_internal())
schema_metadata_aspects = [
aspect
for wu in workunits
if (aspect := wu.get_aspect_of_type(SchemaMetadataClass))
]
assert len(schema_metadata_aspects) == 1
schema_metadata = schema_metadata_aspects[0]
field_paths = {field.fieldPath for field in schema_metadata.fields}
expected_paths = {
"metadata",
"metadata.source",
"metadata.properties",
"metadata.properties.name",
"metadata.properties.value",
"metadata.tags",
"metadata.tags.foo",
"_id",
}
assert field_paths == expected_paths
def test_mongodb_schema_inference_disabled(mock_mongo_client, pipeline_context):
mock_mongo_client.list_database_names.return_value = ["test_db"]
mock_database = MagicMock()
mock_mongo_client.__getitem__.return_value = mock_database
mock_database.list_collection_names.return_value = ["collection1"]
mock_collection = MagicMock()
mock_database.__getitem__.return_value = mock_collection
mock_documents = [
{
"_id": bson.ObjectId("507f1f77bcf86cd799439011"),
"metadata": {
"source": "web",
"properties": [
{"name": "user_agent", "value": "Mozilla/5.0"},
{"name": "ip_address", "value": "192.168.1.1"},
],
"tags": {"foo": "bar"},
},
},
{
"_id": bson.ObjectId("507f1f77bcf86cd799439012"),
"metadata": {
"source": "mobile",
"properties": [
{"name": "device_id", "value": "abc123"},
],
},
},
]
mock_collection.aggregate.return_value = mock_documents
config = MongoDBConfig(
connect_uri="mongodb://localhost:27017", enableSchemaInference=False
)
source = MongoDBSource(ctx=pipeline_context, config=config)
workunits = list(source.get_workunits_internal())
schema_metadata_aspects = [
aspect
for wu in workunits
if (aspect := wu.get_aspect_of_type(SchemaMetadataClass))
]
assert len(schema_metadata_aspects) == 0
def test_mongodb_schema_field_ordering_with_arrays(mock_mongo_client, pipeline_context):
"""
Test that schema fields are ordered correctly with parent fields before children.
This is critical for the MongoDB fix that ensures parent containers appear
before their nested children in the UI, especially when schemas are trimmed.
"""
mock_mongo_client.list_database_names.return_value = ["test_db"]
mock_database = MagicMock()
mock_mongo_client.__getitem__.return_value = mock_database
mock_database.list_collection_names.return_value = ["orders"]
mock_collection = MagicMock()
mock_database.__getitem__.return_value = mock_collection
mock_documents = [
{
"_id": bson.ObjectId("507f1f77bcf86cd799439011"),
"order_id": "ORD-001",
"items": [
{"sku": "ITEM-A", "quantity": 2, "price": 10.50},
{"sku": "ITEM-B", "quantity": 1, "price": 25.00},
],
},
{
"_id": bson.ObjectId("507f1f77bcf86cd799439012"),
"order_id": "ORD-002",
"items": [
{"sku": "ITEM-C", "quantity": 3, "price": 15.75},
],
},
]
mock_collection.aggregate.return_value = mock_documents
config = MongoDBConfig(
connect_uri="mongodb://localhost:27017", enableSchemaInference=True
)
source = MongoDBSource(ctx=pipeline_context, config=config)
workunits = list(source.get_workunits_internal())
schema_metadata_aspects = [
aspect
for wu in workunits
if (aspect := wu.get_aspect_of_type(SchemaMetadataClass))
]
assert len(schema_metadata_aspects) == 1
schema_metadata = schema_metadata_aspects[0]
field_paths = [field.fieldPath for field in schema_metadata.fields]
items_index = field_paths.index("items")
items_sku_index = field_paths.index("items.sku")
items_quantity_index = field_paths.index("items.quantity")
items_price_index = field_paths.index("items.price")
assert items_index < items_sku_index
assert items_index < items_quantity_index
assert items_index < items_price_index
def test_mongodb_database_filtering(mock_mongo_client, pipeline_context):
mock_mongo_client.list_database_names.return_value = [
"prod_db",
"prod_instances",
"test_db",
"dev_db",
"admin",
"config",
"local",
]
mock_database = MagicMock()
mock_mongo_client.__getitem__.return_value = mock_database
mock_database.list_collection_names.return_value = []
config = MongoDBConfig(
connect_uri="mongodb://localhost:27017",
enableSchemaInference=True,
database_pattern={"allow": ["^prod_.*"]},
)
source = MongoDBSource(ctx=pipeline_context, config=config)
workunits = list(source.get_workunits_internal())
databases_ingested: List[str] = [
aspect.name
for wu in workunits
if (aspect := wu.get_aspect_of_type(ContainerPropertiesClass))
]
assert databases_ingested == ["prod_db", "prod_instances"]
filtered_list = list(source.report.filtered)
assert "test_db" in filtered_list
assert "dev_db" in filtered_list
# internal databases, skipped by default, not reported
assert "admin" not in filtered_list
assert "config" not in filtered_list
assert "local" not in filtered_list
# allowed databases
assert "prod_db" not in filtered_list
assert "prod_instances" not in filtered_list
def test_mongodb_collection_filtering(mock_mongo_client, pipeline_context):
mock_mongo_client.list_database_names.return_value = ["test_db"]
mock_database = MagicMock()
mock_mongo_client.__getitem__.return_value = mock_database
mock_database.list_collection_names.return_value = [
"users",
"orders",
"temp_data",
"cache_entries",
]
mock_collection = MagicMock()
mock_database.__getitem__.return_value = mock_collection
mock_collection.aggregate.return_value = []
config = MongoDBConfig(
connect_uri="mongodb://localhost:27017",
enableSchemaInference=True,
collection_pattern={"deny": ["test_db\\.(temp_.*|cache_.*)"]},
)
source = MongoDBSource(ctx=pipeline_context, config=config)
workunits = list(source.get_workunits_internal())
dataset_urns = {
wu.metadata.entityUrn
for wu in workunits
if isinstance(
wu.metadata, (MetadataChangeProposal, MetadataChangeProposalWrapper)
)
and wu.metadata.entityUrn
and guess_entity_type(wu.metadata.entityUrn) == "dataset"
}
assert dataset_urns == {
"urn:li:dataset:(urn:li:dataPlatform:mongodb,test_db.orders,PROD)",
"urn:li:dataset:(urn:li:dataPlatform:mongodb,test_db.users,PROD)",
}
filtered_list = list(source.report.filtered)
assert "test_db.temp_data" in filtered_list
assert "test_db.cache_entries" in filtered_list