diff --git a/ingestion/src/metadata/ingestion/source/database/datalake/metadata.py b/ingestion/src/metadata/ingestion/source/database/datalake/metadata.py index 247b5fc0b7d..541d5f9e631 100644 --- a/ingestion/src/metadata/ingestion/source/database/datalake/metadata.py +++ b/ingestion/src/metadata/ingestion/source/database/datalake/metadata.py @@ -54,10 +54,16 @@ from metadata.ingestion.models.ometa_classification import OMetaTagAndClassifica from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.source.connections import get_connection from metadata.ingestion.source.database.database_service import DatabaseServiceSource -from metadata.ingestion.source.database.datalake.models import DatalakeColumnWrapper -from metadata.ingestion.source.database.datalake.utils import COMPLEX_COLUMN_SEPARATOR +from metadata.ingestion.source.database.datalake.models import ( + DatalakeTableSchemaWrapper, +) from metadata.utils import fqn from metadata.utils.constants import DEFAULT_DATABASE +from metadata.utils.datalake.datalake_utils import ( + COMPLEX_COLUMN_SEPARATOR, + SUPPORTED_TYPES, + fetch_dataframe, +) from metadata.utils.filters import filter_by_schema, filter_by_table from metadata.utils.logger import ingestion_logger @@ -73,52 +79,6 @@ DATALAKE_DATA_TYPES = { ), } -JSON_SUPPORTED_TYPES = (".json", ".json.gz", ".json.zip") - -DATALAKE_SUPPORTED_FILE_TYPES = ( - ".csv", - ".tsv", - ".parquet", - ".avro", -) + JSON_SUPPORTED_TYPES - - -def ometa_to_dataframe(config_source, client, table): - """ - Method to get dataframe for profiling - """ - - data = None - if isinstance(config_source, GCSConfig): - data = DatalakeSource.get_gcs_files( - client=client, - key=table.name.__root__, - bucket_name=table.databaseSchema.name, - ) - if isinstance(config_source, S3Config): - data = DatalakeSource.get_s3_files( - client=client, - key=table.name.__root__, - bucket_name=table.databaseSchema.name, - ) - if isinstance(config_source, AzureConfig): - connection_args = config_source.securityConfig - data = DatalakeSource.get_azure_files( - client=client, - key=table.name.__root__, - container_name=table.databaseSchema.name, - storage_options={ - "tenant_id": connection_args.tenantId, - "client_id": connection_args.clientId, - "client_secret": connection_args.clientSecret.get_secret_value(), - "account_name": connection_args.accountName, - }, - ) - if isinstance(data, DatalakeColumnWrapper): - data = data.dataframes - - return data - class DatalakeSource(DatabaseServiceSource): """ @@ -410,44 +370,22 @@ class DatalakeSource(DatabaseServiceSource): From topology. Prepare a table request and pass it to the sink """ - from pandas import DataFrame # pylint: disable=import-outside-toplevel - table_name, table_type = table_name_and_type schema_name = self.context.database_schema.name.__root__ columns = [] try: table_constraints = None - if isinstance(self.service_connection.configSource, GCSConfig): - data_frame = self.get_gcs_files( - client=self.client, key=table_name, bucket_name=schema_name - ) - if isinstance(self.service_connection.configSource, S3Config): - connection_args = self.service_connection.configSource.securityConfig - data_frame = self.get_s3_files( - client=self.client, + connection_args = self.service_connection.configSource.securityConfig + data_frame = fetch_dataframe( + config_source=self.service_connection.configSource, + client=self.client, + file_fqn=DatalakeTableSchemaWrapper( key=table_name, bucket_name=schema_name, - client_kwargs=connection_args, - ) - if isinstance(self.service_connection.configSource, AzureConfig): - connection_args = self.service_connection.configSource.securityConfig - storage_options = { - "tenant_id": connection_args.tenantId, - "client_id": connection_args.clientId, - "client_secret": connection_args.clientSecret.get_secret_value(), - } - data_frame = self.get_azure_files( - client=self.client, - key=table_name, - container_name=schema_name, - storage_options=storage_options, - ) - if isinstance(data_frame, DataFrame): - columns = self.get_columns(data_frame) - if isinstance(data_frame, list) and data_frame: - columns = self.get_columns(data_frame[0]) - if isinstance(data_frame, DatalakeColumnWrapper): - columns = data_frame.columns + ), + connection_kwargs=connection_args, + ) + columns = self.get_columns(data_frame[0]) if columns: table_request = CreateTableRequest( name=table_name, @@ -465,117 +403,6 @@ class DatalakeSource(DatabaseServiceSource): logger.warning(error) self.status.failed(table_name, error, traceback.format_exc()) - @staticmethod - def get_gcs_files(client, key, bucket_name): - """ - Fetch GCS Bucket files - """ - from metadata.utils.gcs_utils import ( # pylint: disable=import-outside-toplevel - read_avro_from_gcs, - read_csv_from_gcs, - read_json_from_gcs, - read_parquet_from_gcs, - read_tsv_from_gcs, - ) - - try: - if key.endswith(".csv"): - return read_csv_from_gcs(key, bucket_name) - - if key.endswith(".tsv"): - return read_tsv_from_gcs(key, bucket_name) - - if key.endswith(JSON_SUPPORTED_TYPES): - return read_json_from_gcs(client, key, bucket_name) - - if key.endswith(".parquet"): - return read_parquet_from_gcs(key, bucket_name) - - if key.endswith(".avro"): - return read_avro_from_gcs(client, key, bucket_name) - - except Exception as exc: - logger.debug(traceback.format_exc()) - logger.error( - f"Unexpected exception to get GCS files from [{bucket_name}]: {exc}" - ) - return None - - @staticmethod - def get_azure_files(client, key, container_name, storage_options): - """ - Fetch Azure Storage files - """ - from metadata.utils.azure_utils import ( # pylint: disable=import-outside-toplevel - read_avro_from_azure, - read_csv_from_azure, - read_json_from_azure, - read_parquet_from_azure, - ) - - try: - if key.endswith(".csv"): - return read_csv_from_azure(client, key, container_name, storage_options) - - if key.endswith(JSON_SUPPORTED_TYPES): - return read_json_from_azure(client, key, container_name) - - if key.endswith(".parquet"): - return read_parquet_from_azure( - client, key, container_name, storage_options - ) - - if key.endswith(".tsv"): - return read_csv_from_azure( - client, key, container_name, storage_options, sep="\t" - ) - - if key.endswith(".avro"): - return read_avro_from_azure(client, key, container_name) - - except Exception as exc: - logger.debug(traceback.format_exc()) - logger.error( - f"Unexpected exception get in azure for file [{key}] for {container_name}: {exc}" - ) - return None - - @staticmethod - def get_s3_files(client, key, bucket_name, client_kwargs=None): - """ - Fetch S3 Bucket files - """ - from metadata.utils.s3_utils import ( # pylint: disable=import-outside-toplevel - read_avro_from_s3, - read_csv_from_s3, - read_json_from_s3, - read_parquet_from_s3, - read_tsv_from_s3, - ) - - try: - if key.endswith(".csv"): - return read_csv_from_s3(client, key, bucket_name) - - if key.endswith(".tsv"): - return read_tsv_from_s3(client, key, bucket_name) - - if key.endswith(JSON_SUPPORTED_TYPES): - return read_json_from_s3(client, key, bucket_name) - - if key.endswith(".parquet"): - return read_parquet_from_s3(client_kwargs, key, bucket_name) - - if key.endswith(".avro"): - return read_avro_from_s3(client, key, bucket_name) - - except Exception as exc: - logger.debug(traceback.format_exc()) - logger.error( - f"Unexpected exception to get S3 file [{key}] from bucket [{bucket_name}]: {exc}" - ) - return None - @staticmethod def _parse_complex_column( data_frame, @@ -679,7 +506,7 @@ class DatalakeSource(DatabaseServiceSource): return data_type @staticmethod - def get_columns(data_frame): + def get_columns(data_frame: list): """ method to process column details """ @@ -733,8 +560,9 @@ class DatalakeSource(DatabaseServiceSource): return table def check_valid_file_type(self, key_name): - if key_name.endswith(DATALAKE_SUPPORTED_FILE_TYPES): - return True + for supported_types in SUPPORTED_TYPES: + if key_name.endswith(supported_types.value): + return True return False def close(self): diff --git a/ingestion/src/metadata/ingestion/source/database/datalake/models.py b/ingestion/src/metadata/ingestion/source/database/datalake/models.py index c48b270b2eb..4804c04af05 100644 --- a/ingestion/src/metadata/ingestion/source/database/datalake/models.py +++ b/ingestion/src/metadata/ingestion/source/database/datalake/models.py @@ -28,4 +28,13 @@ class DatalakeColumnWrapper(BaseModel): """ columns: Optional[List[Column]] - dataframes: Optional[List[Any]] # pandas.Dataframe does not have any validators + dataframes: Optional[Any] # pandas.Dataframe does not have any validators + + +class DatalakeTableSchemaWrapper(BaseModel): + """ + Instead of sending the whole Table model from profiler, we send only key and bucket name using this model + """ + + key: str + bucket_name: str diff --git a/ingestion/src/metadata/ingestion/source/storage/s3/metadata.py b/ingestion/src/metadata/ingestion/source/storage/s3/metadata.py index a3438e7287c..a3fbbc34d98 100644 --- a/ingestion/src/metadata/ingestion/source/storage/s3/metadata.py +++ b/ingestion/src/metadata/ingestion/source/storage/s3/metadata.py @@ -26,6 +26,9 @@ from metadata.generated.schema.entity.data.container import ( ContainerDataModel, ) from metadata.generated.schema.entity.data.table import Column +from metadata.generated.schema.entity.services.connections.database.datalake.s3Config import ( + S3Config, +) from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import ( OpenMetadataConnection, ) @@ -42,12 +45,15 @@ from metadata.generated.schema.metadataIngestion.workflow import ( from metadata.generated.schema.type.entityReference import EntityReference from metadata.ingestion.api.source import InvalidSourceException from metadata.ingestion.source.database.datalake.metadata import DatalakeSource -from metadata.ingestion.source.database.datalake.models import DatalakeColumnWrapper +from metadata.ingestion.source.database.datalake.models import ( + DatalakeTableSchemaWrapper, +) from metadata.ingestion.source.storage.s3.models import ( S3BucketResponse, S3ContainerDetails, ) from metadata.ingestion.source.storage.storage_service import StorageServiceSource +from metadata.utils.datalake.datalake_utils import fetch_dataframe from metadata.utils.filters import filter_by_container from metadata.utils.logger import ingestion_logger @@ -194,20 +200,23 @@ class S3Source(StorageServiceSource): def extract_column_definitions( self, bucket_name: str, sample_key: str ) -> List[Column]: - client_args = self.service_connection.awsConfig - data_structure_details = DatalakeSource.get_s3_files( - self.s3_client, - key=sample_key, - bucket_name=bucket_name, - client_kwargs=client_args, + """ + Extract Column related metadata from s3 + """ + connection_args = self.service_connection.awsConfig + data_structure_details = fetch_dataframe( + config_source=S3Config(), + client=self.s3_client, + file_fqn=DatalakeTableSchemaWrapper( + key=sample_key, bucket_name=bucket_name + ), + connection_kwargs=connection_args, ) columns = [] if isinstance(data_structure_details, DataFrame): columns = DatalakeSource.get_columns(data_structure_details) if isinstance(data_structure_details, list) and data_structure_details: columns = DatalakeSource.get_columns(data_structure_details[0]) - if isinstance(data_structure_details, DatalakeColumnWrapper): - columns = data_structure_details.columns # pylint: disable=no-member return columns def fetch_buckets(self) -> List[S3BucketResponse]: diff --git a/ingestion/src/metadata/mixins/pandas/pandas_mixin.py b/ingestion/src/metadata/mixins/pandas/pandas_mixin.py index ddd419072bb..6650ea4b746 100644 --- a/ingestion/src/metadata/mixins/pandas/pandas_mixin.py +++ b/ingestion/src/metadata/mixins/pandas/pandas_mixin.py @@ -15,7 +15,7 @@ supporting sqlalchemy abstraction layer """ import math import random -from typing import List, cast +from typing import cast from metadata.data_quality.validations.table.pandas.tableRowInsertedCountToBeBetween import ( TableRowInsertedCountToBeBetweenValidator, @@ -25,7 +25,10 @@ from metadata.generated.schema.entity.data.table import ( PartitionProfilerConfig, ProfileSampleType, ) -from metadata.ingestion.source.database.datalake.metadata import ometa_to_dataframe +from metadata.ingestion.source.database.datalake.models import ( + DatalakeTableSchemaWrapper, +) +from metadata.utils.datalake.datalake_utils import fetch_dataframe from metadata.utils.logger import test_suite_logger logger = test_suite_logger() @@ -86,16 +89,15 @@ class PandasInterfaceMixin: """ returns sampled ometa dataframes """ - from pandas import DataFrame # pylint: disable=import-outside-toplevel - - data = ometa_to_dataframe( + data = fetch_dataframe( config_source=service_connection_config.configSource, client=client, - table=table, + file_fqn=DatalakeTableSchemaWrapper( + key=table.name.__root__, bucket_name=table.databaseSchema.name + ), + is_profiler=True, ) - if isinstance(data, DataFrame): - data: List[DataFrame] = [data] - if data and isinstance(data, list): + if data: random.shuffle(data) # sampling data based on profiler config (if any) if hasattr(profile_sample_config, "profile_sample"): diff --git a/ingestion/src/metadata/profiler/interface/pandas/pandas_profiler_interface.py b/ingestion/src/metadata/profiler/interface/pandas/pandas_profiler_interface.py index 0f7c0aafbf7..a235579471d 100644 --- a/ingestion/src/metadata/profiler/interface/pandas/pandas_profiler_interface.py +++ b/ingestion/src/metadata/profiler/interface/pandas/pandas_profiler_interface.py @@ -29,10 +29,7 @@ from metadata.generated.schema.entity.services.connections.database.datalakeConn ) from metadata.ingestion.api.processor import ProfilerProcessorStatus from metadata.ingestion.source.connections import get_connection -from metadata.ingestion.source.database.datalake.metadata import ( - DATALAKE_DATA_TYPES, - DatalakeSource, -) +from metadata.ingestion.source.database.datalake.metadata import DatalakeSource from metadata.mixins.pandas.pandas_mixin import PandasInterfaceMixin from metadata.profiler.interface.profiler_protocol import ProfilerProtocol from metadata.profiler.metrics.core import MetricTypes diff --git a/ingestion/src/metadata/utils/azure_utils.py b/ingestion/src/metadata/utils/azure_utils.py deleted file mode 100644 index 73a03395094..00000000000 --- a/ingestion/src/metadata/utils/azure_utils.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright 2021 Collate -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Utils module to convert different file types from azure file system into a dataframe -""" - -import gzip -import io -import traceback -import zipfile -from typing import Any - -import pandas as pd - -from metadata.ingestion.source.database.datalake.utils import ( - read_from_avro, - read_from_json, -) -from metadata.utils.logger import utils_logger - -logger = utils_logger() - - -def _get_json_text(key: str, text: str) -> str: - if key.endswith(".gz"): - return gzip.decompress(text) - if key.endswith(".zip"): - with zipfile.ZipFile(io.BytesIO(text)) as zip_file: - return zip_file.read(zip_file.infolist()[0]).decode("utf-8") - return text - - -def get_file_text(client: Any, key: str, container_name: str): - container_client = client.get_container_client(container_name) - blob_client = container_client.get_blob_client(key) - return blob_client.download_blob().readall() - - -def read_csv_from_azure( - client: Any, key: str, container_name: str, storage_options: dict, sep: str = "," -): - """ - Read the csv file from the azure container and return a dataframe - """ - try: - account_url = ( - f"abfs://{container_name}@{client.account_name}.dfs.core.windows.net/{key}" - ) - dataframe = pd.read_csv(account_url, storage_options=storage_options, sep=sep) - return dataframe - except Exception as exc: - logger.debug(traceback.format_exc()) - logger.warning(f"Error reading CSV from ADLS - {exc}") - return None - - -def read_json_from_azure(client: Any, key: str, container_name: str, sample_size=100): - """ - Read the json file from the azure container and return a dataframe - """ - json_text = get_file_text(client=client, key=key, container_name=container_name) - return read_from_json( - key=key, json_text=json_text, sample_size=sample_size, decode=True - ) - - -def read_parquet_from_azure( - client: Any, key: str, container_name: str, storage_options: dict -): - """ - Read the parquet file from the container and return a dataframe - """ - try: - account_url = ( - f"abfs://{container_name}@{client.account_name}.dfs.core.windows.net/{key}" - ) - dataframe = pd.read_parquet(account_url, storage_options=storage_options) - return dataframe - except Exception as exc: - logger.debug(traceback.format_exc()) - logger.warning(f"Error reading parquet file from azure - {exc}") - return None - - -def read_avro_from_azure(client: Any, key: str, container_name: str): - """ - Read the avro file from the gcs bucket and return a dataframe - """ - return read_from_avro( - get_file_text(client=client, key=key, container_name=container_name) - ) diff --git a/ingestion/src/metadata/ingestion/source/database/datalake/utils.py b/ingestion/src/metadata/utils/datalake/avro_dispatch.py similarity index 53% rename from ingestion/src/metadata/ingestion/source/database/datalake/utils.py rename to ingestion/src/metadata/utils/datalake/avro_dispatch.py index d36e3b227d5..af83c75a8c7 100644 --- a/ingestion/src/metadata/ingestion/source/database/datalake/utils.py +++ b/ingestion/src/metadata/utils/datalake/avro_dispatch.py @@ -10,24 +10,34 @@ # limitations under the License. """ -Module to define helper methods for datalake +Module to define helper methods for datalake and to fetch data and metadata +from Avro file formats """ -import gzip + import io -import json -import zipfile -from typing import List, Union +from functools import singledispatch +from typing import Any from avro.datafile import DataFileReader from avro.errors import InvalidAvroBinaryEncoding from avro.io import DatumReader from metadata.generated.schema.entity.data.table import Column +from metadata.generated.schema.entity.services.connections.database.datalake.azureConfig import ( + AzureConfig, +) +from metadata.generated.schema.entity.services.connections.database.datalake.gcsConfig import ( + GCSConfig, +) +from metadata.generated.schema.entity.services.connections.database.datalake.s3Config import ( + S3Config, +) from metadata.generated.schema.type.schema import DataTypeTopic from metadata.ingestion.source.database.datalake.models import DatalakeColumnWrapper from metadata.parsers.avro_parser import parse_avro_schema from metadata.utils.constants import UTF_8 +from metadata.utils.datalake.datalake_utils import DatalakeFileFormatException from metadata.utils.logger import utils_logger logger = utils_logger() @@ -43,12 +53,11 @@ PD_AVRO_FIELD_MAP = { } AVRO_SCHEMA = "avro.schema" -COMPLEX_COLUMN_SEPARATOR = "_##" def read_from_avro( avro_text: bytes, -) -> Union[DatalakeColumnWrapper, List]: +) -> DatalakeColumnWrapper: """ Method to parse the avro data from storage sources """ @@ -62,49 +71,46 @@ def read_from_avro( columns=parse_avro_schema( schema=elements.meta.get(AVRO_SCHEMA).decode(UTF_8), cls=Column ), - dataframes=[DataFrame.from_records(elements)], + dataframes=DataFrame.from_records(elements), ) - return [DataFrame.from_records(elements)] + return DatalakeColumnWrapper(dataframes=DataFrame.from_records(elements)) except (AssertionError, InvalidAvroBinaryEncoding): columns = parse_avro_schema(schema=avro_text, cls=Column) field_map = { col.name.__root__: Series(PD_AVRO_FIELD_MAP.get(col.dataType.value, "str")) for col in columns } - return DatalakeColumnWrapper(columns=columns, dataframes=[DataFrame(field_map)]) + return DatalakeColumnWrapper(columns=columns, dataframes=DataFrame(field_map)) -def _get_json_text(key: str, text: bytes, decode: bool) -> str: - if key.endswith(".gz"): - return gzip.decompress(text) - if key.endswith(".zip"): - with zipfile.ZipFile(io.BytesIO(text)) as zip_file: - return zip_file.read(zip_file.infolist()[0]).decode(UTF_8) - if decode: - return text.decode(UTF_8) - return text +@singledispatch +def read_avro_dispatch(config_source: Any, key: str, **kwargs): + raise DatalakeFileFormatException(config_source=config_source, file_name=key) -def read_from_json( - key: str, json_text: str, sample_size: int = 100, decode: bool = False -) -> List: +@read_avro_dispatch.register +def _(_: GCSConfig, key: str, bucket_name: str, client, **kwargs): """ - Read the json file from the azure container and return a dataframe + Read the avro file from the gcs bucket and return a dataframe """ + from metadata.utils.datalake.datalake_utils import dataframe_to_chunks - # pylint: disable=import-outside-toplevel - from pandas import json_normalize + avro_text = client.get_bucket(bucket_name).get_blob(key).download_as_string() + return dataframe_to_chunks(read_from_avro(avro_text).dataframes) - json_text = _get_json_text(key, json_text, decode) - try: - data = json.loads(json_text) - except json.decoder.JSONDecodeError: - logger.debug("Failed to read as JSON object trying to read as JSON Lines") - data = [ - json.loads(json_obj) - for json_obj in json_text.strip().split("\n")[:sample_size] - ] - if isinstance(data, list): - return [json_normalize(data[:sample_size], sep=COMPLEX_COLUMN_SEPARATOR)] - return [json_normalize(data, sep=COMPLEX_COLUMN_SEPARATOR)] +@read_avro_dispatch.register +def _(_: S3Config, key: str, bucket_name: str, client, **kwargs): + from metadata.utils.datalake.datalake_utils import dataframe_to_chunks + + avro_text = client.get_object(Bucket=bucket_name, Key=key)["Body"].read() + return dataframe_to_chunks(read_from_avro(avro_text).dataframes) + + +@read_avro_dispatch.register +def _(_: AzureConfig, key: str, bucket_name: str, client, **kwargs): + from metadata.utils.datalake.datalake_utils import dataframe_to_chunks + + container_client = client.get_container_client(bucket_name) + avro_text = container_client.get_blob_client(key).download_blob().readall() + return dataframe_to_chunks(read_from_avro(avro_text).dataframes) diff --git a/ingestion/src/metadata/utils/datalake/csv_tsv_dispatch.py b/ingestion/src/metadata/utils/datalake/csv_tsv_dispatch.py new file mode 100644 index 00000000000..937936ffcae --- /dev/null +++ b/ingestion/src/metadata/utils/datalake/csv_tsv_dispatch.py @@ -0,0 +1,128 @@ +# Copyright 2021 Collate +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Module to define helper methods for datalake and to fetch data and metadata +from Csv and Tsv file formats +""" +from functools import singledispatch +from typing import Any + +import pandas as pd + +from metadata.generated.schema.entity.services.connections.database.datalake.azureConfig import ( + AzureConfig, +) +from metadata.generated.schema.entity.services.connections.database.datalake.gcsConfig import ( + GCSConfig, +) +from metadata.generated.schema.entity.services.connections.database.datalake.s3Config import ( + S3Config, +) +from metadata.utils.constants import CHUNKSIZE +from metadata.utils.datalake.datalake_utils import DatalakeFileFormatException +from metadata.utils.logger import utils_logger + +logger = utils_logger() + +TSV_SEPARATOR = "\t" +CSV_SEPARATOR = "," + + +def read_from_pandas(path: str, separator: str, storage_options=None): + chunk_list = [] + with pd.read_csv( + path, sep=separator, chunksize=CHUNKSIZE, storage_options=storage_options + ) as reader: + for chunks in reader: + chunk_list.append(chunks) + return chunk_list + + +@singledispatch +def read_csv_dispatch(config_source: Any, key: str, **kwargs): + raise DatalakeFileFormatException(config_source=config_source, file_name=key) + + +@singledispatch +def read_tsv_dispatch(config_source: Any, key: str, **kwargs): + raise DatalakeFileFormatException(config_source=config_source, file_name=key) + + +@read_csv_dispatch.register +def _(_: GCSConfig, key: str, bucket_name: str, **kwargs): + """ + Read the CSV file from the gcs bucket and return a dataframe + """ + path = f"gs://{bucket_name}/{key}" + return read_from_pandas(path=path, separator=CSV_SEPARATOR) + + +@read_csv_dispatch.register +def _(_: S3Config, key: str, bucket_name: str, client, **kwargs): + path = client.get_object(Bucket=bucket_name, Key=key)["Body"] + return read_from_pandas(path=path, separator=CSV_SEPARATOR) + + +@read_csv_dispatch.register +def _(config_source: AzureConfig, key: str, bucket_name: str, **kwargs): + from metadata.utils.datalake.datalake_utils import ( + AZURE_PATH, + return_azure_storage_options, + ) + + storage_options = return_azure_storage_options(config_source) + path = AZURE_PATH.format( + bucket_name=bucket_name, + account_name=storage_options.get("account_name"), + key=key, + ) + return read_from_pandas( + path=path, + separator=CSV_SEPARATOR, + storage_options=storage_options, + ) + + +@read_tsv_dispatch.register +def _(_: GCSConfig, key: str, bucket_name: str, **kwargs): + """ + Read the TSV file from the gcs bucket and return a dataframe + """ + path = f"gs://{bucket_name}/{key}" + return read_from_pandas(path=path, separator=TSV_SEPARATOR) + + +@read_tsv_dispatch.register +def _(_: S3Config, key: str, bucket_name: str, client, **kwargs): + path = client.get_object(Bucket=bucket_name, Key=key)["Body"] + return read_from_pandas(path=path, separator=TSV_SEPARATOR) + + +@read_tsv_dispatch.register +def _(config_source: AzureConfig, key: str, bucket_name: str, **kwargs): + from metadata.utils.datalake.datalake_utils import ( + AZURE_PATH, + return_azure_storage_options, + ) + + storage_options = return_azure_storage_options(config_source) + + path = AZURE_PATH.format( + bucket_name=bucket_name, + account_name=storage_options.get("account_name"), + key=key, + ) + return read_from_pandas( + path=path, + separator=TSV_SEPARATOR, + storage_options=storage_options, + ) diff --git a/ingestion/src/metadata/utils/datalake/datalake_utils.py b/ingestion/src/metadata/utils/datalake/datalake_utils.py new file mode 100644 index 00000000000..3707d0e9b32 --- /dev/null +++ b/ingestion/src/metadata/utils/datalake/datalake_utils.py @@ -0,0 +1,119 @@ +# Copyright 2021 Collate +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Module to define helper methods for datalake and to fetch data and metadata +from different auths and different file systems. +""" + + +from enum import Enum +from typing import Any, Dict + +from metadata.ingestion.source.database.datalake.models import ( + DatalakeTableSchemaWrapper, +) +from metadata.utils.constants import CHUNKSIZE +from metadata.utils.logger import utils_logger + +logger = utils_logger() +COMPLEX_COLUMN_SEPARATOR = "_##" +AZURE_PATH = "abfs://{bucket_name}@{account_name}.dfs.core.windows.net/{key}" +logger = utils_logger() + + +class DatalakeFileFormatException(Exception): + def __init__(self, config_source: Any, file_name: str) -> None: + message = f"Missing implementation for {config_source.__class__.__name__} for {file_name}" + super().__init__(message) + + +class FILE_FORMAT_DISPATCH_MAP: + @classmethod + def fetch_dispatch(cls): + from metadata.utils.datalake.avro_dispatch import read_avro_dispatch + from metadata.utils.datalake.csv_tsv_dispatch import ( + read_csv_dispatch, + read_tsv_dispatch, + ) + from metadata.utils.datalake.json_dispatch import read_json_dispatch + from metadata.utils.datalake.parquet_dispatch import read_parquet_dispatch + + return { + SUPPORTED_TYPES.CSV: read_csv_dispatch, + SUPPORTED_TYPES.TSV: read_tsv_dispatch, + SUPPORTED_TYPES.AVRO: read_avro_dispatch, + SUPPORTED_TYPES.PARQUET: read_parquet_dispatch, + SUPPORTED_TYPES.JSON: read_json_dispatch, + SUPPORTED_TYPES.JSONGZ: read_json_dispatch, + SUPPORTED_TYPES.JSONZIP: read_json_dispatch, + } + + +class SUPPORTED_TYPES(Enum): + CSV = "csv" + TSV = "tsv" + AVRO = "avro" + PARQUET = "parquet" + JSON = "json" + JSONGZ = "json.gz" + JSONZIP = "json.zip" + + @property + def return_dispatch(self): + return FILE_FORMAT_DISPATCH_MAP.fetch_dispatch().get(self) + + +def return_azure_storage_options(config_source: Any) -> Dict: + connection_args = config_source.securityConfig + return { + "tenant_id": connection_args.tenantId, + "client_id": connection_args.clientId, + "account_name": connection_args.accountName, + "client_secret": connection_args.clientSecret.get_secret_value(), + } + + +def dataframe_to_chunks(df): + """ + Reads the Dataframe and returns list of dataframes broken down in chunks + """ + return [ + df[range_iter : range_iter + CHUNKSIZE] + for range_iter in range(0, len(df), CHUNKSIZE) + ] + + +def fetch_dataframe( + config_source, client, file_fqn: DatalakeTableSchemaWrapper, **kwargs +): + """ + Method to get dataframe for profiling + """ + # dispatch to handle fetching of data from multiple file formats (csv, tsv, json, avro and parquet) + key: str = file_fqn.key + bucket_name: str = file_fqn.bucket_name + + try: + for supported_types_enum in SUPPORTED_TYPES: + if key.endswith(supported_types_enum.value): + return supported_types_enum.return_dispatch( + config_source, + key=key, + bucket_name=bucket_name, + client=client, + **kwargs, + ) + except Exception as err: + logger.error( + f"Error fetching file {bucket_name}/{key} using {config_source.__class__.__name__} due to: {err}" + ) + return None diff --git a/ingestion/src/metadata/utils/datalake/json_dispatch.py b/ingestion/src/metadata/utils/datalake/json_dispatch.py new file mode 100644 index 00000000000..0ca8a99ccba --- /dev/null +++ b/ingestion/src/metadata/utils/datalake/json_dispatch.py @@ -0,0 +1,100 @@ +# Copyright 2021 Collate +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Module to define helper methods for datalake and to fetch data and metadata +from Json file formats +""" +import gzip +import io +import json +import zipfile +from functools import singledispatch +from typing import Any, List + +from metadata.generated.schema.entity.services.connections.database.datalake.azureConfig import ( + AzureConfig, +) +from metadata.generated.schema.entity.services.connections.database.datalake.gcsConfig import ( + GCSConfig, +) +from metadata.generated.schema.entity.services.connections.database.datalake.s3Config import ( + S3Config, +) +from metadata.utils.constants import UTF_8 +from metadata.utils.datalake.datalake_utils import DatalakeFileFormatException +from metadata.utils.logger import utils_logger + +logger = utils_logger() + + +def _get_json_text(key: str, text: bytes, decode: bool) -> str: + if key.endswith(".gz"): + return gzip.decompress(text) + if key.endswith(".zip"): + with zipfile.ZipFile(io.BytesIO(text)) as zip_file: + return zip_file.read(zip_file.infolist()[0]).decode(UTF_8) + if decode: + return text.decode(UTF_8) + return text + + +def read_from_json( + key: str, json_text: str, decode: bool = False, is_profiler: bool = False, **_ +) -> List: + """ + Read the json file from the azure container and return a dataframe + """ + + # pylint: disable=import-outside-toplevel + from pandas import json_normalize + + from metadata.utils.datalake.datalake_utils import ( + COMPLEX_COLUMN_SEPARATOR, + dataframe_to_chunks, + ) + + json_text = _get_json_text(key, json_text, decode) + try: + data = json.loads(json_text) + except json.decoder.JSONDecodeError: + logger.debug("Failed to read as JSON object trying to read as JSON Lines") + data = [json.loads(json_obj) for json_obj in json_text.strip().split("\n")] + if is_profiler: + return dataframe_to_chunks(json_normalize(data)) + return dataframe_to_chunks(json_normalize(data, sep=COMPLEX_COLUMN_SEPARATOR)) + + +@singledispatch +def read_json_dispatch(config_source: Any, key: str, **kwargs): + raise DatalakeFileFormatException(config_source=config_source, file_name=key) + + +@read_json_dispatch.register +def _(_: GCSConfig, key: str, bucket_name: str, client, **kwargs): + """ + Read the json file from the gcs bucket and return a dataframe + """ + json_text = client.get_bucket(bucket_name).get_blob(key).download_as_string() + return read_from_json(key=key, json_text=json_text, decode=True, **kwargs) + + +@read_json_dispatch.register +def _(_: S3Config, key: str, bucket_name: str, client, **kwargs): + json_text = client.get_object(Bucket=bucket_name, Key=key)["Body"].read() + return read_from_json(key=key, json_text=json_text, decode=True, **kwargs) + + +@read_json_dispatch.register +def _(_: AzureConfig, key: str, bucket_name: str, client, **kwargs): + container_client = client.get_container_client(bucket_name) + json_text = container_client.get_blob_client(key).download_blob().readall() + return read_from_json(key=key, json_text=json_text, decode=True, **kwargs) diff --git a/ingestion/src/metadata/utils/datalake/parquet_dispatch.py b/ingestion/src/metadata/utils/datalake/parquet_dispatch.py new file mode 100644 index 00000000000..3e11fea1696 --- /dev/null +++ b/ingestion/src/metadata/utils/datalake/parquet_dispatch.py @@ -0,0 +1,110 @@ +# Copyright 2021 Collate +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Module to define helper methods for datalake and to fetch data and metadata +from Parquet file formats +""" + + +from functools import singledispatch +from typing import Any + +import pandas as pd + +from metadata.generated.schema.entity.services.connections.database.datalake.azureConfig import ( + AzureConfig, +) +from metadata.generated.schema.entity.services.connections.database.datalake.gcsConfig import ( + GCSConfig, +) +from metadata.generated.schema.entity.services.connections.database.datalake.s3Config import ( + S3Config, +) +from metadata.utils.datalake.datalake_utils import DatalakeFileFormatException +from metadata.utils.logger import utils_logger + +logger = utils_logger() + + +@singledispatch +def read_parquet_dispatch(config_source: Any, key: str, **kwargs): + raise DatalakeFileFormatException(config_source=config_source, file_name=key) + + +@read_parquet_dispatch.register +def _(_: GCSConfig, key: str, bucket_name: str, **kwargs): + """ + Read the parquet file from the gcs bucket and return a dataframe + """ + # pylint: disable=import-outside-toplevel + from gcsfs import GCSFileSystem + from pyarrow.parquet import ParquetFile + + from metadata.utils.datalake.datalake_utils import dataframe_to_chunks + + gcs = GCSFileSystem() + file = gcs.open(f"gs://{bucket_name}/{key}") + dataframe_response = ( + ParquetFile(file).read().to_pandas(split_blocks=True, self_destruct=True) + ) + return dataframe_to_chunks(dataframe_response) + + +@read_parquet_dispatch.register +def _(_: S3Config, key: str, bucket_name: str, connection_kwargs, **kwargs): + """ + Read the parquet file from the s3 bucket and return a dataframe + """ + # pylint: disable=import-outside-toplevel + import s3fs + from pyarrow.parquet import ParquetDataset + + from metadata.utils.datalake.datalake_utils import dataframe_to_chunks + + client_kwargs = {} + client = connection_kwargs + if client.endPointURL: + client_kwargs["endpoint_url"] = client.endPointURL + + if client.awsRegion: + client_kwargs["region_name"] = client.awsRegion + + s3_fs = s3fs.S3FileSystem(client_kwargs=client_kwargs) + + if client.awsAccessKeyId and client.awsSecretAccessKey: + s3_fs = s3fs.S3FileSystem( + key=client.awsAccessKeyId, + secret=client.awsSecretAccessKey.get_secret_value(), + token=client.awsSessionToken, + client_kwargs=client_kwargs, + ) + bucket_uri = f"s3://{bucket_name}/{key}" + dataset = ParquetDataset(bucket_uri, filesystem=s3_fs) + return dataframe_to_chunks(dataset.read_pandas().to_pandas()) + + +@read_parquet_dispatch.register +def _(config_source: AzureConfig, key: str, bucket_name: str, **kwargs): + from metadata.utils.datalake.datalake_utils import ( + AZURE_PATH, + dataframe_to_chunks, + return_azure_storage_options, + ) + + storage_options = return_azure_storage_options(config_source) + account_url = AZURE_PATH.format( + bucket_name=bucket_name, + account_name=storage_options.get("account_name"), + key=key, + ) + dataframe = pd.read_parquet(account_url, storage_options=storage_options) + return dataframe_to_chunks(dataframe) diff --git a/ingestion/src/metadata/utils/gcs_utils.py b/ingestion/src/metadata/utils/gcs_utils.py deleted file mode 100644 index efb36241d7d..00000000000 --- a/ingestion/src/metadata/utils/gcs_utils.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2021 Collate -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Utils module to convert different file types from gcs buckets into a dataframe -""" - -import traceback -from typing import Any - -import gcsfs -import pandas as pd -from pandas import DataFrame -from pyarrow.parquet import ParquetFile - -from metadata.ingestion.source.database.datalake.utils import ( - read_from_avro, - read_from_json, -) -from metadata.utils.constants import CHUNKSIZE -from metadata.utils.logger import utils_logger - -logger = utils_logger() - - -def get_file_text(client: Any, key: str, bucket_name: str): - bucket = client.get_bucket(bucket_name) - return bucket.get_blob(key).download_as_string() - - -def read_csv_from_gcs( # pylint: disable=inconsistent-return-statements - key: str, bucket_name: str -) -> DataFrame: - """ - Read the csv file from the gcs bucket and return a dataframe - """ - - try: - chunk_list = [] - with pd.read_csv( - f"gs://{bucket_name}/{key}", sep=",", chunksize=CHUNKSIZE - ) as reader: - for chunks in reader: - chunk_list.append(chunks) - return chunk_list - except Exception as exc: - logger.debug(traceback.format_exc()) - logger.warning(f"Error reading CSV from GCS - {exc}") - - -def read_tsv_from_gcs( # pylint: disable=inconsistent-return-statements - key: str, bucket_name: str -) -> DataFrame: - """ - Read the tsv file from the gcs bucket and return a dataframe - """ - try: - chunk_list = [] - with pd.read_csv( - f"gs://{bucket_name}/{key}", sep="\t", chunksize=CHUNKSIZE - ) as reader: - for chunks in reader: - chunk_list.append(chunks) - return chunk_list - except Exception as exc: - logger.debug(traceback.format_exc()) - logger.warning(f"Error reading CSV from GCS - {exc}") - - -def read_json_from_gcs(client: Any, key: str, bucket_name: str) -> DataFrame: - """ - Read the json file from the gcs bucket and return a dataframe - """ - json_text = get_file_text(client=client, key=key, bucket_name=bucket_name) - return read_from_json(key=key, json_text=json_text, decode=True) - - -def read_parquet_from_gcs(key: str, bucket_name: str) -> DataFrame: - """ - Read the parquet file from the gcs bucket and return a dataframe - """ - - gcs = gcsfs.GCSFileSystem() - file = gcs.open(f"gs://{bucket_name}/{key}") - return [ParquetFile(file).read().to_pandas()] - - -def read_avro_from_gcs(client: Any, key: str, bucket_name: str) -> DataFrame: - """ - Read the avro file from the gcs bucket and return a dataframe - """ - avro_text = get_file_text(client=client, key=key, bucket_name=bucket_name) - return read_from_avro(avro_text) diff --git a/ingestion/src/metadata/utils/s3_utils.py b/ingestion/src/metadata/utils/s3_utils.py deleted file mode 100644 index 0238f228161..00000000000 --- a/ingestion/src/metadata/utils/s3_utils.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright 2021 Collate -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Utils module to convert different file types from s3 buckets into a dataframe -""" - -import traceback -from typing import Any - -import pandas as pd -import pyarrow.parquet as pq -import s3fs - -from metadata.ingestion.source.database.datalake.utils import ( - read_from_avro, - read_from_json, -) -from metadata.utils.constants import CHUNKSIZE -from metadata.utils.logger import utils_logger - -logger = utils_logger() - - -def get_file_text(client: Any, key: str, bucket_name: str): - obj = client.get_object(Bucket=bucket_name, Key=key) - return obj["Body"].read() - - -def read_csv_from_s3( - client: Any, - key: str, - bucket_name: str, - sep: str = ",", -): - """ - Read the csv file from the s3 bucket and return a dataframe - """ - try: - stream = client.get_object(Bucket=bucket_name, Key=key)["Body"] - chunk_list = [] - with pd.read_csv(stream, sep=sep, chunksize=CHUNKSIZE) as reader: - for chunks in reader: - chunk_list.append(chunks) - return chunk_list - except Exception as exc: - logger.debug(traceback.format_exc()) - logger.warning(f"Error reading CSV from s3 - {exc}") - return None - - -def read_tsv_from_s3( - client, - key: str, - bucket_name: str, -): - """ - Read the tsv file from the s3 bucket and return a dataframe - """ - try: - return read_csv_from_s3(client, key, bucket_name, sep="\t") - except Exception as exc: - logger.debug(traceback.format_exc()) - logger.warning(f"Error reading TSV from s3 - {exc}") - return None - - -def read_json_from_s3(client: Any, key: str, bucket_name: str, sample_size=100): - """ - Read the json file from the s3 bucket and return a dataframe - """ - json_text = get_file_text(client=client, key=key, bucket_name=bucket_name) - return read_from_json( - key=key, json_text=json_text, sample_size=sample_size, decode=True - ) - - -def read_parquet_from_s3(client: Any, key: str, bucket_name: str): - """ - Read the parquet file from the s3 bucket and return a dataframe - """ - client_kwargs = {} - if client.endPointURL: - client_kwargs["endpoint_url"] = client.endPointURL - - if client.awsRegion: - client_kwargs["region_name"] = client.awsRegion - - s3_fs = s3fs.S3FileSystem(client_kwargs=client_kwargs) - - if client.awsAccessKeyId and client.awsSecretAccessKey: - s3_fs = s3fs.S3FileSystem( - key=client.awsAccessKeyId, - secret=client.awsSecretAccessKey.get_secret_value(), - token=client.awsSessionToken, - client_kwargs=client_kwargs, - ) - bucket_uri = f"s3://{bucket_name}/{key}" - dataset = pq.ParquetDataset(bucket_uri, filesystem=s3_fs) - return [dataset.read_pandas().to_pandas()] - - -def read_avro_from_s3(client: Any, key: str, bucket_name: str): - """ - Read the avro file from the s3 bucket and return a dataframe - """ - return read_from_avro( - get_file_text(client=client, key=key, bucket_name=bucket_name) - ) diff --git a/ingestion/tests/cli_e2e/test_cli_datalake.py b/ingestion/tests/cli_e2e/test_cli_datalake.py new file mode 100644 index 00000000000..acc645e97d2 --- /dev/null +++ b/ingestion/tests/cli_e2e/test_cli_datalake.py @@ -0,0 +1,102 @@ +# Copyright 2022 Collate +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test Datalake connector with CLI +""" +from typing import List + +from .common.test_cli_db import CliCommonDB +from .common_e2e_sqa_mixins import SQACommonMethods + + +class MysqlCliTest(CliCommonDB.TestSuite, SQACommonMethods): + create_table_query: str = """ + CREATE TABLE persons ( + person_id int, + full_name varchar(255) + ) + """ + + create_view_query: str = """ + CREATE VIEW view_persons AS + SELECT * + FROM openmetadata_db.persons; + """ + + insert_data_queries: List[str] = [ + "INSERT INTO persons (person_id, full_name) VALUES (1,'Peter Parker');", + "INSERT INTO persons (person_id, full_name) VALUES (1, 'Clark Kent');", + ] + + drop_table_query: str = """ + DROP TABLE IF EXISTS openmetadata_db.persons; + """ + + drop_view_query: str = """ + DROP VIEW IF EXISTS openmetadata_db.view_persons; + """ + + @staticmethod + def get_connector_name() -> str: + return "mysql" + + def create_table_and_view(self) -> None: + SQACommonMethods.create_table_and_view(self) + + def delete_table_and_view(self) -> None: + SQACommonMethods.delete_table_and_view(self) + + @staticmethod + def expected_tables() -> int: + return 49 + + def inserted_rows_count(self) -> int: + return len(self.insert_data_queries) + + def view_column_lineage_count(self) -> int: + return 2 + + @staticmethod + def fqn_created_table() -> str: + return "local_mysql.default.openmetadata_db.persons" + + @staticmethod + def get_includes_schemas() -> List[str]: + return ["openmetadata_db.*"] + + @staticmethod + def get_includes_tables() -> List[str]: + return ["entity_*"] + + @staticmethod + def get_excludes_tables() -> List[str]: + return [".*bot.*"] + + @staticmethod + def expected_filtered_schema_includes() -> int: + return 0 + + @staticmethod + def expected_filtered_schema_excludes() -> int: + return 1 + + @staticmethod + def expected_filtered_table_includes() -> int: + return 48 + + @staticmethod + def expected_filtered_table_excludes() -> int: + return 4 + + @staticmethod + def expected_filtered_mix() -> int: + return 48 diff --git a/ingestion/tests/unit/test_ometa_to_dataframe.py b/ingestion/tests/unit/test_ometa_to_dataframe.py index a3af2f4e828..6ad6dc7684f 100644 --- a/ingestion/tests/unit/test_ometa_to_dataframe.py +++ b/ingestion/tests/unit/test_ometa_to_dataframe.py @@ -24,9 +24,9 @@ from metadata.generated.schema.entity.services.connections.database.datalakeConn from metadata.generated.schema.metadataIngestion.workflow import ( OpenMetadataWorkflowConfig, ) +from metadata.generated.schema.type.entityReference import EntityReference from metadata.ingestion.source.database.datalake.metadata import DatalakeSource from metadata.mixins.pandas.pandas_mixin import PandasInterfaceMixin -from metadata.utils.gcs_utils import read_parquet_from_gcs from .topology.database.test_datalake import mock_datalake_config @@ -41,7 +41,7 @@ method_resp_file = [resp_parquet_file] class TestStringMethods(unittest.TestCase): def test_dl_column_parser(self): with patch( - "metadata.utils.gcs_utils.read_parquet_from_gcs", + "metadata.utils.datalake.datalake_utils.fetch_dataframe", return_value=method_resp_file, ) as exec_mock_method: resp = exec_mock_method("key", "string") @@ -52,7 +52,7 @@ class TestStringMethods(unittest.TestCase): ) def test_return_ometa_dataframes_sampled(self, test_connection): with patch( - "metadata.mixins.pandas.pandas_mixin.ometa_to_dataframe", + "metadata.mixins.pandas.pandas_mixin.fetch_dataframe", return_value=[resp_parquet_file], ): config = OpenMetadataWorkflowConfig.parse_obj(mock_datalake_config) @@ -61,7 +61,19 @@ class TestStringMethods(unittest.TestCase): config.workflowConfig.openMetadataServerConfig, ) resp = PandasInterfaceMixin().return_ometa_dataframes_sampled( - datalake_source.service_connection, None, None, None + service_connection_config=datalake_source.service_connection, + table=Table( + id="cec14ccf-123f-4271-8c90-0ae54cc4227e", + columns=[], + name="test", + databaseSchema=EntityReference( + name="Test", + id="cec14ccf-123f-4271-8c90-0ae54cc4227e", + type="databaseSchema", + ), + ), + client=None, + profile_sample_config=None, ) assert resp == method_resp_file @@ -72,7 +84,7 @@ class TestStringMethods(unittest.TestCase): ) def test_return_ometa_dataframes_sampled_fail(self, test_connection): with patch( - "metadata.mixins.pandas.pandas_mixin.ometa_to_dataframe", + "metadata.mixins.pandas.pandas_mixin.fetch_dataframe", return_value=None, ): with self.assertRaises(TypeError) as context: @@ -83,12 +95,18 @@ class TestStringMethods(unittest.TestCase): ) resp = PandasInterfaceMixin().return_ometa_dataframes_sampled( service_connection_config=datalake_source.service_connection, - client=None, table=Table( - id="1dabab2c-0d15-41ca-a834-7c0421d9c951", - name="test", + id="cec14ccf-123f-4271-8c90-0ae54cc4227e", columns=[], + name="test", + databaseSchema=EntityReference( + name="Test", + id="cec14ccf-123f-4271-8c90-0ae54cc4227e", + type="databaseSchema", + ), ), + client=None, profile_sample_config=None, ) + self.assertEqual(context.exception.args[0], "Couldn't fetch test") diff --git a/ingestion/tests/unit/topology/database/test_datalake.py b/ingestion/tests/unit/topology/database/test_datalake.py index c75ce1552fb..50d05624cb3 100644 --- a/ingestion/tests/unit/topology/database/test_datalake.py +++ b/ingestion/tests/unit/topology/database/test_datalake.py @@ -31,10 +31,8 @@ from metadata.generated.schema.metadataIngestion.workflow import ( ) from metadata.generated.schema.type.entityReference import EntityReference from metadata.ingestion.source.database.datalake.metadata import DatalakeSource -from metadata.ingestion.source.database.datalake.utils import ( - read_from_avro, - read_from_json, -) +from metadata.utils.datalake.avro_dispatch import read_from_avro +from metadata.utils.datalake.json_dispatch import read_from_json mock_datalake_config = { "source": { diff --git a/ingestion/tests/unit/topology/storage/test_storage.py b/ingestion/tests/unit/topology/storage/test_storage.py index 4767776f556..64f532b0b6a 100644 --- a/ingestion/tests/unit/topology/storage/test_storage.py +++ b/ingestion/tests/unit/topology/storage/test_storage.py @@ -36,7 +36,6 @@ from metadata.generated.schema.metadataIngestion.workflow import ( ) from metadata.generated.schema.type.entityReference import EntityReference from metadata.ingestion.api.source import InvalidSourceException -from metadata.ingestion.source.database.datalake.metadata import DatalakeSource from metadata.ingestion.source.storage.s3.metadata import ( S3BucketResponse, S3ContainerDetails, @@ -249,34 +248,37 @@ class StorageUnitTest(TestCase): # Most of the parsing support are covered in test_datalake unit tests related to the Data lake implementation def test_extract_column_definitions(self): - DatalakeSource.get_s3_files = lambda client, key, bucket_name, client_kwargs: [ - pd.DataFrame.from_dict( - [ - {"transaction_id": 1, "transaction_value": 100}, - {"transaction_id": 2, "transaction_value": 200}, - {"transaction_id": 3, "transaction_value": 300}, - ] - ) - ] - self.assertListEqual( - [ - Column( - name=ColumnName(__root__="transaction_id"), - dataType=DataType.INT, - dataTypeDisplay="INT", - dataLength=1, - ), - Column( - name=ColumnName(__root__="transaction_value"), - dataType=DataType.INT, - dataTypeDisplay="INT", - dataLength=1, - ), + with patch( + "metadata.ingestion.source.storage.s3.metadata.fetch_dataframe", + return_value=[ + pd.DataFrame.from_dict( + [ + {"transaction_id": 1, "transaction_value": 100}, + {"transaction_id": 2, "transaction_value": 200}, + {"transaction_id": 3, "transaction_value": 300}, + ] + ) ], - self.object_store_source.extract_column_definitions( - bucket_name="test_bucket", sample_key="test.json" - ), - ) + ): + self.assertListEqual( + [ + Column( + name=ColumnName(__root__="transaction_id"), + dataType=DataType.INT, + dataTypeDisplay="INT", + dataLength=1, + ), + Column( + name=ColumnName(__root__="transaction_value"), + dataType=DataType.INT, + dataTypeDisplay="INT", + dataLength=1, + ), + ], + self.object_store_source.extract_column_definitions( + bucket_name="test_bucket", sample_key="test.json" + ), + ) def test_get_sample_file_prefix_for_structured_and_partitioned_metadata(self): input_metadata = MetadataEntry(