feat(s3/ingest): performance improvements for get_dir_to_process and get_folder_info (#14709)

This commit is contained in:
Michael Maltese 2025-10-02 09:51:02 -04:00 committed by GitHub
parent f7ea7f033d
commit 5da54bf14d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 368 additions and 185 deletions

View File

@ -126,7 +126,10 @@ def list_folders_path(
def list_objects_recursive_path( def list_objects_recursive_path(
s3_uri: str, *, startswith: str, aws_config: Optional[AwsConnectionConfig] s3_uri: str,
*,
startswith: str = "",
aws_config: Optional[AwsConnectionConfig] = None,
) -> Iterable["ObjectSummary"]: ) -> Iterable["ObjectSummary"]:
""" """
Given an S3 URI to a folder or bucket, return all objects underneath that URI, optionally Given an S3 URI to a folder or bucket, return all objects underneath that URI, optionally

View File

@ -194,6 +194,9 @@ class PathSpec(ConfigModel):
return True return True
def dir_allowed(self, path: str) -> bool: def dir_allowed(self, path: str) -> bool:
if not path.endswith("/"):
path += "/"
if self.glob_include.endswith("**"): if self.glob_include.endswith("**"):
return self.allowed(path, ignore_ext=True) return self.allowed(path, ignore_ext=True)
@ -221,9 +224,8 @@ class PathSpec(ConfigModel):
): ):
return False return False
file_name_pattern = self.include.rsplit("/", 1)[1]
table_name, _ = self.extract_table_name_and_path( table_name, _ = self.extract_table_name_and_path(
os.path.join(path, file_name_pattern) path + self.get_remaining_glob_include(path)
) )
if not self.tables_filter_pattern.allowed(table_name): if not self.tables_filter_pattern.allowed(table_name):
return False return False
@ -571,3 +573,38 @@ class PathSpec(ConfigModel):
"/".join(path.split("/")[:depth]) + "/" + parsed_vars.named["table"] "/".join(path.split("/")[:depth]) + "/" + parsed_vars.named["table"]
) )
return self._extract_table_name(parsed_vars.named), table_path return self._extract_table_name(parsed_vars.named), table_path
def has_correct_number_of_directory_components(self, path: str) -> bool:
"""
Checks that a given path has the same number of components as the path spec
has directory components. Useful for checking if a path needs to descend further
into child directories or if the source can switch into file listing mode. If the
glob form of the path spec ends in "**", this always returns False.
"""
if self.glob_include.endswith("**"):
return False
if not path.endswith("/"):
path += "/"
path_slash = path.count("/")
glob_slash = self.glob_include.count("/")
if path_slash == glob_slash:
return True
return False
def get_remaining_glob_include(self, path: str) -> str:
"""
Given a path, return the remaining components of the path spec (if any
exist) in glob form. If the glob form of the path spec ends in "**", this
function's return value also always ends in "**", regardless of how
many components the input path has.
"""
if not path.endswith("/"):
path += "/"
path_slash = path.count("/")
remainder = "/".join(self.glob_include.split("/")[path_slash:])
if remainder:
return remainder
if self.glob_include.endswith("**"):
return "**"
return ""

View File

@ -3,14 +3,14 @@ import functools
import logging import logging
import os import os
import pathlib import pathlib
import posixpath
import re import re
import time import time
from datetime import datetime from datetime import datetime
from pathlib import PurePath from pathlib import PurePath
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Tuple
import smart_open.compression as so_compression import smart_open.compression as so_compression
from more_itertools import peekable
from pyspark.conf import SparkConf from pyspark.conf import SparkConf
from pyspark.sql import SparkSession from pyspark.sql import SparkSession
from pyspark.sql.dataframe import DataFrame from pyspark.sql.dataframe import DataFrame
@ -36,9 +36,7 @@ from datahub.ingestion.api.source import MetadataWorkUnitProcessor
from datahub.ingestion.api.workunit import MetadataWorkUnit from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.aws.s3_boto_utils import ( from datahub.ingestion.source.aws.s3_boto_utils import (
get_s3_tags, get_s3_tags,
list_folders,
list_folders_path, list_folders_path,
list_objects_recursive,
list_objects_recursive_path, list_objects_recursive_path,
) )
from datahub.ingestion.source.aws.s3_util import ( from datahub.ingestion.source.aws.s3_util import (
@ -83,9 +81,6 @@ from datahub.metadata.schema_classes import (
from datahub.telemetry import stats, telemetry from datahub.telemetry import stats, telemetry
from datahub.utilities.perf_timer import PerfTimer from datahub.utilities.perf_timer import PerfTimer
if TYPE_CHECKING:
from mypy_boto3_s3.service_resource import Bucket
# hide annoying debug errors from py4j # hide annoying debug errors from py4j
logging.getLogger("py4j").setLevel(logging.ERROR) logging.getLogger("py4j").setLevel(logging.ERROR)
logger: logging.Logger = logging.getLogger(__name__) logger: logging.Logger = logging.getLogger(__name__)
@ -872,35 +867,44 @@ class S3Source(StatefulIngestionSourceBase):
def get_dir_to_process( def get_dir_to_process(
self, self,
bucket_name: str, uri: str,
folder: str,
path_spec: PathSpec, path_spec: PathSpec,
protocol: str,
min: bool = False, min: bool = False,
) -> List[str]: ) -> List[str]:
# if len(path_spec.include.split("/")) == len(f"{protocol}{bucket_name}/{folder}".split("/")): # Add any remaining parts of the path_spec before globs, excluding the
# return [f"{protocol}{bucket_name}/{folder}"] # final filename component, to the URI and prefix so that we don't
# unnecessarily list too many objects.
if not uri.endswith("/"):
uri += "/"
remaining = posixpath.dirname(path_spec.get_remaining_glob_include(uri)).split(
"*"
)[0]
uri += posixpath.dirname(remaining)
prefix = posixpath.basename(remaining)
iterator = list_folders( # Check if we're at the end of the include path. If so, no need to list sub-folders.
bucket_name=bucket_name, if path_spec.has_correct_number_of_directory_components(uri):
prefix=folder, return [uri]
logger.debug(f"get_dir_to_process listing folders {uri=} {prefix=}")
iterator = list_folders_path(
s3_uri=uri,
startswith=prefix,
aws_config=self.source_config.aws_config, aws_config=self.source_config.aws_config,
) )
iterator = peekable(iterator)
if iterator:
sorted_dirs = sorted( sorted_dirs = sorted(
iterator, iterator,
key=functools.cmp_to_key(partitioned_folder_comparator), key=lambda dir: functools.cmp_to_key(partitioned_folder_comparator)(
dir.name
),
reverse=not min, reverse=not min,
) )
folders = [] folders = []
for dir in sorted_dirs: for dir in sorted_dirs:
if path_spec.dir_allowed(f"{protocol}{bucket_name}/{dir}/"): if path_spec.dir_allowed(dir.path):
folders_list = self.get_dir_to_process( folders_list = self.get_dir_to_process(
bucket_name=bucket_name, uri=dir.path,
folder=dir + "/",
path_spec=path_spec, path_spec=path_spec,
protocol=protocol,
min=min, min=min,
) )
folders.extend(folders_list) folders.extend(folders_list)
@ -909,18 +913,16 @@ class S3Source(StatefulIngestionSourceBase):
if folders: if folders:
return folders return folders
else: else:
return [f"{protocol}{bucket_name}/{folder}"] return [uri]
return [f"{protocol}{bucket_name}/{folder}"]
def get_folder_info( def get_folder_info(
self, self,
path_spec: PathSpec, path_spec: PathSpec,
bucket: "Bucket", uri: str,
prefix: str,
) -> Iterable[Folder]: ) -> Iterable[Folder]:
""" """
Retrieves all the folders in a path by listing all the files in the prefix. Retrieves all the folders in a path by recursively listing all the files under the
If the prefix is a full path then only that folder will be extracted. given URI.
A folder has creation and modification times, size, and a sample file path. A folder has creation and modification times, size, and a sample file path.
- Creation time is the earliest creation time of all files in the folder. - Creation time is the earliest creation time of all files in the folder.
@ -930,8 +932,7 @@ class S3Source(StatefulIngestionSourceBase):
Parameters: Parameters:
path_spec (PathSpec): The path specification used to determine partitioning. path_spec (PathSpec): The path specification used to determine partitioning.
bucket (Bucket): The S3 bucket object. uri (str): The path in the S3 bucket to list objects from.
prefix (str): The prefix path in the S3 bucket to list objects from.
Returns: Returns:
List[Folder]: A list of Folder objects representing the partitions found. List[Folder]: A list of Folder objects representing the partitions found.
@ -947,12 +948,22 @@ class S3Source(StatefulIngestionSourceBase):
self.report.report_file_dropped(s3_uri) self.report.report_file_dropped(s3_uri)
return allowed return allowed
# Add any remaining parts of the path_spec before globs to the URI and prefix,
# so that we don't unnecessarily list too many objects.
if not uri.endswith("/"):
uri += "/"
remaining = path_spec.get_remaining_glob_include(uri).split("*")[0]
uri += posixpath.dirname(remaining)
prefix = posixpath.basename(remaining)
# Process objects in a memory-efficient streaming fashion # Process objects in a memory-efficient streaming fashion
# Instead of loading all objects into memory, we'll accumulate folder data incrementally # Instead of loading all objects into memory, we'll accumulate folder data incrementally
folder_data: Dict[str, FolderInfo] = {} # dirname -> FolderInfo folder_data: Dict[str, FolderInfo] = {} # dirname -> FolderInfo
for obj in list_objects_recursive( logger.info(f"Listing objects under {repr(uri)} with {prefix=}")
bucket.name, prefix, self.source_config.aws_config
for obj in list_objects_recursive_path(
uri, startswith=prefix, aws_config=self.source_config.aws_config
): ):
s3_path = self.create_s3_path(obj.bucket_name, obj.key) s3_path = self.create_s3_path(obj.bucket_name, obj.key)
@ -1047,7 +1058,7 @@ class S3Source(StatefulIngestionSourceBase):
# This creates individual file-level datasets # This creates individual file-level datasets
yield from self._process_simple_path(path_spec) yield from self._process_simple_path(path_spec)
def _process_templated_path(self, path_spec: PathSpec) -> Iterable[BrowsePath]: # noqa: C901 def _process_templated_path(self, path_spec: PathSpec) -> Iterable[BrowsePath]:
""" """
Process S3 paths containing {table} templates to create table-level datasets. Process S3 paths containing {table} templates to create table-level datasets.
@ -1133,20 +1144,12 @@ class S3Source(StatefulIngestionSourceBase):
# STEP 4: Process each table folder to create a table-level dataset # STEP 4: Process each table folder to create a table-level dataset
for folder in table_folders: for folder in table_folders:
bucket_name = get_bucket_name(folder.path) logger.info(f"Processing table path: {folder.path}")
table_folder = get_bucket_relative_path(folder.path)
bucket = s3.Bucket(bucket_name)
# Create the full S3 path for this table
table_s3_path = self.create_s3_path(bucket_name, table_folder)
logger.info(
f"Processing table folder: {table_folder} -> {table_s3_path}"
)
# Extract table name using the ORIGINAL path spec pattern matching (not the modified one) # Extract table name using the ORIGINAL path spec pattern matching (not the modified one)
# This uses the compiled regex pattern to extract the table name from the full path # This uses the compiled regex pattern to extract the table name from the full path
table_name, _ = self.extract_table_name_and_path( table_name, _ = self.extract_table_name_and_path(
path_spec, table_s3_path path_spec, folder.path
) )
# Apply table name filtering if configured # Apply table name filtering if configured
@ -1155,90 +1158,52 @@ class S3Source(StatefulIngestionSourceBase):
continue continue
# STEP 5: Handle partition traversal based on configuration # STEP 5: Handle partition traversal based on configuration
# Get all partition folders first
all_partition_folders = list(
list_folders(
bucket_name, table_folder, self.source_config.aws_config
)
)
logger.info(
f"Found {len(all_partition_folders)} partition folders under table {table_name} using method {path_spec.traversal_method}"
)
if all_partition_folders:
# Apply the same traversal logic as the original code
dirs_to_process = [] dirs_to_process = []
if path_spec.traversal_method == FolderTraversalMethod.ALL: if path_spec.traversal_method == FolderTraversalMethod.ALL:
# Process ALL partitions (original behavior) # Process ALL partitions (original behavior)
dirs_to_process = all_partition_folders dirs_to_process = [folder.path]
logger.debug( logger.debug(
f"Processing ALL {len(all_partition_folders)} partitions" f"Processing ALL partition folders under: {folder.path}"
) )
else: else:
# Use the original get_dir_to_process logic for MIN/MAX # Use the original get_dir_to_process logic for MIN/MAX
protocol = "s3://" # Default protocol for S3
if ( if (
path_spec.traversal_method path_spec.traversal_method == FolderTraversalMethod.MIN_MAX
== FolderTraversalMethod.MIN_MAX or path_spec.traversal_method == FolderTraversalMethod.MAX
or path_spec.traversal_method
== FolderTraversalMethod.MAX
): ):
# Get MAX partition using original logic # Get MAX partition using original logic
dirs_to_process_max = self.get_dir_to_process( dirs_to_process_max = self.get_dir_to_process(
bucket_name=bucket_name, uri=folder.path,
folder=table_folder + "/",
path_spec=path_spec, path_spec=path_spec,
protocol=protocol,
min=False, min=False,
) )
if dirs_to_process_max: if dirs_to_process_max:
# Convert full S3 paths back to relative paths for processing dirs_to_process.extend(dirs_to_process_max)
dirs_to_process.extend(
[
d.replace(f"{protocol}{bucket_name}/", "")
for d in dirs_to_process_max
]
)
logger.debug( logger.debug(
f"Added MAX partition: {dirs_to_process_max}" f"Added MAX partition: {dirs_to_process_max}"
) )
if ( if path_spec.traversal_method == FolderTraversalMethod.MIN_MAX:
path_spec.traversal_method
== FolderTraversalMethod.MIN_MAX
):
# Get MIN partition using original logic # Get MIN partition using original logic
dirs_to_process_min = self.get_dir_to_process( dirs_to_process_min = self.get_dir_to_process(
bucket_name=bucket_name, uri=folder.path,
folder=table_folder + "/",
path_spec=path_spec, path_spec=path_spec,
protocol=protocol,
min=True, min=True,
) )
if dirs_to_process_min: if dirs_to_process_min:
# Convert full S3 paths back to relative paths for processing dirs_to_process.extend(dirs_to_process_min)
dirs_to_process.extend(
[
d.replace(f"{protocol}{bucket_name}/", "")
for d in dirs_to_process_min
]
)
logger.debug( logger.debug(
f"Added MIN partition: {dirs_to_process_min}" f"Added MIN partition: {dirs_to_process_min}"
) )
# Process the selected partitions # Process the selected partitions
all_folders = [] all_folders = []
for partition_folder in dirs_to_process: for partition_path in dirs_to_process:
# Ensure we have a clean folder path logger.info(f"Scanning files in partition: {partition_path}")
clean_folder = partition_folder.rstrip("/")
logger.info(f"Scanning files in partition: {clean_folder}")
partition_files = list( partition_files = list(
self.get_folder_info(path_spec, bucket, clean_folder) self.get_folder_info(path_spec, partition_path)
) )
all_folders.extend(partition_files) all_folders.extend(partition_files)
@ -1267,10 +1232,6 @@ class S3Source(StatefulIngestionSourceBase):
logger.warning( logger.warning(
f"No files found in processed partitions for table {table_name}" f"No files found in processed partitions for table {table_name}"
) )
else:
logger.warning(
f"No partition folders found under table {table_name}"
)
except Exception as e: except Exception as e:
if isinstance(e, s3.meta.client.exceptions.NoSuchBucket): if isinstance(e, s3.meta.client.exceptions.NoSuchBucket):

View File

@ -8,13 +8,13 @@
"json": { "json": {
"customProperties": { "customProperties": {
"schema_inferred_from": "gs://my-test-bucket/folder_a/folder_aa/folder_aaa/folder_aaaa/pokemon_abilities_yearwise_2021/month=march/part2.json", "schema_inferred_from": "gs://my-test-bucket/folder_a/folder_aa/folder_aaa/folder_aaaa/pokemon_abilities_yearwise_2021/month=march/part2.json",
"number_of_partitions": "6" "number_of_partitions": "1"
}, },
"externalUrl": "https://console.cloud.google.com/storage/browser/my-test-bucket/folder_a/folder_aa/folder_aaa/folder_aaaa", "externalUrl": "https://console.cloud.google.com/storage/browser/my-test-bucket/folder_a/folder_aa/folder_aaa/folder_aaaa",
"name": "folder_aaaa", "name": "folder_aaaa",
"description": "", "description": "",
"created": { "created": {
"time": 1586847680000 "time": 1586847780000
}, },
"lastModified": { "lastModified": {
"time": 1586847790000 "time": 1586847790000
@ -628,9 +628,9 @@
"aspect": { "aspect": {
"json": { "json": {
"minPartition": { "minPartition": {
"partition": "partition_0=pokemon_abilities_yearwise_2019/partition_1=month=feb", "partition": "partition_0=pokemon_abilities_yearwise_2021/partition_1=month=march",
"createdTime": 1586847680000, "createdTime": 1586847780000,
"lastModifiedTime": 1586847690000 "lastModifiedTime": 1586847790000
}, },
"maxPartition": { "maxPartition": {
"partition": "partition_0=pokemon_abilities_yearwise_2021/partition_1=month=march", "partition": "partition_0=pokemon_abilities_yearwise_2021/partition_1=month=march",

View File

@ -8,13 +8,13 @@
"json": { "json": {
"customProperties": { "customProperties": {
"schema_inferred_from": "s3://my-test-bucket/folder_a/folder_aa/folder_aaa/folder_aaaa/pokemon_abilities_yearwise_2021/month=march/part2.json", "schema_inferred_from": "s3://my-test-bucket/folder_a/folder_aa/folder_aaa/folder_aaaa/pokemon_abilities_yearwise_2021/month=march/part2.json",
"number_of_partitions": "6" "number_of_partitions": "1"
}, },
"externalUrl": "https://us-east-1.console.aws.amazon.com/s3/buckets/my-test-bucket?prefix=folder_a/folder_aa/folder_aaa/folder_aaaa", "externalUrl": "https://us-east-1.console.aws.amazon.com/s3/buckets/my-test-bucket?prefix=folder_a/folder_aa/folder_aaa/folder_aaaa",
"name": "folder_aaaa", "name": "folder_aaaa",
"description": "", "description": "",
"created": { "created": {
"time": 1586847680000 "time": 1586847780000
}, },
"lastModified": { "lastModified": {
"time": 1586847790000 "time": 1586847790000
@ -628,9 +628,9 @@
"aspect": { "aspect": {
"json": { "json": {
"minPartition": { "minPartition": {
"partition": "partition_0=pokemon_abilities_yearwise_2019/partition_1=month=feb", "partition": "partition_0=pokemon_abilities_yearwise_2021/partition_1=month=march",
"createdTime": 1586847680000, "createdTime": 1586847780000,
"lastModifiedTime": 1586847690000 "lastModifiedTime": 1586847790000
}, },
"maxPartition": { "maxPartition": {
"partition": "partition_0=pokemon_abilities_yearwise_2021/partition_1=month=march", "partition": "partition_0=pokemon_abilities_yearwise_2021/partition_1=month=march",

View File

@ -2,7 +2,7 @@ import json
import logging import logging
import os import os
from datetime import datetime from datetime import datetime
from unittest.mock import patch from unittest.mock import Mock, call, patch
import moto.s3 import moto.s3
import pytest import pytest
@ -11,6 +11,11 @@ from moto import mock_s3
from pydantic import ValidationError from pydantic import ValidationError
from datahub.ingestion.run.pipeline import Pipeline, PipelineContext from datahub.ingestion.run.pipeline import Pipeline, PipelineContext
from datahub.ingestion.source.aws.aws_common import AwsConnectionConfig
from datahub.ingestion.source.aws.s3_boto_utils import (
list_folders_path,
list_objects_recursive_path,
)
from datahub.ingestion.source.s3.source import S3Source from datahub.ingestion.source.s3.source import S3Source
from datahub.testing import mce_helpers from datahub.testing import mce_helpers
@ -367,3 +372,158 @@ def test_data_lake_incorrect_config_raises_error(tmp_path, mock_time):
} }
with pytest.raises(ValidationError, match=r"\*\*"): with pytest.raises(ValidationError, match=r"\*\*"):
S3Source.create(source, ctx) S3Source.create(source, ctx)
@pytest.mark.parametrize(
"calls_test_tuple",
[
(
"partitions_and_filename_with_prefix",
{
"include": "s3://my-test-bucket/folder_a/folder_aa/folder_aaa/{table}/year={year}/month={month}/part*.json",
"tables_filter_pattern": {"allow": ["^pokemon_abilities_json$"]},
},
[
call.list_folders_path(
"s3://my-test-bucket/folder_a/folder_aa/folder_aaa/"
),
call.list_folders_path(
s3_uri="s3://my-test-bucket/folder_a/folder_aa/folder_aaa/pokemon_abilities_json/",
startswith="year=",
),
call.list_folders_path(
s3_uri="s3://my-test-bucket/folder_a/folder_aa/folder_aaa/pokemon_abilities_json/year=2022/",
startswith="month=",
),
call.list_objects_recursive_path(
"s3://my-test-bucket/folder_a/folder_aa/folder_aaa/pokemon_abilities_json/year=2022/month=jan/",
startswith="part",
),
],
),
(
"filter_specific_partition",
{
"include": "s3://my-test-bucket/folder_a/folder_aa/folder_aaa/{table}/year=2022/month={month}/*.json",
"tables_filter_pattern": {"allow": ["^pokemon_abilities_json$"]},
},
[
call.list_folders_path(
"s3://my-test-bucket/folder_a/folder_aa/folder_aaa/"
),
call.list_folders_path(
s3_uri="s3://my-test-bucket/folder_a/folder_aa/folder_aaa/pokemon_abilities_json/year=2022",
startswith="month=",
),
call.list_objects_recursive_path(
"s3://my-test-bucket/folder_a/folder_aa/folder_aaa/pokemon_abilities_json/year=2022/month=jan/",
startswith="",
),
],
),
(
"partition_autodetection",
{
"include": "s3://my-test-bucket/folder_a/folder_aa/folder_aaa/{table}/",
"tables_filter_pattern": {"allow": ["^pokemon_abilities_json$"]},
},
[
call.list_folders_path(
"s3://my-test-bucket/folder_a/folder_aa/folder_aaa/"
),
call.list_folders_path(
s3_uri="s3://my-test-bucket/folder_a/folder_aa/folder_aaa/pokemon_abilities_json/",
startswith="",
),
call.list_folders_path(
s3_uri="s3://my-test-bucket/folder_a/folder_aa/folder_aaa/pokemon_abilities_json/year=2022/",
startswith="",
),
call.list_folders_path(
s3_uri="s3://my-test-bucket/folder_a/folder_aa/folder_aaa/pokemon_abilities_json/year=2022/month=jan/",
startswith="",
),
call.list_objects_recursive_path(
"s3://my-test-bucket/folder_a/folder_aa/folder_aaa/pokemon_abilities_json/year=2022/month=jan/",
startswith="",
),
],
),
(
"partitions_traversal_all",
{
"include": "s3://my-test-bucket/folder_a/folder_aa/folder_aaa/{table}/year={year}/month={month}/*.json",
"tables_filter_pattern": {"allow": ["^pokemon_abilities_json$"]},
"traversal_method": "ALL",
},
[
call.list_folders_path(
"s3://my-test-bucket/folder_a/folder_aa/folder_aaa/"
),
call.list_objects_recursive_path(
"s3://my-test-bucket/folder_a/folder_aa/folder_aaa/pokemon_abilities_json/",
startswith="year=",
),
],
),
(
"filter_specific_partition_traversal_all",
{
"include": "s3://my-test-bucket/folder_a/folder_aa/folder_aaa/{table}/year=2022/month={month}/part*.json",
"tables_filter_pattern": {"allow": ["^pokemon_abilities_json$"]},
"traversal_method": "ALL",
},
[
call.list_folders_path(
"s3://my-test-bucket/folder_a/folder_aa/folder_aaa/"
),
call.list_objects_recursive_path(
"s3://my-test-bucket/folder_a/folder_aa/folder_aaa/pokemon_abilities_json/year=2022",
startswith="month=",
),
],
),
],
ids=lambda calls_test_tuple: calls_test_tuple[0],
)
def test_data_lake_s3_calls(s3_populate, calls_test_tuple):
_, path_spec, expected_calls = calls_test_tuple
ctx = PipelineContext(run_id="test-s3")
config = {
"path_specs": [path_spec],
"aws_config": {
"aws_region": "us-east-1",
"aws_access_key_id": "testing",
"aws_secret_access_key": "testing",
},
}
source = S3Source.create(config, ctx)
m = Mock()
m.list_folders_path.side_effect = list_folders_path
m.list_objects_recursive_path.side_effect = list_objects_recursive_path
with (
patch(
"datahub.ingestion.source.s3.source.list_folders_path", m.list_folders_path
),
patch(
"datahub.ingestion.source.s3.source.list_objects_recursive_path",
m.list_objects_recursive_path,
),
):
for _ in source.get_workunits_internal():
pass
# Verify S3 calls. We're checking that we make the minimum necessary calls with
# prefixes when possible to reduce the amount of queries to the S3 API.
calls = []
for c in m.mock_calls:
if isinstance(c.kwargs, dict): # type assertion
c.kwargs.pop("aws_config", None)
if len(c.args) == 3 and isinstance(c.args[2], AwsConnectionConfig):
c = getattr(call, c[0])(*(c.args[:2]), **c.kwargs)
calls.append(c)
assert calls == expected_calls

View File

@ -331,6 +331,25 @@ def test_dir_allowed_with_debug_logging() -> None:
assert result is True assert result is True
@pytest.mark.parametrize(
"include, allowed",
[
("s3://bucket/{table}/1/*.json", "s3://bucket/table/1/"),
("s3://bucket/{table}/1/*/*.json", "s3://bucket/table/1/"),
("s3://bucket/{table}/1/*/*.json", "s3://bucket/table/1/2/"),
],
)
def test_dir_allowed_with_table_filter_pattern(include: str, allowed: str) -> None:
"""Test dir_allowed method with table filter patterns."""
path_spec = PathSpec(
include=include,
tables_filter_pattern=AllowDenyPattern(
allow=["^table$"],
),
)
assert path_spec.dir_allowed(allowed) is True
# Tests for get_parsable_include classmethod # Tests for get_parsable_include classmethod
@pytest.mark.parametrize( @pytest.mark.parametrize(
"include, expected", "include, expected",

View File

@ -1,7 +1,7 @@
import logging import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import List, Tuple from typing import List, Tuple
from unittest.mock import Mock, call from unittest.mock import Mock, call, patch
import pytest import pytest
from boto3.session import Session from boto3.session import Session
@ -314,7 +314,8 @@ def test_get_folder_info_returns_latest_file_in_each_folder(s3_resource):
# act # act
res = _get_s3_source(path_spec).get_folder_info( res = _get_s3_source(path_spec).get_folder_info(
path_spec, bucket, prefix="my-folder" path_spec,
"s3://my-bucket/my-folder",
) )
res = list(res) res = list(res)
@ -329,12 +330,9 @@ def test_get_folder_info_ignores_disallowed_path(s3_resource, caplog):
Test S3Source.get_folder_info skips disallowed files and logs a message Test S3Source.get_folder_info skips disallowed files and logs a message
""" """
# arrange # arrange
path_spec = Mock( path_spec = PathSpec(
spec=PathSpec,
include="s3://my-bucket/{table}/{partition0}/*.csv", include="s3://my-bucket/{table}/{partition0}/*.csv",
table_name="{table}",
) )
path_spec.allowed = Mock(return_value=False)
bucket = s3_resource.Bucket("my-bucket") bucket = s3_resource.Bucket("my-bucket")
bucket.create() bucket.create()
@ -343,13 +341,17 @@ def test_get_folder_info_ignores_disallowed_path(s3_resource, caplog):
s3_source = _get_s3_source(path_spec) s3_source = _get_s3_source(path_spec)
# act # act
res = s3_source.get_folder_info(path_spec, bucket, prefix="my-folder") with patch(
"datahub.ingestion.source.data_lake_common.path_spec.PathSpec.allowed"
) as allowed:
allowed.return_value = False
res = s3_source.get_folder_info(path_spec, "s3://my-bucket/my-folder")
res = list(res) res = list(res)
# assert # assert
expected_called_s3_uri = "s3://my-bucket/my-folder/ignore/this/path/0001.csv" expected_called_s3_uri = "s3://my-bucket/my-folder/ignore/this/path/0001.csv"
assert path_spec.allowed.call_args_list == [call(expected_called_s3_uri)], ( assert allowed.call_args_list == [call(expected_called_s3_uri)], (
"File should be checked if it's allowed" "File should be checked if it's allowed"
) )
assert f"File {expected_called_s3_uri} not allowed and skipping" in caplog.text, ( assert f"File {expected_called_s3_uri} not allowed and skipping" in caplog.text, (
@ -377,7 +379,8 @@ def test_get_folder_info_returns_expected_folder(s3_resource):
# act # act
res = _get_s3_source(path_spec).get_folder_info( res = _get_s3_source(path_spec).get_folder_info(
path_spec, bucket, prefix="my-folder" path_spec,
"s3://my-bucket/my-folder",
) )
res = list(res) res = list(res)