Fixed Bigquery to support data profiling (#494)

This commit is contained in:
Ayush Shah 2021-09-15 10:37:51 +05:30 committed by GitHub
parent 20a98aca81
commit 9b04d781e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 97 additions and 48 deletions

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from typing import Optional, Tuple from typing import Optional, Tuple
import os
from metadata.generated.schema.entity.data.table import TableData from metadata.generated.schema.entity.data.table import TableData
# This import verifies that the dependencies are available. # This import verifies that the dependencies are available.
@ -41,12 +41,13 @@ class BigquerySource(SQLSource):
@classmethod @classmethod
def create(cls, config_dict, metadata_config_dict, ctx): def create(cls, config_dict, metadata_config_dict, ctx):
config = BigQueryConfig.parse_obj(config_dict) config: SQLConnectionConfig = BigQueryConfig.parse_obj(config_dict)
metadata_config = MetadataServerConfig.parse_obj(metadata_config_dict) metadata_config = MetadataServerConfig.parse_obj(metadata_config_dict)
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = config.options['credentials_path']
return cls(config, metadata_config, ctx) return cls(config, metadata_config, ctx)
def standardize_schema_table_names( def standardize_schema_table_names(
self, schema: str, table: str self, schema: str, table: str
) -> Tuple[str, str]: ) -> Tuple[str, str]:
segments = table.split(".") segments = table.split(".")
if len(segments) != 2: if len(segments) != 2:

View File

@ -18,6 +18,7 @@ import uuid
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type
from urllib.parse import quote_plus
from pydantic import ValidationError from pydantic import ValidationError
@ -28,7 +29,8 @@ from metadata.generated.schema.type.entityReference import EntityReference
from metadata.generated.schema.entity.data.database import Database from metadata.generated.schema.entity.data.database import Database
from metadata.generated.schema.entity.data.table import Table, Column, ColumnConstraint, TableType, TableData, \ from metadata.generated.schema.entity.data.table import Table, Column, ColumnConstraint, TableType, \
TableData, \
TableProfile TableProfile
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.reflection import Inspector
@ -57,7 +59,9 @@ class SQLSourceStatus(SourceStatus):
self.success.append(table_name) self.success.append(table_name)
logger.info('Table Scanned: {}'.format(table_name)) logger.info('Table Scanned: {}'.format(table_name))
def filter(self, table_name: str, err: str, dataset_name: str = None, col_type: str = None) -> None: def filter(
self, table_name: str, err: str, dataset_name: str = None, col_type: str = None
) -> None:
self.filtered.append(table_name) self.filtered.append(table_name)
logger.warning("Dropped Table {} due to {}".format(table_name, err)) logger.warning("Dropped Table {} due to {}".format(table_name, err))
@ -83,10 +87,10 @@ class SQLConnectionConfig(ConfigModel):
@abstractmethod @abstractmethod
def get_connection_url(self): def get_connection_url(self):
url = f"{self.scheme}://" url = f"{self.scheme}://"
if self.username: if self.username is not None:
url += f"{self.username}" url += f"{quote_plus(self.username)}"
if self.password: if self.password is not None:
url += f":{self.password}" url += f":{quote_plus(self.password)}"
url += "@" url += "@"
url += f"{self.host_port}" url += f"{self.host_port}"
if self.database: if self.database:
@ -170,8 +174,10 @@ def _get_table_description(schema: str, table: str, inspector: Inspector) -> str
class SQLSource(Source): class SQLSource(Source):
def __init__(self, config: SQLConnectionConfig, metadata_config: MetadataServerConfig, def __init__(
ctx: WorkflowContext): self, config: SQLConnectionConfig, metadata_config: MetadataServerConfig,
ctx: WorkflowContext
):
super().__init__(ctx) super().__init__(ctx)
self.config = config self.config = config
self.metadata_config = metadata_config self.metadata_config = metadata_config
@ -182,8 +188,10 @@ class SQLSource(Source):
self.engine = create_engine(self.connection_string, **self.sql_config.options) self.engine = create_engine(self.connection_string, **self.sql_config.options)
self.connection = self.engine.connect() self.connection = self.engine.connect()
if self.config.data_profiler_enabled: if self.config.data_profiler_enabled:
self.data_profiler = DataProfiler(status=self.status, self.data_profiler = DataProfiler(
connection_str=self.connection_string) status=self.status,
connection_str=self.connection_string,
)
def prepare(self): def prepare(self):
pass pass
@ -223,26 +231,32 @@ class SQLSource(Source):
if self.config.include_views: if self.config.include_views:
yield from self.fetch_views(inspector, schema) yield from self.fetch_views(inspector, schema)
def fetch_tables(self, def fetch_tables(
inspector: Inspector, self,
schema: str) -> Iterable[OMetaDatabaseAndTable]: inspector: Inspector,
schema: str
) -> Iterable[OMetaDatabaseAndTable]:
for table_name in inspector.get_table_names(schema): for table_name in inspector.get_table_names(schema):
try: try:
schema, table_name = self.standardize_schema_table_names(schema, table_name) schema, table_name = self.standardize_schema_table_names(schema, table_name)
if not self.sql_config.filter_pattern.included(table_name): if not self.sql_config.filter_pattern.included(table_name):
self.status.filter('{}.{}'.format(self.config.get_service_name(), table_name), self.status.filter(
"Table pattern not allowed") '{}.{}'.format(self.config.get_service_name(), table_name),
"Table pattern not allowed"
)
continue continue
self.status.scanned('{}.{}'.format(self.config.get_service_name(), table_name)) self.status.scanned('{}.{}'.format(self.config.get_service_name(), table_name))
description = _get_table_description(schema, table_name, inspector) description = _get_table_description(schema, table_name, inspector)
table_columns = self._get_columns(schema, table_name, inspector) table_columns = self._get_columns(schema, table_name, inspector)
table_entity = Table(id=uuid.uuid4(), table_entity = Table(
name=table_name, id=uuid.uuid4(),
tableType='Regular', name=table_name,
description=description if description is not None else ' ', tableType='Regular',
columns=table_columns) description=description if description is not None else ' ',
columns=table_columns
)
if self.sql_config.generate_sample_data: if self.sql_config.generate_sample_data:
table_data = self.fetch_sample_data(schema, table_name) table_data = self.fetch_sample_data(schema, table_name)
table_entity.sampleData = table_data table_entity.sampleData = table_data
@ -251,41 +265,61 @@ class SQLSource(Source):
profile = self.run_data_profiler(table_name, schema) profile = self.run_data_profiler(table_name, schema)
table_entity.tableProfile = profile table_entity.tableProfile = profile
table_and_db = OMetaDatabaseAndTable(table=table_entity, database=self._get_database(schema)) table_and_db = OMetaDatabaseAndTable(
table=table_entity, database=self._get_database(schema)
)
yield table_and_db yield table_and_db
except ValidationError as err: except ValidationError as err:
logger.error(err) logger.error(err)
self.status.failures.append('{}.{}'.format(self.config.service_name, table_name)) self.status.failures.append('{}.{}'.format(self.config.service_name, table_name))
continue continue
def fetch_views(self, def fetch_views(
inspector: Inspector, self,
schema: str) -> Iterable[OMetaDatabaseAndTable]: inspector: Inspector,
schema: str
) -> Iterable[OMetaDatabaseAndTable]:
for view_name in inspector.get_view_names(schema): for view_name in inspector.get_view_names(schema):
try: try:
if self.config.scheme == "bigquery":
schema, view_name = self.standardize_schema_table_names(schema, view_name)
if not self.sql_config.filter_pattern.included(view_name): if not self.sql_config.filter_pattern.included(view_name):
self.status.filter('{}.{}'.format(self.config.get_service_name(), view_name), self.status.filter(
"View pattern not allowed") '{}.{}'.format(self.config.get_service_name(), view_name),
"View pattern not allowed"
)
continue continue
try: try:
view_definition = inspector.get_view_definition(view_name, schema)
if self.config.scheme == "bigquery":
view_definition = inspector.get_view_definition(
f"{self.config.project_id}.{schema}.{view_name}"
)
else:
view_definition = inspector.get_view_definition(
view_name, schema
)
view_definition = "" if view_definition is None else str(view_definition) view_definition = "" if view_definition is None else str(view_definition)
except NotImplementedError: except NotImplementedError:
view_definition = "" view_definition = ""
description = _get_table_description(schema, view_name, inspector) description = _get_table_description(schema, view_name, inspector)
table_columns = self._get_columns(schema, view_name, inspector) table_columns = self._get_columns(schema, view_name, inspector)
table = Table(id=uuid.uuid4(), table = Table(
name=view_name, id=uuid.uuid4(),
tableType='View', name=view_name,
description=description if description is not None else ' ', tableType='View',
columns=table_columns, description=description if description is not None else ' ',
viewDefinition=view_definition) columns=table_columns,
viewDefinition=view_definition
)
if self.sql_config.generate_sample_data: if self.sql_config.generate_sample_data:
table_data = self.fetch_sample_data(schema, view_name) table_data = self.fetch_sample_data(schema, view_name)
table.sampleData = table_data table.sampleData = table_data
table_and_db = OMetaDatabaseAndTable(table=table, database=self._get_database(schema)) table_and_db = OMetaDatabaseAndTable(
table=table, database=self._get_database(schema)
)
yield table_and_db yield table_and_db
except ValidationError as err: except ValidationError as err:
logger.error(err) logger.error(err)
@ -293,13 +327,16 @@ class SQLSource(Source):
continue continue
def _get_database(self, schema: str) -> Database: def _get_database(self, schema: str) -> Database:
return Database(name=schema, return Database(
service=EntityReference(id=self.service.id, type=self.config.service_type)) name=schema,
service=EntityReference(id=self.service.id, type=self.config.service_type)
)
def _get_columns(self, schema: str, table: str, inspector: Inspector) -> List[Column]: def _get_columns(self, schema: str, table: str, inspector: Inspector) -> List[Column]:
pk_constraints = inspector.get_pk_constraint(table, schema) pk_constraints = inspector.get_pk_constraint(table, schema)
pk_columns = pk_constraints['column_constraints'] if len( pk_columns = pk_constraints['column_constraints'] if len(
pk_constraints) > 0 and "column_constraints" in pk_constraints.keys() else {} pk_constraints
) > 0 and "column_constraints" in pk_constraints.keys() else {}
unique_constraints = [] unique_constraints = []
try: try:
unique_constraints = inspector.get_unique_constraints(table, schema) unique_constraints = inspector.get_unique_constraints(table, schema)
@ -329,11 +366,15 @@ class SQLSource(Source):
col_constraint = ColumnConstraint.PRIMARY_KEY col_constraint = ColumnConstraint.PRIMARY_KEY
elif column['name'] in unique_columns: elif column['name'] in unique_columns:
col_constraint = ColumnConstraint.UNIQUE col_constraint = ColumnConstraint.UNIQUE
table_columns.append(Column(name=column['name'], table_columns.append(
description=column.get("comment", None), Column(
columnDataType=col_type, name=column['name'],
columnConstraint=col_constraint, description=column.get("comment", None),
ordinalPosition=row_order)) columnDataType=col_type,
columnConstraint=col_constraint,
ordinalPosition=row_order
)
)
row_order = row_order + 1 row_order = row_order + 1
return table_columns return table_columns
@ -345,14 +386,21 @@ class SQLSource(Source):
) -> TableProfile: ) -> TableProfile:
dataset_name = f"{schema}.{table}" dataset_name = f"{schema}.{table}"
self.status.scanned(f"profile of {dataset_name}") self.status.scanned(f"profile of {dataset_name}")
logger.info(f"Running Profiling for {dataset_name}. " logger.info(
f"If you haven't configured offset and limit this process can take longer") f"Running Profiling for {dataset_name}. "
f"If you haven't configured offset and limit this process can take longer"
)
if self.config.scheme == "bigquery":
table = dataset_name
profile = self.data_profiler.run_profiler( profile = self.data_profiler.run_profiler(
dataset_name=dataset_name, dataset_name=dataset_name,
schema=schema, schema=schema,
table=table, table=table,
limit=self.sql_config.data_profiler_limit, limit=self.sql_config.data_profiler_limit,
offset=self.sql_config.data_profiler_offset) offset=self.sql_config.data_profiler_offset,
project_id=self.config.project_id if self.config.scheme == "bigquery" else None
)
logger.debug(f"Finished profiling {dataset_name}") logger.debug(f"Finished profiling {dataset_name}")
return profile return profile