Refactored Datalake and Deltalake for Topology (#6034)

* rebasing with main

* refactored deltalake for topology

* using requests instead of urllib

* formatting fixes

Co-authored-by: Onkar Ravgan <onkarravgan@Onkars-MacBook-Pro.local>
This commit is contained in:
Onkar Ravgan 2022-07-19 10:07:27 +05:30 committed by GitHub
parent e7dca141ed
commit 8c9dc91ccf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 390 additions and 261 deletions

View File

@ -25,6 +25,7 @@ except ImportError:
from typing_compat import get_args from typing_compat import get_args
from pydantic import BaseModel from pydantic import BaseModel
from requests.utils import quote
from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest
from metadata.generated.schema.entity.data.chart import Chart from metadata.generated.schema.entity.data.chart import Chart
@ -458,7 +459,11 @@ class OpenMetadata(
Return entity by name or None Return entity by name or None
""" """
return self._get(entity=entity, path=f"name/{model_str(fqn)}", fields=fields) return self._get(
entity=entity,
path=f"name/{quote(model_str(fqn), safe='')}",
fields=fields,
)
def get_by_id( def get_by_id(
self, self,

View File

@ -13,15 +13,17 @@
DataLake connector to fetch metadata from a files stored s3, gcs and Hdfs DataLake connector to fetch metadata from a files stored s3, gcs and Hdfs
""" """
import traceback import traceback
import uuid from typing import Iterable, Optional, Tuple
from typing import Iterable, Optional
from metadata.generated.schema.entity.data.database import Database from metadata.generated.schema.api.data.createDatabase import CreateDatabaseRequest
from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema from metadata.generated.schema.api.data.createDatabaseSchema import (
CreateDatabaseSchemaRequest,
)
from metadata.generated.schema.api.data.createTable import CreateTableRequest
from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest
from metadata.generated.schema.entity.data.table import ( from metadata.generated.schema.entity.data.table import (
Column, Column,
DataType, DataType,
Table,
TableData, TableData,
TableType, TableType,
) )
@ -42,10 +44,13 @@ from metadata.generated.schema.metadataIngestion.workflow import (
) )
from metadata.generated.schema.type.entityReference import EntityReference from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.api.common import Entity from metadata.ingestion.api.common import Entity
from metadata.ingestion.api.source import InvalidSourceException, Source, SourceStatus from metadata.ingestion.api.source import InvalidSourceException, SourceStatus
from metadata.ingestion.models.ometa_table_db import OMetaDatabaseAndTable from metadata.ingestion.models.ometa_tag_category import OMetaTagAndCategory
from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.database.common_db_source import SQLSourceStatus from metadata.ingestion.source.database.database_service import (
DatabaseServiceSource,
SQLSourceStatus,
)
from metadata.utils.connections import get_connection, test_connection from metadata.utils.connections import get_connection, test_connection
from metadata.utils.filters import filter_by_table from metadata.utils.filters import filter_by_table
from metadata.utils.gcs_utils import ( from metadata.utils.gcs_utils import (
@ -66,10 +71,11 @@ logger = ingestion_logger()
DATALAKE_INT_TYPES = {"int64", "INT"} DATALAKE_INT_TYPES = {"int64", "INT"}
DATALAKE_SUPPORTED_FILE_TYPES = (".csv", ".tsv", ".json", ".parquet")
class DatalakeSource(Source[Entity]):
class DatalakeSource(DatabaseServiceSource):
def __init__(self, config: WorkflowSource, metadata_config: OpenMetadataConnection): def __init__(self, config: WorkflowSource, metadata_config: OpenMetadataConnection):
super().__init__()
self.status = SQLSourceStatus() self.status = SQLSourceStatus()
self.config = config self.config = config
@ -84,8 +90,10 @@ class DatalakeSource(Source[Entity]):
) )
self.connection = get_connection(self.service_connection) self.connection = get_connection(self.service_connection)
self.client = self.connection.client self.client = self.connection.client
self.table_constraints = None
self.database_source_state = set()
super().__init__()
@classmethod @classmethod
def create(cls, config_dict, metadata_config: OpenMetadataConnection): def create(cls, config_dict, metadata_config: OpenMetadataConnection):
@ -97,146 +105,171 @@ class DatalakeSource(Source[Entity]):
) )
return cls(config, metadata_config) return cls(config, metadata_config)
def prepare(self): def get_database_names(self) -> Iterable[str]:
pass """
Default case with a single database.
def next_record(self) -> Iterable[Entity]: It might come informed - or not - from the source.
try:
bucket_name = self.service_connection.bucketName Sources with multiple databases should overwrite this and
prefix = self.service_connection.prefix apply the necessary filters.
"""
database_name = "default"
yield database_name
def yield_database(self, database_name: str) -> Iterable[CreateDatabaseRequest]:
"""
From topology.
Prepare a database request and pass it to the sink
"""
yield CreateDatabaseRequest(
name=database_name,
service=EntityReference(
id=self.context.database_service.id,
type="databaseService",
),
)
def get_database_schema_names(self) -> Iterable[str]:
"""
return schema names
"""
bucket_name = self.service_connection.bucketName
if isinstance(self.service_connection.configSource, GCSConfig):
if bucket_name:
yield bucket_name
else:
for bucket in self.client.list_buckets():
yield bucket.name
if isinstance(self.service_connection.configSource, S3Config):
if bucket_name:
yield bucket_name
else:
for bucket in self.client.list_buckets()["Buckets"]:
yield bucket["Name"]
def yield_database_schema(
self, schema_name: str
) -> Iterable[CreateDatabaseSchemaRequest]:
"""
From topology.
Prepare a database schema request and pass it to the sink
"""
yield CreateDatabaseSchemaRequest(
name=schema_name,
database=EntityReference(id=self.context.database.id, type="database"),
)
def get_tables_name_and_type(self) -> Optional[Iterable[Tuple[str, str]]]:
"""
Handle table and views.
Fetches them up using the context information and
the inspector set when preparing the db.
:return: tables or views, depending on config
"""
bucket_name = self.context.database_schema.name.__root__
prefix = self.service_connection.prefix
if self.source_config.includeTables:
if isinstance(self.service_connection.configSource, GCSConfig): if isinstance(self.service_connection.configSource, GCSConfig):
if bucket_name: bucket = self.client.get_bucket(bucket_name)
yield from self.get_gcs_files(bucket_name, prefix) for key in bucket.list_blobs(prefix=prefix):
else: if filter_by_table(
for bucket in self.client.list_buckets(): self.config.sourceConfig.config.tableFilterPattern, key.name
yield from self.get_gcs_files(bucket.name, prefix) ) or not self.check_valid_file_type(key.name):
self.status.filter(
"{}".format(key["Key"]),
"Table pattern not allowed",
)
continue
table_name = self.standardize_table_name(bucket_name, key.name)
yield table_name, TableType.Regular
if isinstance(self.service_connection.configSource, S3Config): if isinstance(self.service_connection.configSource, S3Config):
if bucket_name: kwargs = {"Bucket": bucket_name}
yield from self.get_s3_files(bucket_name, prefix) if prefix:
else: kwargs["Prefix"] = prefix if prefix.endswith("/") else f"{prefix}/"
for bucket in self.client.list_buckets()["Buckets"]: for key in self.client.list_objects(**kwargs)["Contents"]:
yield from self.get_s3_files(bucket["Name"], prefix) if filter_by_table(
self.config.sourceConfig.config.tableFilterPattern, key["Key"]
) or not self.check_valid_file_type(key["Key"]):
self.status.filter(
"{}".format(key["Key"]),
"Table pattern not allowed",
)
continue
table_name = self.standardize_table_name(bucket_name, key["Key"])
yield table_name, TableType.Regular
def yield_table(
self, table_name_and_type: Tuple[str, str]
) -> Iterable[Optional[CreateTableRequest]]:
"""
From topology.
Prepare a table request and pass it to the sink
"""
table_name, table_type = table_name_and_type
schema_name = self.context.database_schema.name.__root__
try:
table_constraints = None
if isinstance(self.service_connection.configSource, GCSConfig):
df = self.get_gcs_files(key=table_name, bucket_name=schema_name)
if isinstance(self.service_connection.configSource, S3Config):
df = self.get_s3_files(key=table_name, bucket_name=schema_name)
columns = self.get_columns(df)
table_request = CreateTableRequest(
name=table_name,
tableType=table_type,
description="",
columns=columns,
tableConstraints=table_constraints if table_constraints else None,
databaseSchema=EntityReference(
id=self.context.database_schema.id,
type="databaseSchema",
),
)
yield table_request
self.register_record(table_request=table_request)
except Exception as err: except Exception as err:
logger.error(traceback.format_exc()) logger.debug(traceback.format_exc())
logger.error(err)
self.status.failures.append(
"{}.{}".format(self.config.serviceName, table_name)
)
def get_gcs_files(self, key, bucket_name):
try:
if key.endswith(".csv"):
return read_csv_from_gcs(key, bucket_name)
if key.endswith(".tsv"):
return read_tsv_from_gcs(key, bucket_name)
if key.endswith(".json"):
return read_json_from_gcs(self.client, key, bucket_name)
if key.endswith(".parquet"):
return read_parquet_from_gcs(key, bucket_name)
except Exception as err:
logger.debug(traceback.format_exc())
logger.error(err) logger.error(err)
def get_gcs_files(self, bucket_name, prefix): def get_s3_files(self, key, bucket_name):
bucket = self.client.get_bucket(bucket_name)
for key in bucket.list_blobs(prefix=prefix):
try:
if filter_by_table(
self.config.sourceConfig.config.tableFilterPattern, key.name
):
self.status.filter(
"{}".format(key["Key"]),
"Table pattern not allowed",
)
continue
if key.name.endswith(".csv"):
df = read_csv_from_gcs(key, bucket_name)
yield from self.ingest_tables(key.name, df, bucket_name)
if key.name.endswith(".tsv"):
df = read_tsv_from_gcs(key, bucket_name)
yield from self.ingest_tables(key.name, df, bucket_name)
if key.name.endswith(".json"):
df = read_json_from_gcs(key)
yield from self.ingest_tables(key.name, df, bucket_name)
if key.name.endswith(".parquet"):
df = read_parquet_from_gcs(key, bucket_name)
yield from self.ingest_tables(key.name, df, bucket_name)
except Exception as err:
logger.debug(traceback.format_exc())
logger.error(err)
def get_s3_files(self, bucket_name, prefix):
kwargs = {"Bucket": bucket_name}
if prefix:
kwargs["Prefix"] = prefix if prefix.endswith("/") else f"{prefix}/"
for key in self.client.list_objects(**kwargs)["Contents"]:
try:
if filter_by_table(
self.config.sourceConfig.config.tableFilterPattern, key["Key"]
):
self.status.filter(
"{}".format(key["Key"]),
"Table pattern not allowed",
)
continue
if key["Key"].endswith(".csv"):
df = read_csv_from_s3(self.client, key, bucket_name)
yield from self.ingest_tables(key["Key"], df, bucket_name)
if key["Key"].endswith(".tsv"):
df = read_tsv_from_s3(self.client, key, bucket_name)
yield from self.ingest_tables(key["Key"], df, bucket_name)
if key["Key"].endswith(".json"):
df = read_json_from_s3(self.client, key, bucket_name)
yield from self.ingest_tables(key["Key"], df, bucket_name)
if key["Key"].endswith(".parquet"):
df = read_parquet_from_s3(self.client, key, bucket_name)
yield from self.ingest_tables(key["Key"], df, bucket_name)
except Exception as err:
logger.debug(traceback.format_exc())
logger.error(err)
def ingest_tables(self, key, df, bucket_name) -> Iterable[OMetaDatabaseAndTable]:
try: try:
table_columns = self.get_columns(df) if key.endswith(".csv"):
database_entity = Database( return read_csv_from_s3(self.client, key, bucket_name)
id=uuid.uuid4(),
name="default",
service=EntityReference(id=self.service.id, type="databaseService"),
)
table_entity = Table(
id=uuid.uuid4(),
name=key,
description="",
columns=table_columns,
tableType=TableType.External,
)
schema_entity = DatabaseSchema(
id=uuid.uuid4(),
name=bucket_name,
database=EntityReference(id=database_entity.id, type="database"),
service=EntityReference(id=self.service.id, type="databaseService"),
)
table_and_db = OMetaDatabaseAndTable(
table=table_entity,
database=database_entity,
database_schema=schema_entity,
)
yield table_and_db if key.endswith(".tsv"):
return read_tsv_from_s3(self.client, key, bucket_name)
if key.endswith(".json"):
return read_json_from_s3(self.client, key, bucket_name)
if key.endswith(".parquet"):
return read_parquet_from_s3(self.client, key, bucket_name)
except Exception as err: except Exception as err:
logger.debug(traceback.format_exc()) logger.debug(traceback.format_exc())
@ -259,27 +292,43 @@ class DatalakeSource(Source[Entity]):
return None return None
def get_columns(self, df): def get_columns(self, df):
df_columns = list(df.columns) if hasattr(df, "columns"):
for column in df_columns: df_columns = list(df.columns)
try: for column in df_columns:
try:
if (
hasattr(df[column], "dtypes")
and df[column].dtypes.name in DATALAKE_INT_TYPES
):
if df[column].dtypes.name == "int64":
data_type = DataType.INT.value
else:
data_type = DataType.STRING.value
parsed_string = {}
parsed_string["dataTypeDisplay"] = data_type
parsed_string["dataType"] = data_type
parsed_string["name"] = column[:64]
parsed_string["dataLength"] = parsed_string.get("dataLength", 1)
yield Column(**parsed_string)
except Exception as err:
logger.debug(traceback.format_exc())
logger.error(err)
if ( def yield_view_lineage(
hasattr(df[column], "dtypes") self, table_name_and_type: Tuple[str, str]
and df[column].dtypes.name in DATALAKE_INT_TYPES ) -> Optional[Iterable[AddLineageRequest]]:
): pass
if df[column].dtypes.name == "int64":
data_type = DataType.INT.value def yield_tag(self, schema_name: str) -> Iterable[OMetaTagAndCategory]:
else: pass
data_type = DataType.STRING.value
parsed_string = {} def standardize_table_name(self, schema: str, table: str) -> str:
parsed_string["dataTypeDisplay"] = data_type return table
parsed_string["dataType"] = data_type
parsed_string["name"] = column[:64] def check_valid_file_type(self, key_name):
parsed_string["dataLength"] = parsed_string.get("dataLength", 1) if key_name.endswith(DATALAKE_SUPPORTED_FILE_TYPES):
yield Column(**parsed_string) return True
except Exception as err: return False
logger.debug(traceback.format_exc())
logger.error(err)
def close(self): def close(self):
pass pass

View File

@ -1,15 +1,18 @@
import logging
import re import re
import uuid import traceback
from typing import Any, Dict, Iterable, List, Optional from typing import Any, Dict, Iterable, List, Optional, Tuple
from pyspark.sql import SparkSession from pyspark.sql import SparkSession
from pyspark.sql.types import ArrayType, MapType, StructType from pyspark.sql.types import ArrayType, MapType, StructType
from pyspark.sql.utils import AnalysisException, ParseException from pyspark.sql.utils import AnalysisException, ParseException
from metadata.generated.schema.entity.data.database import Database from metadata.generated.schema.api.data.createDatabase import CreateDatabaseRequest
from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema from metadata.generated.schema.api.data.createDatabaseSchema import (
from metadata.generated.schema.entity.data.table import Column, Table, TableType CreateDatabaseSchemaRequest,
)
from metadata.generated.schema.api.data.createTable import CreateTableRequest
from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest
from metadata.generated.schema.entity.data.table import Column, TableType
from metadata.generated.schema.entity.services.connections.database.deltaLakeConnection import ( from metadata.generated.schema.entity.services.connections.database.deltaLakeConnection import (
DeltaLakeConnection, DeltaLakeConnection,
) )
@ -17,15 +20,20 @@ from metadata.generated.schema.entity.services.connections.metadata.openMetadata
OpenMetadataConnection, OpenMetadataConnection,
) )
from metadata.generated.schema.entity.services.databaseService import DatabaseService from metadata.generated.schema.entity.services.databaseService import DatabaseService
from metadata.generated.schema.metadataIngestion.databaseServiceMetadataPipeline import (
DatabaseServiceMetadataPipeline,
)
from metadata.generated.schema.metadataIngestion.workflow import ( from metadata.generated.schema.metadataIngestion.workflow import (
Source as WorkflowSource, Source as WorkflowSource,
) )
from metadata.generated.schema.type.entityReference import EntityReference from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.api.common import Entity from metadata.ingestion.api.source import InvalidSourceException
from metadata.ingestion.api.source import InvalidSourceException, Source from metadata.ingestion.models.ometa_tag_category import OMetaTagAndCategory
from metadata.ingestion.models.ometa_table_db import OMetaDatabaseAndTable
from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.database.common_db_source import SQLSourceStatus from metadata.ingestion.source.database.database_service import (
DatabaseServiceSource,
SQLSourceStatus,
)
from metadata.utils.column_type_parser import ColumnTypeParser from metadata.utils.column_type_parser import ColumnTypeParser
from metadata.utils.connections import get_connection from metadata.utils.connections import get_connection
from metadata.utils.filters import filter_by_schema, filter_by_table from metadata.utils.filters import filter_by_schema, filter_by_table
@ -42,7 +50,7 @@ class MetaStoreNotFoundException(Exception):
""" """
class DeltalakeSource(Source[Entity]): class DeltalakeSource(DatabaseServiceSource):
spark: SparkSession = None spark: SparkSession = None
def __init__( def __init__(
@ -50,8 +58,11 @@ class DeltalakeSource(Source[Entity]):
config: WorkflowSource, config: WorkflowSource,
metadata_config: OpenMetadataConnection, metadata_config: OpenMetadataConnection,
): ):
super().__init__()
self.config = config self.config = config
self.source_config: DatabaseServiceMetadataPipeline = (
self.config.sourceConfig.config
)
self.metadata_config = metadata_config self.metadata_config = metadata_config
self.metadata = OpenMetadata(metadata_config) self.metadata = OpenMetadata(metadata_config)
self.service_connection = self.config.serviceConnection.__root__.config self.service_connection = self.config.serviceConnection.__root__.config
@ -72,6 +83,9 @@ class DeltalakeSource(Source[Entity]):
self.array_datatype_replace_map = {"(": "<", ")": ">", "=": ":", "<>": ""} self.array_datatype_replace_map = {"(": "<", ")": ">", "=": ":", "<>": ""}
self.ARRAY_CHILD_START_INDEX = 6 self.ARRAY_CHILD_START_INDEX = 6
self.ARRAY_CHILD_END_INDEX = -1 self.ARRAY_CHILD_END_INDEX = -1
self.table_constraints = None
self.database_source_state = set()
super().__init__()
@classmethod @classmethod
def create(cls, config_dict, metadata_config: OpenMetadataConnection): def create(cls, config_dict, metadata_config: OpenMetadataConnection):
@ -87,7 +101,35 @@ class DeltalakeSource(Source[Entity]):
) )
return cls(config, metadata_config) return cls(config, metadata_config)
def next_record(self) -> Iterable[OMetaDatabaseAndTable]: def get_database_names(self) -> Iterable[str]:
"""
Default case with a single database.
It might come informed - or not - from the source.
Sources with multiple databases should overwrite this and
apply the necessary filters.
"""
yield DEFAULT_DATABASE
def yield_database(self, database_name: str) -> Iterable[CreateDatabaseRequest]:
"""
From topology.
Prepare a database request and pass it to the sink
"""
yield CreateDatabaseRequest(
name=database_name,
service=EntityReference(
id=self.context.database_service.id,
type="databaseService",
),
)
def get_database_schema_names(self) -> Iterable[str]:
"""
return schema names
"""
schemas = self.spark.catalog.listDatabases() schemas = self.spark.catalog.listDatabases()
for schema in schemas: for schema in schemas:
if filter_by_schema( if filter_by_schema(
@ -95,88 +137,110 @@ class DeltalakeSource(Source[Entity]):
): ):
self.status.filter(schema.name, "Schema pattern not allowed") self.status.filter(schema.name, "Schema pattern not allowed")
continue continue
yield from self.fetch_tables(schema.name) yield schema.name
def get_status(self): def yield_database_schema(
return self.status self, schema_name: str
) -> Iterable[CreateDatabaseSchemaRequest]:
"""
From topology.
Prepare a database schema request and pass it to the sink
"""
yield CreateDatabaseSchemaRequest(
name=schema_name,
database=EntityReference(id=self.context.database.id, type="database"),
)
def prepare(self): def get_tables_name_and_type(self) -> Optional[Iterable[Tuple[str, str]]]:
pass """
Handle table and views.
def _get_table_type(self, table_type: str): Fetches them up using the context information and
return self.table_type_map.get(table_type.lower(), TableType.Regular.value) the inspector set when preparing the db.
def fetch_tables(self, schema: str) -> Iterable[OMetaDatabaseAndTable]: :return: tables or views, depending on config
for table in self.spark.catalog.listTables(schema): """
schema_name = self.context.database_schema.name.__root__
for table in self.spark.catalog.listTables(schema_name):
try: try:
table_name = table.name table_name = table.name
if filter_by_table( if filter_by_table(
self.config.sourceConfig.config.tableFilterPattern, table_name self.source_config.tableFilterPattern, table_name=table_name
): ):
self.status.filter( self.status.filter(
"{}.{}".format(self.config.serviceName, table_name), f"{table_name}",
"Table pattern not allowed", "Table pattern not allowed",
) )
continue continue
self.status.scanned("{}.{}".format(self.config.serviceName, table_name)) if (
table_columns = self._fetch_columns(schema, table_name) self.source_config.includeTables
if table.tableType and table.tableType.lower() != "view": and table.tableType
table_entity = Table( and table.tableType.lower() != "view"
id=uuid.uuid4(), ):
name=table_name, table_name = self.standardize_table_name(schema_name, table_name)
tableType=self._get_table_type(table.tableType), self.context.table_description = table.description
description=table.description, yield table_name, TableType.Regular
columns=table_columns,
) if (
else: self.source_config.includeViews
view_definition = self._fetch_view_schema(table_name) and table.tableType
table_entity = Table( and table.tableType.lower() == "view"
id=uuid.uuid4(), ):
name=table_name, view_name = self.standardize_table_name(schema_name, table_name)
tableType=self._get_table_type(table.tableType), self.context.table_description = table.description
description=table.description, yield view_name, TableType.View
columns=table_columns,
viewDefinition=view_definition,
)
database = self.get_database_entity()
table_and_db = OMetaDatabaseAndTable(
table=table_entity,
database=database,
database_schema=self._get_database_schema(database, schema),
)
yield table_and_db
except Exception as err: except Exception as err:
logger.error(err) logger.error(err)
self.status.warnings.append( self.status.warnings.append(
"{}.{}".format(self.config.serviceName, table.name) "{}.{}".format(self.config.serviceName, table.name)
) )
def get_database_entity(self) -> Database: def yield_table(
return Database( self, table_name_and_type: Tuple[str, str]
id=uuid.uuid4(), ) -> Iterable[Optional[CreateTableRequest]]:
name=DEFAULT_DATABASE, """
service=EntityReference( From topology.
id=self.service.id, type=self.service_connection.type.value Prepare a table request and pass it to the sink
), """
) table_name, table_type = table_name_and_type
schema_name = self.context.database_schema.name.__root__
def _get_database_schema(self, database: Database, schema: str) -> DatabaseSchema:
return DatabaseSchema(
name=schema,
service=EntityReference(
id=self.service.id, type=self.service_connection.type.value
),
database=EntityReference(id=database.id, type="database"),
)
def _fetch_table_description(self, table_name: str) -> Optional[Dict]:
try: try:
table_details_df = self.spark.sql(f"describe detail {table_name}") columns = self.get_columns(schema_name, table_name)
table_detail = table_details_df.collect()[0] view_definition = (
return table_detail.asDict() self._fetch_view_schema(table_name)
except Exception as e: if table_type == TableType.View
logging.error(e) else None
)
table_request = CreateTableRequest(
name=table_name,
tableType=table_type,
description=self.context.table_description,
columns=columns,
tableConstraints=None,
databaseSchema=EntityReference(
id=self.context.database_schema.id,
type="databaseSchema",
),
viewDefinition=view_definition,
)
yield table_request
self.register_record(table_request=table_request)
except Exception as err:
logger.debug(traceback.format_exc())
logger.error(err)
self.status.failures.append(
"{}.{}".format(self.config.serviceName, table_name)
)
def get_status(self):
return self.status
def prepare(self):
pass
def _fetch_view_schema(self, view_name: str) -> Optional[Dict]: def _fetch_view_schema(self, view_name: str) -> Optional[Dict]:
describe_output = [] describe_output = []
@ -251,7 +315,7 @@ class DeltalakeSource(Source[Entity]):
) )
return column return column
def _fetch_columns(self, schema: str, table: str) -> List[Column]: def get_columns(self, schema: str, table: str) -> List[Column]:
raw_columns = [] raw_columns = []
field_dict: Dict[str, Any] = {} field_dict: Dict[str, Any] = {}
table_name = f"{schema}.{table}" table_name = f"{schema}.{table}"
@ -275,12 +339,19 @@ class DeltalakeSource(Source[Entity]):
return parsed_columns return parsed_columns
def _is_complex_delta_type(self, delta_type: Any) -> bool: def yield_view_lineage(
return ( self, table_name_and_type: Tuple[str, str]
isinstance(delta_type, StructType) ) -> Optional[Iterable[AddLineageRequest]]:
or isinstance(delta_type, ArrayType) pass
or isinstance(delta_type, MapType)
) def yield_tag(self, schema_name: str) -> Iterable[OMetaTagAndCategory]:
pass
def close(self):
pass
def standardize_table_name(self, schema: str, table: str) -> str:
return table
def test_connection(self) -> None: def test_connection(self) -> None:
pass pass

View File

@ -280,6 +280,10 @@ def _(connection: DeltaLakeConnection, verbose: bool = False) -> DeltaLakeClient
elif connection.metastoreFilePath: elif connection.metastoreFilePath:
builder.config("spark.sql.warehouse.dir", f"{connection.metastoreFilePath}") builder.config("spark.sql.warehouse.dir", f"{connection.metastoreFilePath}")
if connection.connectionArguments:
for key, value in connection.connectionArguments:
builder.config(key, value)
deltalake_connection = DeltaLakeClient( deltalake_connection = DeltaLakeClient(
configure_spark_with_delta_pip(builder).getOrCreate() configure_spark_with_delta_pip(builder).getOrCreate()
) )

View File

@ -16,7 +16,6 @@ import dask.dataframe as dd
import gcsfs import gcsfs
import pandas as pd import pandas as pd
import pyarrow.parquet as pq import pyarrow.parquet as pq
from google.cloud.storage.blob import Blob
from pandas import DataFrame from pandas import DataFrame
from metadata.utils.logger import utils_logger from metadata.utils.logger import utils_logger
@ -24,19 +23,21 @@ from metadata.utils.logger import utils_logger
logger = utils_logger() logger = utils_logger()
def read_csv_from_gcs(key: Blob, bucket_name: str) -> DataFrame: def read_csv_from_gcs(key: str, bucket_name: str) -> DataFrame:
df = dd.read_csv(f"gs://{bucket_name}/{key.name}") df = dd.read_csv(f"gs://{bucket_name}/{key}")
return df return df
def read_tsv_from_gcs(key: Blob, bucket_name: str) -> DataFrame: def read_tsv_from_gcs(key: str, bucket_name: str) -> DataFrame:
df = dd.read_csv(f"gs://{bucket_name}/{key.name}", sep="\t") df = dd.read_csv(f"gs://{bucket_name}/{key}", sep="\t")
return df return df
def read_json_from_gcs(key: Blob) -> DataFrame: def read_json_from_gcs(client, key: str, bucket_name: str) -> DataFrame:
try: try:
data = key.download_as_string().decode() bucket = client.get_bucket(bucket_name)
blob = bucket.get_blob(key)
data = blob.download_as_string().decode()
data = json.loads(data) data = json.loads(data)
if isinstance(data, list): if isinstance(data, list):
df = pd.DataFrame.from_dict(data) df = pd.DataFrame.from_dict(data)
@ -51,8 +52,8 @@ def read_json_from_gcs(key: Blob) -> DataFrame:
logger.error(verr) logger.error(verr)
def read_parquet_from_gcs(key: Blob, bucket_name: str) -> DataFrame: def read_parquet_from_gcs(key: str, bucket_name: str) -> DataFrame:
gs = gcsfs.GCSFileSystem() gs = gcsfs.GCSFileSystem()
arrow_df = pq.ParquetDataset(f"gs://{bucket_name}/{key.name}", filesystem=gs) arrow_df = pq.ParquetDataset(f"gs://{bucket_name}/{key}", filesystem=gs)
df = arrow_df.read_pandas().to_pandas() df = arrow_df.read_pandas().to_pandas()
return df return df

View File

@ -16,25 +16,24 @@ import pandas as pd
from pandas import DataFrame from pandas import DataFrame
def read_csv_from_s3(client: Any, key: dict, bucket_name: str) -> DataFrame: def read_csv_from_s3(client: Any, key: str, bucket_name: str) -> DataFrame:
csv_obj = client.get_object(Bucket=bucket_name, Key=key["Key"]) csv_obj = client.get_object(Bucket=bucket_name, Key=key)
body = csv_obj["Body"] body = csv_obj["Body"]
csv_string = body.read().decode("utf-8") csv_string = body.read().decode("utf-8")
df = pd.read_csv(StringIO(csv_string)) df = pd.read_csv(StringIO(csv_string))
return df return df
def read_tsv_from_s3(client: Any, key: dict, bucket_name: str) -> DataFrame: def read_tsv_from_s3(client: Any, key: str, bucket_name: str) -> DataFrame:
tsv_obj = client.get_object(Bucket=bucket_name, Key=key["Key"]) tsv_obj = client.get_object(Bucket=bucket_name, Key=key)
body = tsv_obj["Body"] body = tsv_obj["Body"]
tsv_string = body.read().decode("utf-8") tsv_string = body.read().decode("utf-8")
df = pd.read_csv(StringIO(tsv_string), sep="\t") df = pd.read_csv(StringIO(tsv_string), sep="\t")
return df return df
def read_json_from_s3(client: Any, key: dict, bucket_name: str) -> DataFrame: def read_json_from_s3(client: Any, key: str, bucket_name: str) -> DataFrame:
obj = client.get_object(Bucket=bucket_name, Key=key["Key"]) obj = client.get_object(Bucket=bucket_name, Key=key)
json_text = obj["Body"].read().decode("utf-8") json_text = obj["Body"].read().decode("utf-8")
data = json.loads(json_text) data = json.loads(json_text)
if isinstance(data, list): if isinstance(data, list):
@ -44,7 +43,7 @@ def read_json_from_s3(client: Any, key: dict, bucket_name: str) -> DataFrame:
return df return df
def read_parquet_from_s3(client: Any, key: dict, bucket_name: str) -> DataFrame: def read_parquet_from_s3(client: Any, key: str, bucket_name: str) -> DataFrame:
obj = client.get_object(Bucket=bucket_name, Key=key["Key"]) obj = client.get_object(Bucket=bucket_name, Key=key)
df = pd.read_parquet(BytesIO(obj["Body"].read())) df = pd.read_parquet(BytesIO(obj["Body"].read()))
return df return df