Fix 17698: use resolution logic for snowflake system metrics profiler (#17699)

* fix(profiler): snowflake

resolve tables using the snowflake engine instead of OpenMetadata

* added env for cleaning up dbs in E2E

* moved system metric method to profiler. all the rest says in snowflake

* format

* revert unnecessary changes

* removed test for previous resolution method

* use shutdown39
This commit is contained in:
Imri Paran 2024-09-06 09:25:10 +02:00 committed by GitHub
parent b2f21fa070
commit 84be1a3162
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 687 additions and 400 deletions

View File

@ -1,10 +1,13 @@
from collections import deque from collections import deque
from typing import List, Union
from pydantic import BaseModel from pydantic import BaseModel
def assert_equal_pydantic_objects( def assert_equal_pydantic_objects(
expected: BaseModel, actual: BaseModel, ignore_none=True expected: Union[BaseModel, List[BaseModel]],
actual: Union[BaseModel, List[BaseModel]],
ignore_none=True,
): ):
"""Compare 2 pydantic objects recursively and raise an AssertionError if they are not equal along with all """Compare 2 pydantic objects recursively and raise an AssertionError if they are not equal along with all
the differences by field. If `ignore_none` is set to True, expected None values will be ignored. This can be the differences by field. If `ignore_none` is set to True, expected None values will be ignored. This can be
@ -32,6 +35,10 @@ def assert_equal_pydantic_objects(
Traceback (most recent call last): Traceback (most recent call last):
``` ```
AssertionError: objects mismatched on field: [b.a], expected: [1], actual: [2] AssertionError: objects mismatched on field: [b.a], expected: [1], actual: [2]
>>> assert_equal_pydantic_objects([a1, a2], [a2, a1])
Traceback (most recent call last):
```
AssertionError: objects mismatched on field: [0].a, expected: [1], actual: [2]
Args: Args:
expected (BaseModel): The expected pydantic object. expected (BaseModel): The expected pydantic object.
@ -69,11 +76,24 @@ def assert_equal_pydantic_objects(
errors.append( errors.append(
f"objects mismatched on field: [{new_key_prefix}], expected: [{expected_value}], actual: [{actual_value}]" f"objects mismatched on field: [{new_key_prefix}], expected: [{expected_value}], actual: [{actual_value}]"
) )
elif isinstance(expected, list):
if not isinstance(actual, list):
errors.append(
f"validation error on field: [{current_key_prefix}], expected: [list], actual: [{type(actual).__name__}]"
)
elif len(expected) != len(actual):
errors.append(
f"mismatch length at {current_key_prefix}: expected: [{len(expected)}], actual: [{len(actual)}]"
)
else:
for i, (expected_item, actual_item) in enumerate(zip(expected, actual)):
queue.append(
(expected_item, actual_item, f"{current_key_prefix}[{i}]")
)
else: else:
if expected != actual: if expected != actual:
errors.append( errors.append(
f"mismatch at {current_key_prefix}: expected: [{expected}], actual: [{actual}]" f"mismatch at {current_key_prefix}: expected: [{expected}], actual: [{actual}]"
) )
if errors: if errors:
raise AssertionError("\n".join(errors)) raise AssertionError("\n".join(errors))

View File

@ -16,7 +16,7 @@ To be used by OpenMetadata class
import traceback import traceback
from typing import List, Optional, Type, TypeVar from typing import List, Optional, Type, TypeVar
from pydantic import BaseModel from pydantic import BaseModel, validate_call
from requests.utils import quote from requests.utils import quote
from metadata.generated.schema.api.data.createTableProfile import ( from metadata.generated.schema.api.data.createTableProfile import (
@ -227,6 +227,7 @@ class OMetaTableMixin:
return None return None
@validate_call
def get_profile_data( def get_profile_data(
self, self,
fqn: str, fqn: str,
@ -253,7 +254,6 @@ class OMetaTableMixin:
Returns: Returns:
EntityList: EntityList list object EntityList: EntityList list object
""" """
url_after = f"&after={after}" if after else "" url_after = f"&after={after}" if after else ""
profile_type_url = profile_type.__name__[0].lower() + profile_type.__name__[1:] profile_type_url = profile_type.__name__[0].lower() + profile_type.__name__[1:]

View File

@ -26,7 +26,7 @@ from metadata.ingestion.lineage.models import ConnectionTypeDialectMapper, Diale
from metadata.ingestion.lineage.parser import LineageParser from metadata.ingestion.lineage.parser import LineageParser
from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.utils.logger import ingestion_logger from metadata.utils.logger import ingestion_logger
from metadata.utils.time_utils import convert_timestamp_to_milliseconds from metadata.utils.time_utils import datetime_to_timestamp
logger = ingestion_logger() logger = ingestion_logger()
@ -46,7 +46,7 @@ def parse_sql_statement(record: TableQuery, dialect: Dialect) -> Optional[Parsed
start_date = start_time.root.date() start_date = start_time.root.date()
start_time = datetime.datetime.strptime(str(start_date.isoformat()), "%Y-%m-%d") start_time = datetime.datetime.strptime(str(start_date.isoformat()), "%Y-%m-%d")
start_time = convert_timestamp_to_milliseconds(int(start_time.timestamp())) start_time = datetime_to_timestamp(start_time, milliseconds=True)
lineage_parser = LineageParser(record.query, dialect=dialect) lineage_parser = LineageParser(record.query, dialect=dialect)

View File

@ -36,7 +36,7 @@ from metadata.ingestion.models.topology import TopologyContextManager
from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.utils import fqn from metadata.utils import fqn
from metadata.utils.logger import ingestion_logger from metadata.utils.logger import ingestion_logger
from metadata.utils.time_utils import convert_timestamp_to_milliseconds from metadata.utils.time_utils import datetime_to_timestamp
logger = ingestion_logger() logger = ingestion_logger()
@ -104,10 +104,8 @@ class LifeCycleQueryMixin:
life_cycle = LifeCycle( life_cycle = LifeCycle(
created=AccessDetails( created=AccessDetails(
timestamp=Timestamp( timestamp=Timestamp(
int( datetime_to_timestamp(
convert_timestamp_to_milliseconds( life_cycle_data.created_at, milliseconds=True
life_cycle_data.created_at.timestamp()
)
) )
) )
) )

View File

@ -15,11 +15,18 @@ import urllib
from datetime import datetime from datetime import datetime
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, TypeAdapter, field_validator
from requests.utils import quote from requests.utils import quote
from sqlalchemy import text
from sqlalchemy.orm import Session
from metadata.generated.schema.entity.data.storedProcedure import Language from metadata.generated.schema.entity.data.storedProcedure import Language
from metadata.ingestion.source.database.snowflake.queries import (
SNOWFLAKE_QUERY_LOG_QUERY,
)
from metadata.profiler.metrics.system.dml_operation import DatabaseDMLOperations
from metadata.utils.logger import ingestion_logger from metadata.utils.logger import ingestion_logger
from metadata.utils.profiler_utils import QueryResult
logger = ingestion_logger() logger = ingestion_logger()
@ -95,3 +102,44 @@ class SnowflakeTableList(BaseModel):
def get_not_deleted(self) -> List[SnowflakeTable]: def get_not_deleted(self) -> List[SnowflakeTable]:
return [table for table in self.tables if not table.deleted] return [table for table in self.tables if not table.deleted]
class SnowflakeQueryLogEntry(BaseModel):
"""Entry for a Snowflake query log at SNOWFLAKE.ACCOUNT_USAGE.QUERY_HISTORY
More info at: https://docs.snowflake.com/en/sql-reference/account-usage/query_history
"""
query_id: str
database_name: Optional[str] = None
schema_name: Optional[str] = None
query_type: str
start_time: datetime
query_text: Optional[str] = None
rows_inserted: Optional[int] = None
rows_updated: Optional[int] = None
rows_deleted: Optional[int] = None
@staticmethod
def get_for_table(session: Session, tablename: str):
rows = session.execute(
text(
SNOWFLAKE_QUERY_LOG_QUERY.format(
tablename=tablename, # type: ignore
insert=DatabaseDMLOperations.INSERT.value,
update=DatabaseDMLOperations.UPDATE.value,
delete=DatabaseDMLOperations.DELETE.value,
merge=DatabaseDMLOperations.MERGE.value,
)
)
)
return TypeAdapter(List[SnowflakeQueryLogEntry]).validate_python(
map(dict, rows)
)
class SnowflakeQueryResult(QueryResult):
"""Snowflake system metric query result"""
rows_inserted: Optional[int] = None
rows_updated: Optional[int] = None
rows_deleted: Optional[int] = None

View File

@ -0,0 +1,260 @@
import re
import traceback
from typing import List, Optional, Tuple
import sqlalchemy.orm
from sqlalchemy.orm import DeclarativeMeta, Session
from metadata.ingestion.source.database.snowflake.models import (
SnowflakeQueryLogEntry,
SnowflakeQueryResult,
)
from metadata.utils.logger import profiler_logger
from metadata.utils.lru_cache import LRU_CACHE_SIZE, LRUCache
from metadata.utils.profiler_utils import get_identifiers_from_string
PUBLIC_SCHEMA = "PUBLIC"
logger = profiler_logger()
RESULT_SCAN = """
SELECT *
FROM TABLE(RESULT_SCAN('{query_id}'));
"""
QUERY_PATTERN = r"(?:(INSERT\s*INTO\s*|INSERT\s*OVERWRITE\s*INTO\s*|UPDATE\s*|MERGE\s*INTO\s*|DELETE\s*FROM\s*))([\w._\"\'()]+)(?=[\s*\n])" # pylint: disable=line-too-long
IDENTIFIER_PATTERN = r"(IDENTIFIER\(\')([\w._\"]+)(\'\))"
def _parse_query(query: str) -> Optional[str]:
"""Parse snowflake queries to extract the identifiers"""
match = re.match(QUERY_PATTERN, query, re.IGNORECASE)
try:
# This will match results like `DATABASE.SCHEMA.TABLE1` or IDENTIFIER('TABLE1')
# If we have `IDENTIFIER` type of queries coming from Stored Procedures, we'll need to further clean it up.
identifier = match.group(2)
match_internal_identifier = re.match(
IDENTIFIER_PATTERN, identifier, re.IGNORECASE
)
internal_identifier = (
match_internal_identifier.group(2) if match_internal_identifier else None
)
if internal_identifier:
return internal_identifier
return identifier
except (IndexError, AttributeError):
logger.debug("Could not find identifier in query. Skipping row.")
return None
class SnowflakeTableResovler:
def __init__(self, session: sqlalchemy.orm.Session):
self._cache = LRUCache(LRU_CACHE_SIZE)
self.session = session
def show_tables(self, db, schema, table):
return self.session.execute(
f'SHOW TABLES LIKE \'{table}\' IN SCHEMA "{db}"."{schema}" LIMIT 1;'
).fetchone()
def table_exists(self, db, schema, table):
"""Return True if the table exists in Snowflake. Uses cache to store the results.
Args:
db (str): Database name
schema (str): Schema name
table (str): Table name
Returns:
bool: True if the table exists in Snowflake
"""
if f"{db}.{schema}.{table}" in self._cache:
return self._cache.get(f"{db}.{schema}.{table}")
table = self.show_tables(db, schema, table)
if table:
self._cache.put(f"{db}.{schema}.{table}", True)
return True
return False
def resolve_implicit_fqn(
self,
context_database: str,
context_schema: Optional[str],
table_name: str,
) -> Tuple[str, str, str]:
"""Resolve the fully qualified name of the table from snowflake based on the following logic:
1. If the schema is provided:
a. search for the table in the schema
b. if not found, go to (2)
2. Search for the table in the public schema.
Args:
context_database (str): Database name
context_schema (Optional[str]): Schema name. If not provided, we'll search in the public schema.
table_name (str): Table name
Returns:
tuple: Tuple of database, schema and table names
Raises:
RuntimeError: If the table is not found in the metadata or if there are duplicate results (there shouldn't be)
"""
search_paths = []
if context_schema and self.table_exists(
context_database, context_schema, table_name
):
search_paths += ".".join([context_database, context_schema, table_name])
return context_database, context_schema, table_name
if context_schema != PUBLIC_SCHEMA and self.table_exists(
context_database, PUBLIC_SCHEMA, table_name
):
search_paths += ".".join([context_database, PUBLIC_SCHEMA, table_name])
return context_database, PUBLIC_SCHEMA, table_name
raise RuntimeError(
"Could not find the table {search_paths}.".format(
search_paths=" OR ".join(map(lambda x: f"[{x}]", search_paths))
)
)
def resolve_snowflake_fqn(
self,
context_database: str,
context_schema: Optional[str],
identifier: str,
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""Get query identifiers from the query text. If the schema is not provided in the query, we'll look for
the table under "PUBLIC" in Snowflake.
Database can be retrieved from the query or the query context.
If the schema doesnt exist in the query but does in the context, we need to check with Snowflake if table
exists in (1) the context schema or ib (2) the public schema in order to imitate the behavior of the query
engine. There are edge cases where the table was deleted (and hence not found in the metadata). In such cases,
the function will raise an error. It is advised to set the profier window such that there will be minimal
drift between the query execution and the profiler run.
Args:
context_database (str): Database name from the query context
context_schema (Optional[str]): Schema name from the query context
identifier (str): Identifier string extracted from a query (can be 'db.schema.table', 'schema.table' or just 'table')
Returns:
Tuple[Optional[str], Optional[str], Optional[str]]: Tuple of database, schema and table names
Raises:
RuntimeError: If the table name is not found in the query or if fqn resolution fails
"""
(
database_identifier,
schema_identifier,
table_name,
) = get_identifiers_from_string(identifier)
if not table_name:
raise RuntimeError("Could not extract the table name.")
if not context_database and not database_identifier:
logger.debug(
f"Could not resolve database name. {identifier=}, {context_database=}"
)
raise RuntimeError("Could not resolve database name.")
if schema_identifier is not None:
return (
database_identifier or context_database,
schema_identifier,
table_name,
)
logger.debug(
"Missing schema info from the query. We'll look for it in Snowflake for [%s] or [%s]",
(
".".join(
[
database_identifier or context_database,
context_schema,
table_name,
]
)
if context_schema
else None
),
".".join(
[database_identifier or context_database, PUBLIC_SCHEMA, table_name]
),
)
# If the schema is not explicitly provided in the query, we'll need to resolve it from OpenMetadata
# by cascading the search from the context to the public schema.
result = self.resolve_implicit_fqn(
context_database=context_database,
context_schema=context_schema,
table_name=table_name,
)
logger.debug("Resolved table [%s]", ".".join(result))
return result
def get_snowflake_system_queries(
query_log_entry: SnowflakeQueryLogEntry,
resolver: SnowflakeTableResovler,
) -> Optional[SnowflakeQueryResult]:
"""
Run a regex lookup on the query to identify which operation ran against the table.
If the query does not have the complete set of `database.schema.table` when it runs,
we'll use ES to pick up the table, if we find it.
Args:
query_log_entry (dict): row from the snowflake system queries table
resolver (SnowflakeTableResolver): resolver to get the table identifiers
Returns:
QueryResult: namedtuple with the query result
"""
try:
logger.debug(f"Trying to parse query [{query_log_entry.query_id}]")
identifier = _parse_query(query_log_entry.query_text)
if not identifier:
raise RuntimeError("Could not identify the table from the query.")
database_name, schema_name, table_name = resolver.resolve_snowflake_fqn(
identifier=identifier,
context_database=query_log_entry.database_name,
context_schema=query_log_entry.schema_name,
)
if not all([database_name, schema_name, table_name]):
raise RuntimeError(
f"Could not extract the identifiers from the query [{query_log_entry.query_id}]."
)
return SnowflakeQueryResult(
query_id=query_log_entry.query_id,
database_name=database_name.lower(),
schema_name=schema_name.lower(),
table_name=table_name.lower(),
query_text=query_log_entry.query_text,
query_type=query_log_entry.query_type,
start_time=query_log_entry.start_time,
rows_inserted=query_log_entry.rows_inserted,
rows_updated=query_log_entry.rows_updated,
rows_deleted=query_log_entry.rows_deleted,
)
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"""Error while processing query with id [{query_log_entry.query_id}]: {exc}\n
To investigate the query run:
SELECT * FROM SNOWFLAKE.ACCOUNT_USAGE.QUERY_HISTORY WHERE query_id = '{query_log_entry.query_id}'
"""
)
return None
def build_snowflake_query_results(
session: Session,
table: DeclarativeMeta,
) -> List[SnowflakeQueryResult]:
"""List and parse snowflake DML query results"""
query_results = []
resolver = SnowflakeTableResovler(
session=session,
)
for row in SnowflakeQueryLogEntry.get_for_table(session, table.__tablename__):
result = get_snowflake_system_queries(
query_log_entry=row,
resolver=resolver,
)
if result:
query_results.append(result)
return query_results

View File

@ -355,3 +355,26 @@ ORDER BY PROCEDURE_START_TIME DESC
SNOWFLAKE_GET_TABLE_DDL = """ SNOWFLAKE_GET_TABLE_DDL = """
SELECT GET_DDL('TABLE','{table_name}') AS \"text\" SELECT GET_DDL('TABLE','{table_name}') AS \"text\"
""" """
SNOWFLAKE_QUERY_LOG_QUERY = """
SELECT
QUERY_ID,
QUERY_TEXT,
QUERY_TYPE,
START_TIME,
DATABASE_NAME,
SCHEMA_NAME,
ROWS_INSERTED,
ROWS_UPDATED,
ROWS_DELETED
FROM "SNOWFLAKE"."ACCOUNT_USAGE"."QUERY_HISTORY"
WHERE
start_time>= DATEADD('DAY', -1, CURRENT_TIMESTAMP)
AND QUERY_TEXT ILIKE '%{tablename}%'
AND QUERY_TYPE IN (
'{insert}',
'{update}',
'{delete}',
'{merge}'
)
AND EXECUTION_STATUS = 'SUCCESS';
"""

View File

@ -29,6 +29,7 @@ from metadata.generated.schema.entity.data.databaseSchema import (
) )
from metadata.generated.schema.entity.data.table import ( from metadata.generated.schema.entity.data.table import (
PartitionProfilerConfig, PartitionProfilerConfig,
SystemProfile,
Table, Table,
TableData, TableData,
) )
@ -462,7 +463,7 @@ class ProfilerInterface(ABC):
runner, runner,
*args, *args,
**kwargs, **kwargs,
): ) -> List[SystemProfile]:
"""Get metrics""" """Get metrics"""
raise NotImplementedError raise NotImplementedError

View File

@ -442,7 +442,6 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
sample, sample,
) )
row = None row = None
try: try:
row = self._get_metric_fn[metric_func.metric_type.value]( row = self._get_metric_fn[metric_func.metric_type.value](
metric_func.metrics, metric_func.metrics,
@ -451,11 +450,9 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
column=metric_func.column, column=metric_func.column,
sample=sample, sample=sample,
) )
if row and isinstance(row, dict): if isinstance(row, dict):
row = self._validate_nulls(row) row = self._validate_nulls(row)
if isinstance(row, list):
# System metrics return a list of dictionaries, with UPDATE, INSERT or DELETE ops results
if row and metric_func.metric_type == MetricTypes.System:
row = [ row = [
self._validate_nulls(r) if isinstance(r, dict) else r self._validate_nulls(r) if isinstance(r, dict) else r
for r in row for r in row
@ -537,6 +534,9 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
logger.debug(traceback.format_exc()) logger.debug(traceback.format_exc())
logger.error(f"Operation was cancelled due to TimeoutError - {exc}") logger.error(f"Operation was cancelled due to TimeoutError - {exc}")
raise concurrent.futures.TimeoutError raise concurrent.futures.TimeoutError
except KeyboardInterrupt:
pool.shutdown39(wait=True, cancel_futures=True)
raise
return profile_results return profile_results

View File

@ -19,6 +19,7 @@ from sqlalchemy import text
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from metadata.utils.profiler_utils import QueryResult from metadata.utils.profiler_utils import QueryResult
from metadata.utils.time_utils import datetime_to_timestamp
STL_QUERY = """ STL_QUERY = """
with data as ( with data as (
@ -73,7 +74,7 @@ def get_query_results(
table_name=row.table, table_name=row.table,
query_text=None, query_text=None,
query_type=operation, query_type=operation,
timestamp=row.starttime, start_time=row.starttime,
rows=row.rows, rows=row.rows,
) )
for row in cursor for row in cursor
@ -94,7 +95,7 @@ def get_metric_result(ddls: List[QueryResult], table_name: str) -> List:
""" """
return [ return [
{ {
"timestamp": int(ddl.timestamp.timestamp() * 1000), "timestamp": datetime_to_timestamp(ddl.start_time, milliseconds=True),
"operation": ddl.query_type, "operation": ddl.query_type,
"rowsAffected": ddl.rows, "rowsAffected": ddl.rows,
} }

View File

@ -1,191 +0,0 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Snowflake System Metric Queries and query operations
"""
import re
import traceback
from typing import Optional, Tuple
from sqlalchemy.engine.row import Row
from metadata.generated.schema.entity.services.databaseService import DatabaseService
from metadata.ingestion.lineage.sql_lineage import search_table_entities
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.utils.logger import profiler_logger
from metadata.utils.profiler_utils import (
SnowflakeQueryResult,
get_identifiers_from_string,
)
logger = profiler_logger()
INFORMATION_SCHEMA_QUERY = """
SELECT
QUERY_ID,
QUERY_TEXT,
QUERY_TYPE,
START_TIME,
ROWS_INSERTED,
ROWS_UPDATED,
ROWS_DELETED
FROM "SNOWFLAKE"."ACCOUNT_USAGE"."QUERY_HISTORY"
WHERE
start_time>= DATEADD('DAY', -1, CURRENT_TIMESTAMP)
AND QUERY_TEXT ILIKE '%{tablename}%'
AND QUERY_TYPE IN (
'{insert}',
'{update}',
'{delete}',
'{merge}'
)
AND EXECUTION_STATUS = 'SUCCESS';
"""
RESULT_SCAN = """
SELECT *
FROM TABLE(RESULT_SCAN('{query_id}'));
"""
QUERY_PATTERN = r"(?:(INSERT\s*INTO\s*|INSERT\s*OVERWRITE\s*INTO\s*|UPDATE\s*|MERGE\s*INTO\s*|DELETE\s*FROM\s*))([\w._\"\'()]+)(?=[\s*\n])" # pylint: disable=line-too-long
IDENTIFIER_PATTERN = r"(IDENTIFIER\(\')([\w._\"]+)(\'\))"
def _parse_query(query: str) -> Optional[str]:
"""Parse snowflake queries to extract the identifiers"""
match = re.match(QUERY_PATTERN, query, re.IGNORECASE)
try:
# This will match results like `DATABASE.SCHEMA.TABLE1` or IDENTIFIER('TABLE1')
# If we have `IDENTIFIER` type of queries coming from Stored Procedures, we'll need to further clean it up.
identifier = match.group(2)
match_internal_identifier = re.match(
IDENTIFIER_PATTERN, identifier, re.IGNORECASE
)
internal_identifier = (
match_internal_identifier.group(2) if match_internal_identifier else None
)
if internal_identifier:
return internal_identifier
return identifier
except (IndexError, AttributeError):
logger.debug("Could not find identifier in query. Skipping row.")
return None
def get_identifiers(
identifier: str, ometa_client: OpenMetadata, db_service: DatabaseService
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""Get query identifiers and if needed, fetch them from ES"""
database_name, schema_name, table_name = get_identifiers_from_string(identifier)
if not table_name:
logger.debug("Could not extract the table name. Skipping operation.")
return database_name, schema_name, table_name
if not all([database_name, schema_name]):
logger.debug(
"Missing database or schema info from the query. We'll look for it in ES."
)
es_tables = search_table_entities(
metadata=ometa_client,
service_name=db_service.fullyQualifiedName.root,
database=database_name,
database_schema=schema_name,
table=table_name,
)
if not es_tables:
logger.debug("No tables match the search criteria.")
return database_name, schema_name, table_name
if len(es_tables) > 1:
logger.debug(
"Found more than 1 table matching the search criteria."
" Skipping the computation to not mix system data."
)
return database_name, schema_name, table_name
matched_table = es_tables[0]
database_name = matched_table.database.name
schema_name = matched_table.databaseSchema.name
return database_name, schema_name, table_name
def get_snowflake_system_queries(
row: Row,
database: str,
schema: str,
ometa_client: OpenMetadata,
db_service: DatabaseService,
) -> Optional[SnowflakeQueryResult]:
"""
Run a regex lookup on the query to identify which operation ran against the table.
If the query does not have the complete set of `database.schema.table` when it runs,
we'll use ES to pick up the table, if we find it.
Args:
row (dict): row from the snowflake system queries table
database (str): database name
schema (str): schema name
ometa_client (OpenMetadata): OpenMetadata client to search against ES
db_service (DatabaseService): DB service where the process is running against
Returns:
QueryResult: namedtuple with the query result
"""
try:
dict_row = dict(row)
query_text = dict_row.get("QUERY_TEXT", dict_row.get("query_text"))
logger.debug(f"Trying to parse query:\n{query_text}\n")
identifier = _parse_query(query_text)
if not identifier:
return None
database_name, schema_name, table_name = get_identifiers(
identifier=identifier,
ometa_client=ometa_client,
db_service=db_service,
)
if not all([database_name, schema_name, table_name]):
return None
if (
database.lower() == database_name.lower()
and schema.lower() == schema_name.lower()
):
return SnowflakeQueryResult(
query_id=dict_row.get("QUERY_ID", dict_row.get("query_id")),
database_name=database_name.lower(),
schema_name=schema_name.lower(),
table_name=table_name.lower(),
query_text=query_text,
query_type=dict_row.get("QUERY_TYPE", dict_row.get("query_type")),
timestamp=dict_row.get("START_TIME", dict_row.get("start_time")),
rows_inserted=dict_row.get(
"ROWS_INSERTED", dict_row.get("rows_inserted")
),
rows_updated=dict_row.get("ROWS_UPDATED", dict_row.get("rows_updated")),
rows_deleted=dict_row.get("ROWS_DELETED", dict_row.get("rows_deleted")),
)
except Exception:
logger.debug(traceback.format_exc())
return None
return None

View File

@ -17,15 +17,18 @@ import traceback
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Optional from typing import Dict, List, Optional
from pydantic import TypeAdapter
from sqlalchemy import text from sqlalchemy import text
from sqlalchemy.orm import DeclarativeMeta, Session from sqlalchemy.orm import DeclarativeMeta, Session
from metadata.generated.schema.configuration.profilerConfiguration import MetricType from metadata.generated.schema.configuration.profilerConfiguration import MetricType
from metadata.generated.schema.entity.data.table import SystemProfile
from metadata.generated.schema.entity.services.connections.database.bigQueryConnection import ( from metadata.generated.schema.entity.services.connections.database.bigQueryConnection import (
BigQueryConnection, BigQueryConnection,
) )
from metadata.generated.schema.entity.services.databaseService import DatabaseService from metadata.ingestion.source.database.snowflake.profiler.system_metrics import (
from metadata.ingestion.ometa.ometa_api import OpenMetadata build_snowflake_query_results,
)
from metadata.profiler.metrics.core import SystemMetric from metadata.profiler.metrics.core import SystemMetric
from metadata.profiler.metrics.system.dml_operation import ( from metadata.profiler.metrics.system.dml_operation import (
DML_OPERATION_MAP, DML_OPERATION_MAP,
@ -41,19 +44,12 @@ from metadata.profiler.metrics.system.queries.redshift import (
get_metric_result, get_metric_result,
get_query_results, get_query_results,
) )
from metadata.profiler.metrics.system.queries.snowflake import (
INFORMATION_SCHEMA_QUERY,
get_snowflake_system_queries,
)
from metadata.profiler.orm.registry import Dialects from metadata.profiler.orm.registry import Dialects
from metadata.utils.dispatch import valuedispatch from metadata.utils.dispatch import valuedispatch
from metadata.utils.helpers import deep_size_of_dict from metadata.utils.helpers import deep_size_of_dict
from metadata.utils.logger import profiler_logger from metadata.utils.logger import profiler_logger
from metadata.utils.profiler_utils import ( from metadata.utils.profiler_utils import get_value_from_cache, set_cache
SnowflakeQueryResult, from metadata.utils.time_utils import datetime_to_timestamp
get_value_from_cache,
set_cache,
)
logger = profiler_logger() logger = profiler_logger()
@ -75,7 +71,7 @@ def get_system_metrics_for_dialect(
table: DeclarativeMeta, table: DeclarativeMeta,
*args, *args,
**kwargs, **kwargs,
) -> Optional[Dict]: ) -> Optional[List[SystemProfile]]:
"""_summary_ """_summary_
Args: Args:
@ -91,6 +87,7 @@ def get_system_metrics_for_dialect(
} else returns None } else returns None
""" """
logger.debug(f"System metrics not support for {dialect}. Skipping processing.") logger.debug(f"System metrics not support for {dialect}. Skipping processing.")
return None
@get_system_metrics_for_dialect.register(Dialects.BigQuery) @get_system_metrics_for_dialect.register(Dialects.BigQuery)
@ -101,7 +98,7 @@ def _(
conn_config: BigQueryConnection, conn_config: BigQueryConnection,
*args, *args,
**kwargs, **kwargs,
) -> List[Dict]: ) -> List[SystemProfile]:
"""Compute system metrics for bigquery """Compute system metrics for bigquery
Args: Args:
@ -190,7 +187,7 @@ def _(
} }
) )
return metric_results return TypeAdapter(List[SystemProfile]).validate_python(metric_results)
@get_system_metrics_for_dialect.register(Dialects.Redshift) @get_system_metrics_for_dialect.register(Dialects.Redshift)
@ -200,7 +197,7 @@ def _(
table: DeclarativeMeta, table: DeclarativeMeta,
*args, *args,
**kwargs, **kwargs,
) -> List[Dict]: ) -> List[SystemProfile]:
"""List all the DML operations for reshifts tables """List all the DML operations for reshifts tables
Args: Args:
@ -289,42 +286,7 @@ def _(
) )
metric_results.extend(get_metric_result(updates, table.__tablename__)) # type: ignore metric_results.extend(get_metric_result(updates, table.__tablename__)) # type: ignore
return metric_results return TypeAdapter(List[SystemProfile]).validate_python(metric_results).d
def _snowflake_build_query_result(
session: Session,
table: DeclarativeMeta,
database: str,
schema: str,
ometa_client: OpenMetadata,
db_service: DatabaseService,
) -> List[SnowflakeQueryResult]:
"""List and parse snowflake DML query results"""
rows = session.execute(
text(
INFORMATION_SCHEMA_QUERY.format(
tablename=table.__tablename__, # type: ignore
insert=DatabaseDMLOperations.INSERT.value,
update=DatabaseDMLOperations.UPDATE.value,
delete=DatabaseDMLOperations.DELETE.value,
merge=DatabaseDMLOperations.MERGE.value,
)
)
)
query_results = []
for row in rows:
result = get_snowflake_system_queries(
row=row,
database=database,
schema=schema,
ometa_client=ometa_client,
db_service=db_service,
)
if result:
query_results.append(result)
return query_results
@get_system_metrics_for_dialect.register(Dialects.Snowflake) @get_system_metrics_for_dialect.register(Dialects.Snowflake)
@ -332,8 +294,6 @@ def _(
dialect: str, dialect: str,
session: Session, session: Session,
table: DeclarativeMeta, table: DeclarativeMeta,
ometa_client: OpenMetadata,
db_service: DatabaseService,
*args, *args,
**kwargs, **kwargs,
) -> Optional[List[Dict]]: ) -> Optional[List[Dict]]:
@ -354,18 +314,12 @@ def _(
Dict: system metric Dict: system metric
""" """
logger.debug(f"Fetching system metrics for {dialect}") logger.debug(f"Fetching system metrics for {dialect}")
database = session.get_bind().url.database
schema = table.__table_args__["schema"] # type: ignore
metric_results: List[Dict] = [] metric_results: List[Dict] = []
query_results = _snowflake_build_query_result( query_results = build_snowflake_query_results(
session=session, session=session,
table=table, table=table,
database=database,
schema=schema,
ometa_client=ometa_client,
db_service=db_service,
) )
for query_result in query_results: for query_result in query_results:
@ -380,7 +334,9 @@ def _(
if query_result.rows_inserted: if query_result.rows_inserted:
metric_results.append( metric_results.append(
{ {
"timestamp": int(query_result.timestamp.timestamp() * 1000), "timestamp": datetime_to_timestamp(
query_result.start_time, milliseconds=True
),
"operation": DatabaseDMLOperations.INSERT.value, "operation": DatabaseDMLOperations.INSERT.value,
"rowsAffected": query_result.rows_inserted, "rowsAffected": query_result.rows_inserted,
} }
@ -388,7 +344,9 @@ def _(
if query_result.rows_updated: if query_result.rows_updated:
metric_results.append( metric_results.append(
{ {
"timestamp": int(query_result.timestamp.timestamp() * 1000), "timestamp": datetime_to_timestamp(
query_result.start_time, milliseconds=True
),
"operation": DatabaseDMLOperations.UPDATE.value, "operation": DatabaseDMLOperations.UPDATE.value,
"rowsAffected": query_result.rows_updated, "rowsAffected": query_result.rows_updated,
} }
@ -397,13 +355,15 @@ def _(
metric_results.append( metric_results.append(
{ {
"timestamp": int(query_result.timestamp.timestamp() * 1000), "timestamp": datetime_to_timestamp(
query_result.start_time, milliseconds=True
),
"operation": DML_OPERATION_MAP.get(query_result.query_type), "operation": DML_OPERATION_MAP.get(query_result.query_type),
"rowsAffected": rows_affected, "rowsAffected": rows_affected,
} }
) )
return metric_results return TypeAdapter(List[SystemProfile]).validate_python(metric_results)
class System(SystemMetric): class System(SystemMetric):

View File

@ -274,9 +274,6 @@ class Profiler(Generic[TMetric]):
Data should be saved under self.results Data should be saved under self.results
""" """
logger.debug("Running post Profiler...")
current_col_results: Dict[str, Any] = self._column_results.get(col.name) current_col_results: Dict[str, Any] = self._column_results.get(col.name)
if not current_col_results: if not current_col_results:
logger.debug( logger.debug(

View File

@ -34,20 +34,12 @@ class QueryResult(BaseModel):
schema_name: str schema_name: str
table_name: str table_name: str
query_type: str query_type: str
timestamp: datetime start_time: datetime
query_id: Optional[str] = None query_id: Optional[str] = None
query_text: Optional[str] = None query_text: Optional[str] = None
rows: Optional[int] = None rows: Optional[int] = None
class SnowflakeQueryResult(QueryResult):
"""Snowflake system metric query result"""
rows_inserted: Optional[int] = None
rows_updated: Optional[int] = None
rows_deleted: Optional[int] = None
def clean_up_query(query: str) -> str: def clean_up_query(query: str) -> str:
"""remove comments and newlines from query""" """remove comments and newlines from query"""
return sqlparse.format(query, strip_comments=True).replace("\\n", "") return sqlparse.format(query, strip_comments=True).replace("\\n", "")

View File

@ -17,21 +17,25 @@ from datetime import datetime, time, timedelta, timezone
from math import floor from math import floor
from typing import Union from typing import Union
from metadata.utils.deprecation import deprecated
from metadata.utils.helpers import datetime_to_ts from metadata.utils.helpers import datetime_to_ts
def datetime_to_timestamp(datetime_value, milliseconds=False) -> int: def datetime_to_timestamp(datetime_value: datetime, milliseconds=False) -> int:
"""Convert a datetime object to timestamp integer """Convert a datetime object to timestamp integer. Datetime can be timezone aware or naive. Result
will always be in UTC.
Args: Args:
datetime_value (_type_): datetime object datetime_value (_type_): datetime object
milliseconds (bool, optional): make it a milliseconds timestamp. Defaults to False. milliseconds (bool, optional): make it a milliseconds timestamp. Defaults to False.
Returns: Returns:
int: int : timestamp in seconds or milliseconds
""" """
if not getattr(datetime_value, "timestamp", None): if not getattr(datetime_value, "timestamp", None):
raise TypeError(f"Object of type {datetime_value} has not method `timestamp()`") raise TypeError(
f"Object of type {type(datetime_value).__name__} has not method `timestamp()`"
)
tmsap = datetime_value.timestamp() tmsap = datetime_value.timestamp()
if milliseconds: if milliseconds:
@ -115,6 +119,7 @@ def convert_timestamp(timestamp: str) -> Union[int, float]:
return float(timestamp) / 1000 return float(timestamp) / 1000
@deprecated("Use `datetime_to_timestamp` instead", "1.7.0")
def convert_timestamp_to_milliseconds(timestamp: Union[int, float]) -> int: def convert_timestamp_to_milliseconds(timestamp: Union[int, float]) -> int:
"""convert timestamp to milliseconds """convert timestamp to milliseconds
Args: Args:

View File

@ -13,12 +13,18 @@
Test database connectors which extend from `CommonDbSourceService` with CLI Test database connectors which extend from `CommonDbSourceService` with CLI
""" """
import json import json
import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import yaml
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
from metadata.generated.schema.entity.services.databaseService import DatabaseService
from metadata.generated.schema.metadataIngestion.workflow import (
OpenMetadataWorkflowConfig,
)
from metadata.ingestion.api.status import Status from metadata.ingestion.api.status import Status
from metadata.workflow.metadata import MetadataWorkflow from metadata.workflow.metadata import MetadataWorkflow
@ -45,6 +51,19 @@ class CliCommonDB:
Path(PATH_TO_RESOURCES + f"/database/{connector}/test.yaml") Path(PATH_TO_RESOURCES + f"/database/{connector}/test.yaml")
) )
@classmethod
def tearDownClass(cls):
workflow = OpenMetadataWorkflowConfig.model_validate(
yaml.safe_load(open(cls.config_file_path))
)
db_service: DatabaseService = cls.openmetadata.get_by_name(
DatabaseService, workflow.source.serviceName
)
if db_service and os.getenv("E2E_CLEAN_DB", "false") == "true":
cls.openmetadata.delete(
DatabaseService, db_service.id, hard_delete=True, recursive=True
)
def tearDown(self) -> None: def tearDown(self) -> None:
self.engine.dispose() self.engine.dispose()

View File

@ -12,10 +12,15 @@
""" """
Test Snowflake connector with CLI Test Snowflake connector with CLI
""" """
from datetime import datetime
from time import sleep
from typing import List from typing import List
import pytest import pytest
from _openmetadata_testutils.pydantic.test_utils import assert_equal_pydantic_objects
from metadata.generated.schema.entity.data.table import DmlOperationType, SystemProfile
from metadata.generated.schema.type.basic import Timestamp
from metadata.ingestion.api.status import Status from metadata.ingestion.api.status import Status
from .base.e2e_types import E2EType from .base.e2e_types import E2EType
@ -40,6 +45,9 @@ class SnowflakeCliTest(CliCommonDB.TestSuite, SQACommonMethods):
"CREATE OR REPLACE TABLE e2e_test.test_departments(e2e_testdepartment_id INT PRIMARY KEY,e2e_testdepartment_name VARCHAR (30) NOT NULL,e2e_testlocation_id INT);", "CREATE OR REPLACE TABLE e2e_test.test_departments(e2e_testdepartment_id INT PRIMARY KEY,e2e_testdepartment_name VARCHAR (30) NOT NULL,e2e_testlocation_id INT);",
"CREATE OR REPLACE TABLE e2e_test.test_employees(e2e_testemployee_id INT PRIMARY KEY,e2e_testfirst_name VARCHAR (20),e2e_testlast_name VARCHAR (25) NOT NULL,e2e_testemail VARCHAR (100) NOT NULL,e2e_testphone_number VARCHAR (20),e2e_testhire_date DATE NOT NULL,e2e_testjob_id INT NOT NULL,e2e_testsalary DECIMAL (8, 2) NOT NULL,e2e_testmanager_id INT,e2e_testdepartment_id INT);", "CREATE OR REPLACE TABLE e2e_test.test_employees(e2e_testemployee_id INT PRIMARY KEY,e2e_testfirst_name VARCHAR (20),e2e_testlast_name VARCHAR (25) NOT NULL,e2e_testemail VARCHAR (100) NOT NULL,e2e_testphone_number VARCHAR (20),e2e_testhire_date DATE NOT NULL,e2e_testjob_id INT NOT NULL,e2e_testsalary DECIMAL (8, 2) NOT NULL,e2e_testmanager_id INT,e2e_testdepartment_id INT);",
"CREATE OR REPLACE TABLE e2e_test.test_dependents(e2e_testdependent_id INT PRIMARY KEY,e2e_testfirst_name VARCHAR (50) NOT NULL,e2e_testlast_name VARCHAR (50) NOT NULL,e2e_testrelationship VARCHAR (25) NOT NULL,e2e_testemployee_id INT NOT NULL);", "CREATE OR REPLACE TABLE e2e_test.test_dependents(e2e_testdependent_id INT PRIMARY KEY,e2e_testfirst_name VARCHAR (50) NOT NULL,e2e_testlast_name VARCHAR (50) NOT NULL,e2e_testrelationship VARCHAR (25) NOT NULL,e2e_testemployee_id INT NOT NULL);",
"CREATE OR REPLACE TABLE e2e_test.e2e_table(varchar_column VARCHAR(255),int_column INT);",
"CREATE OR REPLACE TABLE public.public_table(varchar_column VARCHAR(255),int_column INT);",
"CREATE OR REPLACE TABLE public.e2e_table(varchar_column VARCHAR(255),int_column INT);",
] ]
create_table_query: str = """ create_table_query: str = """
@ -58,6 +66,13 @@ class SnowflakeCliTest(CliCommonDB.TestSuite, SQACommonMethods):
insert_data_queries: List[str] = [ insert_data_queries: List[str] = [
"INSERT INTO E2E_DB.e2e_test.persons (person_id, full_name) VALUES (1,'Peter Parker');", "INSERT INTO E2E_DB.e2e_test.persons (person_id, full_name) VALUES (1,'Peter Parker');",
"INSERT INTO E2E_DB.e2e_test.persons (person_id, full_name) VALUES (1, 'Clark Kent');", "INSERT INTO E2E_DB.e2e_test.persons (person_id, full_name) VALUES (1, 'Clark Kent');",
"INSERT INTO e2e_test.e2e_table (varchar_column, int_column) VALUES ('e2e_test.e2e_table', 1);",
"INSERT INTO public.e2e_table (varchar_column, int_column) VALUES ('public.e2e_table', 1);",
"INSERT INTO e2e_table (varchar_column, int_column) VALUES ('e2e_table', 1);",
"INSERT INTO public.public_table (varchar_column, int_column) VALUES ('public.public_table', 1);",
"INSERT INTO public_table (varchar_column, int_column) VALUES ('public_table', 1);",
"MERGE INTO public_table USING (SELECT 'public_table' as varchar_column, 2 as int_column) as source ON public_table.varchar_column = source.varchar_column WHEN MATCHED THEN UPDATE SET public_table.int_column = source.int_column WHEN NOT MATCHED THEN INSERT (varchar_column, int_column) VALUES (source.varchar_column, source.int_column);",
"DELETE FROM public_table WHERE varchar_column = 'public.public_table';",
] ]
drop_table_query: str = """ drop_table_query: str = """
@ -68,6 +83,19 @@ class SnowflakeCliTest(CliCommonDB.TestSuite, SQACommonMethods):
DROP VIEW IF EXISTS E2E_DB.e2e_test.view_persons; DROP VIEW IF EXISTS E2E_DB.e2e_test.view_persons;
""" """
teardown_sql_statements: List[str] = [
"DROP TABLE IF EXISTS E2E_DB.e2e_test.e2e_table;",
"DROP TABLE IF EXISTS E2E_DB.public.e2e_table;",
"DROP TABLE IF EXISTS E2E_DB.public.public_table;",
]
@classmethod
def tearDownClass(cls):
super().tearDownClass()
with cls.engine.connect() as connection:
for stmt in cls.teardown_sql_statements:
connection.execute(stmt)
def setUp(self) -> None: def setUp(self) -> None:
with self.engine.connect() as connection: with self.engine.connect() as connection:
for sql_statements in self.prepare_snowflake_e2e: for sql_statements in self.prepare_snowflake_e2e:
@ -83,15 +111,15 @@ class SnowflakeCliTest(CliCommonDB.TestSuite, SQACommonMethods):
self.assertTrue(len(source_status.failures) == 0) self.assertTrue(len(source_status.failures) == 0)
self.assertTrue(len(source_status.warnings) == 0) self.assertTrue(len(source_status.warnings) == 0)
self.assertTrue(len(source_status.filtered) == 1) self.assertTrue(len(source_status.filtered) == 1)
self.assertTrue( self.assertGreaterEqual(
(len(source_status.records) + len(source_status.updated_records)) (len(source_status.records) + len(source_status.updated_records)),
>= self.expected_tables() self.expected_tables(),
) )
self.assertTrue(len(sink_status.failures) == 0) self.assertTrue(len(sink_status.failures) == 0)
self.assertTrue(len(sink_status.warnings) == 0) self.assertTrue(len(sink_status.warnings) == 0)
self.assertTrue( self.assertGreater(
(len(sink_status.records) + len(sink_status.updated_records)) (len(sink_status.records) + len(sink_status.updated_records)),
> self.expected_tables() self.expected_tables(),
) )
def create_table_and_view(self) -> None: def create_table_and_view(self) -> None:
@ -130,17 +158,22 @@ class SnowflakeCliTest(CliCommonDB.TestSuite, SQACommonMethods):
# Otherwise the sampling here does not pick up rows # Otherwise the sampling here does not pick up rows
extra_args={"profileSample": 100}, extra_args={"profileSample": 100},
) )
# wait for query log to be updated
self.wait_for_query_log()
# run profiler with new tables # run profiler with new tables
result = self.run_command("profile") result = self.run_command("profile")
sink_status, source_status = self.retrieve_statuses(result) sink_status, source_status = self.retrieve_statuses(result)
self.assert_for_table_with_profiler(source_status, sink_status) self.assert_for_table_with_profiler(source_status, sink_status)
self.custom_profiler_assertions()
@staticmethod @staticmethod
def expected_tables() -> int: def expected_tables() -> int:
return 7 return 7
def inserted_rows_count(self) -> int: def inserted_rows_count(self) -> int:
return len(self.insert_data_queries) return len(
[q for q in self.insert_data_queries if "E2E_DB.e2e_test.persons" in q]
)
def view_column_lineage_count(self) -> int: def view_column_lineage_count(self) -> int:
return 2 return 2
@ -171,7 +204,7 @@ class SnowflakeCliTest(CliCommonDB.TestSuite, SQACommonMethods):
@staticmethod @staticmethod
def expected_filtered_table_includes() -> int: def expected_filtered_table_includes() -> int:
return 5 return 8
@staticmethod @staticmethod
def expected_filtered_table_excludes() -> int: def expected_filtered_table_excludes() -> int:
@ -179,7 +212,7 @@ class SnowflakeCliTest(CliCommonDB.TestSuite, SQACommonMethods):
@staticmethod @staticmethod
def expected_filtered_mix() -> int: def expected_filtered_mix() -> int:
return 6 return 7
@staticmethod @staticmethod
def delete_queries() -> List[str]: def delete_queries() -> List[str]:
@ -196,3 +229,90 @@ class SnowflakeCliTest(CliCommonDB.TestSuite, SQACommonMethods):
UPDATE E2E_DB.E2E_TEST.PERSONS SET full_name = 'Bruce Wayne' WHERE full_name = 'Clark Kent' UPDATE E2E_DB.E2E_TEST.PERSONS SET full_name = 'Bruce Wayne' WHERE full_name = 'Clark Kent'
""", """,
] ]
def custom_profiler_assertions(self):
cases = [
(
"e2e_snowflake.E2E_DB.E2E_TEST.E2E_TABLE",
[
SystemProfile(
timestamp=Timestamp(root=0),
operation=DmlOperationType.INSERT,
rowsAffected=1,
),
SystemProfile(
timestamp=Timestamp(root=0),
operation=DmlOperationType.INSERT,
rowsAffected=1,
),
],
),
(
"e2e_snowflake.E2E_DB.PUBLIC.E2E_TABLE",
[
SystemProfile(
timestamp=Timestamp(root=0),
operation=DmlOperationType.INSERT,
rowsAffected=1,
)
],
),
(
"e2e_snowflake.E2E_DB.PUBLIC.PUBLIC_TABLE",
[
SystemProfile(
timestamp=Timestamp(root=0),
operation=DmlOperationType.INSERT,
rowsAffected=1,
),
SystemProfile(
timestamp=Timestamp(root=0),
operation=DmlOperationType.INSERT,
rowsAffected=1,
),
SystemProfile(
timestamp=Timestamp(root=0),
operation=DmlOperationType.UPDATE,
rowsAffected=1,
),
SystemProfile(
timestamp=Timestamp(root=0),
operation=DmlOperationType.DELETE,
rowsAffected=1,
),
],
),
]
for table_fqn, expected_profile in cases:
actual_profiles = self.openmetadata.get_profile_data(
table_fqn,
start_ts=int((datetime.now().timestamp() - 600) * 1000),
end_ts=int(datetime.now().timestamp() * 1000),
profile_type=SystemProfile,
).entities
actual_profiles = sorted(actual_profiles, key=lambda x: x.timestamp.root)
actual_profiles = actual_profiles[-len(expected_profile) :]
actual_profiles = [
p.copy(update={"timestamp": Timestamp(root=0)}) for p in actual_profiles
]
try:
assert_equal_pydantic_objects(expected_profile, actual_profiles)
except AssertionError as e:
raise AssertionError(f"Table: {table_fqn}\n{e}")
@classmethod
def wait_for_query_log(cls, timeout=600):
start = datetime.now().timestamp()
cls.engine.execute("SELECT 'e2e_query_log_wait'")
latest = 0
while latest < start:
sleep(5)
latest = (
cls.engine.execute(
'SELECT max(start_time) FROM "SNOWFLAKE"."ACCOUNT_USAGE"."QUERY_HISTORY"'
)
.scalar()
.timestamp()
)
if (datetime.now().timestamp() - start) > timeout:
raise TimeoutError(f"Query log not updated for {timeout} seconds")

View File

@ -0,0 +1,114 @@
from unittest.mock import MagicMock, Mock
import pytest
from metadata.ingestion.source.database.snowflake.profiler.system_metrics import (
PUBLIC_SCHEMA,
SnowflakeTableResovler,
)
from metadata.utils.profiler_utils import get_identifiers_from_string
@pytest.mark.parametrize(
"schema_name",
["test_schema", PUBLIC_SCHEMA, None],
)
@pytest.mark.parametrize(
"existing_tables",
[
["db.test_schema.test_table", "db.PUBLIC.test_table"],
["db.test_schema.test_table"],
["db.PUBLIC.test_table"],
[],
],
)
def test_resolve_snoflake_fqn(schema_name, existing_tables):
def expected_result(schema_name, existing_tables):
if len(existing_tables) == 0:
return RuntimeError
if schema_name == "test_schema":
if "db.test_schema.test_table" in existing_tables:
return "db", "test_schema", "test_table"
if "db.PUBLIC.test_table" in existing_tables:
return "db", PUBLIC_SCHEMA, "test_table"
if (
schema_name in [None, PUBLIC_SCHEMA]
and "db.PUBLIC.test_table" in existing_tables
):
return "db", PUBLIC_SCHEMA, "test_table"
return RuntimeError
resolver = SnowflakeTableResovler(Mock())
def mock_show_tables(_, schema, table):
for t in existing_tables:
if t == f"db.{schema}.{table}":
return True
resolver.show_tables = mock_show_tables
expected = expected_result(schema_name, existing_tables)
if expected == RuntimeError:
with pytest.raises(expected):
resolver.resolve_implicit_fqn("db", schema_name, "test_table")
else:
result = resolver.resolve_implicit_fqn("db", schema_name, "test_table")
assert result == expected
@pytest.mark.parametrize("context_database", [None, "context_db"])
@pytest.mark.parametrize("context_schema", [None, "context_schema", PUBLIC_SCHEMA])
@pytest.mark.parametrize(
"identifier",
[
"",
"test_table",
"id_schema.test_table",
"PUBLIC.test_table",
"id_db.test_schema.test_table",
"id_db.PUBLIC.test_table",
],
)
@pytest.mark.parametrize(
"resolved_schema",
[
PUBLIC_SCHEMA,
"context_schema",
RuntimeError("could not resolve schema"),
],
)
def test_get_identifiers(
context_database,
context_schema,
identifier,
resolved_schema,
):
def expected_result():
if identifier == "":
return RuntimeError("Could not extract the table name.")
db, id_schema, table = get_identifiers_from_string(identifier)
if db is None and context_database is None:
return RuntimeError("Could not resolve database name.")
if id_schema is None and isinstance(resolved_schema, RuntimeError):
return RuntimeError("could not resolve schema")
return (
(db or context_database),
(id_schema or resolved_schema or context_schema),
table,
)
resolver = SnowflakeTableResovler(Mock())
if isinstance(resolved_schema, RuntimeError):
resolver.resolve_implicit_fqn = MagicMock(side_effect=resolved_schema)
else:
resolver.resolve_implicit_fqn = MagicMock(
return_value=(context_database, resolved_schema, identifier)
)
expected_value = expected_result()
if isinstance(expected_value, RuntimeError):
with pytest.raises(type(expected_value), match=str(expected_value)) as e:
resolver.resolve_snowflake_fqn(context_database, context_schema, identifier)
else:
assert expected_value == resolver.resolve_snowflake_fqn(
context_database, context_schema, identifier
)

View File

@ -12,36 +12,21 @@
""" """
Tests utils function for the profiler Tests utils function for the profiler
""" """
import uuid
from datetime import datetime from datetime import datetime
from unittest import TestCase from unittest import TestCase
from unittest.mock import patch from unittest.mock import Mock
import pytest import pytest
from sqlalchemy import Column from sqlalchemy import Column
from sqlalchemy.orm import declarative_base from sqlalchemy.orm import declarative_base
from sqlalchemy.sql.sqltypes import Integer, String from sqlalchemy.sql.sqltypes import Integer, String
from metadata.generated.schema.entity.data.table import Column as OMetaColumn from metadata.ingestion.source.database.snowflake.models import SnowflakeQueryLogEntry
from metadata.generated.schema.entity.data.table import DataType, Table from metadata.ingestion.source.database.snowflake.profiler.system_metrics import (
from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import ( SnowflakeTableResovler,
AuthProvider,
OpenMetadataConnection,
)
from metadata.generated.schema.entity.services.databaseService import (
DatabaseService,
DatabaseServiceType,
)
from metadata.generated.schema.security.client.openMetadataJWTClientConfig import (
OpenMetadataJWTClientConfig,
)
from metadata.generated.schema.type.basic import EntityName, FullyQualifiedEntityName
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.profiler.metrics.hybrid.histogram import Histogram
from metadata.profiler.metrics.system.queries.snowflake import (
get_snowflake_system_queries, get_snowflake_system_queries,
) )
from metadata.profiler.metrics.hybrid.histogram import Histogram
from metadata.profiler.metrics.system.system import recursive_dic from metadata.profiler.metrics.system.system import recursive_dic
from metadata.utils.profiler_utils import ( from metadata.utils.profiler_utils import (
get_identifiers_from_string, get_identifiers_from_string,
@ -50,8 +35,6 @@ from metadata.utils.profiler_utils import (
) )
from metadata.utils.sqa_utils import is_array from metadata.utils.sqa_utils import is_array
from .conftest import Row
Base = declarative_base() Base = declarative_base()
@ -125,7 +108,7 @@ def test_is_array():
def test_get_snowflake_system_queries(): def test_get_snowflake_system_queries():
"""Test get snowflake system queries""" """Test get snowflake system queries"""
row = Row( row = SnowflakeQueryLogEntry(
query_id="1", query_id="1",
query_type="INSERT", query_type="INSERT",
start_time=datetime.now(), start_time=datetime.now(),
@ -133,8 +116,10 @@ def test_get_snowflake_system_queries():
) )
# We don't need the ometa_client nor the db_service if we have all the db.schema.table in the query # We don't need the ometa_client nor the db_service if we have all the db.schema.table in the query
resolver = SnowflakeTableResovler(Mock())
query_result = get_snowflake_system_queries( query_result = get_snowflake_system_queries(
row=row, database="DATABASE", schema="SCHEMA", ometa_client=..., db_service=... query_log_entry=row,
resolver=resolver,
) # type: ignore ) # type: ignore
assert query_result assert query_result
assert query_result.query_id == "1" assert query_result.query_id == "1"
@ -143,15 +128,16 @@ def test_get_snowflake_system_queries():
assert query_result.schema_name == "schema" assert query_result.schema_name == "schema"
assert query_result.table_name == "table1" assert query_result.table_name == "table1"
row = Row( row = SnowflakeQueryLogEntry(
query_id=1, query_id="1",
query_type="INSERT", query_type="INSERT",
start_time=datetime.now(), start_time=datetime.now(),
query_text="INSERT INTO SCHEMA.TABLE1 (col1, col2) VALUES (1, 'a'), (2, 'b')", query_text="INSERT INTO SCHEMA.TABLE1 (col1, col2) VALUES (1, 'a'), (2, 'b')",
) )
query_result = get_snowflake_system_queries( query_result = get_snowflake_system_queries(
row=row, database="DATABASE", schema="SCHEMA", ometa_client=..., db_service=... query_log_entry=row,
resolver=resolver,
) # type: ignore ) # type: ignore
assert not query_result assert not query_result
@ -184,15 +170,17 @@ def test_get_snowflake_system_queries_all_dll(query, expected):
"""test we ca get all ddl queries """test we ca get all ddl queries
reference https://docs.snowflake.com/en/sql-reference/sql-dml reference https://docs.snowflake.com/en/sql-reference/sql-dml
""" """
row = Row( row = SnowflakeQueryLogEntry(
query_id="1", query_id="1",
query_type=expected, query_type=expected,
start_time=datetime.now(), start_time=datetime.now(),
query_text=query, query_text=query,
) )
resolver = Mock()
resolver.resolve_snowflake_fqn = Mock(return_value=("database", "schema", "table1"))
query_result = get_snowflake_system_queries( query_result = get_snowflake_system_queries(
row=row, database="DATABASE", schema="SCHEMA", ometa_client=..., db_service=... query_log_entry=row,
resolver=resolver,
) # type: ignore ) # type: ignore
assert query_result assert query_result
@ -202,7 +190,8 @@ def test_get_snowflake_system_queries_all_dll(query, expected):
assert query_result.table_name == "table1" assert query_result.table_name == "table1"
query_result = get_snowflake_system_queries( query_result = get_snowflake_system_queries(
row=row, database="DATABASE", schema="SCHEMA", ometa_client=..., db_service=... query_log_entry=SnowflakeQueryLogEntry.model_validate(row),
resolver=resolver,
) # type: ignore ) # type: ignore
assert query_result assert query_result
@ -212,75 +201,6 @@ def test_get_snowflake_system_queries_all_dll(query, expected):
assert query_result.table_name == "table1" assert query_result.table_name == "table1"
def test_get_snowflake_system_queries_from_es():
"""Test the ES integration"""
ometa_client = OpenMetadata(
OpenMetadataConnection(
hostPort="http://localhost:8585/api",
authProvider=AuthProvider.openmetadata,
enableVersionValidation=False,
securityConfig=OpenMetadataJWTClientConfig(jwtToken="token"),
)
)
db_service = DatabaseService(
id=uuid.uuid4(),
name=EntityName("service"),
fullyQualifiedName=FullyQualifiedEntityName("service"),
serviceType=DatabaseServiceType.CustomDatabase,
)
table = Table(
id=uuid.uuid4(),
name="TABLE",
columns=[OMetaColumn(name="id", dataType=DataType.BIGINT)],
database=EntityReference(id=uuid.uuid4(), type="database", name="database"),
databaseSchema=EntityReference(
id=uuid.uuid4(), type="databaseSchema", name="schema"
),
)
# With too many responses, we won't return anything since we don't want false results
# that we cannot properly assign
with patch.object(OpenMetadata, "es_search_from_fqn", return_value=[table] * 4):
row = Row(
query_id=1,
query_type="INSERT",
start_time=datetime.now(),
query_text="INSERT INTO TABLE1 (col1, col2) VALUES (1, 'a'), (2, 'b')",
)
query_result = get_snowflake_system_queries(
row=row,
database="DATABASE",
schema="SCHEMA",
ometa_client=ometa_client,
db_service=db_service,
)
assert not query_result
# Returning a single table should work fine
with patch.object(OpenMetadata, "es_search_from_fqn", return_value=[table]):
row = Row(
query_id="1",
query_type="INSERT",
start_time=datetime.now(),
query_text="INSERT INTO TABLE2 (col1, col2) VALUES (1, 'a'), (2, 'b')",
)
query_result = get_snowflake_system_queries(
row=row,
database="DATABASE",
schema="SCHEMA",
ometa_client=ometa_client,
db_service=db_service,
)
assert query_result
assert query_result.query_type == "INSERT"
assert query_result.database_name == "database"
assert query_result.schema_name == "schema"
assert query_result.table_name == "table2"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"identifier, expected", "identifier, expected",
[ [