mirror of
				https://github.com/open-metadata/OpenMetadata.git
				synced 2025-11-04 04:29:13 +00:00 
			
		
		
		
	* linting: fix python linting * fix: get column types from parquet schema for parquet files * style: python linting * fix: remove displayType check in test as variation depending on OS
This commit is contained in:
		
							parent
							
								
									6838fadec6
								
							
						
					
					
						commit
						9a4a9df836
					
				@ -54,7 +54,6 @@ class TestCaseRunner(Processor):
 | 
			
		||||
    """Execute the test suite tests and create test cases from the YAML config"""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config: OpenMetadataWorkflowConfig, metadata: OpenMetadata):
 | 
			
		||||
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.config = config
 | 
			
		||||
 | 
			
		||||
@ -149,7 +149,6 @@ class TestSuiteSource(Source):
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
 | 
			
		||||
            test_suite_cases = self._get_test_cases_from_test_suite(table.testSuite)
 | 
			
		||||
 | 
			
		||||
            yield Either(
 | 
			
		||||
 | 
			
		||||
@ -17,7 +17,7 @@ from typing import Optional
 | 
			
		||||
 | 
			
		||||
from metadata.profiler.metrics.core import add_props
 | 
			
		||||
from metadata.profiler.metrics.registry import Metrics
 | 
			
		||||
from metadata.utils.datalake.datalake_utils import fetch_col_types
 | 
			
		||||
from metadata.utils.datalake.datalake_utils import GenericDataFrameColumnParser
 | 
			
		||||
from metadata.utils.entity_link import get_decoded_column
 | 
			
		||||
from metadata.utils.sqa_like_column import SQALikeColumn
 | 
			
		||||
 | 
			
		||||
@ -28,7 +28,9 @@ class PandasValidatorMixin:
 | 
			
		||||
    def get_column_name(self, entity_link: str, dfs) -> SQALikeColumn:
 | 
			
		||||
        # we'll use the first dataframe chunk to get the column name.
 | 
			
		||||
        column = dfs[0][get_decoded_column(entity_link)]
 | 
			
		||||
        _type = fetch_col_types(dfs[0], get_decoded_column(entity_link))
 | 
			
		||||
        _type = GenericDataFrameColumnParser.fetch_col_types(
 | 
			
		||||
            dfs[0], get_decoded_column(entity_link)
 | 
			
		||||
        )
 | 
			
		||||
        sqa_like_column = SQALikeColumn(
 | 
			
		||||
            name=column.name,
 | 
			
		||||
            type=_type,
 | 
			
		||||
 | 
			
		||||
@ -389,7 +389,6 @@ class LineageParser:
 | 
			
		||||
    def _evaluate_best_parser(
 | 
			
		||||
        query: str, dialect: Dialect, timeout_seconds: int
 | 
			
		||||
    ) -> Optional[LineageRunner]:
 | 
			
		||||
 | 
			
		||||
        if query is None:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -210,9 +210,9 @@ def _determine_restricted_operation(
 | 
			
		||||
    Only retain add operation for restrict_update_fields fields
 | 
			
		||||
    """
 | 
			
		||||
    path = patch_ops.get("path")
 | 
			
		||||
    op = patch_ops.get("op")
 | 
			
		||||
    ops = patch_ops.get("op")
 | 
			
		||||
    for field in restrict_update_fields or []:
 | 
			
		||||
        if field in path and op != PatchOperation.ADD.value:
 | 
			
		||||
        if field in path and ops != PatchOperation.ADD.value:
 | 
			
		||||
            return False
 | 
			
		||||
    return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -135,7 +135,6 @@ class OMetaPatchMixin(OMetaPatchMixinBase):
 | 
			
		||||
            Updated Entity
 | 
			
		||||
        """
 | 
			
		||||
        try:
 | 
			
		||||
 | 
			
		||||
            patch = build_patch(
 | 
			
		||||
                source=source,
 | 
			
		||||
                destination=destination,
 | 
			
		||||
 | 
			
		||||
@ -531,7 +531,6 @@ class DashboardServiceSource(TopologyRunnerMixin, Source, ABC):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def check_database_schema_name(self, database_schema_name: str):
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        Check if the input database schema name is equal to "<default>" and return the input name if it is not.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -833,7 +833,6 @@ class LookerSource(DashboardServiceSource):
 | 
			
		||||
                    to_entity.id.__root__
 | 
			
		||||
                    not in self._added_lineage[from_entity.id.__root__]
 | 
			
		||||
                ):
 | 
			
		||||
 | 
			
		||||
                    self._added_lineage[from_entity.id.__root__].append(
 | 
			
		||||
                        to_entity.id.__root__
 | 
			
		||||
                    )
 | 
			
		||||
@ -943,7 +942,6 @@ class LookerSource(DashboardServiceSource):
 | 
			
		||||
        dashboard_name = self.context.dashboard
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
 | 
			
		||||
            dashboard_fqn = fqn.build(
 | 
			
		||||
                metadata=self.metadata,
 | 
			
		||||
                entity_type=Dashboard,
 | 
			
		||||
 | 
			
		||||
@ -192,7 +192,6 @@ class SupersetAPISource(SupersetSourceMixin):
 | 
			
		||||
    def yield_datamodel(
 | 
			
		||||
        self, dashboard_details: DashboardResult
 | 
			
		||||
    ) -> Iterable[Either[CreateDashboardDataModelRequest]]:
 | 
			
		||||
 | 
			
		||||
        if self.source_config.includeDataModels:
 | 
			
		||||
            for chart_id in self._get_charts_of_dashboard(dashboard_details):
 | 
			
		||||
                try:
 | 
			
		||||
 | 
			
		||||
@ -216,7 +216,6 @@ class SupersetDBSource(SupersetSourceMixin):
 | 
			
		||||
    def yield_datamodel(
 | 
			
		||||
        self, dashboard_details: FetchDashboard
 | 
			
		||||
    ) -> Iterable[Either[CreateDashboardDataModelRequest]]:
 | 
			
		||||
 | 
			
		||||
        if self.source_config.includeDataModels:
 | 
			
		||||
            for chart_id in self._get_charts_of_dashboard(dashboard_details):
 | 
			
		||||
                chart_json = self.all_charts.get(chart_id)
 | 
			
		||||
 | 
			
		||||
@ -77,7 +77,6 @@ class AzuresqlSource(CommonDbSourceService, MultiDBSource):
 | 
			
		||||
        yield from self._execute_database_query(AZURE_SQL_GET_DATABASES)
 | 
			
		||||
 | 
			
		||||
    def get_database_names(self) -> Iterable[str]:
 | 
			
		||||
 | 
			
		||||
        if not self.config.serviceConnection.__root__.config.ingestAllDatabases:
 | 
			
		||||
            configured_db = self.config.serviceConnection.__root__.config.database
 | 
			
		||||
            self.set_inspector(database_name=configured_db)
 | 
			
		||||
 | 
			
		||||
@ -46,7 +46,7 @@ from metadata.ingestion.source.database.database_service import DatabaseServiceS
 | 
			
		||||
from metadata.ingestion.source.database.stored_procedures_mixin import QueryByProcedure
 | 
			
		||||
from metadata.utils import fqn
 | 
			
		||||
from metadata.utils.constants import DEFAULT_DATABASE
 | 
			
		||||
from metadata.utils.datalake.datalake_utils import get_columns
 | 
			
		||||
from metadata.utils.datalake.datalake_utils import DataFrameColumnParser
 | 
			
		||||
from metadata.utils.filters import filter_by_schema, filter_by_table
 | 
			
		||||
from metadata.utils.logger import ingestion_logger
 | 
			
		||||
 | 
			
		||||
@ -217,7 +217,8 @@ class CommonNoSQLSource(DatabaseServiceSource, ABC):
 | 
			
		||||
        try:
 | 
			
		||||
            data = self.get_table_columns_dict(schema_name, table_name)
 | 
			
		||||
            df = pd.DataFrame.from_records(list(data))
 | 
			
		||||
            columns = get_columns(df)
 | 
			
		||||
            column_parser = DataFrameColumnParser.create(df)
 | 
			
		||||
            columns = column_parser.get_columns()
 | 
			
		||||
            table_request = CreateTableRequest(
 | 
			
		||||
                name=table_name,
 | 
			
		||||
                tableType=table_type,
 | 
			
		||||
 | 
			
		||||
@ -94,7 +94,6 @@ class DatabricksClient:
 | 
			
		||||
        Method returns List the history of queries through SQL warehouses
 | 
			
		||||
        """
 | 
			
		||||
        try:
 | 
			
		||||
 | 
			
		||||
            data = {}
 | 
			
		||||
            daydiff = end_date - start_date
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -77,8 +77,8 @@ from metadata.utils import fqn
 | 
			
		||||
from metadata.utils.constants import DEFAULT_DATABASE
 | 
			
		||||
from metadata.utils.credentials import GOOGLE_CREDENTIALS
 | 
			
		||||
from metadata.utils.datalake.datalake_utils import (
 | 
			
		||||
    DataFrameColumnParser,
 | 
			
		||||
    fetch_dataframe,
 | 
			
		||||
    get_columns,
 | 
			
		||||
    get_file_format_type,
 | 
			
		||||
)
 | 
			
		||||
from metadata.utils.filters import filter_by_database, filter_by_schema, filter_by_table
 | 
			
		||||
@ -416,9 +416,14 @@ class DatalakeSource(DatabaseServiceSource):
 | 
			
		||||
                    file_extension=table_extension,
 | 
			
		||||
                ),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            if data_frame:
 | 
			
		||||
                column_parser = DataFrameColumnParser.create(
 | 
			
		||||
                    data_frame[0], table_extension
 | 
			
		||||
                )
 | 
			
		||||
                columns = column_parser.get_columns()
 | 
			
		||||
            else:
 | 
			
		||||
                # If no data_frame (due to unsupported type), ignore
 | 
			
		||||
            columns = get_columns(data_frame[0]) if data_frame else None
 | 
			
		||||
                columns = None
 | 
			
		||||
            if columns:
 | 
			
		||||
                table_request = CreateTableRequest(
 | 
			
		||||
                    name=table_name,
 | 
			
		||||
 | 
			
		||||
@ -692,7 +692,6 @@ class DbtSource(DbtServiceSource):
 | 
			
		||||
        )
 | 
			
		||||
        if table_entity:
 | 
			
		||||
            try:
 | 
			
		||||
 | 
			
		||||
                service_name, database_name, schema_name, table_name = fqn.split(
 | 
			
		||||
                    table_entity.fullyQualifiedName.__root__
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
@ -126,7 +126,6 @@ def get_metastore_connection(connection: Any) -> Engine:
 | 
			
		||||
 | 
			
		||||
@get_metastore_connection.register
 | 
			
		||||
def _(connection: PostgresConnection):
 | 
			
		||||
 | 
			
		||||
    # import required to load sqlalchemy plugin
 | 
			
		||||
    # pylint: disable=import-outside-toplevel,unused-import
 | 
			
		||||
    from metadata.ingestion.source.database.hive.metastore_dialects.postgres import (  # nopycln: import
 | 
			
		||||
@ -153,7 +152,6 @@ def _(connection: PostgresConnection):
 | 
			
		||||
 | 
			
		||||
@get_metastore_connection.register
 | 
			
		||||
def _(connection: MysqlConnection):
 | 
			
		||||
 | 
			
		||||
    # import required to load sqlalchemy plugin
 | 
			
		||||
    # pylint: disable=import-outside-toplevel,unused-import
 | 
			
		||||
    from metadata.ingestion.source.database.hive.metastore_dialects.mysql import (  # nopycln: import
 | 
			
		||||
 | 
			
		||||
@ -114,7 +114,6 @@ class MssqlSource(StoredProcedureMixin, CommonDbSourceService, MultiDBSource):
 | 
			
		||||
        yield from self._execute_database_query(MSSQL_GET_DATABASE)
 | 
			
		||||
 | 
			
		||||
    def get_database_names(self) -> Iterable[str]:
 | 
			
		||||
 | 
			
		||||
        if not self.config.serviceConnection.__root__.config.ingestAllDatabases:
 | 
			
		||||
            configured_db = self.config.serviceConnection.__root__.config.database
 | 
			
		||||
            self.set_inspector(database_name=configured_db)
 | 
			
		||||
 | 
			
		||||
@ -62,7 +62,6 @@ def get_view_definition(
 | 
			
		||||
    dblink="",
 | 
			
		||||
    **kw,
 | 
			
		||||
):
 | 
			
		||||
 | 
			
		||||
    return get_view_definition_wrapper(
 | 
			
		||||
        self,
 | 
			
		||||
        connection,
 | 
			
		||||
 | 
			
		||||
@ -75,7 +75,6 @@ def get_lineage_from_multi_tenant_table(
 | 
			
		||||
    connection: any,
 | 
			
		||||
    service_name: str,
 | 
			
		||||
) -> Iterator[Either[AddLineageRequest]]:
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    For PGSpider, firstly, get list of multi-tenant tables.
 | 
			
		||||
    Next, get child foreign tables of each multi-tenant tables.
 | 
			
		||||
 | 
			
		||||
@ -800,7 +800,6 @@ class SampleDataSource(
 | 
			
		||||
 | 
			
		||||
        # Create table and stored procedure lineage
 | 
			
		||||
        for lineage_entities in self.stored_procedures["lineage"]:
 | 
			
		||||
 | 
			
		||||
            from_table = self.metadata.get_by_name(
 | 
			
		||||
                entity=Table, fqn=lineage_entities["from_table_fqn"]
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
@ -465,7 +465,6 @@ class SasSource(
 | 
			
		||||
                or table_entity.extension.__root__.get("analysisTimeStamp")
 | 
			
		||||
                != table_extension.get("analysisTimeStamp")
 | 
			
		||||
            ):
 | 
			
		||||
 | 
			
		||||
                # create the columns of the table
 | 
			
		||||
                columns, col_profile_list = self.create_columns_and_profiles(
 | 
			
		||||
                    col_entity_instances, table_entity_instance
 | 
			
		||||
@ -711,10 +710,10 @@ class SasSource(
 | 
			
		||||
                if "state" in table_resource and table_resource["state"] == "unloaded":
 | 
			
		||||
                    self.sas_client.load_table(table_uri + "/state?value=loaded")
 | 
			
		||||
 | 
			
		||||
            except HTTPError as e:
 | 
			
		||||
            except HTTPError as exc:
 | 
			
		||||
                # append http error to table description if it can't be found
 | 
			
		||||
                logger.error(f"table_uri: {table_uri}")
 | 
			
		||||
                self.report_description.append(str(e))
 | 
			
		||||
                self.report_description.append(str(exc))
 | 
			
		||||
                name_index = table_uri.rindex("/")
 | 
			
		||||
                table_name = table_uri[name_index + 1 :]
 | 
			
		||||
                param = f"filter=eq(name,'{table_name}')"
 | 
			
		||||
 | 
			
		||||
@ -39,7 +39,6 @@ from metadata.utils.sqlalchemy_utils import (
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _quoted_name(entity_name: Optional[str]) -> Optional[str]:
 | 
			
		||||
 | 
			
		||||
    if entity_name:
 | 
			
		||||
        return fqn.quote_name(entity_name)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -153,7 +153,6 @@ class StoredProcedureMixin(ABC):
 | 
			
		||||
            query_type=query_by_procedure.query_type,
 | 
			
		||||
            query_text=query_by_procedure.query_text,
 | 
			
		||||
        ):
 | 
			
		||||
 | 
			
		||||
            self.context.stored_procedure_query_lineage = True
 | 
			
		||||
            for either_lineage in get_lineage_by_query(
 | 
			
		||||
                self.metadata,
 | 
			
		||||
 | 
			
		||||
@ -246,7 +246,6 @@ class CommonBrokerSource(MessagingServiceSource, ABC):
 | 
			
		||||
                if messages:
 | 
			
		||||
                    for message in messages:
 | 
			
		||||
                        try:
 | 
			
		||||
 | 
			
		||||
                            value = message.value()
 | 
			
		||||
                            sample_data.append(
 | 
			
		||||
                                self.decode_message(
 | 
			
		||||
 | 
			
		||||
@ -131,7 +131,6 @@ class SplineSource(PipelineServiceSource):
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    def _get_table_from_datasource_name(self, datasource: str) -> Optional[Table]:
 | 
			
		||||
 | 
			
		||||
        if (
 | 
			
		||||
            not datasource
 | 
			
		||||
            and not datasource.startswith("dbfs")
 | 
			
		||||
 | 
			
		||||
@ -111,7 +111,6 @@ class ElasticsearchSource(SearchServiceSource):
 | 
			
		||||
        Method to Get Sample Data of Search Index Entity
 | 
			
		||||
        """
 | 
			
		||||
        if self.source_config.includeSampleData and self.context.search_index:
 | 
			
		||||
 | 
			
		||||
            sample_data = self.client.search(
 | 
			
		||||
                index=self.context.search_index,
 | 
			
		||||
                q=WILDCARD_SEARCH,
 | 
			
		||||
 | 
			
		||||
@ -14,8 +14,6 @@ Base class for ingesting Object Storage services
 | 
			
		||||
from abc import ABC, abstractmethod
 | 
			
		||||
from typing import Any, Iterable, List, Optional, Set
 | 
			
		||||
 | 
			
		||||
from pandas import DataFrame
 | 
			
		||||
 | 
			
		||||
from metadata.generated.schema.api.data.createContainer import CreateContainerRequest
 | 
			
		||||
from metadata.generated.schema.entity.data.container import Container
 | 
			
		||||
from metadata.generated.schema.entity.services.storageService import (
 | 
			
		||||
@ -53,7 +51,10 @@ from metadata.readers.dataframe.models import DatalakeTableSchemaWrapper
 | 
			
		||||
from metadata.readers.dataframe.reader_factory import SupportedTypes
 | 
			
		||||
from metadata.readers.models import ConfigSource
 | 
			
		||||
from metadata.utils import fqn
 | 
			
		||||
from metadata.utils.datalake.datalake_utils import fetch_dataframe, get_columns
 | 
			
		||||
from metadata.utils.datalake.datalake_utils import (
 | 
			
		||||
    DataFrameColumnParser,
 | 
			
		||||
    fetch_dataframe,
 | 
			
		||||
)
 | 
			
		||||
from metadata.utils.logger import ingestion_logger
 | 
			
		||||
from metadata.utils.storage_metadata_config import (
 | 
			
		||||
    StorageMetadataConfigException,
 | 
			
		||||
@ -67,7 +68,6 @@ OPENMETADATA_TEMPLATE_FILE_NAME = "openmetadata.json"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class StorageServiceTopology(ServiceTopology):
 | 
			
		||||
 | 
			
		||||
    root = TopologyNode(
 | 
			
		||||
        producer="get_services",
 | 
			
		||||
        stages=[
 | 
			
		||||
@ -271,10 +271,10 @@ class StorageServiceSource(TopologyRunnerMixin, Source, ABC):
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        columns = []
 | 
			
		||||
        if isinstance(data_structure_details, DataFrame):
 | 
			
		||||
            columns = get_columns(data_structure_details)
 | 
			
		||||
        if isinstance(data_structure_details, list) and data_structure_details:
 | 
			
		||||
            columns = get_columns(data_structure_details[0])
 | 
			
		||||
        column_parser = DataFrameColumnParser.create(
 | 
			
		||||
            data_structure_details, SupportedTypes(metadata_entry.structureFormat)
 | 
			
		||||
        )
 | 
			
		||||
        columns = column_parser.get_columns()
 | 
			
		||||
        return columns
 | 
			
		||||
 | 
			
		||||
    def _get_columns(
 | 
			
		||||
 | 
			
		||||
@ -34,7 +34,10 @@ from metadata.profiler.metrics.core import MetricTypes
 | 
			
		||||
from metadata.profiler.metrics.registry import Metrics
 | 
			
		||||
from metadata.readers.dataframe.models import DatalakeTableSchemaWrapper
 | 
			
		||||
from metadata.utils.constants import COMPLEX_COLUMN_SEPARATOR, SAMPLE_DATA_DEFAULT_COUNT
 | 
			
		||||
from metadata.utils.datalake.datalake_utils import fetch_col_types, fetch_dataframe
 | 
			
		||||
from metadata.utils.datalake.datalake_utils import (
 | 
			
		||||
    GenericDataFrameColumnParser,
 | 
			
		||||
    fetch_dataframe,
 | 
			
		||||
)
 | 
			
		||||
from metadata.utils.logger import profiler_interface_registry_logger
 | 
			
		||||
from metadata.utils.sqa_like_column import SQALikeColumn
 | 
			
		||||
 | 
			
		||||
@ -411,7 +414,9 @@ class PandasProfilerInterface(ProfilerInterface, PandasInterfaceMixin):
 | 
			
		||||
                sqalike_columns.append(
 | 
			
		||||
                    SQALikeColumn(
 | 
			
		||||
                        column_name,
 | 
			
		||||
                        fetch_col_types(self.complex_dataframe_sample[0], column_name),
 | 
			
		||||
                        GenericDataFrameColumnParser.fetch_col_types(
 | 
			
		||||
                            self.complex_dataframe_sample[0], column_name
 | 
			
		||||
                        ),
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
            return sqalike_columns
 | 
			
		||||
 | 
			
		||||
@ -30,6 +30,7 @@ from metadata.utils.logger import profiler_logger
 | 
			
		||||
 | 
			
		||||
logger = profiler_logger()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# pylint: disable=too-many-locals
 | 
			
		||||
class Histogram(HybridMetric):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
@ -126,7 +126,6 @@ class AbstractTableMetricComputer(ABC):
 | 
			
		||||
        table: Table,
 | 
			
		||||
        where_clause: Optional[List[ColumnOperators]] = None,
 | 
			
		||||
    ):
 | 
			
		||||
 | 
			
		||||
        query = select(*columns).select_from(table)
 | 
			
		||||
        if where_clause:
 | 
			
		||||
            query = query.where(*where_clause)
 | 
			
		||||
 | 
			
		||||
@ -40,7 +40,6 @@ class ProfilerProcessor(Processor):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config: OpenMetadataWorkflowConfig):
 | 
			
		||||
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.config = config
 | 
			
		||||
@ -56,7 +55,6 @@ class ProfilerProcessor(Processor):
 | 
			
		||||
        return "Profiler"
 | 
			
		||||
 | 
			
		||||
    def _run(self, record: ProfilerSourceAndEntity) -> Either[ProfilerResponse]:
 | 
			
		||||
 | 
			
		||||
        profiler_runner: Profiler = record.profiler_source.get_profiler_runner(
 | 
			
		||||
            record.entity, self.profiler_config
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -217,7 +217,6 @@ class ProfilerSource(ProfilerSourceInterface):
 | 
			
		||||
    def _get_context_entities(
 | 
			
		||||
        self, entity: Table
 | 
			
		||||
    ) -> Tuple[DatabaseSchema, Database, DatabaseService]:
 | 
			
		||||
 | 
			
		||||
        schema_entity = None
 | 
			
		||||
        database_entity = None
 | 
			
		||||
        db_service = None
 | 
			
		||||
 | 
			
		||||
@ -29,7 +29,6 @@ def bigquery_type_mapper(_type_map: dict, col: Column):
 | 
			
		||||
    from sqlalchemy_bigquery import STRUCT
 | 
			
		||||
 | 
			
		||||
    def build_struct(_type_map: dict, col: Column):
 | 
			
		||||
 | 
			
		||||
        structs = []
 | 
			
		||||
        for child in col.children:
 | 
			
		||||
            if child.dataType != DataType.STRUCT:
 | 
			
		||||
 | 
			
		||||
@ -35,7 +35,6 @@ class ApiReader(Reader, ABC):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, credentials: ReadersCredentials):
 | 
			
		||||
 | 
			
		||||
        self._auth_headers = None
 | 
			
		||||
        self.credentials = credentials
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -15,8 +15,9 @@ from different auths and different file systems.
 | 
			
		||||
"""
 | 
			
		||||
import ast
 | 
			
		||||
import json
 | 
			
		||||
import random
 | 
			
		||||
import traceback
 | 
			
		||||
from typing import Dict, List, Optional, cast
 | 
			
		||||
from typing import Dict, List, Optional, Union, cast
 | 
			
		||||
 | 
			
		||||
from metadata.generated.schema.entity.data.table import Column, DataType
 | 
			
		||||
from metadata.ingestion.source.database.column_helpers import truncate_column_name
 | 
			
		||||
@ -29,18 +30,6 @@ from metadata.utils.logger import utils_logger
 | 
			
		||||
 | 
			
		||||
logger = utils_logger()
 | 
			
		||||
 | 
			
		||||
DATALAKE_DATA_TYPES = {
 | 
			
		||||
    **dict.fromkeys(["int64", "int", "int32"], DataType.INT),
 | 
			
		||||
    "dict": DataType.JSON,
 | 
			
		||||
    "list": DataType.ARRAY,
 | 
			
		||||
    **dict.fromkeys(["float64", "float32", "float"], DataType.FLOAT),
 | 
			
		||||
    "bool": DataType.BOOLEAN,
 | 
			
		||||
    **dict.fromkeys(
 | 
			
		||||
        ["datetime64", "timedelta[ns]", "datetime64[ns]"], DataType.DATETIME
 | 
			
		||||
    ),
 | 
			
		||||
    "str": DataType.STRING,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fetch_dataframe(
 | 
			
		||||
    config_source,
 | 
			
		||||
@ -100,76 +89,105 @@ def get_file_format_type(key_name, metadata_entry=None):
 | 
			
		||||
    return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def unique_json_structure(dicts: List[Dict]) -> Dict:
 | 
			
		||||
    """Given a sample of `n` json objects, return a json object that represents the unique structure of all `n` objects.
 | 
			
		||||
    Note that the type of the key will be that of the last object seen in the sample.
 | 
			
		||||
# pylint: disable=import-outside-toplevel
 | 
			
		||||
class DataFrameColumnParser:
 | 
			
		||||
    """A column parser object. This serves as a Creator class for the appropriate column parser object parser
 | 
			
		||||
    for datalake types. It allows us to implement different schema parsers for different datalake types without
 | 
			
		||||
    implementing many conditionals statements.
 | 
			
		||||
 | 
			
		||||
    e.g. if we want to implement a column parser for parquet files, we can simply implement a
 | 
			
		||||
    ParquetDataFrameColumnParser class and add it as part of the `create` method. The `create` method will then return
 | 
			
		||||
    the appropriate parser based on the file type. The `ColumnParser` class has a single entry point `get_columns` which
 | 
			
		||||
    will call the `get_columns` method of the appropriate parser.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, parser):
 | 
			
		||||
        """Initialize the column parser object"""
 | 
			
		||||
        self.parser = parser
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def create(
 | 
			
		||||
        cls,
 | 
			
		||||
        data_frame: "DataFrame",
 | 
			
		||||
        file_type: Optional[SupportedTypes] = None,
 | 
			
		||||
        sample: bool = True,
 | 
			
		||||
        shuffle: bool = False,
 | 
			
		||||
    ):
 | 
			
		||||
        """Instantiate a column parser object with the appropriate parser
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
        dicts: list of json objects
 | 
			
		||||
            data_frame: the dataframe object
 | 
			
		||||
            file_type: the file type of the dataframe. Will be used to determine the appropriate parser.
 | 
			
		||||
            sample: whether to sample the dataframe or not if we have a list of dataframes.
 | 
			
		||||
                If sample is False, we will concatenate the dataframes, which can be cause OOM error for large dataset.
 | 
			
		||||
                (default: True)
 | 
			
		||||
            shuffle: whether to shuffle the dataframe list or not if sample is True. (default: False)
 | 
			
		||||
        """
 | 
			
		||||
    result = {}
 | 
			
		||||
    for dict_ in dicts:
 | 
			
		||||
        for key, value in dict_.items():
 | 
			
		||||
            if isinstance(value, dict):
 | 
			
		||||
                nested_json = result.get(key, {})
 | 
			
		||||
                # `isinstance(nested_json, dict)` if for a key we first see a non dict value
 | 
			
		||||
                # but then see a dict value later, we will consider the key to be a dict.
 | 
			
		||||
                result[key] = unique_json_structure(
 | 
			
		||||
                    [nested_json if isinstance(nested_json, dict) else {}, value]
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                result[key] = value
 | 
			
		||||
    return result
 | 
			
		||||
        data_frame = cls._get_data_frame(data_frame, sample, shuffle)
 | 
			
		||||
        if file_type == SupportedTypes.PARQUET:
 | 
			
		||||
            parser = ParquetDataFrameColumnParser(data_frame)
 | 
			
		||||
            return cls(parser)
 | 
			
		||||
        parser = GenericDataFrameColumnParser(data_frame)
 | 
			
		||||
        return cls(parser)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _get_data_frame(
 | 
			
		||||
        data_frame: Union[List["DataFrame"], "DataFrame"], sample: bool, shuffle: bool
 | 
			
		||||
    ):
 | 
			
		||||
        """Return the dataframe to use for parsing"""
 | 
			
		||||
        import pandas as pd
 | 
			
		||||
 | 
			
		||||
        if not isinstance(data_frame, list):
 | 
			
		||||
            return data_frame
 | 
			
		||||
 | 
			
		||||
        if sample:
 | 
			
		||||
            if shuffle:
 | 
			
		||||
                random.shuffle(data_frame)
 | 
			
		||||
            return data_frame[0]
 | 
			
		||||
 | 
			
		||||
        return pd.concat(data_frame)
 | 
			
		||||
 | 
			
		||||
    def get_columns(self):
 | 
			
		||||
        """Get the columns from the parser"""
 | 
			
		||||
        return self.parser.get_columns()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def construct_json_column_children(json_column: Dict) -> List[Dict]:
 | 
			
		||||
    """Construt a dict representation of a Column object
 | 
			
		||||
class GenericDataFrameColumnParser:
 | 
			
		||||
    """Given a dataframe object, parse the columns and return a list of Column objects.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        json_column: unique json structure of a column
 | 
			
		||||
    # TODO: We should consider making the function above part of the `GenericDataFrameColumnParser` class
 | 
			
		||||
    # though we need to do a thorough overview of where they are used to ensure unnecessary coupling.
 | 
			
		||||
    """
 | 
			
		||||
    children = []
 | 
			
		||||
    for key, value in json_column.items():
 | 
			
		||||
        column = {}
 | 
			
		||||
        type_ = type(value).__name__.lower()
 | 
			
		||||
        column["dataTypeDisplay"] = DATALAKE_DATA_TYPES.get(
 | 
			
		||||
            type_, DataType.UNKNOWN
 | 
			
		||||
        ).value
 | 
			
		||||
        column["dataType"] = DATALAKE_DATA_TYPES.get(type_, DataType.UNKNOWN).value
 | 
			
		||||
        column["name"] = truncate_column_name(key)
 | 
			
		||||
        column["displayName"] = key
 | 
			
		||||
        if isinstance(value, dict):
 | 
			
		||||
            column["children"] = construct_json_column_children(value)
 | 
			
		||||
        children.append(column)
 | 
			
		||||
 | 
			
		||||
    return children
 | 
			
		||||
    _data_formats = {
 | 
			
		||||
        **dict.fromkeys(["int64", "int", "int32"], DataType.INT),
 | 
			
		||||
        "dict": DataType.JSON,
 | 
			
		||||
        "list": DataType.ARRAY,
 | 
			
		||||
        **dict.fromkeys(["float64", "float32", "float"], DataType.FLOAT),
 | 
			
		||||
        "bool": DataType.BOOLEAN,
 | 
			
		||||
        **dict.fromkeys(
 | 
			
		||||
            ["datetime64", "timedelta[ns]", "datetime64[ns]"], DataType.DATETIME
 | 
			
		||||
        ),
 | 
			
		||||
        "str": DataType.STRING,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    def __init__(self, data_frame: "DataFrame"):
 | 
			
		||||
        self.data_frame = data_frame
 | 
			
		||||
 | 
			
		||||
def get_children(json_column) -> List[Dict]:
 | 
			
		||||
    """Get children of json column.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        json_column (pandas.Series): column with 100 sample rows.
 | 
			
		||||
            Sample rows will be used to infer children.
 | 
			
		||||
    """
 | 
			
		||||
    from pandas import Series  # pylint: disable=import-outside-toplevel
 | 
			
		||||
 | 
			
		||||
    json_column = cast(Series, json_column)
 | 
			
		||||
    try:
 | 
			
		||||
        json_column = json_column.apply(json.loads)
 | 
			
		||||
    except TypeError:
 | 
			
		||||
        # if values are not strings, we will assume they are already json objects
 | 
			
		||||
        # based on the read class logic
 | 
			
		||||
        pass
 | 
			
		||||
    json_structure = unique_json_structure(json_column.values.tolist())
 | 
			
		||||
 | 
			
		||||
    return construct_json_column_children(json_structure)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_columns(data_frame: "DataFrame"):
 | 
			
		||||
    def get_columns(self):
 | 
			
		||||
        """
 | 
			
		||||
        method to process column details
 | 
			
		||||
        """
 | 
			
		||||
        return self._get_columns(self.data_frame)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _get_columns(cls, data_frame: "DataFrame"):
 | 
			
		||||
        """
 | 
			
		||||
        method to process column details.
 | 
			
		||||
 | 
			
		||||
        Note this was move from a function to a class method to bring it closer to the
 | 
			
		||||
        `GenericDataFrameColumnParser` class. Should be rethought as part of the TODO.
 | 
			
		||||
        """
 | 
			
		||||
        cols = []
 | 
			
		||||
        if hasattr(data_frame, "columns"):
 | 
			
		||||
            df_columns = list(data_frame.columns)
 | 
			
		||||
@ -178,7 +196,7 @@ def get_columns(data_frame: "DataFrame"):
 | 
			
		||||
                data_type = DataType.STRING
 | 
			
		||||
                try:
 | 
			
		||||
                    if hasattr(data_frame[column], "dtypes"):
 | 
			
		||||
                    data_type = fetch_col_types(data_frame, column_name=column)
 | 
			
		||||
                        data_type = cls.fetch_col_types(data_frame, column_name=column)
 | 
			
		||||
 | 
			
		||||
                    parsed_string = {
 | 
			
		||||
                        "dataTypeDisplay": data_type.value,
 | 
			
		||||
@ -190,20 +208,25 @@ def get_columns(data_frame: "DataFrame"):
 | 
			
		||||
                        parsed_string["arrayDataType"] = DataType.UNKNOWN
 | 
			
		||||
 | 
			
		||||
                    if data_type == DataType.JSON:
 | 
			
		||||
                    parsed_string["children"] = get_children(
 | 
			
		||||
                        parsed_string["children"] = cls.get_children(
 | 
			
		||||
                            data_frame[column].dropna()[:100]
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
                    cols.append(Column(**parsed_string))
 | 
			
		||||
                except Exception as exc:
 | 
			
		||||
                    logger.debug(traceback.format_exc())
 | 
			
		||||
                logger.warning(f"Unexpected exception parsing column [{column}]: {exc}")
 | 
			
		||||
                    logger.warning(
 | 
			
		||||
                        f"Unexpected exception parsing column [{column}]: {exc}"
 | 
			
		||||
                    )
 | 
			
		||||
        return cols
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fetch_col_types(data_frame, column_name):
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def fetch_col_types(cls, data_frame, column_name):
 | 
			
		||||
        """fetch_col_types: Fetch Column Type for the c
 | 
			
		||||
 | 
			
		||||
        Note this was move from a function to a class method to bring it closer to the
 | 
			
		||||
        `GenericDataFrameColumnParser` class. Should be rethought as part of the TODO.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            data_frame (DataFrame)
 | 
			
		||||
            column_name (string)
 | 
			
		||||
@ -223,7 +246,7 @@ def fetch_col_types(data_frame, column_name):
 | 
			
		||||
                    # Handle any exceptions that may occur
 | 
			
		||||
                    data_type = "string"
 | 
			
		||||
 | 
			
		||||
        data_type = DATALAKE_DATA_TYPES.get(
 | 
			
		||||
            data_type = cls._data_formats.get(
 | 
			
		||||
                data_type or data_frame[column_name].dtypes.name, DataType.STRING
 | 
			
		||||
            )
 | 
			
		||||
        except Exception as err:
 | 
			
		||||
@ -232,3 +255,213 @@ def fetch_col_types(data_frame, column_name):
 | 
			
		||||
            )
 | 
			
		||||
            logger.debug(traceback.format_exc())
 | 
			
		||||
        return data_type
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def unique_json_structure(cls, dicts: List[Dict]) -> Dict:
 | 
			
		||||
        """Given a sample of `n` json objects, return a json object that represents the unique
 | 
			
		||||
        structure of all `n` objects. Note that the type of the key will be that of
 | 
			
		||||
        the last object seen in the sample.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            dicts: list of json objects
 | 
			
		||||
        """
 | 
			
		||||
        result = {}
 | 
			
		||||
        for dict_ in dicts:
 | 
			
		||||
            for key, value in dict_.items():
 | 
			
		||||
                if isinstance(value, dict):
 | 
			
		||||
                    nested_json = result.get(key, {})
 | 
			
		||||
                    # `isinstance(nested_json, dict)` if for a key we first see a non dict value
 | 
			
		||||
                    # but then see a dict value later, we will consider the key to be a dict.
 | 
			
		||||
                    result[key] = cls.unique_json_structure(
 | 
			
		||||
                        [nested_json if isinstance(nested_json, dict) else {}, value]
 | 
			
		||||
                    )
 | 
			
		||||
                else:
 | 
			
		||||
                    result[key] = value
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def construct_json_column_children(cls, json_column: Dict) -> List[Dict]:
 | 
			
		||||
        """Construt a dict representation of a Column object
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            json_column: unique json structure of a column
 | 
			
		||||
        """
 | 
			
		||||
        children = []
 | 
			
		||||
        for key, value in json_column.items():
 | 
			
		||||
            column = {}
 | 
			
		||||
            type_ = type(value).__name__.lower()
 | 
			
		||||
            column["dataTypeDisplay"] = cls._data_formats.get(
 | 
			
		||||
                type_, DataType.UNKNOWN
 | 
			
		||||
            ).value
 | 
			
		||||
            column["dataType"] = cls._data_formats.get(type_, DataType.UNKNOWN).value
 | 
			
		||||
            column["name"] = truncate_column_name(key)
 | 
			
		||||
            column["displayName"] = key
 | 
			
		||||
            if isinstance(value, dict):
 | 
			
		||||
                column["children"] = cls.construct_json_column_children(value)
 | 
			
		||||
            children.append(column)
 | 
			
		||||
 | 
			
		||||
        return children
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def get_children(cls, json_column) -> List[Dict]:
 | 
			
		||||
        """Get children of json column.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            json_column (pandas.Series): column with 100 sample rows.
 | 
			
		||||
                Sample rows will be used to infer children.
 | 
			
		||||
        """
 | 
			
		||||
        from pandas import Series  # pylint: disable=import-outside-toplevel
 | 
			
		||||
 | 
			
		||||
        json_column = cast(Series, json_column)
 | 
			
		||||
        try:
 | 
			
		||||
            json_column = json_column.apply(json.loads)
 | 
			
		||||
        except TypeError:
 | 
			
		||||
            # if values are not strings, we will assume they are already json objects
 | 
			
		||||
            # based on the read class logic
 | 
			
		||||
            pass
 | 
			
		||||
        json_structure = cls.unique_json_structure(json_column.values.tolist())
 | 
			
		||||
 | 
			
		||||
        return cls.construct_json_column_children(json_structure)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# pylint: disable=import-outside-toplevel
 | 
			
		||||
class ParquetDataFrameColumnParser:
 | 
			
		||||
    """Given a dataframe object generated from a parquet file, parse the columns and return a list of Column objects."""
 | 
			
		||||
 | 
			
		||||
    import pyarrow as pa
 | 
			
		||||
 | 
			
		||||
    _data_formats = {
 | 
			
		||||
        **dict.fromkeys(
 | 
			
		||||
            ["int8", "int16", "int32", "int64", "int", pa.DurationType], DataType.INT
 | 
			
		||||
        ),
 | 
			
		||||
        **dict.fromkeys(["uint8", "uint16", "uint32", "uint64", "uint"], DataType.UINT),
 | 
			
		||||
        pa.StructType: DataType.STRUCT,
 | 
			
		||||
        **dict.fromkeys([pa.ListType, pa.LargeListType], DataType.ARRAY),
 | 
			
		||||
        **dict.fromkeys(
 | 
			
		||||
            ["halffloat", "float32", "float64", "double", "float"], DataType.FLOAT
 | 
			
		||||
        ),
 | 
			
		||||
        "bool": DataType.BOOLEAN,
 | 
			
		||||
        **dict.fromkeys(
 | 
			
		||||
            [
 | 
			
		||||
                "datetime64",
 | 
			
		||||
                "timedelta[ns]",
 | 
			
		||||
                "datetime64[ns]",
 | 
			
		||||
                "time32[s]",
 | 
			
		||||
                "time32[ms]",
 | 
			
		||||
                "time64[ns]",
 | 
			
		||||
                "time64[us]",
 | 
			
		||||
                pa.TimestampType,
 | 
			
		||||
                "date64",
 | 
			
		||||
            ],
 | 
			
		||||
            DataType.DATETIME,
 | 
			
		||||
        ),
 | 
			
		||||
        "date32[day]": DataType.DATE,
 | 
			
		||||
        "string": DataType.STRING,
 | 
			
		||||
        **dict.fromkeys(
 | 
			
		||||
            ["binary", "large_binary", pa.FixedSizeBinaryType], DataType.BINARY
 | 
			
		||||
        ),
 | 
			
		||||
        **dict.fromkeys([pa.Decimal128Type, pa.Decimal256Type], DataType.DECIMAL),
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    def __init__(self, data_frame: "DataFrame"):
 | 
			
		||||
        import pyarrow as pa
 | 
			
		||||
 | 
			
		||||
        self.data_frame = data_frame
 | 
			
		||||
        self._arrow_table = pa.Table.from_pandas(self.data_frame)
 | 
			
		||||
 | 
			
		||||
    def get_columns(self):
 | 
			
		||||
        """
 | 
			
		||||
        method to process column details for parquet files
 | 
			
		||||
        """
 | 
			
		||||
        import pyarrow as pa
 | 
			
		||||
 | 
			
		||||
        schema: List[pa.Field] = self._arrow_table.schema
 | 
			
		||||
        columns = []
 | 
			
		||||
        for column in schema:
 | 
			
		||||
            parsed_column = {
 | 
			
		||||
                "dataTypeDisplay": str(column.type),
 | 
			
		||||
                "dataType": self._get_pq_data_type(column),
 | 
			
		||||
                "name": truncate_column_name(column.name),
 | 
			
		||||
                "displayName": column.name,
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            if parsed_column["dataType"] == DataType.ARRAY:
 | 
			
		||||
                try:
 | 
			
		||||
                    item_field = column.type.value_field
 | 
			
		||||
                    parsed_column["arrayDataType"] = self._get_pq_data_type(item_field)
 | 
			
		||||
                except AttributeError:
 | 
			
		||||
                    # if the value field is not specified, we will set it to UNKNOWN
 | 
			
		||||
                    parsed_column["arrayDataType"] = DataType.UNKNOWN
 | 
			
		||||
 | 
			
		||||
            if parsed_column["dataType"] == DataType.BINARY:
 | 
			
		||||
                try:
 | 
			
		||||
                    data_length = type(column.type).byte_width
 | 
			
		||||
                except AttributeError:
 | 
			
		||||
                    # if the byte width is not specified, we will set it to -1
 | 
			
		||||
                    # following pyarrow convention
 | 
			
		||||
                    data_length = -1
 | 
			
		||||
                parsed_column["dataLength"] = data_length
 | 
			
		||||
 | 
			
		||||
            if parsed_column["dataType"] == DataType.STRUCT:
 | 
			
		||||
                parsed_column["children"] = self._get_children(column)
 | 
			
		||||
            columns.append(Column(**parsed_column))
 | 
			
		||||
 | 
			
		||||
        return columns
 | 
			
		||||
 | 
			
		||||
    def _get_children(self, column):
 | 
			
		||||
        """For struct types, get the children of the column
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            column (pa.Field): pa column
 | 
			
		||||
        """
 | 
			
		||||
        field_idx = column.type.num_fields
 | 
			
		||||
 | 
			
		||||
        children = []
 | 
			
		||||
        for idx in range(field_idx):
 | 
			
		||||
            child = column.type.field(idx)
 | 
			
		||||
            data_type = self._get_pq_data_type(child)
 | 
			
		||||
 | 
			
		||||
            child_column = {
 | 
			
		||||
                "dataTypeDisplay": str(child.type),
 | 
			
		||||
                "dataType": data_type,
 | 
			
		||||
                "name": truncate_column_name(child.name),
 | 
			
		||||
                "displayName": child.name,
 | 
			
		||||
            }
 | 
			
		||||
            if data_type == DataType.STRUCT:
 | 
			
		||||
                child_column["children"] = self._get_children(child)
 | 
			
		||||
            children.append(child_column)
 | 
			
		||||
 | 
			
		||||
        return children
 | 
			
		||||
 | 
			
		||||
    def _get_pq_data_type(self, column):
 | 
			
		||||
        """Given a column return the type of the column
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            column (pa.Field): pa column
 | 
			
		||||
        """
 | 
			
		||||
        import pyarrow as pa
 | 
			
		||||
 | 
			
		||||
        if isinstance(
 | 
			
		||||
            column.type,
 | 
			
		||||
            (
 | 
			
		||||
                pa.DurationType,
 | 
			
		||||
                pa.StructType,
 | 
			
		||||
                pa.ListType,
 | 
			
		||||
                pa.LargeListType,
 | 
			
		||||
                pa.TimestampType,
 | 
			
		||||
                pa.Decimal128Type,
 | 
			
		||||
                pa.Decimal256Type,
 | 
			
		||||
                pa.FixedSizeBinaryType,
 | 
			
		||||
            ),
 | 
			
		||||
        ):
 | 
			
		||||
            # the above type can take many shape
 | 
			
		||||
            # (i.e. pa.ListType(pa.StructType([pa.column("a", pa.int64())])), etc,)
 | 
			
		||||
            # so we'll use their type to determine the data type
 | 
			
		||||
            data_type = self._data_formats.get(type(column.type), DataType.UNKNOWN)
 | 
			
		||||
        else:
 | 
			
		||||
            # for the other types we need to use their string representation
 | 
			
		||||
            # to determine the data type as `type(column.type)` will return
 | 
			
		||||
            # a generic `pyarrow.lib.DataType`
 | 
			
		||||
            data_type = self._data_formats.get(str(column.type), DataType.UNKNOWN)
 | 
			
		||||
 | 
			
		||||
        return data_type
 | 
			
		||||
 | 
			
		||||
@ -63,7 +63,6 @@ def _(provider: SecretsManagerProvider) -> Optional[AWSCredentials]:
 | 
			
		||||
 | 
			
		||||
@secrets_manager_client_loader.add(SecretsManagerClientLoader.env.value)
 | 
			
		||||
def _(provider: SecretsManagerProvider) -> Optional[AWSCredentials]:
 | 
			
		||||
 | 
			
		||||
    if provider in {
 | 
			
		||||
        SecretsManagerProvider.aws,
 | 
			
		||||
        SecretsManagerProvider.managed_aws,
 | 
			
		||||
 | 
			
		||||
@ -104,7 +104,6 @@ class SecretsManagerFactory(metaclass=Singleton):
 | 
			
		||||
        return self.secrets_manager
 | 
			
		||||
 | 
			
		||||
    def _load_secrets_manager_credentials(self) -> Optional["AWSCredentials"]:
 | 
			
		||||
 | 
			
		||||
        if not self.secrets_manager_loader:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -80,7 +80,6 @@ class ApplicationWorkflow(BaseWorkflow, ABC):
 | 
			
		||||
    runner: Optional[AppRunner]
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config_dict: dict):
 | 
			
		||||
 | 
			
		||||
        self.runner = None  # Will be passed in post-init
 | 
			
		||||
        # TODO: Create a parse_gracefully method
 | 
			
		||||
        self.config = OpenMetadataApplicationConfig.parse_obj(config_dict)
 | 
			
		||||
 | 
			
		||||
@ -235,7 +235,6 @@ class BaseWorkflow(ABC, WorkflowStatusMixin):
 | 
			
		||||
            service = self._get_ingestion_pipeline_service()
 | 
			
		||||
 | 
			
		||||
            if service is not None:
 | 
			
		||||
 | 
			
		||||
                return self.metadata.create_or_update(
 | 
			
		||||
                    CreateIngestionPipelineRequest(
 | 
			
		||||
                        name=pipeline_name,
 | 
			
		||||
 | 
			
		||||
@ -31,7 +31,6 @@ class MetadataWorkflow(IngestionWorkflow):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def set_steps(self):
 | 
			
		||||
 | 
			
		||||
        # We keep the source registered in the workflow
 | 
			
		||||
        self.source = self._get_source()
 | 
			
		||||
        sink = self._get_sink()
 | 
			
		||||
 | 
			
		||||
@ -33,7 +33,6 @@ class UsageWorkflow(IngestionWorkflow):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def set_steps(self):
 | 
			
		||||
 | 
			
		||||
        # We keep the source registered in the workflow
 | 
			
		||||
        self.source = self._get_source()
 | 
			
		||||
        processor = self._get_processor()
 | 
			
		||||
 | 
			
		||||
@ -20,7 +20,6 @@ from .common_e2e_sqa_mixins import SQACommonMethods
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HiveCliTest(CliCommonDB.TestSuite, SQACommonMethods):
 | 
			
		||||
 | 
			
		||||
    prepare_e2e: List[str] = [
 | 
			
		||||
        "DROP DATABASE IF EXISTS e2e_cli_tests CASCADE",
 | 
			
		||||
        "CREATE DATABASE e2e_cli_tests",
 | 
			
		||||
 | 
			
		||||
@ -20,7 +20,6 @@ from .common.test_cli_dashboard import CliCommonDashboard
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MetabaseCliTest(CliCommonDashboard.TestSuite):
 | 
			
		||||
 | 
			
		||||
    # in case we want to do something before running the tests
 | 
			
		||||
    def prepare(self) -> None:
 | 
			
		||||
        redshift_file_path = str(
 | 
			
		||||
 | 
			
		||||
@ -23,7 +23,6 @@ from .common_e2e_sqa_mixins import SQACommonMethods
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class OracleCliTest(CliCommonDB.TestSuite, SQACommonMethods):
 | 
			
		||||
 | 
			
		||||
    create_table_query: str = """
 | 
			
		||||
       CREATE TABLE admin.admin_emp (
 | 
			
		||||
         empno      NUMBER(5) PRIMARY KEY,
 | 
			
		||||
 | 
			
		||||
@ -20,7 +20,6 @@ from .common.test_cli_dashboard import CliCommonDashboard
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PowerBICliTest(CliCommonDashboard.TestSuite):
 | 
			
		||||
 | 
			
		||||
    # in case we want to do something before running the tests
 | 
			
		||||
    def prepare(self) -> None:
 | 
			
		||||
        redshift_file_path = str(
 | 
			
		||||
 | 
			
		||||
@ -20,7 +20,6 @@ from .common.test_cli_dashboard import CliCommonDashboard
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TableauCliTest(CliCommonDashboard.TestSuite):
 | 
			
		||||
 | 
			
		||||
    # in case we want to do something before running the tests
 | 
			
		||||
    def prepare(self) -> None:
 | 
			
		||||
        redshift_file_path = str(
 | 
			
		||||
 | 
			
		||||
@ -157,7 +157,6 @@ class TestAirflowLineageRuner(TestCase):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_lineage_runner(self):
 | 
			
		||||
 | 
			
		||||
        with DAG("test_runner", start_date=datetime(2021, 1, 1)) as dag:
 | 
			
		||||
            BashOperator(
 | 
			
		||||
                task_id="print_date",
 | 
			
		||||
 | 
			
		||||
@ -300,7 +300,6 @@ def get_create_test_case(
 | 
			
		||||
def get_test_dag(name: str) -> DAG:
 | 
			
		||||
    """Get a DAG with the tasks created in the CreatePipelineRequest"""
 | 
			
		||||
    with DAG(name, start_date=datetime(2021, 1, 1)) as dag:
 | 
			
		||||
 | 
			
		||||
        tasks = [
 | 
			
		||||
            BashOperator(
 | 
			
		||||
                task_id=task_id,
 | 
			
		||||
 | 
			
		||||
@ -52,7 +52,6 @@ class TestSecretsManagerFactory(TestCase):
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def test_invalid_config_secret_manager(self):
 | 
			
		||||
 | 
			
		||||
        om_connection: OpenMetadataConnection = self.build_open_metadata_connection(
 | 
			
		||||
            SecretsManagerProvider.db,
 | 
			
		||||
            SecretsManagerClientLoader.noop,
 | 
			
		||||
 | 
			
		||||
@ -40,7 +40,6 @@ class ColumnNameScannerTest(TestCase):
 | 
			
		||||
        self.assertIsNone(ColumnNameScanner.scan("user_id"))
 | 
			
		||||
 | 
			
		||||
    def test_column_names_sensitive(self):
 | 
			
		||||
 | 
			
		||||
        # Bank
 | 
			
		||||
        self.assertEqual(ColumnNameScanner.scan("bank_account"), EXPECTED_SENSITIVE)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -62,7 +62,6 @@ class FakeConnection:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PandasInterfaceTest(TestCase):
 | 
			
		||||
 | 
			
		||||
    import pandas as pd
 | 
			
		||||
 | 
			
		||||
    col_names = [
 | 
			
		||||
 | 
			
		||||
@ -328,7 +328,6 @@ class ProfilerInterfaceTest(TestCase):
 | 
			
		||||
        self.assertEqual(50, actual)
 | 
			
		||||
 | 
			
		||||
    def test_table_config_casting(self):
 | 
			
		||||
 | 
			
		||||
        expected = TableConfig(
 | 
			
		||||
            profileSample=200,
 | 
			
		||||
            profileSampleType=ProfileSampleType.PERCENTAGE,
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								ingestion/tests/unit/resources/datalake/example.parquet
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								ingestion/tests/unit/resources/datalake/example.parquet
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							@ -18,7 +18,7 @@ from unittest import TestCase
 | 
			
		||||
 | 
			
		||||
from metadata.generated.schema.entity.data.table import DataType
 | 
			
		||||
from metadata.ingestion.source.database.column_type_parser import ColumnTypeParser
 | 
			
		||||
from metadata.utils.datalake.datalake_utils import fetch_col_types
 | 
			
		||||
from metadata.utils.datalake.datalake_utils import GenericDataFrameColumnParser
 | 
			
		||||
 | 
			
		||||
COLUMN_TYPE_PARSE = [
 | 
			
		||||
    "array<string>",
 | 
			
		||||
@ -129,4 +129,6 @@ def test_check_datalake_type():
 | 
			
		||||
    }
 | 
			
		||||
    df = pd.read_csv("ingestion/tests/unit/test_column_type_parser.csv")
 | 
			
		||||
    for column_name in df.columns.values.tolist():
 | 
			
		||||
        assert assert_col_type_dict.get(column_name) == fetch_col_types(df, column_name)
 | 
			
		||||
        assert assert_col_type_dict.get(
 | 
			
		||||
            column_name
 | 
			
		||||
        ) == GenericDataFrameColumnParser.fetch_col_types(df, column_name)
 | 
			
		||||
 | 
			
		||||
@ -622,7 +622,6 @@ class PGSpiderLineageUnitTests(TestCase):
 | 
			
		||||
                connection=self.postgres.service_connection,
 | 
			
		||||
                service_name=self.postgres.config.serviceName,
 | 
			
		||||
            ):
 | 
			
		||||
 | 
			
		||||
                if isinstance(record, AddLineageRequest):
 | 
			
		||||
                    requests.append(record)
 | 
			
		||||
 | 
			
		||||
@ -661,7 +660,6 @@ class PGSpiderLineageUnitTests(TestCase):
 | 
			
		||||
                connection=self.postgres.service_connection,
 | 
			
		||||
                service_name=self.postgres.config.serviceName,
 | 
			
		||||
            ):
 | 
			
		||||
 | 
			
		||||
                if isinstance(record, AddLineageRequest):
 | 
			
		||||
                    requests.append(record)
 | 
			
		||||
 | 
			
		||||
@ -700,7 +698,6 @@ class PGSpiderLineageUnitTests(TestCase):
 | 
			
		||||
                connection=self.postgres.service_connection,
 | 
			
		||||
                service_name=self.postgres.config.serviceName,
 | 
			
		||||
            ):
 | 
			
		||||
 | 
			
		||||
                if isinstance(record, AddLineageRequest):
 | 
			
		||||
                    requests.append(record)
 | 
			
		||||
 | 
			
		||||
@ -738,7 +735,6 @@ class PGSpiderLineageUnitTests(TestCase):
 | 
			
		||||
                connection=self.postgres.service_connection,
 | 
			
		||||
                service_name=self.postgres.config.serviceName,
 | 
			
		||||
            ):
 | 
			
		||||
 | 
			
		||||
                if isinstance(record, AddLineageRequest):
 | 
			
		||||
                    requests.append(record)
 | 
			
		||||
 | 
			
		||||
@ -773,7 +769,6 @@ class PGSpiderLineageUnitTests(TestCase):
 | 
			
		||||
                connection=self.postgres.service_connection,
 | 
			
		||||
                service_name=self.postgres.config.serviceName,
 | 
			
		||||
            ):
 | 
			
		||||
 | 
			
		||||
                if isinstance(record, AddLineageRequest):
 | 
			
		||||
                    requests.append(record)
 | 
			
		||||
 | 
			
		||||
@ -809,7 +804,6 @@ class PGSpiderLineageUnitTests(TestCase):
 | 
			
		||||
                connection=self.postgres.service_connection,
 | 
			
		||||
                service_name=self.postgres.config.serviceName,
 | 
			
		||||
            ):
 | 
			
		||||
 | 
			
		||||
                if isinstance(record, AddLineageRequest):
 | 
			
		||||
                    requests.append(record)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -33,7 +33,7 @@ from metadata.generated.schema.type.entityReference import EntityReference
 | 
			
		||||
from metadata.ingestion.source.database.datalake.metadata import DatalakeSource
 | 
			
		||||
from metadata.readers.dataframe.avro import AvroDataFrameReader
 | 
			
		||||
from metadata.readers.dataframe.json import JSONDataFrameReader
 | 
			
		||||
from metadata.utils.datalake.datalake_utils import get_columns
 | 
			
		||||
from metadata.utils.datalake.datalake_utils import GenericDataFrameColumnParser
 | 
			
		||||
 | 
			
		||||
mock_datalake_config = {
 | 
			
		||||
    "source": {
 | 
			
		||||
@ -459,13 +459,17 @@ class DatalakeUnitTest(TestCase):
 | 
			
		||||
        actual_df_3 = JSONDataFrameReader.read_from_json(
 | 
			
		||||
            key="file.json", json_text=EXAMPLE_JSON_TEST_3, decode=True
 | 
			
		||||
        )[0]
 | 
			
		||||
        actual_cols_3 = get_columns(actual_df_3)
 | 
			
		||||
        actual_cols_3 = GenericDataFrameColumnParser._get_columns(
 | 
			
		||||
            actual_df_3
 | 
			
		||||
        )  # pylint: disable=protected-access
 | 
			
		||||
        assert actual_cols_3 == EXAMPLE_JSON_COL_3
 | 
			
		||||
 | 
			
		||||
        actual_df_4 = JSONDataFrameReader.read_from_json(
 | 
			
		||||
            key="file.json", json_text=EXAMPLE_JSON_TEST_4, decode=True
 | 
			
		||||
        )[0]
 | 
			
		||||
        actual_cols_4 = get_columns(actual_df_4)
 | 
			
		||||
        actual_cols_4 = GenericDataFrameColumnParser._get_columns(
 | 
			
		||||
            actual_df_4
 | 
			
		||||
        )  # pylint: disable=protected-access
 | 
			
		||||
        assert actual_cols_4 == EXAMPLE_JSON_COL_4
 | 
			
		||||
 | 
			
		||||
    def test_avro_file_parse(self):
 | 
			
		||||
 | 
			
		||||
@ -641,7 +641,6 @@ class IcebergUnitTest(TestCase):
 | 
			
		||||
        with patch.object(
 | 
			
		||||
            HiveCatalog, "list_tables", return_value=MOCK_TABLE_LIST
 | 
			
		||||
        ), patch.object(HiveCatalog, "load_table", return_value=LoadTableMock()):
 | 
			
		||||
 | 
			
		||||
            for i, table in enumerate(self.iceberg.get_tables_name_and_type()):
 | 
			
		||||
                self.assertEqual(table, EXPECTED_TABLE_LIST[i])
 | 
			
		||||
 | 
			
		||||
@ -655,7 +654,6 @@ class IcebergUnitTest(TestCase):
 | 
			
		||||
        ), patch.object(
 | 
			
		||||
            HiveCatalog, "load_table", side_effect=raise_no_such_iceberg_table
 | 
			
		||||
        ):
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(len(list(self.iceberg.get_tables_name_and_type())), 0)
 | 
			
		||||
 | 
			
		||||
        # When pyiceberg.exceptions.NoSuchTableError is raised
 | 
			
		||||
@ -666,7 +664,6 @@ class IcebergUnitTest(TestCase):
 | 
			
		||||
        with patch.object(
 | 
			
		||||
            HiveCatalog, "list_tables", return_value=MOCK_TABLE_LIST
 | 
			
		||||
        ), patch.object(HiveCatalog, "load_table", side_effect=raise_no_such_table):
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(len(list(self.iceberg.get_tables_name_and_type())), 0)
 | 
			
		||||
 | 
			
		||||
    def test_get_owner_ref(self):
 | 
			
		||||
@ -802,7 +799,6 @@ class IcebergUnitTest(TestCase):
 | 
			
		||||
        with patch.object(
 | 
			
		||||
            OpenMetadata, "get_reference_by_email", return_value=ref
 | 
			
		||||
        ), patch.object(fqn, "build", return_value=fq_database_schema):
 | 
			
		||||
 | 
			
		||||
            result = next(self.iceberg.yield_table((table_name, table_type))).right
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(result, expected)
 | 
			
		||||
 | 
			
		||||
@ -192,7 +192,6 @@ class AirbyteUnitTest(TestCase):
 | 
			
		||||
        assert pipline == EXPECTED_CREATED_PIPELINES
 | 
			
		||||
 | 
			
		||||
    def test_pipeline_status(self):
 | 
			
		||||
 | 
			
		||||
        status = [
 | 
			
		||||
            either.right
 | 
			
		||||
            for either in self.airbyte.yield_pipeline_status(EXPECTED_ARIBYTE_DETAILS)
 | 
			
		||||
 | 
			
		||||
@ -308,7 +308,6 @@ class StorageUnitTest(TestCase):
 | 
			
		||||
                )
 | 
			
		||||
            ],
 | 
			
		||||
        ):
 | 
			
		||||
 | 
			
		||||
            Column.__eq__ = custom_column_compare
 | 
			
		||||
            self.assertListEqual(
 | 
			
		||||
                [
 | 
			
		||||
 | 
			
		||||
@ -12,12 +12,17 @@
 | 
			
		||||
Test datalake utils
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
from unittest import TestCase
 | 
			
		||||
 | 
			
		||||
from metadata.generated.schema.entity.data.table import Column
 | 
			
		||||
import pandas as pd
 | 
			
		||||
 | 
			
		||||
from metadata.generated.schema.entity.data.table import Column, DataType
 | 
			
		||||
from metadata.readers.dataframe.reader_factory import SupportedTypes
 | 
			
		||||
from metadata.utils.datalake.datalake_utils import (
 | 
			
		||||
    construct_json_column_children,
 | 
			
		||||
    unique_json_structure,
 | 
			
		||||
    DataFrameColumnParser,
 | 
			
		||||
    GenericDataFrameColumnParser,
 | 
			
		||||
    ParquetDataFrameColumnParser,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
STRUCTURE = {
 | 
			
		||||
@ -53,7 +58,7 @@ class TestDatalakeUtils(TestCase):
 | 
			
		||||
        ]
 | 
			
		||||
        expected = STRUCTURE
 | 
			
		||||
 | 
			
		||||
        actual = unique_json_structure(sample_data)
 | 
			
		||||
        actual = GenericDataFrameColumnParser.unique_json_structure(sample_data)
 | 
			
		||||
 | 
			
		||||
        self.assertDictEqual(expected, actual)
 | 
			
		||||
 | 
			
		||||
@ -153,14 +158,16 @@ class TestDatalakeUtils(TestCase):
 | 
			
		||||
                ],
 | 
			
		||||
            },
 | 
			
		||||
        ]
 | 
			
		||||
        actual = construct_json_column_children(STRUCTURE)
 | 
			
		||||
        actual = GenericDataFrameColumnParser.construct_json_column_children(STRUCTURE)
 | 
			
		||||
 | 
			
		||||
        for el in zip(expected, actual):
 | 
			
		||||
            self.assertDictEqual(el[0], el[1])
 | 
			
		||||
 | 
			
		||||
    def test_create_column_object(self):
 | 
			
		||||
        """test create column object fn"""
 | 
			
		||||
        formatted_column = construct_json_column_children(STRUCTURE)
 | 
			
		||||
        formatted_column = GenericDataFrameColumnParser.construct_json_column_children(
 | 
			
		||||
            STRUCTURE
 | 
			
		||||
        )
 | 
			
		||||
        column = {
 | 
			
		||||
            "dataTypeDisplay": "STRING",
 | 
			
		||||
            "dataType": "STRING",
 | 
			
		||||
@ -170,3 +177,270 @@ class TestDatalakeUtils(TestCase):
 | 
			
		||||
        }
 | 
			
		||||
        column_obj = Column(**column)
 | 
			
		||||
        assert len(column_obj.children) == 3
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestParquetDataFrameColumnParser(TestCase):
 | 
			
		||||
    """Test parquet dataframe column parser"""
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def setUpClass(cls) -> None:
 | 
			
		||||
        resources_path = os.path.join(
 | 
			
		||||
            os.path.dirname(os.path.dirname(__file__)), "resources"
 | 
			
		||||
        )
 | 
			
		||||
        cls.parquet_path = os.path.join(resources_path, "datalake", "example.parquet")
 | 
			
		||||
 | 
			
		||||
        cls.df = pd.read_parquet(cls.parquet_path)
 | 
			
		||||
 | 
			
		||||
        cls.parquet_parser = ParquetDataFrameColumnParser(cls.df)
 | 
			
		||||
 | 
			
		||||
    def test_parser_instantiation(self):
 | 
			
		||||
        """Test the right parser is instantiated from the creator method"""
 | 
			
		||||
        parquet_parser = DataFrameColumnParser.create(self.df, SupportedTypes.PARQUET)
 | 
			
		||||
        self.assertIsInstance(parquet_parser.parser, ParquetDataFrameColumnParser)
 | 
			
		||||
 | 
			
		||||
        other_types = [typ for typ in SupportedTypes if typ != SupportedTypes.PARQUET]
 | 
			
		||||
        for other_type in other_types:
 | 
			
		||||
            with self.subTest(other_type=other_type):
 | 
			
		||||
                generic_parser = DataFrameColumnParser.create(self.df, other_type)
 | 
			
		||||
                self.assertIsInstance(
 | 
			
		||||
                    generic_parser.parser, GenericDataFrameColumnParser
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
    def test_shuffle_and_sample_from_parser(self):
 | 
			
		||||
        """test the shuffle and sampling logic from the parser creator method"""
 | 
			
		||||
        parquet_parser = DataFrameColumnParser.create(self.df, SupportedTypes.PARQUET)
 | 
			
		||||
        self.assertEqual(parquet_parser.parser.data_frame.shape, self.df.shape)
 | 
			
		||||
 | 
			
		||||
        parquet_parser = DataFrameColumnParser.create(
 | 
			
		||||
            [self.df, self.df], SupportedTypes.PARQUET
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(parquet_parser.parser.data_frame.shape, self.df.shape)
 | 
			
		||||
 | 
			
		||||
        parquet_parser = DataFrameColumnParser.create(
 | 
			
		||||
            [self.df, self.df], SupportedTypes.PARQUET, sample=False
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            parquet_parser.parser.data_frame.shape, pd.concat([self.df, self.df]).shape
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_get_columns(self):
 | 
			
		||||
        """test `get_columns` method of the parquet column parser"""
 | 
			
		||||
        expected = [
 | 
			
		||||
            Column(
 | 
			
		||||
                dataTypeDisplay="bool",
 | 
			
		||||
                dataType=DataType.BOOLEAN,
 | 
			
		||||
                name="a",
 | 
			
		||||
                displayName="a",
 | 
			
		||||
            ),  # type: ignore
 | 
			
		||||
            Column(
 | 
			
		||||
                dataTypeDisplay="int8",
 | 
			
		||||
                dataType=DataType.INT,
 | 
			
		||||
                name="b",
 | 
			
		||||
                displayName="b",
 | 
			
		||||
            ),  # type: ignore
 | 
			
		||||
            Column(
 | 
			
		||||
                dataTypeDisplay="int16",
 | 
			
		||||
                dataType=DataType.INT,
 | 
			
		||||
                name="c",
 | 
			
		||||
                displayName="c",
 | 
			
		||||
            ),  # type: ignore
 | 
			
		||||
            Column(
 | 
			
		||||
                dataTypeDisplay="int32",
 | 
			
		||||
                dataType=DataType.INT,
 | 
			
		||||
                name="d",
 | 
			
		||||
                displayName="d",
 | 
			
		||||
            ),  # type: ignore
 | 
			
		||||
            Column(
 | 
			
		||||
                dataTypeDisplay="int64",
 | 
			
		||||
                dataType=DataType.INT,
 | 
			
		||||
                name="e",
 | 
			
		||||
                displayName="e",
 | 
			
		||||
            ),  # type: ignore
 | 
			
		||||
            Column(
 | 
			
		||||
                dataTypeDisplay="uint8",
 | 
			
		||||
                dataType=DataType.UINT,
 | 
			
		||||
                name="f",
 | 
			
		||||
                displayName="f",
 | 
			
		||||
            ),  # type: ignore
 | 
			
		||||
            Column(
 | 
			
		||||
                dataTypeDisplay="uint16",
 | 
			
		||||
                dataType=DataType.UINT,
 | 
			
		||||
                name="g",
 | 
			
		||||
                displayName="g",
 | 
			
		||||
            ),  # type: ignore
 | 
			
		||||
            Column(
 | 
			
		||||
                dataTypeDisplay="uint32",
 | 
			
		||||
                dataType=DataType.UINT,
 | 
			
		||||
                name="h",
 | 
			
		||||
                displayName="h",
 | 
			
		||||
            ),  # type: ignore
 | 
			
		||||
            Column(
 | 
			
		||||
                dataTypeDisplay="uint64",
 | 
			
		||||
                dataType=DataType.UINT,
 | 
			
		||||
                name="i",
 | 
			
		||||
                displayName="i",
 | 
			
		||||
            ),  # type: ignore
 | 
			
		||||
            Column(
 | 
			
		||||
                dataTypeDisplay="float",
 | 
			
		||||
                dataType=DataType.FLOAT,
 | 
			
		||||
                name="k",
 | 
			
		||||
                displayName="k",
 | 
			
		||||
            ),  # type: ignore
 | 
			
		||||
            Column(
 | 
			
		||||
                dataTypeDisplay="double",
 | 
			
		||||
                dataType=DataType.FLOAT,
 | 
			
		||||
                name="l",
 | 
			
		||||
                displayName="l",
 | 
			
		||||
            ),  # type: ignore
 | 
			
		||||
            Column(
 | 
			
		||||
                dataTypeDisplay="time64[us]",
 | 
			
		||||
                dataType=DataType.DATETIME,
 | 
			
		||||
                name="n",
 | 
			
		||||
                displayName="n",
 | 
			
		||||
            ),  # type: ignore
 | 
			
		||||
            Column(
 | 
			
		||||
                dataTypeDisplay="timestamp[ns]",
 | 
			
		||||
                dataType=DataType.DATETIME,
 | 
			
		||||
                name="o",
 | 
			
		||||
                displayName="o",
 | 
			
		||||
            ),  # type: ignore
 | 
			
		||||
            Column(
 | 
			
		||||
                dataTypeDisplay="date32[day]",
 | 
			
		||||
                dataType=DataType.DATE,
 | 
			
		||||
                name="p",
 | 
			
		||||
                displayName="p",
 | 
			
		||||
            ),  # type: ignore
 | 
			
		||||
            Column(
 | 
			
		||||
                dataTypeDisplay="date32[day]",
 | 
			
		||||
                dataType=DataType.DATE,
 | 
			
		||||
                name="q",
 | 
			
		||||
                displayName="q",
 | 
			
		||||
            ),  # type: ignore
 | 
			
		||||
            Column(
 | 
			
		||||
                dataTypeDisplay="duration[ns]",
 | 
			
		||||
                dataType=DataType.INT,
 | 
			
		||||
                name="r",
 | 
			
		||||
                displayName="r",
 | 
			
		||||
            ),  # type: ignore
 | 
			
		||||
            Column(
 | 
			
		||||
                dataTypeDisplay="binary",
 | 
			
		||||
                dataType=DataType.BINARY,
 | 
			
		||||
                name="t",
 | 
			
		||||
                displayName="t",
 | 
			
		||||
            ),  # type: ignore
 | 
			
		||||
            Column(
 | 
			
		||||
                dataTypeDisplay="string",
 | 
			
		||||
                dataType=DataType.STRING,
 | 
			
		||||
                name="u",
 | 
			
		||||
                displayName="u",
 | 
			
		||||
            ),  # type: ignore
 | 
			
		||||
            Column(
 | 
			
		||||
                dataTypeDisplay="string",
 | 
			
		||||
                dataType=DataType.STRING,
 | 
			
		||||
                name="v",
 | 
			
		||||
                displayName="v",
 | 
			
		||||
            ),  # type: ignore
 | 
			
		||||
            Column(
 | 
			
		||||
                dataTypeDisplay="binary",
 | 
			
		||||
                dataType=DataType.BINARY,
 | 
			
		||||
                name="w",
 | 
			
		||||
                displayName="w",
 | 
			
		||||
            ),  # type: ignore
 | 
			
		||||
            Column(
 | 
			
		||||
                dataTypeDisplay="string",
 | 
			
		||||
                dataType=DataType.STRING,
 | 
			
		||||
                name="x",
 | 
			
		||||
                displayName="x",
 | 
			
		||||
            ),  # type: ignore
 | 
			
		||||
            Column(
 | 
			
		||||
                dataTypeDisplay="string",
 | 
			
		||||
                dataType=DataType.STRING,
 | 
			
		||||
                name="y",
 | 
			
		||||
                displayName="y",
 | 
			
		||||
            ),  # type: ignore
 | 
			
		||||
            Column(
 | 
			
		||||
                dataTypeDisplay="list<item: int64>",
 | 
			
		||||
                dataType=DataType.ARRAY,
 | 
			
		||||
                name="aa",
 | 
			
		||||
                displayName="aa",
 | 
			
		||||
            ),  # type: ignore
 | 
			
		||||
            Column(
 | 
			
		||||
                dataTypeDisplay="list<item: int64>",
 | 
			
		||||
                dataType=DataType.ARRAY,
 | 
			
		||||
                name="bb",
 | 
			
		||||
                displayName="bb",
 | 
			
		||||
            ),  # type: ignore
 | 
			
		||||
            Column(
 | 
			
		||||
                dataTypeDisplay="struct<ee: int64, ff: int64, gg: struct<hh: struct<ii: int64, jj: int64, kk: int64>>>",
 | 
			
		||||
                dataType=DataType.STRUCT,
 | 
			
		||||
                name="dd",
 | 
			
		||||
                displayName="dd",
 | 
			
		||||
                children=[
 | 
			
		||||
                    Column(
 | 
			
		||||
                        dataTypeDisplay="int64",
 | 
			
		||||
                        dataType=DataType.INT,
 | 
			
		||||
                        name="ee",
 | 
			
		||||
                        displayName="ee",
 | 
			
		||||
                    ),  # type: ignore
 | 
			
		||||
                    Column(
 | 
			
		||||
                        dataTypeDisplay="int64",
 | 
			
		||||
                        dataType=DataType.INT,
 | 
			
		||||
                        name="ff",
 | 
			
		||||
                        displayName="ff",
 | 
			
		||||
                    ),  # type: ignore
 | 
			
		||||
                    Column(
 | 
			
		||||
                        dataTypeDisplay="struct<hh: struct<ii: int64, jj: int64, kk: int64>>",
 | 
			
		||||
                        dataType=DataType.STRUCT,
 | 
			
		||||
                        name="gg",
 | 
			
		||||
                        displayName="gg",
 | 
			
		||||
                        children=[
 | 
			
		||||
                            Column(
 | 
			
		||||
                                dataTypeDisplay="struct<ii: int64, jj: int64, kk: int64>",
 | 
			
		||||
                                dataType=DataType.STRUCT,
 | 
			
		||||
                                name="hh",
 | 
			
		||||
                                displayName="hh",
 | 
			
		||||
                                children=[
 | 
			
		||||
                                    Column(
 | 
			
		||||
                                        dataTypeDisplay="int64",
 | 
			
		||||
                                        dataType=DataType.INT,
 | 
			
		||||
                                        name="ii",
 | 
			
		||||
                                        displayName="ii",
 | 
			
		||||
                                    ),  # type: ignore
 | 
			
		||||
                                    Column(
 | 
			
		||||
                                        dataTypeDisplay="int64",
 | 
			
		||||
                                        dataType=DataType.INT,
 | 
			
		||||
                                        name="jj",
 | 
			
		||||
                                        displayName="jj",
 | 
			
		||||
                                    ),  # type: ignore
 | 
			
		||||
                                    Column(
 | 
			
		||||
                                        dataTypeDisplay="int64",
 | 
			
		||||
                                        dataType=DataType.INT,
 | 
			
		||||
                                        name="kk",
 | 
			
		||||
                                        displayName="kk",
 | 
			
		||||
                                    ),  # type: ignore
 | 
			
		||||
                                ],
 | 
			
		||||
                            ),
 | 
			
		||||
                        ],
 | 
			
		||||
                    ),
 | 
			
		||||
                ],
 | 
			
		||||
            ),  # type: ignore
 | 
			
		||||
        ]
 | 
			
		||||
        actual = self.parquet_parser.get_columns()
 | 
			
		||||
        for validation in zip(expected, actual):
 | 
			
		||||
            with self.subTest(validation=validation):
 | 
			
		||||
                expected_col, actual_col = validation
 | 
			
		||||
                self.assertEqual(expected_col.name, actual_col.name)
 | 
			
		||||
                self.assertEqual(expected_col.displayName, actual_col.displayName)
 | 
			
		||||
                self.assertEqual(expected_col.dataType, actual_col.dataType)
 | 
			
		||||
 | 
			
		||||
    def _validate_parsed_column(self, expected, actual):
 | 
			
		||||
        """validate parsed column"""
 | 
			
		||||
        self.assertEqual(expected.name, actual.name)
 | 
			
		||||
        self.assertEqual(expected.dataType, actual.dataType)
 | 
			
		||||
        self.assertEqual(expected.displayName, actual.displayName)
 | 
			
		||||
        if expected.children:
 | 
			
		||||
            self.assertEqual(len(expected.children), len(actual.children))
 | 
			
		||||
            for validation in zip(expected.children, actual.children):
 | 
			
		||||
                with self.subTest(validation=validation):
 | 
			
		||||
                    expected_col, actual_col = validation
 | 
			
		||||
                    self._validate_parsed_column(expected_col, actual_col)
 | 
			
		||||
 | 
			
		||||
@ -110,7 +110,6 @@ def get_dagbag():
 | 
			
		||||
 | 
			
		||||
class ScanDagsTask(Process):
 | 
			
		||||
    def run(self):
 | 
			
		||||
 | 
			
		||||
        if airflow_version >= "2.6":
 | 
			
		||||
            scheduler_job = self._run_new_scheduler_job()
 | 
			
		||||
        else:
 | 
			
		||||
 | 
			
		||||
@ -149,7 +149,8 @@
 | 
			
		||||
        "LARGEINT",
 | 
			
		||||
        "QUANTILE_STATE",
 | 
			
		||||
        "AGG_STATE",
 | 
			
		||||
        "BITMAP"
 | 
			
		||||
        "BITMAP",
 | 
			
		||||
        "UINT"
 | 
			
		||||
      ]
 | 
			
		||||
    },
 | 
			
		||||
    "constraint": {
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user