From db985fda57daeeee9fb3814d1dd9ee329018712e Mon Sep 17 00:00:00 2001 From: Pere Miquel Brull Date: Fri, 26 Jan 2024 14:11:16 +0100 Subject: [PATCH] MINOR - Snowflake system queries to work with ES & IDENTIFIER (#14864) --- .../metrics/system/queries/snowflake.py | 110 ++++++++++++++--- .../profiler/metrics/system/system.py | 82 +++++++++---- .../metadata/profiler/processor/default.py | 17 ++- .../profiler/source/base/profiler_source.py | 6 +- ingestion/tests/unit/profiler/conftest.py | 6 +- .../unit/profiler/sqlalchemy/test_metrics.py | 4 +- ingestion/tests/unit/profiler/test_utils.py | 112 +++++++++++++++++- 7 files changed, 285 insertions(+), 52 deletions(-) diff --git a/ingestion/src/metadata/profiler/metrics/system/queries/snowflake.py b/ingestion/src/metadata/profiler/metrics/system/queries/snowflake.py index 90e2d46fb65..ef895bd7095 100644 --- a/ingestion/src/metadata/profiler/metrics/system/queries/snowflake.py +++ b/ingestion/src/metadata/profiler/metrics/system/queries/snowflake.py @@ -15,10 +15,13 @@ Snowflake System Metric Queries and query operations import re import traceback -from typing import Optional +from typing import Optional, Tuple from sqlalchemy.engine.row import Row +from metadata.generated.schema.entity.services.databaseService import DatabaseService +from metadata.ingestion.lineage.sql_lineage import search_table_entities +from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.utils.logger import profiler_logger from metadata.utils.profiler_utils import ( SnowflakeQueryResult, @@ -55,20 +58,92 @@ RESULT_SCAN = """ """ -def get_snowflake_system_queries( - row: Row, database: str, schema: str -) -> Optional[SnowflakeQueryResult]: - """get snowflake system queries for a specific database and schema. Parsing the query - is the only reliable way to get the DDL operation as fields in the table are not. If parsing - fails we'll fall back to regex lookup +QUERY_PATTERN = r"(?:(INSERT\s*INTO\s*|INSERT\s*OVERWRITE\s*INTO\s*|UPDATE\s*|MERGE\s*INTO\s*|DELETE\s*FROM\s*))([\w._\"\'()]+)(?=[\s*\n])" # pylint: disable=line-too-long +IDENTIFIER_PATTERN = r"(IDENTIFIER\(\')([\w._\"]+)(\'\))" - 1. Parse the query and check if we have an Identifier - 2. + +def _parse_query(query: str) -> Optional[str]: + """Parse snowflake queries to extract the identifiers""" + match = re.match(QUERY_PATTERN, query, re.IGNORECASE) + try: + # This will match results like `DATABASE.SCHEMA.TABLE1` or IDENTIFIER('TABLE1') + # If we have `IDENTIFIER` type of queries coming from Stored Procedures, we'll need to further clean it up. + identifier = match.group(2) + + match_internal_identifier = re.match( + IDENTIFIER_PATTERN, identifier, re.IGNORECASE + ) + internal_identifier = ( + match_internal_identifier.group(2) if match_internal_identifier else None + ) + if internal_identifier: + return internal_identifier + + return identifier + except (IndexError, AttributeError): + logger.debug("Could not find identifier in query. Skipping row.") + return None + + +def get_identifiers( + identifier: str, ometa_client: OpenMetadata, db_service: DatabaseService +) -> Tuple[Optional[str], Optional[str], Optional[str]]: + """Get query identifiers and if needed, fetch them from ES""" + database_name, schema_name, table_name = get_identifiers_from_string(identifier) + + if not table_name: + logger.debug("Could not extract the table name. Skipping operation.") + return database_name, schema_name, table_name + + if not all([database_name, schema_name]): + logger.debug( + "Missing database or schema info from the query. We'll look for it in ES." + ) + es_tables = search_table_entities( + metadata=ometa_client, + service_name=db_service.fullyQualifiedName.__root__, + database=database_name, + database_schema=schema_name, + table=table_name, + ) + + if not es_tables: + logger.debug("No tables match the search criteria.") + return database_name, schema_name, table_name + + if len(es_tables) > 1: + logger.debug( + "Found more than 1 table matching the search criteria." + " Skipping the computation to not mix system data." + ) + return database_name, schema_name, table_name + + matched_table = es_tables[0] + database_name = matched_table.database.name + schema_name = matched_table.databaseSchema.name + + return database_name, schema_name, table_name + + +def get_snowflake_system_queries( + row: Row, + database: str, + schema: str, + ometa_client: OpenMetadata, + db_service: DatabaseService, +) -> Optional[SnowflakeQueryResult]: + """ + Run a regex lookup on the query to identify which operation ran against the table. + + If the query does not have the complete set of `database.schema.table` when it runs, + we'll use ES to pick up the table, if we find it. Args: row (dict): row from the snowflake system queries table database (str): database name schema (str): schema name + ometa_client (OpenMetadata): OpenMetadata client to search against ES + db_service (DatabaseService): DB service where the process is running against Returns: QueryResult: namedtuple with the query result """ @@ -78,20 +153,17 @@ def get_snowflake_system_queries( query_text = dict_row.get("QUERY_TEXT", dict_row.get("query_text")) logger.debug(f"Trying to parse query:\n{query_text}\n") - pattern = r"(?:(INSERT\s*INTO\s*|INSERT\s*OVERWRITE\s*INTO\s*|UPDATE\s*|MERGE\s*INTO\s*|DELETE\s*FROM\s*))([\w._\"]+)(?=[\s*\n])" # pylint: disable=line-too-long - match = re.match(pattern, query_text, re.IGNORECASE) - try: - identifier = match.group(2) - except (IndexError, AttributeError): - logger.debug("Could not find identifier in query. Skipping row.") + identifier = _parse_query(query_text) + if not identifier: return None - database_name, schema_name, table_name = get_identifiers_from_string(identifier) + database_name, schema_name, table_name = get_identifiers( + identifier=identifier, + ometa_client=ometa_client, + db_service=db_service, + ) if not all([database_name, schema_name, table_name]): - logger.debug( - "Missing database, schema, or table. Can't link operation to table entity in OpenMetadata." - ) return None if ( diff --git a/ingestion/src/metadata/profiler/metrics/system/system.py b/ingestion/src/metadata/profiler/metrics/system/system.py index 1d9826c9222..9d70b507a89 100644 --- a/ingestion/src/metadata/profiler/metrics/system/system.py +++ b/ingestion/src/metadata/profiler/metrics/system/system.py @@ -23,6 +23,8 @@ from sqlalchemy.orm import DeclarativeMeta, Session from metadata.generated.schema.entity.services.connections.database.bigQueryConnection import ( BigQueryConnection, ) +from metadata.generated.schema.entity.services.databaseService import DatabaseService +from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.profiler.metrics.core import SystemMetric from metadata.profiler.metrics.system.dml_operation import ( DML_OPERATION_MAP, @@ -46,7 +48,11 @@ from metadata.profiler.orm.registry import Dialects from metadata.utils.dispatch import valuedispatch from metadata.utils.helpers import deep_size_of_dict from metadata.utils.logger import profiler_logger -from metadata.utils.profiler_utils import get_value_from_cache, set_cache +from metadata.utils.profiler_utils import ( + SnowflakeQueryResult, + get_value_from_cache, + set_cache, +) logger = profiler_logger() @@ -285,11 +291,48 @@ def _( return metric_results +def _snowflake_build_query_result( + session: Session, + table: DeclarativeMeta, + database: str, + schema: str, + ometa_client: OpenMetadata, + db_service: DatabaseService, +) -> List[SnowflakeQueryResult]: + """List and parse snowflake DML query results""" + rows = session.execute( + text( + INFORMATION_SCHEMA_QUERY.format( + tablename=table.__tablename__, # type: ignore + insert=DatabaseDMLOperations.INSERT.value, + update=DatabaseDMLOperations.UPDATE.value, + delete=DatabaseDMLOperations.DELETE.value, + merge=DatabaseDMLOperations.MERGE.value, + ) + ) + ) + query_results = [] + for row in rows: + result = get_snowflake_system_queries( + row=row, + database=database, + schema=schema, + ometa_client=ometa_client, + db_service=db_service, + ) + if result: + query_results.append(result) + + return query_results + + @get_system_metrics_for_dialect.register(Dialects.Snowflake) def _( dialect: str, session: Session, table: DeclarativeMeta, + ometa_client: OpenMetadata, + db_service: DatabaseService, *args, **kwargs, ) -> Optional[List[Dict]]: @@ -315,22 +358,14 @@ def _( metric_results: List[Dict] = [] - rows = session.execute( - text( - INFORMATION_SCHEMA_QUERY.format( - tablename=table.__tablename__, # type: ignore - insert=DatabaseDMLOperations.INSERT.value, - update=DatabaseDMLOperations.UPDATE.value, - delete=DatabaseDMLOperations.DELETE.value, - merge=DatabaseDMLOperations.MERGE.value, - ) - ) + query_results = _snowflake_build_query_result( + session=session, + table=table, + database=database, + schema=schema, + ometa_client=ometa_client, + db_service=db_service, ) - query_results = [] - for row in rows: - result = get_snowflake_system_queries(row, database, schema) - if result: - query_results.append(result) for query_result in query_results: rows_affected = None @@ -409,12 +444,17 @@ class System(SystemMetric): logger.debug("Clearing system cache") SYSTEM_QUERY_RESULT_CACHE.clear() + def _validate_attrs(self, attr_list: List[str]) -> None: + """Validate the necessary attributes given via add_props""" + for attr in attr_list: + if not hasattr(self, attr): + raise AttributeError( + f"System requires a table to be set: add_props({attr}=...)(Metrics.SYSTEM.value)" + ) + def sql(self, session: Session, **kwargs): """Implements the SQL logic to fetch system data""" - if not hasattr(self, "table"): - raise AttributeError( - "System requires a table to be set: add_props(table=...)(Metrics.COLUMN_COUNT)" - ) + self._validate_attrs(["table", "ometa_client", "db_service"]) conn_config = kwargs.get("conn_config") @@ -423,6 +463,8 @@ class System(SystemMetric): session=session, table=self.table, # pylint: disable=no-member conn_config=conn_config, + ometa_client=self.ometa_client, # pylint: disable=no-member + db_service=self.db_service, # pylint: disable=no-member ) self._manage_cache() return system_metrics diff --git a/ingestion/src/metadata/profiler/processor/default.py b/ingestion/src/metadata/profiler/processor/default.py index aea8e7b5e0f..75b8c72aaf7 100644 --- a/ingestion/src/metadata/profiler/processor/default.py +++ b/ingestion/src/metadata/profiler/processor/default.py @@ -17,19 +17,28 @@ from typing import List, Optional from sqlalchemy.orm import DeclarativeMeta from metadata.generated.schema.entity.data.table import ColumnProfilerConfig +from metadata.generated.schema.entity.services.databaseService import DatabaseService +from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.profiler.interface.profiler_interface import ProfilerInterface from metadata.profiler.metrics.core import Metric, add_props from metadata.profiler.metrics.registry import Metrics from metadata.profiler.processor.core import Profiler -def get_default_metrics(table: DeclarativeMeta) -> List[Metric]: +def get_default_metrics( + table: DeclarativeMeta, + ometa_client: Optional[OpenMetadata] = None, + db_service: Optional[DatabaseService] = None, +) -> List[Metric]: return [ # Table Metrics Metrics.ROW_COUNT.value, add_props(table=table)(Metrics.COLUMN_COUNT.value), add_props(table=table)(Metrics.COLUMN_NAMES.value), - add_props(table=table)(Metrics.SYSTEM.value), + # We'll use the ometa_client & db_service in case we need to fetch info to ES + add_props(table=table, ometa_client=ometa_client, db_service=db_service)( + Metrics.SYSTEM.value + ), # Column Metrics Metrics.MEDIAN.value, Metrics.FIRST_QUARTILE.value, @@ -65,7 +74,9 @@ class DefaultProfiler(Profiler): include_columns: Optional[List[ColumnProfilerConfig]] = None, exclude_columns: Optional[List[str]] = None, ): - _metrics = get_default_metrics(profiler_interface.table) + _metrics = get_default_metrics( + table=profiler_interface.table, ometa_client=profiler_interface.ometa_client + ) super().__init__( *_metrics, diff --git a/ingestion/src/metadata/profiler/source/base/profiler_source.py b/ingestion/src/metadata/profiler/source/base/profiler_source.py index cd4ba730a7d..0ef3994a29a 100644 --- a/ingestion/src/metadata/profiler/source/base/profiler_source.py +++ b/ingestion/src/metadata/profiler/source/base/profiler_source.py @@ -279,7 +279,11 @@ class ProfilerSource(ProfilerSourceInterface): metrics = ( [Metrics.get(name) for name in profiler_config.profiler.metrics] if profiler_config.profiler.metrics - else get_default_metrics(profiler_interface.table) + else get_default_metrics( + table=profiler_interface.table, + ometa_client=self.ometa_client, + db_service=db_service, + ) ) return Profiler( diff --git a/ingestion/tests/unit/profiler/conftest.py b/ingestion/tests/unit/profiler/conftest.py index d4120557e51..245978aaecf 100644 --- a/ingestion/tests/unit/profiler/conftest.py +++ b/ingestion/tests/unit/profiler/conftest.py @@ -10,7 +10,7 @@ # limitations under the License. """ -Confest for profiler tests +Conftest for profiler tests """ from uuid import UUID @@ -71,7 +71,7 @@ class Row: self.QUERY_TEXT = query_text def __iter__(self): - """implemetation to support dict(row)""" + """implementation to support dict(row)""" yield "QUERY_ID", self.QUERY_ID yield "QUERY_TYPE", self.QUERY_TYPE yield "START_TIME", self.START_TIME @@ -92,7 +92,7 @@ class LowerRow: self.QUERY_TEXT = query_text def __iter__(self): - """implemetation to support dict(row)""" + """implementation to support dict(row)""" yield "query_id", self.QUERY_ID yield "query_type", self.QUERY_TYPE yield "start_time", self.START_TIME diff --git a/ingestion/tests/unit/profiler/sqlalchemy/test_metrics.py b/ingestion/tests/unit/profiler/sqlalchemy/test_metrics.py index 5d25e7acf0f..866899764aa 100644 --- a/ingestion/tests/unit/profiler/sqlalchemy/test_metrics.py +++ b/ingestion/tests/unit/profiler/sqlalchemy/test_metrics.py @@ -862,7 +862,9 @@ class MetricsTest(TestCase): assert res == 61 def test_system_metric(self): - system = add_props(table=User)(Metrics.SYSTEM.value) + system = add_props(table=User, ometa_client=None, db_service=None)( + Metrics.SYSTEM.value + ) session = self.sqa_profiler_interface.session system().sql(session) diff --git a/ingestion/tests/unit/profiler/test_utils.py b/ingestion/tests/unit/profiler/test_utils.py index 48175f7ada7..1d57e80e66e 100644 --- a/ingestion/tests/unit/profiler/test_utils.py +++ b/ingestion/tests/unit/profiler/test_utils.py @@ -12,15 +12,32 @@ """ Tests utils function for the profiler """ - +import uuid from datetime import datetime from unittest import TestCase +from unittest.mock import patch import pytest from sqlalchemy import Column from sqlalchemy.orm import declarative_base from sqlalchemy.sql.sqltypes import Integer, String +from metadata.generated.schema.entity.data.table import Column as OMetaColumn +from metadata.generated.schema.entity.data.table import DataType, Table +from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import ( + AuthProvider, + OpenMetadataConnection, +) +from metadata.generated.schema.entity.services.databaseService import ( + DatabaseService, + DatabaseServiceType, +) +from metadata.generated.schema.security.client.openMetadataJWTClientConfig import ( + OpenMetadataJWTClientConfig, +) +from metadata.generated.schema.type.basic import EntityName, FullyQualifiedEntityName +from metadata.generated.schema.type.entityReference import EntityReference +from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.profiler.metrics.hybrid.histogram import Histogram from metadata.profiler.metrics.system.queries.snowflake import ( get_snowflake_system_queries, @@ -115,7 +132,10 @@ def test_get_snowflake_system_queries(): query_text="INSERT INTO DATABASE.SCHEMA.TABLE1 (col1, col2) VALUES (1, 'a'), (2, 'b')", ) - query_result = get_snowflake_system_queries(row, "DATABASE", "SCHEMA") # type: ignore + # We don't need the ometa_client nor the db_service if we have all the db.schema.table in the query + query_result = get_snowflake_system_queries( + row=row, database="DATABASE", schema="SCHEMA", ometa_client=..., db_service=... + ) # type: ignore assert query_result assert query_result.query_id == "1" assert query_result.query_type == "INSERT" @@ -130,7 +150,9 @@ def test_get_snowflake_system_queries(): query_text="INSERT INTO SCHEMA.TABLE1 (col1, col2) VALUES (1, 'a'), (2, 'b')", ) - query_result = get_snowflake_system_queries(row, "DATABASE", "SCHEMA") # type: ignore + query_result = get_snowflake_system_queries( + row=row, database="DATABASE", schema="SCHEMA", ometa_client=..., db_service=... + ) # type: ignore assert not query_result @@ -138,6 +160,10 @@ def test_get_snowflake_system_queries(): @pytest.mark.parametrize( "query, expected", [ + ( + "INSERT INTO IDENTIFIER('DATABASE.SCHEMA.TABLE1') (col1, col2) VALUES (1, 'a'), (2, 'b')", + "INSERT", + ), ( "INSERT INTO DATABASE.SCHEMA.TABLE1 (col1, col2) VALUES (1, 'a'), (2, 'b')", "INSERT", @@ -172,7 +198,9 @@ def test_get_snowflake_system_queries_all_dll(query, expected): query_text=query, ) - query_result = get_snowflake_system_queries(row, "DATABASE", "SCHEMA") # type: ignore + query_result = get_snowflake_system_queries( + row=row, database="DATABASE", schema="SCHEMA", ometa_client=..., db_service=... + ) # type: ignore assert query_result assert query_result.query_type == expected @@ -180,7 +208,9 @@ def test_get_snowflake_system_queries_all_dll(query, expected): assert query_result.schema_name == "schema" assert query_result.table_name == "table1" - query_result = get_snowflake_system_queries(lower_row, "DATABASE", "SCHEMA") # type: ignore + query_result = get_snowflake_system_queries( + row=row, database="DATABASE", schema="SCHEMA", ometa_client=..., db_service=... + ) # type: ignore assert query_result assert query_result.query_type == expected @@ -189,6 +219,75 @@ def test_get_snowflake_system_queries_all_dll(query, expected): assert query_result.table_name == "table1" +def test_get_snowflake_system_queries_from_es(): + """Test the ES integration""" + + ometa_client = OpenMetadata( + OpenMetadataConnection( + hostPort="http://localhost:8585/api", + authProvider=AuthProvider.openmetadata, + enableVersionValidation=False, + securityConfig=OpenMetadataJWTClientConfig(jwtToken="token"), + ) + ) + + db_service = DatabaseService( + id=uuid.uuid4(), + name=EntityName(__root__="service"), + fullyQualifiedName=FullyQualifiedEntityName(__root__="service"), + serviceType=DatabaseServiceType.CustomDatabase, + ) + + table = Table( + id=uuid.uuid4(), + name="TABLE", + columns=[OMetaColumn(name="id", dataType=DataType.BIGINT)], + database=EntityReference(id=uuid.uuid4(), type="database", name="database"), + databaseSchema=EntityReference( + id=uuid.uuid4(), type="databaseSchema", name="schema" + ), + ) + + # With too many responses, we won't return anything since we don't want false results + # that we cannot properly assign + with patch.object(OpenMetadata, "es_search_from_fqn", return_value=[table] * 4): + row = Row( + query_id=1, + query_type="INSERT", + start_time=datetime.now(), + query_text="INSERT INTO TABLE1 (col1, col2) VALUES (1, 'a'), (2, 'b')", + ) + query_result = get_snowflake_system_queries( + row=row, + database="DATABASE", + schema="SCHEMA", + ometa_client=ometa_client, + db_service=db_service, + ) + assert not query_result + + # Returning a single table should work fine + with patch.object(OpenMetadata, "es_search_from_fqn", return_value=[table]): + row = Row( + query_id=1, + query_type="INSERT", + start_time=datetime.now(), + query_text="INSERT INTO TABLE2 (col1, col2) VALUES (1, 'a'), (2, 'b')", + ) + query_result = get_snowflake_system_queries( + row=row, + database="DATABASE", + schema="SCHEMA", + ometa_client=ometa_client, + db_service=db_service, + ) + assert query_result + assert query_result.query_type == "INSERT" + assert query_result.database_name == "database" + assert query_result.schema_name == "schema" + assert query_result.table_name == "table2" + + @pytest.mark.parametrize( "identifier, expected", [ @@ -202,6 +301,9 @@ def test_get_snowflake_system_queries_all_dll(query, expected): '"DATABASE.DOT"."SCHEMA.DOT"."TABLE.DOT"', ("DATABASE.DOT", "SCHEMA.DOT", "TABLE.DOT"), ), + ("SCHEMA.TABLE", (None, "SCHEMA", "TABLE")), + ("TABLE", (None, None, "TABLE")), + ('"SCHEMA.DOT"."TABLE.DOT"', (None, "SCHEMA.DOT", "TABLE.DOT")), ], ) def test_get_identifiers_from_string(identifier, expected):