diff --git a/ingestion/src/_openmetadata_testutils/pydantic/test_utils.py b/ingestion/src/_openmetadata_testutils/pydantic/test_utils.py index 34c91246c72..386488955af 100644 --- a/ingestion/src/_openmetadata_testutils/pydantic/test_utils.py +++ b/ingestion/src/_openmetadata_testutils/pydantic/test_utils.py @@ -1,10 +1,13 @@ from collections import deque +from typing import List, Union from pydantic import BaseModel def assert_equal_pydantic_objects( - expected: BaseModel, actual: BaseModel, ignore_none=True + expected: Union[BaseModel, List[BaseModel]], + actual: Union[BaseModel, List[BaseModel]], + ignore_none=True, ): """Compare 2 pydantic objects recursively and raise an AssertionError if they are not equal along with all the differences by field. If `ignore_none` is set to True, expected None values will be ignored. This can be @@ -32,6 +35,10 @@ def assert_equal_pydantic_objects( Traceback (most recent call last): ``` AssertionError: objects mismatched on field: [b.a], expected: [1], actual: [2] + >>> assert_equal_pydantic_objects([a1, a2], [a2, a1]) + Traceback (most recent call last): + ``` + AssertionError: objects mismatched on field: [0].a, expected: [1], actual: [2] Args: expected (BaseModel): The expected pydantic object. @@ -69,11 +76,24 @@ def assert_equal_pydantic_objects( errors.append( f"objects mismatched on field: [{new_key_prefix}], expected: [{expected_value}], actual: [{actual_value}]" ) + elif isinstance(expected, list): + if not isinstance(actual, list): + errors.append( + f"validation error on field: [{current_key_prefix}], expected: [list], actual: [{type(actual).__name__}]" + ) + elif len(expected) != len(actual): + errors.append( + f"mismatch length at {current_key_prefix}: expected: [{len(expected)}], actual: [{len(actual)}]" + ) + else: + for i, (expected_item, actual_item) in enumerate(zip(expected, actual)): + queue.append( + (expected_item, actual_item, f"{current_key_prefix}[{i}]") + ) else: if expected != actual: errors.append( f"mismatch at {current_key_prefix}: expected: [{expected}], actual: [{actual}]" ) - if errors: raise AssertionError("\n".join(errors)) diff --git a/ingestion/src/metadata/ingestion/ometa/mixins/table_mixin.py b/ingestion/src/metadata/ingestion/ometa/mixins/table_mixin.py index e6a539b111e..107732d7e90 100644 --- a/ingestion/src/metadata/ingestion/ometa/mixins/table_mixin.py +++ b/ingestion/src/metadata/ingestion/ometa/mixins/table_mixin.py @@ -16,7 +16,7 @@ To be used by OpenMetadata class import traceback from typing import List, Optional, Type, TypeVar -from pydantic import BaseModel +from pydantic import BaseModel, validate_call from requests.utils import quote from metadata.generated.schema.api.data.createTableProfile import ( @@ -227,6 +227,7 @@ class OMetaTableMixin: return None + @validate_call def get_profile_data( self, fqn: str, @@ -253,7 +254,6 @@ class OMetaTableMixin: Returns: EntityList: EntityList list object """ - url_after = f"&after={after}" if after else "" profile_type_url = profile_type.__name__[0].lower() + profile_type.__name__[1:] diff --git a/ingestion/src/metadata/ingestion/processor/query_parser.py b/ingestion/src/metadata/ingestion/processor/query_parser.py index 65021d1212e..cb9252bc56a 100644 --- a/ingestion/src/metadata/ingestion/processor/query_parser.py +++ b/ingestion/src/metadata/ingestion/processor/query_parser.py @@ -26,7 +26,7 @@ from metadata.ingestion.lineage.models import ConnectionTypeDialectMapper, Diale from metadata.ingestion.lineage.parser import LineageParser from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.utils.logger import ingestion_logger -from metadata.utils.time_utils import convert_timestamp_to_milliseconds +from metadata.utils.time_utils import datetime_to_timestamp logger = ingestion_logger() @@ -46,7 +46,7 @@ def parse_sql_statement(record: TableQuery, dialect: Dialect) -> Optional[Parsed start_date = start_time.root.date() start_time = datetime.datetime.strptime(str(start_date.isoformat()), "%Y-%m-%d") - start_time = convert_timestamp_to_milliseconds(int(start_time.timestamp())) + start_time = datetime_to_timestamp(start_time, milliseconds=True) lineage_parser = LineageParser(record.query, dialect=dialect) diff --git a/ingestion/src/metadata/ingestion/source/database/life_cycle_query_mixin.py b/ingestion/src/metadata/ingestion/source/database/life_cycle_query_mixin.py index b25692a6f14..1b444d87532 100644 --- a/ingestion/src/metadata/ingestion/source/database/life_cycle_query_mixin.py +++ b/ingestion/src/metadata/ingestion/source/database/life_cycle_query_mixin.py @@ -36,7 +36,7 @@ from metadata.ingestion.models.topology import TopologyContextManager from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.utils import fqn from metadata.utils.logger import ingestion_logger -from metadata.utils.time_utils import convert_timestamp_to_milliseconds +from metadata.utils.time_utils import datetime_to_timestamp logger = ingestion_logger() @@ -104,10 +104,8 @@ class LifeCycleQueryMixin: life_cycle = LifeCycle( created=AccessDetails( timestamp=Timestamp( - int( - convert_timestamp_to_milliseconds( - life_cycle_data.created_at.timestamp() - ) + datetime_to_timestamp( + life_cycle_data.created_at, milliseconds=True ) ) ) diff --git a/ingestion/src/metadata/ingestion/source/database/snowflake/models.py b/ingestion/src/metadata/ingestion/source/database/snowflake/models.py index 4b0dc69ce8c..a69b24fca59 100644 --- a/ingestion/src/metadata/ingestion/source/database/snowflake/models.py +++ b/ingestion/src/metadata/ingestion/source/database/snowflake/models.py @@ -15,11 +15,18 @@ import urllib from datetime import datetime from typing import List, Optional -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, TypeAdapter, field_validator from requests.utils import quote +from sqlalchemy import text +from sqlalchemy.orm import Session from metadata.generated.schema.entity.data.storedProcedure import Language +from metadata.ingestion.source.database.snowflake.queries import ( + SNOWFLAKE_QUERY_LOG_QUERY, +) +from metadata.profiler.metrics.system.dml_operation import DatabaseDMLOperations from metadata.utils.logger import ingestion_logger +from metadata.utils.profiler_utils import QueryResult logger = ingestion_logger() @@ -95,3 +102,44 @@ class SnowflakeTableList(BaseModel): def get_not_deleted(self) -> List[SnowflakeTable]: return [table for table in self.tables if not table.deleted] + + +class SnowflakeQueryLogEntry(BaseModel): + """Entry for a Snowflake query log at SNOWFLAKE.ACCOUNT_USAGE.QUERY_HISTORY + More info at: https://docs.snowflake.com/en/sql-reference/account-usage/query_history + """ + + query_id: str + database_name: Optional[str] = None + schema_name: Optional[str] = None + query_type: str + start_time: datetime + query_text: Optional[str] = None + rows_inserted: Optional[int] = None + rows_updated: Optional[int] = None + rows_deleted: Optional[int] = None + + @staticmethod + def get_for_table(session: Session, tablename: str): + rows = session.execute( + text( + SNOWFLAKE_QUERY_LOG_QUERY.format( + tablename=tablename, # type: ignore + insert=DatabaseDMLOperations.INSERT.value, + update=DatabaseDMLOperations.UPDATE.value, + delete=DatabaseDMLOperations.DELETE.value, + merge=DatabaseDMLOperations.MERGE.value, + ) + ) + ) + return TypeAdapter(List[SnowflakeQueryLogEntry]).validate_python( + map(dict, rows) + ) + + +class SnowflakeQueryResult(QueryResult): + """Snowflake system metric query result""" + + rows_inserted: Optional[int] = None + rows_updated: Optional[int] = None + rows_deleted: Optional[int] = None diff --git a/ingestion/tests/unit/metadata/__init__.py b/ingestion/src/metadata/ingestion/source/database/snowflake/profiler/__init__.py similarity index 100% rename from ingestion/tests/unit/metadata/__init__.py rename to ingestion/src/metadata/ingestion/source/database/snowflake/profiler/__init__.py diff --git a/ingestion/src/metadata/ingestion/source/database/snowflake/profiler/system_metrics.py b/ingestion/src/metadata/ingestion/source/database/snowflake/profiler/system_metrics.py new file mode 100644 index 00000000000..b7d398d79ea --- /dev/null +++ b/ingestion/src/metadata/ingestion/source/database/snowflake/profiler/system_metrics.py @@ -0,0 +1,260 @@ +import re +import traceback +from typing import List, Optional, Tuple + +import sqlalchemy.orm +from sqlalchemy.orm import DeclarativeMeta, Session + +from metadata.ingestion.source.database.snowflake.models import ( + SnowflakeQueryLogEntry, + SnowflakeQueryResult, +) +from metadata.utils.logger import profiler_logger +from metadata.utils.lru_cache import LRU_CACHE_SIZE, LRUCache +from metadata.utils.profiler_utils import get_identifiers_from_string + +PUBLIC_SCHEMA = "PUBLIC" +logger = profiler_logger() +RESULT_SCAN = """ + SELECT * + FROM TABLE(RESULT_SCAN('{query_id}')); + """ +QUERY_PATTERN = r"(?:(INSERT\s*INTO\s*|INSERT\s*OVERWRITE\s*INTO\s*|UPDATE\s*|MERGE\s*INTO\s*|DELETE\s*FROM\s*))([\w._\"\'()]+)(?=[\s*\n])" # pylint: disable=line-too-long +IDENTIFIER_PATTERN = r"(IDENTIFIER\(\')([\w._\"]+)(\'\))" + + +def _parse_query(query: str) -> Optional[str]: + """Parse snowflake queries to extract the identifiers""" + match = re.match(QUERY_PATTERN, query, re.IGNORECASE) + try: + # This will match results like `DATABASE.SCHEMA.TABLE1` or IDENTIFIER('TABLE1') + # If we have `IDENTIFIER` type of queries coming from Stored Procedures, we'll need to further clean it up. + identifier = match.group(2) + + match_internal_identifier = re.match( + IDENTIFIER_PATTERN, identifier, re.IGNORECASE + ) + internal_identifier = ( + match_internal_identifier.group(2) if match_internal_identifier else None + ) + if internal_identifier: + return internal_identifier + + return identifier + except (IndexError, AttributeError): + logger.debug("Could not find identifier in query. Skipping row.") + return None + + +class SnowflakeTableResovler: + def __init__(self, session: sqlalchemy.orm.Session): + self._cache = LRUCache(LRU_CACHE_SIZE) + self.session = session + + def show_tables(self, db, schema, table): + return self.session.execute( + f'SHOW TABLES LIKE \'{table}\' IN SCHEMA "{db}"."{schema}" LIMIT 1;' + ).fetchone() + + def table_exists(self, db, schema, table): + """Return True if the table exists in Snowflake. Uses cache to store the results. + + Args: + db (str): Database name + schema (str): Schema name + table (str): Table name + + Returns: + bool: True if the table exists in Snowflake + """ + if f"{db}.{schema}.{table}" in self._cache: + return self._cache.get(f"{db}.{schema}.{table}") + table = self.show_tables(db, schema, table) + if table: + self._cache.put(f"{db}.{schema}.{table}", True) + return True + return False + + def resolve_implicit_fqn( + self, + context_database: str, + context_schema: Optional[str], + table_name: str, + ) -> Tuple[str, str, str]: + """Resolve the fully qualified name of the table from snowflake based on the following logic: + 1. If the schema is provided: + a. search for the table in the schema + b. if not found, go to (2) + 2. Search for the table in the public schema. + + Args: + context_database (str): Database name + context_schema (Optional[str]): Schema name. If not provided, we'll search in the public schema. + table_name (str): Table name + Returns: + tuple: Tuple of database, schema and table names + Raises: + RuntimeError: If the table is not found in the metadata or if there are duplicate results (there shouldn't be) + + """ + search_paths = [] + if context_schema and self.table_exists( + context_database, context_schema, table_name + ): + search_paths += ".".join([context_database, context_schema, table_name]) + return context_database, context_schema, table_name + if context_schema != PUBLIC_SCHEMA and self.table_exists( + context_database, PUBLIC_SCHEMA, table_name + ): + search_paths += ".".join([context_database, PUBLIC_SCHEMA, table_name]) + return context_database, PUBLIC_SCHEMA, table_name + raise RuntimeError( + "Could not find the table {search_paths}.".format( + search_paths=" OR ".join(map(lambda x: f"[{x}]", search_paths)) + ) + ) + + def resolve_snowflake_fqn( + self, + context_database: str, + context_schema: Optional[str], + identifier: str, + ) -> Tuple[Optional[str], Optional[str], Optional[str]]: + """Get query identifiers from the query text. If the schema is not provided in the query, we'll look for + the table under "PUBLIC" in Snowflake. + Database can be retrieved from the query or the query context. + If the schema doesnt exist in the query but does in the context, we need to check with Snowflake if table + exists in (1) the context schema or ib (2) the public schema in order to imitate the behavior of the query + engine. There are edge cases where the table was deleted (and hence not found in the metadata). In such cases, + the function will raise an error. It is advised to set the profier window such that there will be minimal + drift between the query execution and the profiler run. + + Args: + context_database (str): Database name from the query context + context_schema (Optional[str]): Schema name from the query context + identifier (str): Identifier string extracted from a query (can be 'db.schema.table', 'schema.table' or just 'table') + Returns: + Tuple[Optional[str], Optional[str], Optional[str]]: Tuple of database, schema and table names + Raises: + RuntimeError: If the table name is not found in the query or if fqn resolution fails + """ + ( + database_identifier, + schema_identifier, + table_name, + ) = get_identifiers_from_string(identifier) + if not table_name: + raise RuntimeError("Could not extract the table name.") + if not context_database and not database_identifier: + logger.debug( + f"Could not resolve database name. {identifier=}, {context_database=}" + ) + raise RuntimeError("Could not resolve database name.") + if schema_identifier is not None: + return ( + database_identifier or context_database, + schema_identifier, + table_name, + ) + logger.debug( + "Missing schema info from the query. We'll look for it in Snowflake for [%s] or [%s]", + ( + ".".join( + [ + database_identifier or context_database, + context_schema, + table_name, + ] + ) + if context_schema + else None + ), + ".".join( + [database_identifier or context_database, PUBLIC_SCHEMA, table_name] + ), + ) + # If the schema is not explicitly provided in the query, we'll need to resolve it from OpenMetadata + # by cascading the search from the context to the public schema. + result = self.resolve_implicit_fqn( + context_database=context_database, + context_schema=context_schema, + table_name=table_name, + ) + logger.debug("Resolved table [%s]", ".".join(result)) + return result + + +def get_snowflake_system_queries( + query_log_entry: SnowflakeQueryLogEntry, + resolver: SnowflakeTableResovler, +) -> Optional[SnowflakeQueryResult]: + """ + Run a regex lookup on the query to identify which operation ran against the table. + + If the query does not have the complete set of `database.schema.table` when it runs, + we'll use ES to pick up the table, if we find it. + + Args: + query_log_entry (dict): row from the snowflake system queries table + resolver (SnowflakeTableResolver): resolver to get the table identifiers + Returns: + QueryResult: namedtuple with the query result + """ + + try: + logger.debug(f"Trying to parse query [{query_log_entry.query_id}]") + identifier = _parse_query(query_log_entry.query_text) + if not identifier: + raise RuntimeError("Could not identify the table from the query.") + + database_name, schema_name, table_name = resolver.resolve_snowflake_fqn( + identifier=identifier, + context_database=query_log_entry.database_name, + context_schema=query_log_entry.schema_name, + ) + + if not all([database_name, schema_name, table_name]): + raise RuntimeError( + f"Could not extract the identifiers from the query [{query_log_entry.query_id}]." + ) + + return SnowflakeQueryResult( + query_id=query_log_entry.query_id, + database_name=database_name.lower(), + schema_name=schema_name.lower(), + table_name=table_name.lower(), + query_text=query_log_entry.query_text, + query_type=query_log_entry.query_type, + start_time=query_log_entry.start_time, + rows_inserted=query_log_entry.rows_inserted, + rows_updated=query_log_entry.rows_updated, + rows_deleted=query_log_entry.rows_deleted, + ) + except Exception as exc: + logger.debug(traceback.format_exc()) + logger.warning( + f"""Error while processing query with id [{query_log_entry.query_id}]: {exc}\n + To investigate the query run: + SELECT * FROM SNOWFLAKE.ACCOUNT_USAGE.QUERY_HISTORY WHERE query_id = '{query_log_entry.query_id}' + """ + ) + return None + + +def build_snowflake_query_results( + session: Session, + table: DeclarativeMeta, +) -> List[SnowflakeQueryResult]: + """List and parse snowflake DML query results""" + query_results = [] + resolver = SnowflakeTableResovler( + session=session, + ) + for row in SnowflakeQueryLogEntry.get_for_table(session, table.__tablename__): + result = get_snowflake_system_queries( + query_log_entry=row, + resolver=resolver, + ) + if result: + query_results.append(result) + return query_results diff --git a/ingestion/src/metadata/ingestion/source/database/snowflake/queries.py b/ingestion/src/metadata/ingestion/source/database/snowflake/queries.py index 104c3ccaf59..0ff9b3f0450 100644 --- a/ingestion/src/metadata/ingestion/source/database/snowflake/queries.py +++ b/ingestion/src/metadata/ingestion/source/database/snowflake/queries.py @@ -355,3 +355,26 @@ ORDER BY PROCEDURE_START_TIME DESC SNOWFLAKE_GET_TABLE_DDL = """ SELECT GET_DDL('TABLE','{table_name}') AS \"text\" """ +SNOWFLAKE_QUERY_LOG_QUERY = """ + SELECT + QUERY_ID, + QUERY_TEXT, + QUERY_TYPE, + START_TIME, + DATABASE_NAME, + SCHEMA_NAME, + ROWS_INSERTED, + ROWS_UPDATED, + ROWS_DELETED + FROM "SNOWFLAKE"."ACCOUNT_USAGE"."QUERY_HISTORY" + WHERE + start_time>= DATEADD('DAY', -1, CURRENT_TIMESTAMP) + AND QUERY_TEXT ILIKE '%{tablename}%' + AND QUERY_TYPE IN ( + '{insert}', + '{update}', + '{delete}', + '{merge}' + ) + AND EXECUTION_STATUS = 'SUCCESS'; +""" diff --git a/ingestion/src/metadata/profiler/interface/profiler_interface.py b/ingestion/src/metadata/profiler/interface/profiler_interface.py index 26176480c92..2db4c071c9d 100644 --- a/ingestion/src/metadata/profiler/interface/profiler_interface.py +++ b/ingestion/src/metadata/profiler/interface/profiler_interface.py @@ -29,6 +29,7 @@ from metadata.generated.schema.entity.data.databaseSchema import ( ) from metadata.generated.schema.entity.data.table import ( PartitionProfilerConfig, + SystemProfile, Table, TableData, ) @@ -462,7 +463,7 @@ class ProfilerInterface(ABC): runner, *args, **kwargs, - ): + ) -> List[SystemProfile]: """Get metrics""" raise NotImplementedError diff --git a/ingestion/src/metadata/profiler/interface/sqlalchemy/profiler_interface.py b/ingestion/src/metadata/profiler/interface/sqlalchemy/profiler_interface.py index 005f7c95ea1..68ef0695767 100644 --- a/ingestion/src/metadata/profiler/interface/sqlalchemy/profiler_interface.py +++ b/ingestion/src/metadata/profiler/interface/sqlalchemy/profiler_interface.py @@ -442,7 +442,6 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin): sample, ) row = None - try: row = self._get_metric_fn[metric_func.metric_type.value]( metric_func.metrics, @@ -451,11 +450,9 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin): column=metric_func.column, sample=sample, ) - if row and isinstance(row, dict): + if isinstance(row, dict): row = self._validate_nulls(row) - - # System metrics return a list of dictionaries, with UPDATE, INSERT or DELETE ops results - if row and metric_func.metric_type == MetricTypes.System: + if isinstance(row, list): row = [ self._validate_nulls(r) if isinstance(r, dict) else r for r in row @@ -537,6 +534,9 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin): logger.debug(traceback.format_exc()) logger.error(f"Operation was cancelled due to TimeoutError - {exc}") raise concurrent.futures.TimeoutError + except KeyboardInterrupt: + pool.shutdown39(wait=True, cancel_futures=True) + raise return profile_results diff --git a/ingestion/src/metadata/profiler/metrics/system/queries/redshift.py b/ingestion/src/metadata/profiler/metrics/system/queries/redshift.py index 1e76ab80b98..37712e4f1b7 100644 --- a/ingestion/src/metadata/profiler/metrics/system/queries/redshift.py +++ b/ingestion/src/metadata/profiler/metrics/system/queries/redshift.py @@ -19,6 +19,7 @@ from sqlalchemy import text from sqlalchemy.orm import Session from metadata.utils.profiler_utils import QueryResult +from metadata.utils.time_utils import datetime_to_timestamp STL_QUERY = """ with data as ( @@ -73,7 +74,7 @@ def get_query_results( table_name=row.table, query_text=None, query_type=operation, - timestamp=row.starttime, + start_time=row.starttime, rows=row.rows, ) for row in cursor @@ -94,7 +95,7 @@ def get_metric_result(ddls: List[QueryResult], table_name: str) -> List: """ return [ { - "timestamp": int(ddl.timestamp.timestamp() * 1000), + "timestamp": datetime_to_timestamp(ddl.start_time, milliseconds=True), "operation": ddl.query_type, "rowsAffected": ddl.rows, } diff --git a/ingestion/src/metadata/profiler/metrics/system/queries/snowflake.py b/ingestion/src/metadata/profiler/metrics/system/queries/snowflake.py deleted file mode 100644 index eaff79ec47e..00000000000 --- a/ingestion/src/metadata/profiler/metrics/system/queries/snowflake.py +++ /dev/null @@ -1,191 +0,0 @@ -# Copyright 2021 Collate -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Snowflake System Metric Queries and query operations -""" - -import re -import traceback -from typing import Optional, Tuple - -from sqlalchemy.engine.row import Row - -from metadata.generated.schema.entity.services.databaseService import DatabaseService -from metadata.ingestion.lineage.sql_lineage import search_table_entities -from metadata.ingestion.ometa.ometa_api import OpenMetadata -from metadata.utils.logger import profiler_logger -from metadata.utils.profiler_utils import ( - SnowflakeQueryResult, - get_identifiers_from_string, -) - -logger = profiler_logger() - -INFORMATION_SCHEMA_QUERY = """ - SELECT - QUERY_ID, - QUERY_TEXT, - QUERY_TYPE, - START_TIME, - ROWS_INSERTED, - ROWS_UPDATED, - ROWS_DELETED - FROM "SNOWFLAKE"."ACCOUNT_USAGE"."QUERY_HISTORY" - WHERE - start_time>= DATEADD('DAY', -1, CURRENT_TIMESTAMP) - AND QUERY_TEXT ILIKE '%{tablename}%' - AND QUERY_TYPE IN ( - '{insert}', - '{update}', - '{delete}', - '{merge}' - ) - AND EXECUTION_STATUS = 'SUCCESS'; -""" - -RESULT_SCAN = """ - SELECT * - FROM TABLE(RESULT_SCAN('{query_id}')); - """ - - -QUERY_PATTERN = r"(?:(INSERT\s*INTO\s*|INSERT\s*OVERWRITE\s*INTO\s*|UPDATE\s*|MERGE\s*INTO\s*|DELETE\s*FROM\s*))([\w._\"\'()]+)(?=[\s*\n])" # pylint: disable=line-too-long -IDENTIFIER_PATTERN = r"(IDENTIFIER\(\')([\w._\"]+)(\'\))" - - -def _parse_query(query: str) -> Optional[str]: - """Parse snowflake queries to extract the identifiers""" - match = re.match(QUERY_PATTERN, query, re.IGNORECASE) - try: - # This will match results like `DATABASE.SCHEMA.TABLE1` or IDENTIFIER('TABLE1') - # If we have `IDENTIFIER` type of queries coming from Stored Procedures, we'll need to further clean it up. - identifier = match.group(2) - - match_internal_identifier = re.match( - IDENTIFIER_PATTERN, identifier, re.IGNORECASE - ) - internal_identifier = ( - match_internal_identifier.group(2) if match_internal_identifier else None - ) - if internal_identifier: - return internal_identifier - - return identifier - except (IndexError, AttributeError): - logger.debug("Could not find identifier in query. Skipping row.") - return None - - -def get_identifiers( - identifier: str, ometa_client: OpenMetadata, db_service: DatabaseService -) -> Tuple[Optional[str], Optional[str], Optional[str]]: - """Get query identifiers and if needed, fetch them from ES""" - database_name, schema_name, table_name = get_identifiers_from_string(identifier) - - if not table_name: - logger.debug("Could not extract the table name. Skipping operation.") - return database_name, schema_name, table_name - - if not all([database_name, schema_name]): - logger.debug( - "Missing database or schema info from the query. We'll look for it in ES." - ) - es_tables = search_table_entities( - metadata=ometa_client, - service_name=db_service.fullyQualifiedName.root, - database=database_name, - database_schema=schema_name, - table=table_name, - ) - - if not es_tables: - logger.debug("No tables match the search criteria.") - return database_name, schema_name, table_name - - if len(es_tables) > 1: - logger.debug( - "Found more than 1 table matching the search criteria." - " Skipping the computation to not mix system data." - ) - return database_name, schema_name, table_name - - matched_table = es_tables[0] - database_name = matched_table.database.name - schema_name = matched_table.databaseSchema.name - - return database_name, schema_name, table_name - - -def get_snowflake_system_queries( - row: Row, - database: str, - schema: str, - ometa_client: OpenMetadata, - db_service: DatabaseService, -) -> Optional[SnowflakeQueryResult]: - """ - Run a regex lookup on the query to identify which operation ran against the table. - - If the query does not have the complete set of `database.schema.table` when it runs, - we'll use ES to pick up the table, if we find it. - - Args: - row (dict): row from the snowflake system queries table - database (str): database name - schema (str): schema name - ometa_client (OpenMetadata): OpenMetadata client to search against ES - db_service (DatabaseService): DB service where the process is running against - Returns: - QueryResult: namedtuple with the query result - """ - - try: - dict_row = dict(row) - query_text = dict_row.get("QUERY_TEXT", dict_row.get("query_text")) - logger.debug(f"Trying to parse query:\n{query_text}\n") - - identifier = _parse_query(query_text) - if not identifier: - return None - - database_name, schema_name, table_name = get_identifiers( - identifier=identifier, - ometa_client=ometa_client, - db_service=db_service, - ) - - if not all([database_name, schema_name, table_name]): - return None - - if ( - database.lower() == database_name.lower() - and schema.lower() == schema_name.lower() - ): - return SnowflakeQueryResult( - query_id=dict_row.get("QUERY_ID", dict_row.get("query_id")), - database_name=database_name.lower(), - schema_name=schema_name.lower(), - table_name=table_name.lower(), - query_text=query_text, - query_type=dict_row.get("QUERY_TYPE", dict_row.get("query_type")), - timestamp=dict_row.get("START_TIME", dict_row.get("start_time")), - rows_inserted=dict_row.get( - "ROWS_INSERTED", dict_row.get("rows_inserted") - ), - rows_updated=dict_row.get("ROWS_UPDATED", dict_row.get("rows_updated")), - rows_deleted=dict_row.get("ROWS_DELETED", dict_row.get("rows_deleted")), - ) - except Exception: - logger.debug(traceback.format_exc()) - return None - - return None diff --git a/ingestion/src/metadata/profiler/metrics/system/system.py b/ingestion/src/metadata/profiler/metrics/system/system.py index 9496b02a98c..a8454263c30 100644 --- a/ingestion/src/metadata/profiler/metrics/system/system.py +++ b/ingestion/src/metadata/profiler/metrics/system/system.py @@ -17,15 +17,18 @@ import traceback from collections import defaultdict from typing import Dict, List, Optional +from pydantic import TypeAdapter from sqlalchemy import text from sqlalchemy.orm import DeclarativeMeta, Session from metadata.generated.schema.configuration.profilerConfiguration import MetricType +from metadata.generated.schema.entity.data.table import SystemProfile from metadata.generated.schema.entity.services.connections.database.bigQueryConnection import ( BigQueryConnection, ) -from metadata.generated.schema.entity.services.databaseService import DatabaseService -from metadata.ingestion.ometa.ometa_api import OpenMetadata +from metadata.ingestion.source.database.snowflake.profiler.system_metrics import ( + build_snowflake_query_results, +) from metadata.profiler.metrics.core import SystemMetric from metadata.profiler.metrics.system.dml_operation import ( DML_OPERATION_MAP, @@ -41,19 +44,12 @@ from metadata.profiler.metrics.system.queries.redshift import ( get_metric_result, get_query_results, ) -from metadata.profiler.metrics.system.queries.snowflake import ( - INFORMATION_SCHEMA_QUERY, - get_snowflake_system_queries, -) from metadata.profiler.orm.registry import Dialects from metadata.utils.dispatch import valuedispatch from metadata.utils.helpers import deep_size_of_dict from metadata.utils.logger import profiler_logger -from metadata.utils.profiler_utils import ( - SnowflakeQueryResult, - get_value_from_cache, - set_cache, -) +from metadata.utils.profiler_utils import get_value_from_cache, set_cache +from metadata.utils.time_utils import datetime_to_timestamp logger = profiler_logger() @@ -75,7 +71,7 @@ def get_system_metrics_for_dialect( table: DeclarativeMeta, *args, **kwargs, -) -> Optional[Dict]: +) -> Optional[List[SystemProfile]]: """_summary_ Args: @@ -91,6 +87,7 @@ def get_system_metrics_for_dialect( } else returns None """ logger.debug(f"System metrics not support for {dialect}. Skipping processing.") + return None @get_system_metrics_for_dialect.register(Dialects.BigQuery) @@ -101,7 +98,7 @@ def _( conn_config: BigQueryConnection, *args, **kwargs, -) -> List[Dict]: +) -> List[SystemProfile]: """Compute system metrics for bigquery Args: @@ -190,7 +187,7 @@ def _( } ) - return metric_results + return TypeAdapter(List[SystemProfile]).validate_python(metric_results) @get_system_metrics_for_dialect.register(Dialects.Redshift) @@ -200,7 +197,7 @@ def _( table: DeclarativeMeta, *args, **kwargs, -) -> List[Dict]: +) -> List[SystemProfile]: """List all the DML operations for reshifts tables Args: @@ -289,42 +286,7 @@ def _( ) metric_results.extend(get_metric_result(updates, table.__tablename__)) # type: ignore - return metric_results - - -def _snowflake_build_query_result( - session: Session, - table: DeclarativeMeta, - database: str, - schema: str, - ometa_client: OpenMetadata, - db_service: DatabaseService, -) -> List[SnowflakeQueryResult]: - """List and parse snowflake DML query results""" - rows = session.execute( - text( - INFORMATION_SCHEMA_QUERY.format( - tablename=table.__tablename__, # type: ignore - insert=DatabaseDMLOperations.INSERT.value, - update=DatabaseDMLOperations.UPDATE.value, - delete=DatabaseDMLOperations.DELETE.value, - merge=DatabaseDMLOperations.MERGE.value, - ) - ) - ) - query_results = [] - for row in rows: - result = get_snowflake_system_queries( - row=row, - database=database, - schema=schema, - ometa_client=ometa_client, - db_service=db_service, - ) - if result: - query_results.append(result) - - return query_results + return TypeAdapter(List[SystemProfile]).validate_python(metric_results).d @get_system_metrics_for_dialect.register(Dialects.Snowflake) @@ -332,8 +294,6 @@ def _( dialect: str, session: Session, table: DeclarativeMeta, - ometa_client: OpenMetadata, - db_service: DatabaseService, *args, **kwargs, ) -> Optional[List[Dict]]: @@ -354,18 +314,12 @@ def _( Dict: system metric """ logger.debug(f"Fetching system metrics for {dialect}") - database = session.get_bind().url.database - schema = table.__table_args__["schema"] # type: ignore metric_results: List[Dict] = [] - query_results = _snowflake_build_query_result( + query_results = build_snowflake_query_results( session=session, table=table, - database=database, - schema=schema, - ometa_client=ometa_client, - db_service=db_service, ) for query_result in query_results: @@ -380,7 +334,9 @@ def _( if query_result.rows_inserted: metric_results.append( { - "timestamp": int(query_result.timestamp.timestamp() * 1000), + "timestamp": datetime_to_timestamp( + query_result.start_time, milliseconds=True + ), "operation": DatabaseDMLOperations.INSERT.value, "rowsAffected": query_result.rows_inserted, } @@ -388,7 +344,9 @@ def _( if query_result.rows_updated: metric_results.append( { - "timestamp": int(query_result.timestamp.timestamp() * 1000), + "timestamp": datetime_to_timestamp( + query_result.start_time, milliseconds=True + ), "operation": DatabaseDMLOperations.UPDATE.value, "rowsAffected": query_result.rows_updated, } @@ -397,13 +355,15 @@ def _( metric_results.append( { - "timestamp": int(query_result.timestamp.timestamp() * 1000), + "timestamp": datetime_to_timestamp( + query_result.start_time, milliseconds=True + ), "operation": DML_OPERATION_MAP.get(query_result.query_type), "rowsAffected": rows_affected, } ) - return metric_results + return TypeAdapter(List[SystemProfile]).validate_python(metric_results) class System(SystemMetric): diff --git a/ingestion/src/metadata/profiler/processor/core.py b/ingestion/src/metadata/profiler/processor/core.py index 6f7ff6d4d02..ec9db6da458 100644 --- a/ingestion/src/metadata/profiler/processor/core.py +++ b/ingestion/src/metadata/profiler/processor/core.py @@ -274,9 +274,6 @@ class Profiler(Generic[TMetric]): Data should be saved under self.results """ - - logger.debug("Running post Profiler...") - current_col_results: Dict[str, Any] = self._column_results.get(col.name) if not current_col_results: logger.debug( diff --git a/ingestion/src/metadata/utils/profiler_utils.py b/ingestion/src/metadata/utils/profiler_utils.py index 061960b5337..c6841e56872 100644 --- a/ingestion/src/metadata/utils/profiler_utils.py +++ b/ingestion/src/metadata/utils/profiler_utils.py @@ -34,20 +34,12 @@ class QueryResult(BaseModel): schema_name: str table_name: str query_type: str - timestamp: datetime + start_time: datetime query_id: Optional[str] = None query_text: Optional[str] = None rows: Optional[int] = None -class SnowflakeQueryResult(QueryResult): - """Snowflake system metric query result""" - - rows_inserted: Optional[int] = None - rows_updated: Optional[int] = None - rows_deleted: Optional[int] = None - - def clean_up_query(query: str) -> str: """remove comments and newlines from query""" return sqlparse.format(query, strip_comments=True).replace("\\n", "") diff --git a/ingestion/src/metadata/utils/time_utils.py b/ingestion/src/metadata/utils/time_utils.py index ddcee67e650..b2c4196dab3 100644 --- a/ingestion/src/metadata/utils/time_utils.py +++ b/ingestion/src/metadata/utils/time_utils.py @@ -17,21 +17,25 @@ from datetime import datetime, time, timedelta, timezone from math import floor from typing import Union +from metadata.utils.deprecation import deprecated from metadata.utils.helpers import datetime_to_ts -def datetime_to_timestamp(datetime_value, milliseconds=False) -> int: - """Convert a datetime object to timestamp integer +def datetime_to_timestamp(datetime_value: datetime, milliseconds=False) -> int: + """Convert a datetime object to timestamp integer. Datetime can be timezone aware or naive. Result + will always be in UTC. Args: datetime_value (_type_): datetime object milliseconds (bool, optional): make it a milliseconds timestamp. Defaults to False. Returns: - int: + int : timestamp in seconds or milliseconds """ if not getattr(datetime_value, "timestamp", None): - raise TypeError(f"Object of type {datetime_value} has not method `timestamp()`") + raise TypeError( + f"Object of type {type(datetime_value).__name__} has not method `timestamp()`" + ) tmsap = datetime_value.timestamp() if milliseconds: @@ -115,6 +119,7 @@ def convert_timestamp(timestamp: str) -> Union[int, float]: return float(timestamp) / 1000 +@deprecated("Use `datetime_to_timestamp` instead", "1.7.0") def convert_timestamp_to_milliseconds(timestamp: Union[int, float]) -> int: """convert timestamp to milliseconds Args: diff --git a/ingestion/tests/cli_e2e/common/test_cli_db.py b/ingestion/tests/cli_e2e/common/test_cli_db.py index a41b736d086..cb59aad0156 100644 --- a/ingestion/tests/cli_e2e/common/test_cli_db.py +++ b/ingestion/tests/cli_e2e/common/test_cli_db.py @@ -13,12 +13,18 @@ Test database connectors which extend from `CommonDbSourceService` with CLI """ import json +import os from abc import ABC, abstractmethod from pathlib import Path from typing import Optional +import yaml from sqlalchemy.engine import Engine +from metadata.generated.schema.entity.services.databaseService import DatabaseService +from metadata.generated.schema.metadataIngestion.workflow import ( + OpenMetadataWorkflowConfig, +) from metadata.ingestion.api.status import Status from metadata.workflow.metadata import MetadataWorkflow @@ -45,6 +51,19 @@ class CliCommonDB: Path(PATH_TO_RESOURCES + f"/database/{connector}/test.yaml") ) + @classmethod + def tearDownClass(cls): + workflow = OpenMetadataWorkflowConfig.model_validate( + yaml.safe_load(open(cls.config_file_path)) + ) + db_service: DatabaseService = cls.openmetadata.get_by_name( + DatabaseService, workflow.source.serviceName + ) + if db_service and os.getenv("E2E_CLEAN_DB", "false") == "true": + cls.openmetadata.delete( + DatabaseService, db_service.id, hard_delete=True, recursive=True + ) + def tearDown(self) -> None: self.engine.dispose() diff --git a/ingestion/tests/cli_e2e/test_cli_snowflake.py b/ingestion/tests/cli_e2e/test_cli_snowflake.py index cceb19d2674..dc27e2fee9d 100644 --- a/ingestion/tests/cli_e2e/test_cli_snowflake.py +++ b/ingestion/tests/cli_e2e/test_cli_snowflake.py @@ -12,10 +12,15 @@ """ Test Snowflake connector with CLI """ +from datetime import datetime +from time import sleep from typing import List import pytest +from _openmetadata_testutils.pydantic.test_utils import assert_equal_pydantic_objects +from metadata.generated.schema.entity.data.table import DmlOperationType, SystemProfile +from metadata.generated.schema.type.basic import Timestamp from metadata.ingestion.api.status import Status from .base.e2e_types import E2EType @@ -40,6 +45,9 @@ class SnowflakeCliTest(CliCommonDB.TestSuite, SQACommonMethods): "CREATE OR REPLACE TABLE e2e_test.test_departments(e2e_testdepartment_id INT PRIMARY KEY,e2e_testdepartment_name VARCHAR (30) NOT NULL,e2e_testlocation_id INT);", "CREATE OR REPLACE TABLE e2e_test.test_employees(e2e_testemployee_id INT PRIMARY KEY,e2e_testfirst_name VARCHAR (20),e2e_testlast_name VARCHAR (25) NOT NULL,e2e_testemail VARCHAR (100) NOT NULL,e2e_testphone_number VARCHAR (20),e2e_testhire_date DATE NOT NULL,e2e_testjob_id INT NOT NULL,e2e_testsalary DECIMAL (8, 2) NOT NULL,e2e_testmanager_id INT,e2e_testdepartment_id INT);", "CREATE OR REPLACE TABLE e2e_test.test_dependents(e2e_testdependent_id INT PRIMARY KEY,e2e_testfirst_name VARCHAR (50) NOT NULL,e2e_testlast_name VARCHAR (50) NOT NULL,e2e_testrelationship VARCHAR (25) NOT NULL,e2e_testemployee_id INT NOT NULL);", + "CREATE OR REPLACE TABLE e2e_test.e2e_table(varchar_column VARCHAR(255),int_column INT);", + "CREATE OR REPLACE TABLE public.public_table(varchar_column VARCHAR(255),int_column INT);", + "CREATE OR REPLACE TABLE public.e2e_table(varchar_column VARCHAR(255),int_column INT);", ] create_table_query: str = """ @@ -58,6 +66,13 @@ class SnowflakeCliTest(CliCommonDB.TestSuite, SQACommonMethods): insert_data_queries: List[str] = [ "INSERT INTO E2E_DB.e2e_test.persons (person_id, full_name) VALUES (1,'Peter Parker');", "INSERT INTO E2E_DB.e2e_test.persons (person_id, full_name) VALUES (1, 'Clark Kent');", + "INSERT INTO e2e_test.e2e_table (varchar_column, int_column) VALUES ('e2e_test.e2e_table', 1);", + "INSERT INTO public.e2e_table (varchar_column, int_column) VALUES ('public.e2e_table', 1);", + "INSERT INTO e2e_table (varchar_column, int_column) VALUES ('e2e_table', 1);", + "INSERT INTO public.public_table (varchar_column, int_column) VALUES ('public.public_table', 1);", + "INSERT INTO public_table (varchar_column, int_column) VALUES ('public_table', 1);", + "MERGE INTO public_table USING (SELECT 'public_table' as varchar_column, 2 as int_column) as source ON public_table.varchar_column = source.varchar_column WHEN MATCHED THEN UPDATE SET public_table.int_column = source.int_column WHEN NOT MATCHED THEN INSERT (varchar_column, int_column) VALUES (source.varchar_column, source.int_column);", + "DELETE FROM public_table WHERE varchar_column = 'public.public_table';", ] drop_table_query: str = """ @@ -68,6 +83,19 @@ class SnowflakeCliTest(CliCommonDB.TestSuite, SQACommonMethods): DROP VIEW IF EXISTS E2E_DB.e2e_test.view_persons; """ + teardown_sql_statements: List[str] = [ + "DROP TABLE IF EXISTS E2E_DB.e2e_test.e2e_table;", + "DROP TABLE IF EXISTS E2E_DB.public.e2e_table;", + "DROP TABLE IF EXISTS E2E_DB.public.public_table;", + ] + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + with cls.engine.connect() as connection: + for stmt in cls.teardown_sql_statements: + connection.execute(stmt) + def setUp(self) -> None: with self.engine.connect() as connection: for sql_statements in self.prepare_snowflake_e2e: @@ -83,15 +111,15 @@ class SnowflakeCliTest(CliCommonDB.TestSuite, SQACommonMethods): self.assertTrue(len(source_status.failures) == 0) self.assertTrue(len(source_status.warnings) == 0) self.assertTrue(len(source_status.filtered) == 1) - self.assertTrue( - (len(source_status.records) + len(source_status.updated_records)) - >= self.expected_tables() + self.assertGreaterEqual( + (len(source_status.records) + len(source_status.updated_records)), + self.expected_tables(), ) self.assertTrue(len(sink_status.failures) == 0) self.assertTrue(len(sink_status.warnings) == 0) - self.assertTrue( - (len(sink_status.records) + len(sink_status.updated_records)) - > self.expected_tables() + self.assertGreater( + (len(sink_status.records) + len(sink_status.updated_records)), + self.expected_tables(), ) def create_table_and_view(self) -> None: @@ -130,17 +158,22 @@ class SnowflakeCliTest(CliCommonDB.TestSuite, SQACommonMethods): # Otherwise the sampling here does not pick up rows extra_args={"profileSample": 100}, ) + # wait for query log to be updated + self.wait_for_query_log() # run profiler with new tables result = self.run_command("profile") sink_status, source_status = self.retrieve_statuses(result) self.assert_for_table_with_profiler(source_status, sink_status) + self.custom_profiler_assertions() @staticmethod def expected_tables() -> int: return 7 def inserted_rows_count(self) -> int: - return len(self.insert_data_queries) + return len( + [q for q in self.insert_data_queries if "E2E_DB.e2e_test.persons" in q] + ) def view_column_lineage_count(self) -> int: return 2 @@ -171,7 +204,7 @@ class SnowflakeCliTest(CliCommonDB.TestSuite, SQACommonMethods): @staticmethod def expected_filtered_table_includes() -> int: - return 5 + return 8 @staticmethod def expected_filtered_table_excludes() -> int: @@ -179,7 +212,7 @@ class SnowflakeCliTest(CliCommonDB.TestSuite, SQACommonMethods): @staticmethod def expected_filtered_mix() -> int: - return 6 + return 7 @staticmethod def delete_queries() -> List[str]: @@ -196,3 +229,90 @@ class SnowflakeCliTest(CliCommonDB.TestSuite, SQACommonMethods): UPDATE E2E_DB.E2E_TEST.PERSONS SET full_name = 'Bruce Wayne' WHERE full_name = 'Clark Kent' """, ] + + def custom_profiler_assertions(self): + cases = [ + ( + "e2e_snowflake.E2E_DB.E2E_TEST.E2E_TABLE", + [ + SystemProfile( + timestamp=Timestamp(root=0), + operation=DmlOperationType.INSERT, + rowsAffected=1, + ), + SystemProfile( + timestamp=Timestamp(root=0), + operation=DmlOperationType.INSERT, + rowsAffected=1, + ), + ], + ), + ( + "e2e_snowflake.E2E_DB.PUBLIC.E2E_TABLE", + [ + SystemProfile( + timestamp=Timestamp(root=0), + operation=DmlOperationType.INSERT, + rowsAffected=1, + ) + ], + ), + ( + "e2e_snowflake.E2E_DB.PUBLIC.PUBLIC_TABLE", + [ + SystemProfile( + timestamp=Timestamp(root=0), + operation=DmlOperationType.INSERT, + rowsAffected=1, + ), + SystemProfile( + timestamp=Timestamp(root=0), + operation=DmlOperationType.INSERT, + rowsAffected=1, + ), + SystemProfile( + timestamp=Timestamp(root=0), + operation=DmlOperationType.UPDATE, + rowsAffected=1, + ), + SystemProfile( + timestamp=Timestamp(root=0), + operation=DmlOperationType.DELETE, + rowsAffected=1, + ), + ], + ), + ] + for table_fqn, expected_profile in cases: + actual_profiles = self.openmetadata.get_profile_data( + table_fqn, + start_ts=int((datetime.now().timestamp() - 600) * 1000), + end_ts=int(datetime.now().timestamp() * 1000), + profile_type=SystemProfile, + ).entities + actual_profiles = sorted(actual_profiles, key=lambda x: x.timestamp.root) + actual_profiles = actual_profiles[-len(expected_profile) :] + actual_profiles = [ + p.copy(update={"timestamp": Timestamp(root=0)}) for p in actual_profiles + ] + try: + assert_equal_pydantic_objects(expected_profile, actual_profiles) + except AssertionError as e: + raise AssertionError(f"Table: {table_fqn}\n{e}") + + @classmethod + def wait_for_query_log(cls, timeout=600): + start = datetime.now().timestamp() + cls.engine.execute("SELECT 'e2e_query_log_wait'") + latest = 0 + while latest < start: + sleep(5) + latest = ( + cls.engine.execute( + 'SELECT max(start_time) FROM "SNOWFLAKE"."ACCOUNT_USAGE"."QUERY_HISTORY"' + ) + .scalar() + .timestamp() + ) + if (datetime.now().timestamp() - start) > timeout: + raise TimeoutError(f"Query log not updated for {timeout} seconds") diff --git a/ingestion/tests/unit/metadata/data_quality/__init__.py b/ingestion/tests/unit/metadata/data_quality/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/ingestion/tests/unit/metadata/ingestion/source/database/snowflake/profiler/test_system_metrics.py b/ingestion/tests/unit/metadata/ingestion/source/database/snowflake/profiler/test_system_metrics.py new file mode 100644 index 00000000000..c2fcbdc090f --- /dev/null +++ b/ingestion/tests/unit/metadata/ingestion/source/database/snowflake/profiler/test_system_metrics.py @@ -0,0 +1,114 @@ +from unittest.mock import MagicMock, Mock + +import pytest + +from metadata.ingestion.source.database.snowflake.profiler.system_metrics import ( + PUBLIC_SCHEMA, + SnowflakeTableResovler, +) +from metadata.utils.profiler_utils import get_identifiers_from_string + + +@pytest.mark.parametrize( + "schema_name", + ["test_schema", PUBLIC_SCHEMA, None], +) +@pytest.mark.parametrize( + "existing_tables", + [ + ["db.test_schema.test_table", "db.PUBLIC.test_table"], + ["db.test_schema.test_table"], + ["db.PUBLIC.test_table"], + [], + ], +) +def test_resolve_snoflake_fqn(schema_name, existing_tables): + def expected_result(schema_name, existing_tables): + if len(existing_tables) == 0: + return RuntimeError + if schema_name == "test_schema": + if "db.test_schema.test_table" in existing_tables: + return "db", "test_schema", "test_table" + if "db.PUBLIC.test_table" in existing_tables: + return "db", PUBLIC_SCHEMA, "test_table" + if ( + schema_name in [None, PUBLIC_SCHEMA] + and "db.PUBLIC.test_table" in existing_tables + ): + return "db", PUBLIC_SCHEMA, "test_table" + return RuntimeError + + resolver = SnowflakeTableResovler(Mock()) + + def mock_show_tables(_, schema, table): + for t in existing_tables: + if t == f"db.{schema}.{table}": + return True + + resolver.show_tables = mock_show_tables + expected = expected_result(schema_name, existing_tables) + if expected == RuntimeError: + with pytest.raises(expected): + resolver.resolve_implicit_fqn("db", schema_name, "test_table") + else: + result = resolver.resolve_implicit_fqn("db", schema_name, "test_table") + assert result == expected + + +@pytest.mark.parametrize("context_database", [None, "context_db"]) +@pytest.mark.parametrize("context_schema", [None, "context_schema", PUBLIC_SCHEMA]) +@pytest.mark.parametrize( + "identifier", + [ + "", + "test_table", + "id_schema.test_table", + "PUBLIC.test_table", + "id_db.test_schema.test_table", + "id_db.PUBLIC.test_table", + ], +) +@pytest.mark.parametrize( + "resolved_schema", + [ + PUBLIC_SCHEMA, + "context_schema", + RuntimeError("could not resolve schema"), + ], +) +def test_get_identifiers( + context_database, + context_schema, + identifier, + resolved_schema, +): + def expected_result(): + if identifier == "": + return RuntimeError("Could not extract the table name.") + db, id_schema, table = get_identifiers_from_string(identifier) + if db is None and context_database is None: + return RuntimeError("Could not resolve database name.") + if id_schema is None and isinstance(resolved_schema, RuntimeError): + return RuntimeError("could not resolve schema") + return ( + (db or context_database), + (id_schema or resolved_schema or context_schema), + table, + ) + + resolver = SnowflakeTableResovler(Mock()) + if isinstance(resolved_schema, RuntimeError): + resolver.resolve_implicit_fqn = MagicMock(side_effect=resolved_schema) + else: + resolver.resolve_implicit_fqn = MagicMock( + return_value=(context_database, resolved_schema, identifier) + ) + + expected_value = expected_result() + if isinstance(expected_value, RuntimeError): + with pytest.raises(type(expected_value), match=str(expected_value)) as e: + resolver.resolve_snowflake_fqn(context_database, context_schema, identifier) + else: + assert expected_value == resolver.resolve_snowflake_fqn( + context_database, context_schema, identifier + ) diff --git a/ingestion/tests/unit/profiler/test_utils.py b/ingestion/tests/unit/profiler/test_utils.py index a57b1eb8e12..523ad24c645 100644 --- a/ingestion/tests/unit/profiler/test_utils.py +++ b/ingestion/tests/unit/profiler/test_utils.py @@ -12,36 +12,21 @@ """ Tests utils function for the profiler """ -import uuid from datetime import datetime from unittest import TestCase -from unittest.mock import patch +from unittest.mock import Mock import pytest from sqlalchemy import Column from sqlalchemy.orm import declarative_base from sqlalchemy.sql.sqltypes import Integer, String -from metadata.generated.schema.entity.data.table import Column as OMetaColumn -from metadata.generated.schema.entity.data.table import DataType, Table -from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import ( - AuthProvider, - OpenMetadataConnection, -) -from metadata.generated.schema.entity.services.databaseService import ( - DatabaseService, - DatabaseServiceType, -) -from metadata.generated.schema.security.client.openMetadataJWTClientConfig import ( - OpenMetadataJWTClientConfig, -) -from metadata.generated.schema.type.basic import EntityName, FullyQualifiedEntityName -from metadata.generated.schema.type.entityReference import EntityReference -from metadata.ingestion.ometa.ometa_api import OpenMetadata -from metadata.profiler.metrics.hybrid.histogram import Histogram -from metadata.profiler.metrics.system.queries.snowflake import ( +from metadata.ingestion.source.database.snowflake.models import SnowflakeQueryLogEntry +from metadata.ingestion.source.database.snowflake.profiler.system_metrics import ( + SnowflakeTableResovler, get_snowflake_system_queries, ) +from metadata.profiler.metrics.hybrid.histogram import Histogram from metadata.profiler.metrics.system.system import recursive_dic from metadata.utils.profiler_utils import ( get_identifiers_from_string, @@ -50,8 +35,6 @@ from metadata.utils.profiler_utils import ( ) from metadata.utils.sqa_utils import is_array -from .conftest import Row - Base = declarative_base() @@ -125,7 +108,7 @@ def test_is_array(): def test_get_snowflake_system_queries(): """Test get snowflake system queries""" - row = Row( + row = SnowflakeQueryLogEntry( query_id="1", query_type="INSERT", start_time=datetime.now(), @@ -133,8 +116,10 @@ def test_get_snowflake_system_queries(): ) # We don't need the ometa_client nor the db_service if we have all the db.schema.table in the query + resolver = SnowflakeTableResovler(Mock()) query_result = get_snowflake_system_queries( - row=row, database="DATABASE", schema="SCHEMA", ometa_client=..., db_service=... + query_log_entry=row, + resolver=resolver, ) # type: ignore assert query_result assert query_result.query_id == "1" @@ -143,15 +128,16 @@ def test_get_snowflake_system_queries(): assert query_result.schema_name == "schema" assert query_result.table_name == "table1" - row = Row( - query_id=1, + row = SnowflakeQueryLogEntry( + query_id="1", query_type="INSERT", start_time=datetime.now(), query_text="INSERT INTO SCHEMA.TABLE1 (col1, col2) VALUES (1, 'a'), (2, 'b')", ) query_result = get_snowflake_system_queries( - row=row, database="DATABASE", schema="SCHEMA", ometa_client=..., db_service=... + query_log_entry=row, + resolver=resolver, ) # type: ignore assert not query_result @@ -184,15 +170,17 @@ def test_get_snowflake_system_queries_all_dll(query, expected): """test we ca get all ddl queries reference https://docs.snowflake.com/en/sql-reference/sql-dml """ - row = Row( + row = SnowflakeQueryLogEntry( query_id="1", query_type=expected, start_time=datetime.now(), query_text=query, ) - + resolver = Mock() + resolver.resolve_snowflake_fqn = Mock(return_value=("database", "schema", "table1")) query_result = get_snowflake_system_queries( - row=row, database="DATABASE", schema="SCHEMA", ometa_client=..., db_service=... + query_log_entry=row, + resolver=resolver, ) # type: ignore assert query_result @@ -202,7 +190,8 @@ def test_get_snowflake_system_queries_all_dll(query, expected): assert query_result.table_name == "table1" query_result = get_snowflake_system_queries( - row=row, database="DATABASE", schema="SCHEMA", ometa_client=..., db_service=... + query_log_entry=SnowflakeQueryLogEntry.model_validate(row), + resolver=resolver, ) # type: ignore assert query_result @@ -212,75 +201,6 @@ def test_get_snowflake_system_queries_all_dll(query, expected): assert query_result.table_name == "table1" -def test_get_snowflake_system_queries_from_es(): - """Test the ES integration""" - - ometa_client = OpenMetadata( - OpenMetadataConnection( - hostPort="http://localhost:8585/api", - authProvider=AuthProvider.openmetadata, - enableVersionValidation=False, - securityConfig=OpenMetadataJWTClientConfig(jwtToken="token"), - ) - ) - - db_service = DatabaseService( - id=uuid.uuid4(), - name=EntityName("service"), - fullyQualifiedName=FullyQualifiedEntityName("service"), - serviceType=DatabaseServiceType.CustomDatabase, - ) - - table = Table( - id=uuid.uuid4(), - name="TABLE", - columns=[OMetaColumn(name="id", dataType=DataType.BIGINT)], - database=EntityReference(id=uuid.uuid4(), type="database", name="database"), - databaseSchema=EntityReference( - id=uuid.uuid4(), type="databaseSchema", name="schema" - ), - ) - - # With too many responses, we won't return anything since we don't want false results - # that we cannot properly assign - with patch.object(OpenMetadata, "es_search_from_fqn", return_value=[table] * 4): - row = Row( - query_id=1, - query_type="INSERT", - start_time=datetime.now(), - query_text="INSERT INTO TABLE1 (col1, col2) VALUES (1, 'a'), (2, 'b')", - ) - query_result = get_snowflake_system_queries( - row=row, - database="DATABASE", - schema="SCHEMA", - ometa_client=ometa_client, - db_service=db_service, - ) - assert not query_result - - # Returning a single table should work fine - with patch.object(OpenMetadata, "es_search_from_fqn", return_value=[table]): - row = Row( - query_id="1", - query_type="INSERT", - start_time=datetime.now(), - query_text="INSERT INTO TABLE2 (col1, col2) VALUES (1, 'a'), (2, 'b')", - ) - query_result = get_snowflake_system_queries( - row=row, - database="DATABASE", - schema="SCHEMA", - ometa_client=ometa_client, - db_service=db_service, - ) - assert query_result - assert query_result.query_type == "INSERT" - assert query_result.database_name == "database" - assert query_result.schema_name == "schema" - assert query_result.table_name == "table2" - - @pytest.mark.parametrize( "identifier, expected", [