From 6b3d9c1b1de48506dc12b0a4943ddfbae9bf0ccf Mon Sep 17 00:00:00 2001 From: Ayush Shah Date: Sat, 21 Aug 2021 00:43:07 +0530 Subject: [PATCH 1/3] Few Pipeline fixes and Documentation updated --- docs/install/setup-ingestion.md | 11 ++++++++--- ingestion/ingestion_scheduler/jobs.py | 3 ++- ingestion/src/metadata/cmd.py | 3 ++- ingestion/src/metadata/ingestion/api/source.py | 13 +++++-------- .../metadata/ingestion/bulksink/metadata_usage.py | 2 ++ .../metadata/ingestion/source/snowflake_usage.py | 5 ++++- .../src/metadata/ingestion/source/sql_source.py | 3 ++- 7 files changed, 25 insertions(+), 15 deletions(-) diff --git a/docs/install/setup-ingestion.md b/docs/install/setup-ingestion.md index a2b6eb85bce..408471283ba 100644 --- a/docs/install/setup-ingestion.md +++ b/docs/install/setup-ingestion.md @@ -39,14 +39,14 @@ You only need to run above command once. ```text source env/bin/activate -metadata ingest -c ./pipelines/redshift.json +metadata ingest -c ./examples/workflows/redshift.json ``` #### Generate Redshift Usage Data ```text source env/bin/activate - metadata ingest -c ./pipelines/redshift_usage.json + metadata ingest -c ./examples/workflows/redshift_usage.json ``` #### Generate Sample Tables @@ -55,7 +55,12 @@ metadata ingest -c ./pipelines/redshift.json source env/bin/activate metadata ingest -c ./pipelines/sample_tables.json ``` +#### Generate Sample Usage +```text + source env/bin/activate + metadata ingest -c ./pipelines/sample_usage.json +``` #### Generate Sample Users ```text @@ -75,7 +80,7 @@ metadata ingest -c ./pipelines/redshift.json ```text source env/bin/activate export GOOGLE_APPLICATION_CREDENTIALS="$PWD/examples/creds/bigquery-cred.json" - metadata ingest -c ./pipelines/bigquery.json + metadata ingest -c ./examples/workflows/bigquery.json ``` #### Index Metadata into ElasticSearch diff --git a/ingestion/ingestion_scheduler/jobs.py b/ingestion/ingestion_scheduler/jobs.py index eb8c8b3aadd..e0ab96790e7 100644 --- a/ingestion/ingestion_scheduler/jobs.py +++ b/ingestion/ingestion_scheduler/jobs.py @@ -35,7 +35,8 @@ class MetadataLoaderJob(job.JobBase, Workflow): def run(self, pipeline_data, *args, **kwargs): config_data = json.loads(pipeline_data) - del config_data['cron'] + if config_data.get('cron'): + del config_data['cron'] self.workflow = Workflow.create(config_data) self.workflow.execute() self.workflow.raise_from_status() diff --git a/ingestion/src/metadata/cmd.py b/ingestion/src/metadata/cmd.py index 90a0672ea21..cc1009e1bdc 100644 --- a/ingestion/src/metadata/cmd.py +++ b/ingestion/src/metadata/cmd.py @@ -66,7 +66,8 @@ def ingest(config: str) -> None: try: logger.info(f"Using config: {workflow_config}") - del workflow_config['cron'] + if workflow_config.get('cron'): + del workflow_config['cron'] workflow = Workflow.create(workflow_config) except ValidationError as e: click.echo(e, err=True) diff --git a/ingestion/src/metadata/ingestion/api/source.py b/ingestion/src/metadata/ingestion/api/source.py index 8ba2170acda..5eef7d41274 100644 --- a/ingestion/src/metadata/ingestion/api/source.py +++ b/ingestion/src/metadata/ingestion/api/source.py @@ -25,21 +25,18 @@ from .status import Status class SourceStatus(Status): records = 0 - warnings: Dict[str, List[str]] = field(default_factory=dict) - failures: Dict[str, List[str]] = field(default_factory=dict) + warnings: List[str] = field(default_factory=list) + failures: List[str] = field(default_factory=list) def scanned(self, record: Record) -> None: self.records += 1 def warning(self, key: str, reason: str) -> None: - if key not in self.warnings: - self.warnings[key] = [] - self.warnings[key].append(reason) + self.warnings.append({key:reason}) def failure(self, key: str, reason: str) -> None: - if key not in self.failures: - self.failures[key] = [] - self.failures[key].append(reason) + self.failures.append({key:reason}) + @dataclass # type: ignore[misc] diff --git a/ingestion/src/metadata/ingestion/bulksink/metadata_usage.py b/ingestion/src/metadata/ingestion/bulksink/metadata_usage.py index cadf40c20fd..6e83af57d40 100644 --- a/ingestion/src/metadata/ingestion/bulksink/metadata_usage.py +++ b/ingestion/src/metadata/ingestion/bulksink/metadata_usage.py @@ -71,6 +71,8 @@ class MetadataUsageBulkSink(BulkSink): usage_records = [json.loads(l) for l in self.file_handler.readlines()] for record in usage_records: table_usage = TableUsageCount(**json.loads(record)) + if '.' in table_usage.table: + table_usage.table = table_usage.table.split(".")[1] if table_usage.table in self.tables_dict: table_entity = self.tables_dict[table_usage.table] table_usage_request = TableUsageRequest(date=table_usage.date, count=table_usage.count) diff --git a/ingestion/src/metadata/ingestion/source/snowflake_usage.py b/ingestion/src/metadata/ingestion/source/snowflake_usage.py index 30c0ccd884e..53d413297c8 100644 --- a/ingestion/src/metadata/ingestion/source/snowflake_usage.py +++ b/ingestion/src/metadata/ingestion/source/snowflake_usage.py @@ -83,7 +83,10 @@ class SnowflakeUsageSource(Source): for row in self._get_raw_extract_iter(): tq = TableQuery(row['query'], row['label'], 0, 0, 0, str(row['starttime']), str(row['endtime']), str(row['starttime'])[0:19], 2, row['database'], 0, row['sql']) - self.report.scanned(f"{row['database']}.{row['schema_name']}") + if row['schema_name'] is not None: + self.report.scanned(f"{row['database']}.{row['schema_name']}") + else: + self.report.scanned(f"{row['database']}") yield tq def get_report(self): diff --git a/ingestion/src/metadata/ingestion/source/sql_source.py b/ingestion/src/metadata/ingestion/source/sql_source.py index 5bfe9371a9c..5d9f42cf2c7 100644 --- a/ingestion/src/metadata/ingestion/source/sql_source.py +++ b/ingestion/src/metadata/ingestion/source/sql_source.py @@ -83,6 +83,7 @@ class SQLConnectionConfig(ConfigModel): url += f"{self.host_port}" if self.database: url += f"/{self.database}" + logger.info(url) return url def get_service_type(self) -> DatabaseServiceType: @@ -209,7 +210,7 @@ class SQLSource(Source): for table_name in inspector.get_table_names(schema): try: schema, table = self.standardize_schema_table_names(schema, table_name) - if not self.sql_config.filter_pattern.included(table_name): + if not self.sql_config.filter_pattern.included(f'{schema}.{table_name}'): self.status.filtered('{}.{}'.format(self.config.get_service_name(), table_name), "Table pattern not allowed") continue From f543c3f99b42891012d0ed1cd8ab52fc39511823 Mon Sep 17 00:00:00 2001 From: Ayush Shah Date: Tue, 24 Aug 2021 02:13:55 +0530 Subject: [PATCH 2/3] Pipeline Fixes --- ingestion/examples/workflows/snowflake.json | 6 +- .../src/metadata/ingestion/source/athena.py | 2 + .../src/metadata/ingestion/source/bigquery.py | 22 +++---- .../src/metadata/ingestion/source/hive.py | 8 ++- .../src/metadata/ingestion/source/mssql.py | 3 + .../src/metadata/ingestion/source/mysql.py | 3 + .../src/metadata/ingestion/source/oracle.py | 6 ++ .../src/metadata/ingestion/source/postgres.py | 3 + .../src/metadata/ingestion/source/presto.py | 3 + .../src/metadata/ingestion/source/redshift.py | 3 + .../metadata/ingestion/source/snowflake.py | 3 + .../metadata/ingestion/source/sql_source.py | 57 +++++++++---------- 12 files changed, 74 insertions(+), 45 deletions(-) diff --git a/ingestion/examples/workflows/snowflake.json b/ingestion/examples/workflows/snowflake.json index 4b7f4c41351..f217f37831d 100644 --- a/ingestion/examples/workflows/snowflake.json +++ b/ingestion/examples/workflows/snowflake.json @@ -10,10 +10,8 @@ "service_name": "snowflake", "service_type": "Snowflake", "filter_pattern": { - "includes": [ - "(\\w)*tpcds_sf100tcl", - "(\\w)*tpcds_sf100tcl", - "(\\w)*tpcds_sf10tcl" + "excludes": [ + "tpcds_sf100tcl" ] } } diff --git a/ingestion/src/metadata/ingestion/source/athena.py b/ingestion/src/metadata/ingestion/source/athena.py index 448274c7a2b..6f1cefca85a 100644 --- a/ingestion/src/metadata/ingestion/source/athena.py +++ b/ingestion/src/metadata/ingestion/source/athena.py @@ -45,6 +45,8 @@ class AthenaConfig(SQLConnectionConfig): return url + def fetch_sample_data(self, schema: str, table: str, connection): + return super().fetch_sample_data(schema, table, connection) class AthenaSource(SQLSource): def __init__(self, config, metadata_config, ctx): diff --git a/ingestion/src/metadata/ingestion/source/bigquery.py b/ingestion/src/metadata/ingestion/source/bigquery.py index dd74c1ce025..ad96d517089 100644 --- a/ingestion/src/metadata/ingestion/source/bigquery.py +++ b/ingestion/src/metadata/ingestion/source/bigquery.py @@ -15,6 +15,8 @@ from typing import Optional, Tuple +from metadata.generated.schema.entity.data.table import TableData + # This import verifies that the dependencies are available. from .sql_source import SQLConnectionConfig, SQLSource @@ -30,6 +32,16 @@ class BigQueryConfig(SQLConnectionConfig, SQLSource): return f"{self.scheme}://{self.project_id}" return f"{self.scheme}://" + def fetch_sample_data(self, schema: str, table: str, connection): + query = f"select * from {self.project_id}.{schema},{table} limit 50" + results = self.connection.execute(query) + cols = list(results.keys()) + rows = [] + for r in results: + row = list(r) + rows.append(row) + return TableData(columns=cols, rows=rows) + class BigquerySource(SQLSource): def __init__(self, config, metadata_config, ctx): @@ -40,13 +52,3 @@ class BigquerySource(SQLSource): config = BigQueryConfig.parse_obj(config_dict) metadata_config = MetadataServerConfig.parse_obj(metadata_config_dict) return cls(config, metadata_config, ctx) - - def standardize_schema_table_names( - self, schema: str, table: str - ) -> Tuple[str, str]: - segments = table.split(".") - if len(segments) != 2: - raise ValueError(f"expected table to contain schema name already {table}") - if segments[0] != schema: - raise ValueError(f"schema {schema} does not match table {table}") - return segments[0], segments[1] diff --git a/ingestion/src/metadata/ingestion/source/hive.py b/ingestion/src/metadata/ingestion/source/hive.py index c56a978b654..80a5372ae08 100644 --- a/ingestion/src/metadata/ingestion/source/hive.py +++ b/ingestion/src/metadata/ingestion/source/hive.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional from pyhive import hive # noqa: F401 from pyhive.sqlalchemy_hive import HiveDate, HiveDecimal, HiveTimestamp @@ -30,9 +31,14 @@ register_custom_type(HiveDecimal, "NUMBER") class HiveConfig(SQLConnectionConfig): scheme = "hive" + auth_options = Optional[str] def get_connection_url(self): - return super().get_connection_url() + url = super().get_connection_url() + return f'{url};{self.auth_options}' + + def fetch_sample_data(self, schema: str, table: str, connection): + return super().fetch_sample_data(schema, table, connection) class HiveSource(SQLSource): diff --git a/ingestion/src/metadata/ingestion/source/mssql.py b/ingestion/src/metadata/ingestion/source/mssql.py index 31bcb9ff5d4..76fef34f82c 100644 --- a/ingestion/src/metadata/ingestion/source/mssql.py +++ b/ingestion/src/metadata/ingestion/source/mssql.py @@ -27,6 +27,9 @@ class MssqlConfig(SQLConnectionConfig): def get_connection_url(self): return super().get_connection_url() + def fetch_sample_data(self, schema: str, table: str, connection): + return super().fetch_sample_data(schema, table, connection) + class MssqlSource(SQLSource): def __init__(self, config, metadata_config, ctx): diff --git a/ingestion/src/metadata/ingestion/source/mysql.py b/ingestion/src/metadata/ingestion/source/mysql.py index 5129b1209e6..e0dc18751c5 100644 --- a/ingestion/src/metadata/ingestion/source/mysql.py +++ b/ingestion/src/metadata/ingestion/source/mysql.py @@ -24,6 +24,9 @@ class MySQLConfig(SQLConnectionConfig): def get_connection_url(self): return super().get_connection_url() + def fetch_sample_data(self, schema: str, table: str, connection): + return super().fetch_sample_data(schema, table, connection) + class MysqlSource(SQLSource): def __init__(self, config, metadata_config, ctx): diff --git a/ingestion/src/metadata/ingestion/source/oracle.py b/ingestion/src/metadata/ingestion/source/oracle.py index c43fcda8102..19276be3abe 100644 --- a/ingestion/src/metadata/ingestion/source/oracle.py +++ b/ingestion/src/metadata/ingestion/source/oracle.py @@ -24,6 +24,12 @@ class OracleConfig(SQLConnectionConfig): # defaults scheme = "oracle+cx_oracle" + def fetch_sample_data(self, schema: str, table: str, connection): + return super().fetch_sample_data(schema, table, connection) + + def get_connection_url(self): + return super().get_connection_url() + class OracleSource(SQLSource): def __init__(self, config, metadata_config, ctx): diff --git a/ingestion/src/metadata/ingestion/source/postgres.py b/ingestion/src/metadata/ingestion/source/postgres.py index 7d7213cd5c5..d5ef6c2e293 100644 --- a/ingestion/src/metadata/ingestion/source/postgres.py +++ b/ingestion/src/metadata/ingestion/source/postgres.py @@ -50,6 +50,9 @@ class PostgresSourceConfig(SQLConnectionConfig): def get_connection_url(self): return super().get_connection_url() + def fetch_sample_data(self, schema: str, table: str, connection): + return super().fetch_sample_data(schema, table, connection) + def get_table_key(row: Dict[str, Any]) -> Union[TableKey, None]: """ diff --git a/ingestion/src/metadata/ingestion/source/presto.py b/ingestion/src/metadata/ingestion/source/presto.py index b1ea29319c7..e654d663783 100644 --- a/ingestion/src/metadata/ingestion/source/presto.py +++ b/ingestion/src/metadata/ingestion/source/presto.py @@ -33,6 +33,9 @@ class PrestoConfig(SQLConnectionConfig): url += f"?schema={quote_plus(self.database)}" return url + def fetch_sample_data(self, schema: str, table: str, connection): + return super().fetch_sample_data(schema, table, connection) + class PrestoSource(SQLSource): def __init__(self, config, metadata_config, ctx): diff --git a/ingestion/src/metadata/ingestion/source/redshift.py b/ingestion/src/metadata/ingestion/source/redshift.py index d7355a0a630..0401a9ea142 100644 --- a/ingestion/src/metadata/ingestion/source/redshift.py +++ b/ingestion/src/metadata/ingestion/source/redshift.py @@ -37,6 +37,9 @@ class RedshiftConfig(SQLConnectionConfig): def get_connection_url(self): return super().get_connection_url() + def fetch_sample_data(self, schema: str, table: str, connection): + return super().fetch_sample_data(schema, table, connection) + class RedshiftSource(SQLSource): diff --git a/ingestion/src/metadata/ingestion/source/snowflake.py b/ingestion/src/metadata/ingestion/source/snowflake.py index 87cf1fa7e5e..6916779cb6f 100644 --- a/ingestion/src/metadata/ingestion/source/snowflake.py +++ b/ingestion/src/metadata/ingestion/source/snowflake.py @@ -49,6 +49,9 @@ class SnowflakeConfig(SQLConnectionConfig): connect_string = f"{connect_string}?{params}" return connect_string + def fetch_sample_data(self, schema: str, table: str, connection): + return super().fetch_sample_data(schema, table, connection) + class SnowflakeSource(SQLSource): def __init__(self, config, metadata_config, ctx): diff --git a/ingestion/src/metadata/ingestion/source/sql_source.py b/ingestion/src/metadata/ingestion/source/sql_source.py index 5d9f42cf2c7..2326c90f05c 100644 --- a/ingestion/src/metadata/ingestion/source/sql_source.py +++ b/ingestion/src/metadata/ingestion/source/sql_source.py @@ -72,6 +72,21 @@ class SQLConnectionConfig(ConfigModel): generate_sample_data: Optional[bool] = True filter_pattern: IncludeFilterPattern = IncludeFilterPattern.allow_all() + @abstractmethod + def fetch_sample_data(self, schema: str, table: str, connection): + try: + query = f"select * from {schema}.{table} limit 50" + logger.info("Fetching sample data, this may take a while {}".format(query)) + results = connection.execute(query) + cols = list(results.keys()) + rows = [] + for r in results: + row = list(r) + rows.append(row) + return TableData(columns=cols, rows=rows) + except Exception as err: + logger.error("Failed to generate sample data for {} - {}".format(table, err)) + @abstractmethod def get_connection_url(self): url = f"{self.scheme}://" @@ -161,7 +176,7 @@ class SQLSource(Source): self.status = SQLSourceStatus() self.sql_config = self.config self.engine = create_engine(self.sql_config.get_connection_url(), **self.sql_config.options) - self.connection = None + self.connection = self.engine.connect() def prepare(self): pass @@ -170,29 +185,7 @@ class SQLSource(Source): def create(cls, config_dict: dict, metadata_config_dict: dict, ctx: WorkflowContext): pass - def fetch_sample_data(self, schema: str, table: str): - try: - if self.connection is None: - self.connection = self.engine.connect() - query = f"select * from {schema}.{table} limit 50" - logger.info("Fetching sample data, this may take a while {}".format(query)) - results = self.connection.execute(query) - cols = list(results.keys()) - rows = [] - for r in results: - row = list(r) - rows.append(row) - return TableData(columns=cols, rows=rows) - except: - logger.error("Failed to generate sample data for {}".format(table)) - - def standardize_schema_table_names( - self, schema: str, table: str - ) -> Tuple[str, str]: - return schema, table - def next_record(self) -> Iterable[OMetaDatabaseAndTable]: - inspector = inspect(self.engine) for schema in inspector.get_schema_names(): if not self.sql_config.filter_pattern.included(schema): @@ -209,8 +202,7 @@ class SQLSource(Source): schema: str) -> Iterable[OMetaDatabaseAndTable]: for table_name in inspector.get_table_names(schema): try: - schema, table = self.standardize_schema_table_names(schema, table_name) - if not self.sql_config.filter_pattern.included(f'{schema}.{table_name}'): + if not self.sql_config.filter_pattern.included(table_name): self.status.filtered('{}.{}'.format(self.config.get_service_name(), table_name), "Table pattern not allowed") continue @@ -226,7 +218,7 @@ class SQLSource(Source): description=description if description is not None else ' ', columns=table_columns) if self.sql_config.generate_sample_data: - table_data = self.fetch_sample_data(schema, table_name) + table_data = self.sql_config.fetch_sample_data(schema, table_name, self.connection) table.sampleData = table_data table_and_db = OMetaDatabaseAndTable(table=table, database=self._get_database(schema)) @@ -261,7 +253,7 @@ class SQLSource(Source): columns=table_columns, viewDefinition=view_definition) if self.sql_config.generate_sample_data: - table_data = self.fetch_sample_data(schema, view_name) + table_data = self.sql_config.fetch_sample_data(schema, view_name, self.connection) table.sampleData = table_data table_and_db = OMetaDatabaseAndTable(table=table, database=self._get_database(schema)) @@ -294,7 +286,11 @@ class SQLSource(Source): table_columns = [] row_order = 1 for column in columns: - col_type = get_column_type(self.status, dataset_name, column['type']) + col_type = None + try: + col_type = get_column_type(self.status, dataset_name, column['type']) + except Exception as err: + logger.error(err) col_constraint = None if column['nullable']: col_constraint = ColumnConstraint.NULL @@ -316,10 +312,11 @@ class SQLSource(Source): return table_columns def _get_table_description(self, schema: str, table: str, inspector: Inspector) -> str: + description = None try: table_info: dict = inspector.get_table_comment(table, schema) - except NotImplementedError: - description: Optional[str] = None + except Exception as err: + logger.error(f"Table Description Error : {err}") else: description = table_info["text"] return description From 07dd7b36f4062a1bd90094d6c591e39cbfece655 Mon Sep 17 00:00:00 2001 From: Ayush Shah Date: Tue, 24 Aug 2021 12:01:32 +0530 Subject: [PATCH 3/3] Pipeline Ingestion refactoring --- ingestion/examples/workflows/bigquery.json | 9 ++- .../src/metadata/ingestion/source/athena.py | 3 +- .../src/metadata/ingestion/source/bigquery.py | 30 ++++++---- .../src/metadata/ingestion/source/hive.py | 5 +- .../src/metadata/ingestion/source/mssql.py | 3 +- .../src/metadata/ingestion/source/mysql.py | 3 - .../src/metadata/ingestion/source/oracle.py | 3 +- .../src/metadata/ingestion/source/postgres.py | 3 +- .../src/metadata/ingestion/source/presto.py | 3 +- .../src/metadata/ingestion/source/redshift.py | 3 +- .../metadata/ingestion/source/snowflake.py | 3 +- .../metadata/ingestion/source/sql_source.py | 58 ++++++++++--------- 12 files changed, 67 insertions(+), 59 deletions(-) diff --git a/ingestion/examples/workflows/bigquery.json b/ingestion/examples/workflows/bigquery.json index 6ac3cd657f6..9f6b8670328 100644 --- a/ingestion/examples/workflows/bigquery.json +++ b/ingestion/examples/workflows/bigquery.json @@ -6,7 +6,14 @@ "host_port": "https://bigquery.googleapis.com", "username": "gcpuser@project_id.iam.gserviceaccount.com", "service_name": "gcp_bigquery", - "service_type": "BigQuery" + "service_type": "BigQuery", + "filter_pattern": { + "excludes": [ + "[\\w]*cloudaudit.*", + "[\\w]*logging_googleapis_com.*", + "[\\w]*clouderrorreporting.*" + ] + } } }, "processor": { diff --git a/ingestion/src/metadata/ingestion/source/athena.py b/ingestion/src/metadata/ingestion/source/athena.py index 6f1cefca85a..3c587caa7cb 100644 --- a/ingestion/src/metadata/ingestion/source/athena.py +++ b/ingestion/src/metadata/ingestion/source/athena.py @@ -45,8 +45,7 @@ class AthenaConfig(SQLConnectionConfig): return url - def fetch_sample_data(self, schema: str, table: str, connection): - return super().fetch_sample_data(schema, table, connection) + class AthenaSource(SQLSource): def __init__(self, config, metadata_config, ctx): diff --git a/ingestion/src/metadata/ingestion/source/bigquery.py b/ingestion/src/metadata/ingestion/source/bigquery.py index ad96d517089..58fe12c6c04 100644 --- a/ingestion/src/metadata/ingestion/source/bigquery.py +++ b/ingestion/src/metadata/ingestion/source/bigquery.py @@ -32,16 +32,6 @@ class BigQueryConfig(SQLConnectionConfig, SQLSource): return f"{self.scheme}://{self.project_id}" return f"{self.scheme}://" - def fetch_sample_data(self, schema: str, table: str, connection): - query = f"select * from {self.project_id}.{schema},{table} limit 50" - results = self.connection.execute(query) - cols = list(results.keys()) - rows = [] - for r in results: - row = list(r) - rows.append(row) - return TableData(columns=cols, rows=rows) - class BigquerySource(SQLSource): def __init__(self, config, metadata_config, ctx): @@ -52,3 +42,23 @@ class BigquerySource(SQLSource): config = BigQueryConfig.parse_obj(config_dict) metadata_config = MetadataServerConfig.parse_obj(metadata_config_dict) return cls(config, metadata_config, ctx) + + def fetch_sample_data(self, schema: str, table: str): + query = f"select * from {self.config.project_id}.{schema}.{table} limit 50" + results = self.connection.execute(query) + cols = list(results.keys()) + rows = [] + for r in results: + row = list(r) + rows.append(row) + return TableData(columns=cols, rows=rows) + + def standardize_schema_table_names( + self, schema: str, table: str + ) -> Tuple[str, str]: + segments = table.split(".") + if len(segments) != 2: + raise ValueError(f"expected table to contain schema name already {table}") + if segments[0] != schema: + raise ValueError(f"schema {schema} does not match table {table}") + return segments[0], segments[1] diff --git a/ingestion/src/metadata/ingestion/source/hive.py b/ingestion/src/metadata/ingestion/source/hive.py index 80a5372ae08..d4b22b4621c 100644 --- a/ingestion/src/metadata/ingestion/source/hive.py +++ b/ingestion/src/metadata/ingestion/source/hive.py @@ -31,15 +31,12 @@ register_custom_type(HiveDecimal, "NUMBER") class HiveConfig(SQLConnectionConfig): scheme = "hive" - auth_options = Optional[str] + auth_options: Optional[str] = None def get_connection_url(self): url = super().get_connection_url() return f'{url};{self.auth_options}' - def fetch_sample_data(self, schema: str, table: str, connection): - return super().fetch_sample_data(schema, table, connection) - class HiveSource(SQLSource): def __init__(self, config, metadata_config, ctx): diff --git a/ingestion/src/metadata/ingestion/source/mssql.py b/ingestion/src/metadata/ingestion/source/mssql.py index 76fef34f82c..0ad4ef09d7b 100644 --- a/ingestion/src/metadata/ingestion/source/mssql.py +++ b/ingestion/src/metadata/ingestion/source/mssql.py @@ -27,8 +27,7 @@ class MssqlConfig(SQLConnectionConfig): def get_connection_url(self): return super().get_connection_url() - def fetch_sample_data(self, schema: str, table: str, connection): - return super().fetch_sample_data(schema, table, connection) + class MssqlSource(SQLSource): diff --git a/ingestion/src/metadata/ingestion/source/mysql.py b/ingestion/src/metadata/ingestion/source/mysql.py index e0dc18751c5..5129b1209e6 100644 --- a/ingestion/src/metadata/ingestion/source/mysql.py +++ b/ingestion/src/metadata/ingestion/source/mysql.py @@ -24,9 +24,6 @@ class MySQLConfig(SQLConnectionConfig): def get_connection_url(self): return super().get_connection_url() - def fetch_sample_data(self, schema: str, table: str, connection): - return super().fetch_sample_data(schema, table, connection) - class MysqlSource(SQLSource): def __init__(self, config, metadata_config, ctx): diff --git a/ingestion/src/metadata/ingestion/source/oracle.py b/ingestion/src/metadata/ingestion/source/oracle.py index 19276be3abe..9b59bfc55d5 100644 --- a/ingestion/src/metadata/ingestion/source/oracle.py +++ b/ingestion/src/metadata/ingestion/source/oracle.py @@ -24,8 +24,7 @@ class OracleConfig(SQLConnectionConfig): # defaults scheme = "oracle+cx_oracle" - def fetch_sample_data(self, schema: str, table: str, connection): - return super().fetch_sample_data(schema, table, connection) + def get_connection_url(self): return super().get_connection_url() diff --git a/ingestion/src/metadata/ingestion/source/postgres.py b/ingestion/src/metadata/ingestion/source/postgres.py index d5ef6c2e293..553e80e9069 100644 --- a/ingestion/src/metadata/ingestion/source/postgres.py +++ b/ingestion/src/metadata/ingestion/source/postgres.py @@ -50,8 +50,7 @@ class PostgresSourceConfig(SQLConnectionConfig): def get_connection_url(self): return super().get_connection_url() - def fetch_sample_data(self, schema: str, table: str, connection): - return super().fetch_sample_data(schema, table, connection) + def get_table_key(row: Dict[str, Any]) -> Union[TableKey, None]: diff --git a/ingestion/src/metadata/ingestion/source/presto.py b/ingestion/src/metadata/ingestion/source/presto.py index e654d663783..5e638467e77 100644 --- a/ingestion/src/metadata/ingestion/source/presto.py +++ b/ingestion/src/metadata/ingestion/source/presto.py @@ -33,8 +33,7 @@ class PrestoConfig(SQLConnectionConfig): url += f"?schema={quote_plus(self.database)}" return url - def fetch_sample_data(self, schema: str, table: str, connection): - return super().fetch_sample_data(schema, table, connection) + class PrestoSource(SQLSource): diff --git a/ingestion/src/metadata/ingestion/source/redshift.py b/ingestion/src/metadata/ingestion/source/redshift.py index 0401a9ea142..da576a6ef06 100644 --- a/ingestion/src/metadata/ingestion/source/redshift.py +++ b/ingestion/src/metadata/ingestion/source/redshift.py @@ -37,8 +37,7 @@ class RedshiftConfig(SQLConnectionConfig): def get_connection_url(self): return super().get_connection_url() - def fetch_sample_data(self, schema: str, table: str, connection): - return super().fetch_sample_data(schema, table, connection) + class RedshiftSource(SQLSource): diff --git a/ingestion/src/metadata/ingestion/source/snowflake.py b/ingestion/src/metadata/ingestion/source/snowflake.py index 6916779cb6f..ca573ab08f4 100644 --- a/ingestion/src/metadata/ingestion/source/snowflake.py +++ b/ingestion/src/metadata/ingestion/source/snowflake.py @@ -49,8 +49,7 @@ class SnowflakeConfig(SQLConnectionConfig): connect_string = f"{connect_string}?{params}" return connect_string - def fetch_sample_data(self, schema: str, table: str, connection): - return super().fetch_sample_data(schema, table, connection) + class SnowflakeSource(SQLSource): diff --git a/ingestion/src/metadata/ingestion/source/sql_source.py b/ingestion/src/metadata/ingestion/source/sql_source.py index 2326c90f05c..78ba132fcf7 100644 --- a/ingestion/src/metadata/ingestion/source/sql_source.py +++ b/ingestion/src/metadata/ingestion/source/sql_source.py @@ -55,7 +55,7 @@ class SQLSourceStatus(SourceStatus): def filtered(self, table_name: str, err: str, dataset_name: str = None, col_type: str = None) -> None: self.warnings.append(table_name) - logger.warning("Dropped Table {} due to {}".format(dataset_name, err)) + logger.warning("Dropped Table {} due to {}".format(table_name, err)) class SQLConnectionConfig(ConfigModel): @@ -72,21 +72,6 @@ class SQLConnectionConfig(ConfigModel): generate_sample_data: Optional[bool] = True filter_pattern: IncludeFilterPattern = IncludeFilterPattern.allow_all() - @abstractmethod - def fetch_sample_data(self, schema: str, table: str, connection): - try: - query = f"select * from {schema}.{table} limit 50" - logger.info("Fetching sample data, this may take a while {}".format(query)) - results = connection.execute(query) - cols = list(results.keys()) - rows = [] - for r in results: - row = list(r) - rows.append(row) - return TableData(columns=cols, rows=rows) - except Exception as err: - logger.error("Failed to generate sample data for {} - {}".format(table, err)) - @abstractmethod def get_connection_url(self): url = f"{self.scheme}://" @@ -185,6 +170,26 @@ class SQLSource(Source): def create(cls, config_dict: dict, metadata_config_dict: dict, ctx: WorkflowContext): pass + def standardize_schema_table_names( + self, schema: str, table: str + ) -> Tuple[str, str]: + print("IN SQL SOURCE") + return schema, table + + def fetch_sample_data(self, schema: str, table: str): + try: + query = f"select * from {schema}.{table} limit 50" + logger.info("Fetching sample data, this may take a while {}".format(query)) + results = self.connection.execute(query) + cols = list(results.keys()) + rows = [] + for r in results: + row = list(r) + rows.append(row) + return TableData(columns=cols, rows=rows) + except Exception as err: + logger.error("Failed to generate sample data for {} - {}".format(table, err)) + def next_record(self) -> Iterable[OMetaDatabaseAndTable]: inspector = inspect(self.engine) for schema in inspector.get_schema_names(): @@ -202,6 +207,7 @@ class SQLSource(Source): schema: str) -> Iterable[OMetaDatabaseAndTable]: for table_name in inspector.get_table_names(schema): try: + schema, table_name = self.standardize_schema_table_names(schema, table_name) if not self.sql_config.filter_pattern.included(table_name): self.status.filtered('{}.{}'.format(self.config.get_service_name(), table_name), "Table pattern not allowed") @@ -211,17 +217,16 @@ class SQLSource(Source): description = self._get_table_description(schema, table_name, inspector) table_columns = self._get_columns(schema, table_name, inspector) - - table = Table(id=uuid.uuid4(), - name=table_name, - tableType='Regular', - description=description if description is not None else ' ', - columns=table_columns) + table_entity = Table(id=uuid.uuid4(), + name=table_name, + tableType='Regular', + description=description if description is not None else ' ', + columns=table_columns) if self.sql_config.generate_sample_data: - table_data = self.sql_config.fetch_sample_data(schema, table_name, self.connection) - table.sampleData = table_data + table_data = self.fetch_sample_data(schema, table_name) + table_entity.sampleData = table_data - table_and_db = OMetaDatabaseAndTable(table=table, database=self._get_database(schema)) + table_and_db = OMetaDatabaseAndTable(table=table_entity, database=self._get_database(schema)) yield table_and_db except ValidationError as err: logger.error(err) @@ -253,7 +258,7 @@ class SQLSource(Source): columns=table_columns, viewDefinition=view_definition) if self.sql_config.generate_sample_data: - table_data = self.sql_config.fetch_sample_data(schema, view_name, self.connection) + table_data = self.fetch_sample_data(schema, view_name) table.sampleData = table_data table_and_db = OMetaDatabaseAndTable(table=table, database=self._get_database(schema)) @@ -301,7 +306,6 @@ class SQLSource(Source): col_constraint = ColumnConstraint.PRIMARY_KEY elif column['name'] in unique_columns: col_constraint = ColumnConstraint.UNIQUE - table_columns.append(Column(name=column['name'], description=column.get("comment", None), columnDataType=col_type,