refactor(ingest/s3): enhance readability (#12686)

This commit is contained in:
Austin SeungJun Park 2025-02-28 13:19:46 +00:00 committed by GitHub
parent 9c3bd34995
commit e65f133667
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 70 additions and 40 deletions

View File

@ -847,7 +847,7 @@ class S3Source(StatefulIngestionSourceBase):
path_spec: PathSpec,
bucket: "Bucket",
prefix: str,
) -> List[Folder]:
) -> Iterable[Folder]:
"""
Retrieves all the folders in a path by listing all the files in the prefix.
If the prefix is a full path then only that folder will be extracted.
@ -877,51 +877,30 @@ class S3Source(StatefulIngestionSourceBase):
s3_objects = (
obj
for obj in bucket.objects.filter(Prefix=prefix).page_size(PAGE_SIZE)
if _is_allowed_path(path_spec, f"s3://{obj.bucket_name}/{obj.key}")
if _is_allowed_path(
path_spec, self.create_s3_path(obj.bucket_name, obj.key)
)
)
partitions: List[Folder] = []
grouped_s3_objects_by_dirname = groupby_unsorted(
s3_objects,
key=lambda obj: obj.key.rsplit("/", 1)[0],
)
for key, group in grouped_s3_objects_by_dirname:
file_size = 0
creation_time = None
modification_time = None
for _, group in grouped_s3_objects_by_dirname:
max_file = max(group, key=lambda x: x.last_modified)
max_file_s3_path = self.create_s3_path(max_file.bucket_name, max_file.key)
for item in group:
file_size += item.size
if creation_time is None or item.last_modified < creation_time:
creation_time = item.last_modified
if modification_time is None or item.last_modified > modification_time:
modification_time = item.last_modified
max_file = item
# If partition_id is None, it means the folder is not a partition
partition_id = path_spec.get_partition_from_path(max_file_s3_path)
if modification_time is None:
logger.warning(
f"Unable to find any files in the folder {key}. Skipping..."
)
continue
id = path_spec.get_partition_from_path(
self.create_s3_path(max_file.bucket_name, max_file.key)
yield Folder(
partition_id=partition_id,
is_partition=bool(partition_id),
creation_time=min(obj.last_modified for obj in group),
modification_time=max_file.last_modified,
sample_file=max_file_s3_path,
size=sum(obj.size for obj in group),
)
# If id is None, it means the folder is not a partition
partitions.append(
Folder(
partition_id=id,
is_partition=bool(id),
creation_time=creation_time if creation_time else None, # type: ignore[arg-type]
modification_time=modification_time,
sample_file=self.create_s3_path(max_file.bucket_name, max_file.key),
size=file_size,
)
)
return partitions
def s3_browser(self, path_spec: PathSpec, sample_size: int) -> Iterable[BrowsePath]:
if self.source_config.aws_config is None:
raise ValueError("aws_config not set. Cannot browse s3")
@ -1000,7 +979,7 @@ class S3Source(StatefulIngestionSourceBase):
min=True,
)
dirs_to_process.append(dirs_to_process_min[0])
folders = []
folders: List[Folder] = []
for dir in dirs_to_process:
logger.info(f"Getting files from folder: {dir}")
prefix_to_process = urlparse(dir).path.lstrip("/")

View File

@ -9,7 +9,11 @@ from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.data_lake_common.data_lake_utils import ContainerWUCreator
from datahub.ingestion.source.data_lake_common.path_spec import PathSpec
from datahub.ingestion.source.s3.source import S3Source, partitioned_folder_comparator
from datahub.ingestion.source.s3.source import (
Folder,
S3Source,
partitioned_folder_comparator,
)
def _get_s3_source(path_spec_: PathSpec) -> S3Source:
@ -257,7 +261,7 @@ def test_container_generation_with_multiple_folders():
}
def test_get_folder_info():
def test_get_folder_info_returns_latest_file_in_each_folder() -> None:
"""
Test S3Source.get_folder_info returns the latest file in each folder
"""
@ -298,6 +302,7 @@ def test_get_folder_info():
res = _get_s3_source(path_spec).get_folder_info(
path_spec, bucket, prefix="/my-folder"
)
res = list(res)
# assert
assert len(res) == 2
@ -336,6 +341,7 @@ def test_get_folder_info_ignores_disallowed_path(
# act
res = s3_source.get_folder_info(path_spec, bucket, prefix="/my-folder")
res = list(res)
# assert
expected_called_s3_uri = "s3://my-bucket/my-folder/ignore/this/path/0001.csv"
@ -350,3 +356,48 @@ def test_get_folder_info_ignores_disallowed_path(
"Dropped file should be in the report.filtered"
)
assert res == [], "Dropped file should not be in the result"
def test_get_folder_info_returns_expected_folder() -> None:
# arrange
path_spec = PathSpec(
include="s3://my-bucket/{table}/{partition0}/*.csv",
table_name="{table}",
)
bucket = Mock()
bucket.objects.filter().page_size = Mock(
return_value=[
Mock(
bucket_name="my-bucket",
key="my-folder/dir1/0001.csv",
creation_time=datetime(2025, 1, 1, 1),
last_modified=datetime(2025, 1, 1, 1),
size=100,
),
Mock(
bucket_name="my-bucket",
key="my-folder/dir1/0002.csv",
creation_time=datetime(2025, 1, 1, 2),
last_modified=datetime(2025, 1, 1, 2),
size=50,
),
]
)
# act
res = _get_s3_source(path_spec).get_folder_info(
path_spec, bucket, prefix="/my-folder"
)
res = list(res)
# assert
assert len(res) == 1
assert res[0] == Folder(
partition_id=[("partition0", "dir1")],
is_partition=True,
creation_time=datetime(2025, 1, 1, 1),
modification_time=datetime(2025, 1, 1, 2),
size=150,
sample_file="s3://my-bucket/my-folder/dir1/0002.csv",
)