diff --git a/ingestion/src/metadata/ingestion/source/database/datalake/metadata.py b/ingestion/src/metadata/ingestion/source/database/datalake/metadata.py index ae56f4b6567..6f938f6af86 100644 --- a/ingestion/src/metadata/ingestion/source/database/datalake/metadata.py +++ b/ingestion/src/metadata/ingestion/source/database/datalake/metadata.py @@ -58,7 +58,15 @@ from metadata.utils.logger import ingestion_logger logger = ingestion_logger() -DATALAKE_INT_TYPES = {"int64", "INT", "int32"} +DATALAKE_DATA_TYPES = { + **dict.fromkeys(["int64", "INT", "int32"], DataType.INT.value), + "object": DataType.STRING.value, + **dict.fromkeys(["float64", "float32", "float"], DataType.FLOAT.value), + "bool": DataType.BOOLEAN.value, + **dict.fromkeys( + ["datetime64", "timedelta[ns]", "datetime64[ns]"], DataType.DATETIME.value + ), +} DATALAKE_SUPPORTED_FILE_TYPES = (".csv", ".tsv", ".json", ".parquet", ".json.gz") @@ -374,6 +382,7 @@ class DatalakeSource(DatabaseServiceSource): # pylint: disable=too-many-public- table_name, table_type = table_name_and_type schema_name = self.context.database_schema.name.__root__ + columns = [] try: table_constraints = None if isinstance(self.service_connection.configSource, GCSConfig): @@ -381,11 +390,14 @@ class DatalakeSource(DatabaseServiceSource): # pylint: disable=too-many-public- client=self.client, key=table_name, bucket_name=schema_name ) if isinstance(self.service_connection.configSource, S3Config): + connection_args = self.service_connection.configSource.securityConfig data_frame = self.get_s3_files( - client=self.client, key=table_name, bucket_name=schema_name + client=self.client, + key=table_name, + bucket_name=schema_name, + client_kwargs=connection_args, ) if isinstance(self.service_connection.configSource, AzureConfig): - columns = None connection_args = self.service_connection.configSource.securityConfig storage_options = { "tenant_id": connection_args.tenantId, @@ -492,7 +504,7 @@ class DatalakeSource(DatabaseServiceSource): # pylint: disable=too-many-public- return None @staticmethod - def get_s3_files(client, key, bucket_name): + def get_s3_files(client, key, bucket_name, client_kwargs=None): """ Fetch S3 Bucket files """ @@ -514,7 +526,7 @@ class DatalakeSource(DatabaseServiceSource): # pylint: disable=too-many-public- return read_json_from_s3(client, key, bucket_name) if key.endswith(".parquet"): - return read_parquet_from_s3(client, key, bucket_name) + return read_parquet_from_s3(client_kwargs, key, bucket_name) except Exception as exc: logger.debug(traceback.format_exc()) @@ -534,14 +546,11 @@ class DatalakeSource(DatabaseServiceSource): # pylint: disable=too-many-public- for column in df_columns: # use String by default data_type = DataType.STRING.value - try: - if ( - hasattr(data_frame[column], "dtypes") - and data_frame[column].dtypes.name in DATALAKE_INT_TYPES - and data_frame[column].dtypes.name in ("int64", "int32") - ): - data_type = DataType.INT.value + if hasattr(data_frame[column], "dtypes"): + data_type = DATALAKE_DATA_TYPES.get( + data_frame[column].dtypes.name, DataType.STRING.value + ) parsed_string = { "dataTypeDisplay": data_type, diff --git a/ingestion/src/metadata/utils/s3_utils.py b/ingestion/src/metadata/utils/s3_utils.py index fa0527f57b5..8b990ee159c 100644 --- a/ingestion/src/metadata/utils/s3_utils.py +++ b/ingestion/src/metadata/utils/s3_utils.py @@ -15,13 +15,12 @@ Utils module to convert different file types from s3 buckets into a dataframe import gzip import json -import os import traceback from typing import Any import pandas as pd -from pyarrow import fs -from pyarrow.parquet import ParquetFile +import pyarrow.parquet as pq +import s3fs from metadata.utils.logger import utils_logger @@ -90,10 +89,11 @@ def read_parquet_from_s3(client: Any, key: str, bucket_name: str): """ Read the parquet file from the s3 bucket and return a dataframe """ - - s3_file = fs.S3FileSystem(region=client.meta.region_name) - return [ - ParquetFile(s3_file.open_input_file(os.path.join(bucket_name, key))) - .read() - .to_pandas() - ] + s3_fs = s3fs.S3FileSystem( + key=client.awsAccessKeyId, + secret=client.awsSecretAccessKey.get_secret_value(), + token=client.awsSessionToken, + ) + bucket_uri = f"s3://{bucket_name}/{key}" + dataset = pq.ParquetDataset(bucket_uri, filesystem=s3_fs) + return [dataset.read_pandas().to_pandas()]