mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-08-19 06:28:03 +00:00
Fix 17698: use resolution logic for snowflake system metrics profiler (#17699)
* fix(profiler): snowflake resolve tables using the snowflake engine instead of OpenMetadata * added env for cleaning up dbs in E2E * moved system metric method to profiler. all the rest says in snowflake * format * revert unnecessary changes * removed test for previous resolution method * use shutdown39
This commit is contained in:
parent
b2f21fa070
commit
84be1a3162
@ -1,10 +1,13 @@
|
||||
from collections import deque
|
||||
from typing import List, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def assert_equal_pydantic_objects(
|
||||
expected: BaseModel, actual: BaseModel, ignore_none=True
|
||||
expected: Union[BaseModel, List[BaseModel]],
|
||||
actual: Union[BaseModel, List[BaseModel]],
|
||||
ignore_none=True,
|
||||
):
|
||||
"""Compare 2 pydantic objects recursively and raise an AssertionError if they are not equal along with all
|
||||
the differences by field. If `ignore_none` is set to True, expected None values will be ignored. This can be
|
||||
@ -32,6 +35,10 @@ def assert_equal_pydantic_objects(
|
||||
Traceback (most recent call last):
|
||||
```
|
||||
AssertionError: objects mismatched on field: [b.a], expected: [1], actual: [2]
|
||||
>>> assert_equal_pydantic_objects([a1, a2], [a2, a1])
|
||||
Traceback (most recent call last):
|
||||
```
|
||||
AssertionError: objects mismatched on field: [0].a, expected: [1], actual: [2]
|
||||
|
||||
Args:
|
||||
expected (BaseModel): The expected pydantic object.
|
||||
@ -69,11 +76,24 @@ def assert_equal_pydantic_objects(
|
||||
errors.append(
|
||||
f"objects mismatched on field: [{new_key_prefix}], expected: [{expected_value}], actual: [{actual_value}]"
|
||||
)
|
||||
elif isinstance(expected, list):
|
||||
if not isinstance(actual, list):
|
||||
errors.append(
|
||||
f"validation error on field: [{current_key_prefix}], expected: [list], actual: [{type(actual).__name__}]"
|
||||
)
|
||||
elif len(expected) != len(actual):
|
||||
errors.append(
|
||||
f"mismatch length at {current_key_prefix}: expected: [{len(expected)}], actual: [{len(actual)}]"
|
||||
)
|
||||
else:
|
||||
for i, (expected_item, actual_item) in enumerate(zip(expected, actual)):
|
||||
queue.append(
|
||||
(expected_item, actual_item, f"{current_key_prefix}[{i}]")
|
||||
)
|
||||
else:
|
||||
if expected != actual:
|
||||
errors.append(
|
||||
f"mismatch at {current_key_prefix}: expected: [{expected}], actual: [{actual}]"
|
||||
)
|
||||
|
||||
if errors:
|
||||
raise AssertionError("\n".join(errors))
|
||||
|
@ -16,7 +16,7 @@ To be used by OpenMetadata class
|
||||
import traceback
|
||||
from typing import List, Optional, Type, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, validate_call
|
||||
from requests.utils import quote
|
||||
|
||||
from metadata.generated.schema.api.data.createTableProfile import (
|
||||
@ -227,6 +227,7 @@ class OMetaTableMixin:
|
||||
|
||||
return None
|
||||
|
||||
@validate_call
|
||||
def get_profile_data(
|
||||
self,
|
||||
fqn: str,
|
||||
@ -253,7 +254,6 @@ class OMetaTableMixin:
|
||||
Returns:
|
||||
EntityList: EntityList list object
|
||||
"""
|
||||
|
||||
url_after = f"&after={after}" if after else ""
|
||||
profile_type_url = profile_type.__name__[0].lower() + profile_type.__name__[1:]
|
||||
|
||||
|
@ -26,7 +26,7 @@ from metadata.ingestion.lineage.models import ConnectionTypeDialectMapper, Diale
|
||||
from metadata.ingestion.lineage.parser import LineageParser
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
from metadata.utils.logger import ingestion_logger
|
||||
from metadata.utils.time_utils import convert_timestamp_to_milliseconds
|
||||
from metadata.utils.time_utils import datetime_to_timestamp
|
||||
|
||||
logger = ingestion_logger()
|
||||
|
||||
@ -46,7 +46,7 @@ def parse_sql_statement(record: TableQuery, dialect: Dialect) -> Optional[Parsed
|
||||
start_date = start_time.root.date()
|
||||
start_time = datetime.datetime.strptime(str(start_date.isoformat()), "%Y-%m-%d")
|
||||
|
||||
start_time = convert_timestamp_to_milliseconds(int(start_time.timestamp()))
|
||||
start_time = datetime_to_timestamp(start_time, milliseconds=True)
|
||||
|
||||
lineage_parser = LineageParser(record.query, dialect=dialect)
|
||||
|
||||
|
@ -36,7 +36,7 @@ from metadata.ingestion.models.topology import TopologyContextManager
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
from metadata.utils import fqn
|
||||
from metadata.utils.logger import ingestion_logger
|
||||
from metadata.utils.time_utils import convert_timestamp_to_milliseconds
|
||||
from metadata.utils.time_utils import datetime_to_timestamp
|
||||
|
||||
logger = ingestion_logger()
|
||||
|
||||
@ -104,10 +104,8 @@ class LifeCycleQueryMixin:
|
||||
life_cycle = LifeCycle(
|
||||
created=AccessDetails(
|
||||
timestamp=Timestamp(
|
||||
int(
|
||||
convert_timestamp_to_milliseconds(
|
||||
life_cycle_data.created_at.timestamp()
|
||||
)
|
||||
datetime_to_timestamp(
|
||||
life_cycle_data.created_at, milliseconds=True
|
||||
)
|
||||
)
|
||||
)
|
||||
|
@ -15,11 +15,18 @@ import urllib
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from requests.utils import quote
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from metadata.generated.schema.entity.data.storedProcedure import Language
|
||||
from metadata.ingestion.source.database.snowflake.queries import (
|
||||
SNOWFLAKE_QUERY_LOG_QUERY,
|
||||
)
|
||||
from metadata.profiler.metrics.system.dml_operation import DatabaseDMLOperations
|
||||
from metadata.utils.logger import ingestion_logger
|
||||
from metadata.utils.profiler_utils import QueryResult
|
||||
|
||||
logger = ingestion_logger()
|
||||
|
||||
@ -95,3 +102,44 @@ class SnowflakeTableList(BaseModel):
|
||||
|
||||
def get_not_deleted(self) -> List[SnowflakeTable]:
|
||||
return [table for table in self.tables if not table.deleted]
|
||||
|
||||
|
||||
class SnowflakeQueryLogEntry(BaseModel):
|
||||
"""Entry for a Snowflake query log at SNOWFLAKE.ACCOUNT_USAGE.QUERY_HISTORY
|
||||
More info at: https://docs.snowflake.com/en/sql-reference/account-usage/query_history
|
||||
"""
|
||||
|
||||
query_id: str
|
||||
database_name: Optional[str] = None
|
||||
schema_name: Optional[str] = None
|
||||
query_type: str
|
||||
start_time: datetime
|
||||
query_text: Optional[str] = None
|
||||
rows_inserted: Optional[int] = None
|
||||
rows_updated: Optional[int] = None
|
||||
rows_deleted: Optional[int] = None
|
||||
|
||||
@staticmethod
|
||||
def get_for_table(session: Session, tablename: str):
|
||||
rows = session.execute(
|
||||
text(
|
||||
SNOWFLAKE_QUERY_LOG_QUERY.format(
|
||||
tablename=tablename, # type: ignore
|
||||
insert=DatabaseDMLOperations.INSERT.value,
|
||||
update=DatabaseDMLOperations.UPDATE.value,
|
||||
delete=DatabaseDMLOperations.DELETE.value,
|
||||
merge=DatabaseDMLOperations.MERGE.value,
|
||||
)
|
||||
)
|
||||
)
|
||||
return TypeAdapter(List[SnowflakeQueryLogEntry]).validate_python(
|
||||
map(dict, rows)
|
||||
)
|
||||
|
||||
|
||||
class SnowflakeQueryResult(QueryResult):
|
||||
"""Snowflake system metric query result"""
|
||||
|
||||
rows_inserted: Optional[int] = None
|
||||
rows_updated: Optional[int] = None
|
||||
rows_deleted: Optional[int] = None
|
||||
|
@ -0,0 +1,260 @@
|
||||
import re
|
||||
import traceback
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import sqlalchemy.orm
|
||||
from sqlalchemy.orm import DeclarativeMeta, Session
|
||||
|
||||
from metadata.ingestion.source.database.snowflake.models import (
|
||||
SnowflakeQueryLogEntry,
|
||||
SnowflakeQueryResult,
|
||||
)
|
||||
from metadata.utils.logger import profiler_logger
|
||||
from metadata.utils.lru_cache import LRU_CACHE_SIZE, LRUCache
|
||||
from metadata.utils.profiler_utils import get_identifiers_from_string
|
||||
|
||||
PUBLIC_SCHEMA = "PUBLIC"
|
||||
logger = profiler_logger()
|
||||
RESULT_SCAN = """
|
||||
SELECT *
|
||||
FROM TABLE(RESULT_SCAN('{query_id}'));
|
||||
"""
|
||||
QUERY_PATTERN = r"(?:(INSERT\s*INTO\s*|INSERT\s*OVERWRITE\s*INTO\s*|UPDATE\s*|MERGE\s*INTO\s*|DELETE\s*FROM\s*))([\w._\"\'()]+)(?=[\s*\n])" # pylint: disable=line-too-long
|
||||
IDENTIFIER_PATTERN = r"(IDENTIFIER\(\')([\w._\"]+)(\'\))"
|
||||
|
||||
|
||||
def _parse_query(query: str) -> Optional[str]:
|
||||
"""Parse snowflake queries to extract the identifiers"""
|
||||
match = re.match(QUERY_PATTERN, query, re.IGNORECASE)
|
||||
try:
|
||||
# This will match results like `DATABASE.SCHEMA.TABLE1` or IDENTIFIER('TABLE1')
|
||||
# If we have `IDENTIFIER` type of queries coming from Stored Procedures, we'll need to further clean it up.
|
||||
identifier = match.group(2)
|
||||
|
||||
match_internal_identifier = re.match(
|
||||
IDENTIFIER_PATTERN, identifier, re.IGNORECASE
|
||||
)
|
||||
internal_identifier = (
|
||||
match_internal_identifier.group(2) if match_internal_identifier else None
|
||||
)
|
||||
if internal_identifier:
|
||||
return internal_identifier
|
||||
|
||||
return identifier
|
||||
except (IndexError, AttributeError):
|
||||
logger.debug("Could not find identifier in query. Skipping row.")
|
||||
return None
|
||||
|
||||
|
||||
class SnowflakeTableResovler:
|
||||
def __init__(self, session: sqlalchemy.orm.Session):
|
||||
self._cache = LRUCache(LRU_CACHE_SIZE)
|
||||
self.session = session
|
||||
|
||||
def show_tables(self, db, schema, table):
|
||||
return self.session.execute(
|
||||
f'SHOW TABLES LIKE \'{table}\' IN SCHEMA "{db}"."{schema}" LIMIT 1;'
|
||||
).fetchone()
|
||||
|
||||
def table_exists(self, db, schema, table):
|
||||
"""Return True if the table exists in Snowflake. Uses cache to store the results.
|
||||
|
||||
Args:
|
||||
db (str): Database name
|
||||
schema (str): Schema name
|
||||
table (str): Table name
|
||||
|
||||
Returns:
|
||||
bool: True if the table exists in Snowflake
|
||||
"""
|
||||
if f"{db}.{schema}.{table}" in self._cache:
|
||||
return self._cache.get(f"{db}.{schema}.{table}")
|
||||
table = self.show_tables(db, schema, table)
|
||||
if table:
|
||||
self._cache.put(f"{db}.{schema}.{table}", True)
|
||||
return True
|
||||
return False
|
||||
|
||||
def resolve_implicit_fqn(
|
||||
self,
|
||||
context_database: str,
|
||||
context_schema: Optional[str],
|
||||
table_name: str,
|
||||
) -> Tuple[str, str, str]:
|
||||
"""Resolve the fully qualified name of the table from snowflake based on the following logic:
|
||||
1. If the schema is provided:
|
||||
a. search for the table in the schema
|
||||
b. if not found, go to (2)
|
||||
2. Search for the table in the public schema.
|
||||
|
||||
Args:
|
||||
context_database (str): Database name
|
||||
context_schema (Optional[str]): Schema name. If not provided, we'll search in the public schema.
|
||||
table_name (str): Table name
|
||||
Returns:
|
||||
tuple: Tuple of database, schema and table names
|
||||
Raises:
|
||||
RuntimeError: If the table is not found in the metadata or if there are duplicate results (there shouldn't be)
|
||||
|
||||
"""
|
||||
search_paths = []
|
||||
if context_schema and self.table_exists(
|
||||
context_database, context_schema, table_name
|
||||
):
|
||||
search_paths += ".".join([context_database, context_schema, table_name])
|
||||
return context_database, context_schema, table_name
|
||||
if context_schema != PUBLIC_SCHEMA and self.table_exists(
|
||||
context_database, PUBLIC_SCHEMA, table_name
|
||||
):
|
||||
search_paths += ".".join([context_database, PUBLIC_SCHEMA, table_name])
|
||||
return context_database, PUBLIC_SCHEMA, table_name
|
||||
raise RuntimeError(
|
||||
"Could not find the table {search_paths}.".format(
|
||||
search_paths=" OR ".join(map(lambda x: f"[{x}]", search_paths))
|
||||
)
|
||||
)
|
||||
|
||||
def resolve_snowflake_fqn(
|
||||
self,
|
||||
context_database: str,
|
||||
context_schema: Optional[str],
|
||||
identifier: str,
|
||||
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
"""Get query identifiers from the query text. If the schema is not provided in the query, we'll look for
|
||||
the table under "PUBLIC" in Snowflake.
|
||||
Database can be retrieved from the query or the query context.
|
||||
If the schema doesnt exist in the query but does in the context, we need to check with Snowflake if table
|
||||
exists in (1) the context schema or ib (2) the public schema in order to imitate the behavior of the query
|
||||
engine. There are edge cases where the table was deleted (and hence not found in the metadata). In such cases,
|
||||
the function will raise an error. It is advised to set the profier window such that there will be minimal
|
||||
drift between the query execution and the profiler run.
|
||||
|
||||
Args:
|
||||
context_database (str): Database name from the query context
|
||||
context_schema (Optional[str]): Schema name from the query context
|
||||
identifier (str): Identifier string extracted from a query (can be 'db.schema.table', 'schema.table' or just 'table')
|
||||
Returns:
|
||||
Tuple[Optional[str], Optional[str], Optional[str]]: Tuple of database, schema and table names
|
||||
Raises:
|
||||
RuntimeError: If the table name is not found in the query or if fqn resolution fails
|
||||
"""
|
||||
(
|
||||
database_identifier,
|
||||
schema_identifier,
|
||||
table_name,
|
||||
) = get_identifiers_from_string(identifier)
|
||||
if not table_name:
|
||||
raise RuntimeError("Could not extract the table name.")
|
||||
if not context_database and not database_identifier:
|
||||
logger.debug(
|
||||
f"Could not resolve database name. {identifier=}, {context_database=}"
|
||||
)
|
||||
raise RuntimeError("Could not resolve database name.")
|
||||
if schema_identifier is not None:
|
||||
return (
|
||||
database_identifier or context_database,
|
||||
schema_identifier,
|
||||
table_name,
|
||||
)
|
||||
logger.debug(
|
||||
"Missing schema info from the query. We'll look for it in Snowflake for [%s] or [%s]",
|
||||
(
|
||||
".".join(
|
||||
[
|
||||
database_identifier or context_database,
|
||||
context_schema,
|
||||
table_name,
|
||||
]
|
||||
)
|
||||
if context_schema
|
||||
else None
|
||||
),
|
||||
".".join(
|
||||
[database_identifier or context_database, PUBLIC_SCHEMA, table_name]
|
||||
),
|
||||
)
|
||||
# If the schema is not explicitly provided in the query, we'll need to resolve it from OpenMetadata
|
||||
# by cascading the search from the context to the public schema.
|
||||
result = self.resolve_implicit_fqn(
|
||||
context_database=context_database,
|
||||
context_schema=context_schema,
|
||||
table_name=table_name,
|
||||
)
|
||||
logger.debug("Resolved table [%s]", ".".join(result))
|
||||
return result
|
||||
|
||||
|
||||
def get_snowflake_system_queries(
|
||||
query_log_entry: SnowflakeQueryLogEntry,
|
||||
resolver: SnowflakeTableResovler,
|
||||
) -> Optional[SnowflakeQueryResult]:
|
||||
"""
|
||||
Run a regex lookup on the query to identify which operation ran against the table.
|
||||
|
||||
If the query does not have the complete set of `database.schema.table` when it runs,
|
||||
we'll use ES to pick up the table, if we find it.
|
||||
|
||||
Args:
|
||||
query_log_entry (dict): row from the snowflake system queries table
|
||||
resolver (SnowflakeTableResolver): resolver to get the table identifiers
|
||||
Returns:
|
||||
QueryResult: namedtuple with the query result
|
||||
"""
|
||||
|
||||
try:
|
||||
logger.debug(f"Trying to parse query [{query_log_entry.query_id}]")
|
||||
identifier = _parse_query(query_log_entry.query_text)
|
||||
if not identifier:
|
||||
raise RuntimeError("Could not identify the table from the query.")
|
||||
|
||||
database_name, schema_name, table_name = resolver.resolve_snowflake_fqn(
|
||||
identifier=identifier,
|
||||
context_database=query_log_entry.database_name,
|
||||
context_schema=query_log_entry.schema_name,
|
||||
)
|
||||
|
||||
if not all([database_name, schema_name, table_name]):
|
||||
raise RuntimeError(
|
||||
f"Could not extract the identifiers from the query [{query_log_entry.query_id}]."
|
||||
)
|
||||
|
||||
return SnowflakeQueryResult(
|
||||
query_id=query_log_entry.query_id,
|
||||
database_name=database_name.lower(),
|
||||
schema_name=schema_name.lower(),
|
||||
table_name=table_name.lower(),
|
||||
query_text=query_log_entry.query_text,
|
||||
query_type=query_log_entry.query_type,
|
||||
start_time=query_log_entry.start_time,
|
||||
rows_inserted=query_log_entry.rows_inserted,
|
||||
rows_updated=query_log_entry.rows_updated,
|
||||
rows_deleted=query_log_entry.rows_deleted,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"""Error while processing query with id [{query_log_entry.query_id}]: {exc}\n
|
||||
To investigate the query run:
|
||||
SELECT * FROM SNOWFLAKE.ACCOUNT_USAGE.QUERY_HISTORY WHERE query_id = '{query_log_entry.query_id}'
|
||||
"""
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def build_snowflake_query_results(
|
||||
session: Session,
|
||||
table: DeclarativeMeta,
|
||||
) -> List[SnowflakeQueryResult]:
|
||||
"""List and parse snowflake DML query results"""
|
||||
query_results = []
|
||||
resolver = SnowflakeTableResovler(
|
||||
session=session,
|
||||
)
|
||||
for row in SnowflakeQueryLogEntry.get_for_table(session, table.__tablename__):
|
||||
result = get_snowflake_system_queries(
|
||||
query_log_entry=row,
|
||||
resolver=resolver,
|
||||
)
|
||||
if result:
|
||||
query_results.append(result)
|
||||
return query_results
|
@ -355,3 +355,26 @@ ORDER BY PROCEDURE_START_TIME DESC
|
||||
SNOWFLAKE_GET_TABLE_DDL = """
|
||||
SELECT GET_DDL('TABLE','{table_name}') AS \"text\"
|
||||
"""
|
||||
SNOWFLAKE_QUERY_LOG_QUERY = """
|
||||
SELECT
|
||||
QUERY_ID,
|
||||
QUERY_TEXT,
|
||||
QUERY_TYPE,
|
||||
START_TIME,
|
||||
DATABASE_NAME,
|
||||
SCHEMA_NAME,
|
||||
ROWS_INSERTED,
|
||||
ROWS_UPDATED,
|
||||
ROWS_DELETED
|
||||
FROM "SNOWFLAKE"."ACCOUNT_USAGE"."QUERY_HISTORY"
|
||||
WHERE
|
||||
start_time>= DATEADD('DAY', -1, CURRENT_TIMESTAMP)
|
||||
AND QUERY_TEXT ILIKE '%{tablename}%'
|
||||
AND QUERY_TYPE IN (
|
||||
'{insert}',
|
||||
'{update}',
|
||||
'{delete}',
|
||||
'{merge}'
|
||||
)
|
||||
AND EXECUTION_STATUS = 'SUCCESS';
|
||||
"""
|
||||
|
@ -29,6 +29,7 @@ from metadata.generated.schema.entity.data.databaseSchema import (
|
||||
)
|
||||
from metadata.generated.schema.entity.data.table import (
|
||||
PartitionProfilerConfig,
|
||||
SystemProfile,
|
||||
Table,
|
||||
TableData,
|
||||
)
|
||||
@ -462,7 +463,7 @@ class ProfilerInterface(ABC):
|
||||
runner,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
) -> List[SystemProfile]:
|
||||
"""Get metrics"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -442,7 +442,6 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
|
||||
sample,
|
||||
)
|
||||
row = None
|
||||
|
||||
try:
|
||||
row = self._get_metric_fn[metric_func.metric_type.value](
|
||||
metric_func.metrics,
|
||||
@ -451,11 +450,9 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
|
||||
column=metric_func.column,
|
||||
sample=sample,
|
||||
)
|
||||
if row and isinstance(row, dict):
|
||||
if isinstance(row, dict):
|
||||
row = self._validate_nulls(row)
|
||||
|
||||
# System metrics return a list of dictionaries, with UPDATE, INSERT or DELETE ops results
|
||||
if row and metric_func.metric_type == MetricTypes.System:
|
||||
if isinstance(row, list):
|
||||
row = [
|
||||
self._validate_nulls(r) if isinstance(r, dict) else r
|
||||
for r in row
|
||||
@ -537,6 +534,9 @@ class SQAProfilerInterface(ProfilerInterface, SQAInterfaceMixin):
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.error(f"Operation was cancelled due to TimeoutError - {exc}")
|
||||
raise concurrent.futures.TimeoutError
|
||||
except KeyboardInterrupt:
|
||||
pool.shutdown39(wait=True, cancel_futures=True)
|
||||
raise
|
||||
|
||||
return profile_results
|
||||
|
||||
|
@ -19,6 +19,7 @@ from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from metadata.utils.profiler_utils import QueryResult
|
||||
from metadata.utils.time_utils import datetime_to_timestamp
|
||||
|
||||
STL_QUERY = """
|
||||
with data as (
|
||||
@ -73,7 +74,7 @@ def get_query_results(
|
||||
table_name=row.table,
|
||||
query_text=None,
|
||||
query_type=operation,
|
||||
timestamp=row.starttime,
|
||||
start_time=row.starttime,
|
||||
rows=row.rows,
|
||||
)
|
||||
for row in cursor
|
||||
@ -94,7 +95,7 @@ def get_metric_result(ddls: List[QueryResult], table_name: str) -> List:
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"timestamp": int(ddl.timestamp.timestamp() * 1000),
|
||||
"timestamp": datetime_to_timestamp(ddl.start_time, milliseconds=True),
|
||||
"operation": ddl.query_type,
|
||||
"rowsAffected": ddl.rows,
|
||||
}
|
||||
|
@ -1,191 +0,0 @@
|
||||
# Copyright 2021 Collate
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Snowflake System Metric Queries and query operations
|
||||
"""
|
||||
|
||||
import re
|
||||
import traceback
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from sqlalchemy.engine.row import Row
|
||||
|
||||
from metadata.generated.schema.entity.services.databaseService import DatabaseService
|
||||
from metadata.ingestion.lineage.sql_lineage import search_table_entities
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
from metadata.utils.logger import profiler_logger
|
||||
from metadata.utils.profiler_utils import (
|
||||
SnowflakeQueryResult,
|
||||
get_identifiers_from_string,
|
||||
)
|
||||
|
||||
logger = profiler_logger()
|
||||
|
||||
INFORMATION_SCHEMA_QUERY = """
|
||||
SELECT
|
||||
QUERY_ID,
|
||||
QUERY_TEXT,
|
||||
QUERY_TYPE,
|
||||
START_TIME,
|
||||
ROWS_INSERTED,
|
||||
ROWS_UPDATED,
|
||||
ROWS_DELETED
|
||||
FROM "SNOWFLAKE"."ACCOUNT_USAGE"."QUERY_HISTORY"
|
||||
WHERE
|
||||
start_time>= DATEADD('DAY', -1, CURRENT_TIMESTAMP)
|
||||
AND QUERY_TEXT ILIKE '%{tablename}%'
|
||||
AND QUERY_TYPE IN (
|
||||
'{insert}',
|
||||
'{update}',
|
||||
'{delete}',
|
||||
'{merge}'
|
||||
)
|
||||
AND EXECUTION_STATUS = 'SUCCESS';
|
||||
"""
|
||||
|
||||
RESULT_SCAN = """
|
||||
SELECT *
|
||||
FROM TABLE(RESULT_SCAN('{query_id}'));
|
||||
"""
|
||||
|
||||
|
||||
QUERY_PATTERN = r"(?:(INSERT\s*INTO\s*|INSERT\s*OVERWRITE\s*INTO\s*|UPDATE\s*|MERGE\s*INTO\s*|DELETE\s*FROM\s*))([\w._\"\'()]+)(?=[\s*\n])" # pylint: disable=line-too-long
|
||||
IDENTIFIER_PATTERN = r"(IDENTIFIER\(\')([\w._\"]+)(\'\))"
|
||||
|
||||
|
||||
def _parse_query(query: str) -> Optional[str]:
|
||||
"""Parse snowflake queries to extract the identifiers"""
|
||||
match = re.match(QUERY_PATTERN, query, re.IGNORECASE)
|
||||
try:
|
||||
# This will match results like `DATABASE.SCHEMA.TABLE1` or IDENTIFIER('TABLE1')
|
||||
# If we have `IDENTIFIER` type of queries coming from Stored Procedures, we'll need to further clean it up.
|
||||
identifier = match.group(2)
|
||||
|
||||
match_internal_identifier = re.match(
|
||||
IDENTIFIER_PATTERN, identifier, re.IGNORECASE
|
||||
)
|
||||
internal_identifier = (
|
||||
match_internal_identifier.group(2) if match_internal_identifier else None
|
||||
)
|
||||
if internal_identifier:
|
||||
return internal_identifier
|
||||
|
||||
return identifier
|
||||
except (IndexError, AttributeError):
|
||||
logger.debug("Could not find identifier in query. Skipping row.")
|
||||
return None
|
||||
|
||||
|
||||
def get_identifiers(
|
||||
identifier: str, ometa_client: OpenMetadata, db_service: DatabaseService
|
||||
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
"""Get query identifiers and if needed, fetch them from ES"""
|
||||
database_name, schema_name, table_name = get_identifiers_from_string(identifier)
|
||||
|
||||
if not table_name:
|
||||
logger.debug("Could not extract the table name. Skipping operation.")
|
||||
return database_name, schema_name, table_name
|
||||
|
||||
if not all([database_name, schema_name]):
|
||||
logger.debug(
|
||||
"Missing database or schema info from the query. We'll look for it in ES."
|
||||
)
|
||||
es_tables = search_table_entities(
|
||||
metadata=ometa_client,
|
||||
service_name=db_service.fullyQualifiedName.root,
|
||||
database=database_name,
|
||||
database_schema=schema_name,
|
||||
table=table_name,
|
||||
)
|
||||
|
||||
if not es_tables:
|
||||
logger.debug("No tables match the search criteria.")
|
||||
return database_name, schema_name, table_name
|
||||
|
||||
if len(es_tables) > 1:
|
||||
logger.debug(
|
||||
"Found more than 1 table matching the search criteria."
|
||||
" Skipping the computation to not mix system data."
|
||||
)
|
||||
return database_name, schema_name, table_name
|
||||
|
||||
matched_table = es_tables[0]
|
||||
database_name = matched_table.database.name
|
||||
schema_name = matched_table.databaseSchema.name
|
||||
|
||||
return database_name, schema_name, table_name
|
||||
|
||||
|
||||
def get_snowflake_system_queries(
|
||||
row: Row,
|
||||
database: str,
|
||||
schema: str,
|
||||
ometa_client: OpenMetadata,
|
||||
db_service: DatabaseService,
|
||||
) -> Optional[SnowflakeQueryResult]:
|
||||
"""
|
||||
Run a regex lookup on the query to identify which operation ran against the table.
|
||||
|
||||
If the query does not have the complete set of `database.schema.table` when it runs,
|
||||
we'll use ES to pick up the table, if we find it.
|
||||
|
||||
Args:
|
||||
row (dict): row from the snowflake system queries table
|
||||
database (str): database name
|
||||
schema (str): schema name
|
||||
ometa_client (OpenMetadata): OpenMetadata client to search against ES
|
||||
db_service (DatabaseService): DB service where the process is running against
|
||||
Returns:
|
||||
QueryResult: namedtuple with the query result
|
||||
"""
|
||||
|
||||
try:
|
||||
dict_row = dict(row)
|
||||
query_text = dict_row.get("QUERY_TEXT", dict_row.get("query_text"))
|
||||
logger.debug(f"Trying to parse query:\n{query_text}\n")
|
||||
|
||||
identifier = _parse_query(query_text)
|
||||
if not identifier:
|
||||
return None
|
||||
|
||||
database_name, schema_name, table_name = get_identifiers(
|
||||
identifier=identifier,
|
||||
ometa_client=ometa_client,
|
||||
db_service=db_service,
|
||||
)
|
||||
|
||||
if not all([database_name, schema_name, table_name]):
|
||||
return None
|
||||
|
||||
if (
|
||||
database.lower() == database_name.lower()
|
||||
and schema.lower() == schema_name.lower()
|
||||
):
|
||||
return SnowflakeQueryResult(
|
||||
query_id=dict_row.get("QUERY_ID", dict_row.get("query_id")),
|
||||
database_name=database_name.lower(),
|
||||
schema_name=schema_name.lower(),
|
||||
table_name=table_name.lower(),
|
||||
query_text=query_text,
|
||||
query_type=dict_row.get("QUERY_TYPE", dict_row.get("query_type")),
|
||||
timestamp=dict_row.get("START_TIME", dict_row.get("start_time")),
|
||||
rows_inserted=dict_row.get(
|
||||
"ROWS_INSERTED", dict_row.get("rows_inserted")
|
||||
),
|
||||
rows_updated=dict_row.get("ROWS_UPDATED", dict_row.get("rows_updated")),
|
||||
rows_deleted=dict_row.get("ROWS_DELETED", dict_row.get("rows_deleted")),
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(traceback.format_exc())
|
||||
return None
|
||||
|
||||
return None
|
@ -17,15 +17,18 @@ import traceback
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import DeclarativeMeta, Session
|
||||
|
||||
from metadata.generated.schema.configuration.profilerConfiguration import MetricType
|
||||
from metadata.generated.schema.entity.data.table import SystemProfile
|
||||
from metadata.generated.schema.entity.services.connections.database.bigQueryConnection import (
|
||||
BigQueryConnection,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.databaseService import DatabaseService
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
from metadata.ingestion.source.database.snowflake.profiler.system_metrics import (
|
||||
build_snowflake_query_results,
|
||||
)
|
||||
from metadata.profiler.metrics.core import SystemMetric
|
||||
from metadata.profiler.metrics.system.dml_operation import (
|
||||
DML_OPERATION_MAP,
|
||||
@ -41,19 +44,12 @@ from metadata.profiler.metrics.system.queries.redshift import (
|
||||
get_metric_result,
|
||||
get_query_results,
|
||||
)
|
||||
from metadata.profiler.metrics.system.queries.snowflake import (
|
||||
INFORMATION_SCHEMA_QUERY,
|
||||
get_snowflake_system_queries,
|
||||
)
|
||||
from metadata.profiler.orm.registry import Dialects
|
||||
from metadata.utils.dispatch import valuedispatch
|
||||
from metadata.utils.helpers import deep_size_of_dict
|
||||
from metadata.utils.logger import profiler_logger
|
||||
from metadata.utils.profiler_utils import (
|
||||
SnowflakeQueryResult,
|
||||
get_value_from_cache,
|
||||
set_cache,
|
||||
)
|
||||
from metadata.utils.profiler_utils import get_value_from_cache, set_cache
|
||||
from metadata.utils.time_utils import datetime_to_timestamp
|
||||
|
||||
logger = profiler_logger()
|
||||
|
||||
@ -75,7 +71,7 @@ def get_system_metrics_for_dialect(
|
||||
table: DeclarativeMeta,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Optional[Dict]:
|
||||
) -> Optional[List[SystemProfile]]:
|
||||
"""_summary_
|
||||
|
||||
Args:
|
||||
@ -91,6 +87,7 @@ def get_system_metrics_for_dialect(
|
||||
} else returns None
|
||||
"""
|
||||
logger.debug(f"System metrics not support for {dialect}. Skipping processing.")
|
||||
return None
|
||||
|
||||
|
||||
@get_system_metrics_for_dialect.register(Dialects.BigQuery)
|
||||
@ -101,7 +98,7 @@ def _(
|
||||
conn_config: BigQueryConnection,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> List[Dict]:
|
||||
) -> List[SystemProfile]:
|
||||
"""Compute system metrics for bigquery
|
||||
|
||||
Args:
|
||||
@ -190,7 +187,7 @@ def _(
|
||||
}
|
||||
)
|
||||
|
||||
return metric_results
|
||||
return TypeAdapter(List[SystemProfile]).validate_python(metric_results)
|
||||
|
||||
|
||||
@get_system_metrics_for_dialect.register(Dialects.Redshift)
|
||||
@ -200,7 +197,7 @@ def _(
|
||||
table: DeclarativeMeta,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> List[Dict]:
|
||||
) -> List[SystemProfile]:
|
||||
"""List all the DML operations for reshifts tables
|
||||
|
||||
Args:
|
||||
@ -289,42 +286,7 @@ def _(
|
||||
)
|
||||
metric_results.extend(get_metric_result(updates, table.__tablename__)) # type: ignore
|
||||
|
||||
return metric_results
|
||||
|
||||
|
||||
def _snowflake_build_query_result(
|
||||
session: Session,
|
||||
table: DeclarativeMeta,
|
||||
database: str,
|
||||
schema: str,
|
||||
ometa_client: OpenMetadata,
|
||||
db_service: DatabaseService,
|
||||
) -> List[SnowflakeQueryResult]:
|
||||
"""List and parse snowflake DML query results"""
|
||||
rows = session.execute(
|
||||
text(
|
||||
INFORMATION_SCHEMA_QUERY.format(
|
||||
tablename=table.__tablename__, # type: ignore
|
||||
insert=DatabaseDMLOperations.INSERT.value,
|
||||
update=DatabaseDMLOperations.UPDATE.value,
|
||||
delete=DatabaseDMLOperations.DELETE.value,
|
||||
merge=DatabaseDMLOperations.MERGE.value,
|
||||
)
|
||||
)
|
||||
)
|
||||
query_results = []
|
||||
for row in rows:
|
||||
result = get_snowflake_system_queries(
|
||||
row=row,
|
||||
database=database,
|
||||
schema=schema,
|
||||
ometa_client=ometa_client,
|
||||
db_service=db_service,
|
||||
)
|
||||
if result:
|
||||
query_results.append(result)
|
||||
|
||||
return query_results
|
||||
return TypeAdapter(List[SystemProfile]).validate_python(metric_results).d
|
||||
|
||||
|
||||
@get_system_metrics_for_dialect.register(Dialects.Snowflake)
|
||||
@ -332,8 +294,6 @@ def _(
|
||||
dialect: str,
|
||||
session: Session,
|
||||
table: DeclarativeMeta,
|
||||
ometa_client: OpenMetadata,
|
||||
db_service: DatabaseService,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Optional[List[Dict]]:
|
||||
@ -354,18 +314,12 @@ def _(
|
||||
Dict: system metric
|
||||
"""
|
||||
logger.debug(f"Fetching system metrics for {dialect}")
|
||||
database = session.get_bind().url.database
|
||||
schema = table.__table_args__["schema"] # type: ignore
|
||||
|
||||
metric_results: List[Dict] = []
|
||||
|
||||
query_results = _snowflake_build_query_result(
|
||||
query_results = build_snowflake_query_results(
|
||||
session=session,
|
||||
table=table,
|
||||
database=database,
|
||||
schema=schema,
|
||||
ometa_client=ometa_client,
|
||||
db_service=db_service,
|
||||
)
|
||||
|
||||
for query_result in query_results:
|
||||
@ -380,7 +334,9 @@ def _(
|
||||
if query_result.rows_inserted:
|
||||
metric_results.append(
|
||||
{
|
||||
"timestamp": int(query_result.timestamp.timestamp() * 1000),
|
||||
"timestamp": datetime_to_timestamp(
|
||||
query_result.start_time, milliseconds=True
|
||||
),
|
||||
"operation": DatabaseDMLOperations.INSERT.value,
|
||||
"rowsAffected": query_result.rows_inserted,
|
||||
}
|
||||
@ -388,7 +344,9 @@ def _(
|
||||
if query_result.rows_updated:
|
||||
metric_results.append(
|
||||
{
|
||||
"timestamp": int(query_result.timestamp.timestamp() * 1000),
|
||||
"timestamp": datetime_to_timestamp(
|
||||
query_result.start_time, milliseconds=True
|
||||
),
|
||||
"operation": DatabaseDMLOperations.UPDATE.value,
|
||||
"rowsAffected": query_result.rows_updated,
|
||||
}
|
||||
@ -397,13 +355,15 @@ def _(
|
||||
|
||||
metric_results.append(
|
||||
{
|
||||
"timestamp": int(query_result.timestamp.timestamp() * 1000),
|
||||
"timestamp": datetime_to_timestamp(
|
||||
query_result.start_time, milliseconds=True
|
||||
),
|
||||
"operation": DML_OPERATION_MAP.get(query_result.query_type),
|
||||
"rowsAffected": rows_affected,
|
||||
}
|
||||
)
|
||||
|
||||
return metric_results
|
||||
return TypeAdapter(List[SystemProfile]).validate_python(metric_results)
|
||||
|
||||
|
||||
class System(SystemMetric):
|
||||
|
@ -274,9 +274,6 @@ class Profiler(Generic[TMetric]):
|
||||
|
||||
Data should be saved under self.results
|
||||
"""
|
||||
|
||||
logger.debug("Running post Profiler...")
|
||||
|
||||
current_col_results: Dict[str, Any] = self._column_results.get(col.name)
|
||||
if not current_col_results:
|
||||
logger.debug(
|
||||
|
@ -34,20 +34,12 @@ class QueryResult(BaseModel):
|
||||
schema_name: str
|
||||
table_name: str
|
||||
query_type: str
|
||||
timestamp: datetime
|
||||
start_time: datetime
|
||||
query_id: Optional[str] = None
|
||||
query_text: Optional[str] = None
|
||||
rows: Optional[int] = None
|
||||
|
||||
|
||||
class SnowflakeQueryResult(QueryResult):
|
||||
"""Snowflake system metric query result"""
|
||||
|
||||
rows_inserted: Optional[int] = None
|
||||
rows_updated: Optional[int] = None
|
||||
rows_deleted: Optional[int] = None
|
||||
|
||||
|
||||
def clean_up_query(query: str) -> str:
|
||||
"""remove comments and newlines from query"""
|
||||
return sqlparse.format(query, strip_comments=True).replace("\\n", "")
|
||||
|
@ -17,21 +17,25 @@ from datetime import datetime, time, timedelta, timezone
|
||||
from math import floor
|
||||
from typing import Union
|
||||
|
||||
from metadata.utils.deprecation import deprecated
|
||||
from metadata.utils.helpers import datetime_to_ts
|
||||
|
||||
|
||||
def datetime_to_timestamp(datetime_value, milliseconds=False) -> int:
|
||||
"""Convert a datetime object to timestamp integer
|
||||
def datetime_to_timestamp(datetime_value: datetime, milliseconds=False) -> int:
|
||||
"""Convert a datetime object to timestamp integer. Datetime can be timezone aware or naive. Result
|
||||
will always be in UTC.
|
||||
|
||||
Args:
|
||||
datetime_value (_type_): datetime object
|
||||
milliseconds (bool, optional): make it a milliseconds timestamp. Defaults to False.
|
||||
|
||||
Returns:
|
||||
int:
|
||||
int : timestamp in seconds or milliseconds
|
||||
"""
|
||||
if not getattr(datetime_value, "timestamp", None):
|
||||
raise TypeError(f"Object of type {datetime_value} has not method `timestamp()`")
|
||||
raise TypeError(
|
||||
f"Object of type {type(datetime_value).__name__} has not method `timestamp()`"
|
||||
)
|
||||
|
||||
tmsap = datetime_value.timestamp()
|
||||
if milliseconds:
|
||||
@ -115,6 +119,7 @@ def convert_timestamp(timestamp: str) -> Union[int, float]:
|
||||
return float(timestamp) / 1000
|
||||
|
||||
|
||||
@deprecated("Use `datetime_to_timestamp` instead", "1.7.0")
|
||||
def convert_timestamp_to_milliseconds(timestamp: Union[int, float]) -> int:
|
||||
"""convert timestamp to milliseconds
|
||||
Args:
|
||||
|
@ -13,12 +13,18 @@
|
||||
Test database connectors which extend from `CommonDbSourceService` with CLI
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
from metadata.generated.schema.entity.services.databaseService import DatabaseService
|
||||
from metadata.generated.schema.metadataIngestion.workflow import (
|
||||
OpenMetadataWorkflowConfig,
|
||||
)
|
||||
from metadata.ingestion.api.status import Status
|
||||
from metadata.workflow.metadata import MetadataWorkflow
|
||||
|
||||
@ -45,6 +51,19 @@ class CliCommonDB:
|
||||
Path(PATH_TO_RESOURCES + f"/database/{connector}/test.yaml")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
workflow = OpenMetadataWorkflowConfig.model_validate(
|
||||
yaml.safe_load(open(cls.config_file_path))
|
||||
)
|
||||
db_service: DatabaseService = cls.openmetadata.get_by_name(
|
||||
DatabaseService, workflow.source.serviceName
|
||||
)
|
||||
if db_service and os.getenv("E2E_CLEAN_DB", "false") == "true":
|
||||
cls.openmetadata.delete(
|
||||
DatabaseService, db_service.id, hard_delete=True, recursive=True
|
||||
)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.engine.dispose()
|
||||
|
||||
|
@ -12,10 +12,15 @@
|
||||
"""
|
||||
Test Snowflake connector with CLI
|
||||
"""
|
||||
from datetime import datetime
|
||||
from time import sleep
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from _openmetadata_testutils.pydantic.test_utils import assert_equal_pydantic_objects
|
||||
from metadata.generated.schema.entity.data.table import DmlOperationType, SystemProfile
|
||||
from metadata.generated.schema.type.basic import Timestamp
|
||||
from metadata.ingestion.api.status import Status
|
||||
|
||||
from .base.e2e_types import E2EType
|
||||
@ -40,6 +45,9 @@ class SnowflakeCliTest(CliCommonDB.TestSuite, SQACommonMethods):
|
||||
"CREATE OR REPLACE TABLE e2e_test.test_departments(e2e_testdepartment_id INT PRIMARY KEY,e2e_testdepartment_name VARCHAR (30) NOT NULL,e2e_testlocation_id INT);",
|
||||
"CREATE OR REPLACE TABLE e2e_test.test_employees(e2e_testemployee_id INT PRIMARY KEY,e2e_testfirst_name VARCHAR (20),e2e_testlast_name VARCHAR (25) NOT NULL,e2e_testemail VARCHAR (100) NOT NULL,e2e_testphone_number VARCHAR (20),e2e_testhire_date DATE NOT NULL,e2e_testjob_id INT NOT NULL,e2e_testsalary DECIMAL (8, 2) NOT NULL,e2e_testmanager_id INT,e2e_testdepartment_id INT);",
|
||||
"CREATE OR REPLACE TABLE e2e_test.test_dependents(e2e_testdependent_id INT PRIMARY KEY,e2e_testfirst_name VARCHAR (50) NOT NULL,e2e_testlast_name VARCHAR (50) NOT NULL,e2e_testrelationship VARCHAR (25) NOT NULL,e2e_testemployee_id INT NOT NULL);",
|
||||
"CREATE OR REPLACE TABLE e2e_test.e2e_table(varchar_column VARCHAR(255),int_column INT);",
|
||||
"CREATE OR REPLACE TABLE public.public_table(varchar_column VARCHAR(255),int_column INT);",
|
||||
"CREATE OR REPLACE TABLE public.e2e_table(varchar_column VARCHAR(255),int_column INT);",
|
||||
]
|
||||
|
||||
create_table_query: str = """
|
||||
@ -58,6 +66,13 @@ class SnowflakeCliTest(CliCommonDB.TestSuite, SQACommonMethods):
|
||||
insert_data_queries: List[str] = [
|
||||
"INSERT INTO E2E_DB.e2e_test.persons (person_id, full_name) VALUES (1,'Peter Parker');",
|
||||
"INSERT INTO E2E_DB.e2e_test.persons (person_id, full_name) VALUES (1, 'Clark Kent');",
|
||||
"INSERT INTO e2e_test.e2e_table (varchar_column, int_column) VALUES ('e2e_test.e2e_table', 1);",
|
||||
"INSERT INTO public.e2e_table (varchar_column, int_column) VALUES ('public.e2e_table', 1);",
|
||||
"INSERT INTO e2e_table (varchar_column, int_column) VALUES ('e2e_table', 1);",
|
||||
"INSERT INTO public.public_table (varchar_column, int_column) VALUES ('public.public_table', 1);",
|
||||
"INSERT INTO public_table (varchar_column, int_column) VALUES ('public_table', 1);",
|
||||
"MERGE INTO public_table USING (SELECT 'public_table' as varchar_column, 2 as int_column) as source ON public_table.varchar_column = source.varchar_column WHEN MATCHED THEN UPDATE SET public_table.int_column = source.int_column WHEN NOT MATCHED THEN INSERT (varchar_column, int_column) VALUES (source.varchar_column, source.int_column);",
|
||||
"DELETE FROM public_table WHERE varchar_column = 'public.public_table';",
|
||||
]
|
||||
|
||||
drop_table_query: str = """
|
||||
@ -68,6 +83,19 @@ class SnowflakeCliTest(CliCommonDB.TestSuite, SQACommonMethods):
|
||||
DROP VIEW IF EXISTS E2E_DB.e2e_test.view_persons;
|
||||
"""
|
||||
|
||||
teardown_sql_statements: List[str] = [
|
||||
"DROP TABLE IF EXISTS E2E_DB.e2e_test.e2e_table;",
|
||||
"DROP TABLE IF EXISTS E2E_DB.public.e2e_table;",
|
||||
"DROP TABLE IF EXISTS E2E_DB.public.public_table;",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
super().tearDownClass()
|
||||
with cls.engine.connect() as connection:
|
||||
for stmt in cls.teardown_sql_statements:
|
||||
connection.execute(stmt)
|
||||
|
||||
def setUp(self) -> None:
|
||||
with self.engine.connect() as connection:
|
||||
for sql_statements in self.prepare_snowflake_e2e:
|
||||
@ -83,15 +111,15 @@ class SnowflakeCliTest(CliCommonDB.TestSuite, SQACommonMethods):
|
||||
self.assertTrue(len(source_status.failures) == 0)
|
||||
self.assertTrue(len(source_status.warnings) == 0)
|
||||
self.assertTrue(len(source_status.filtered) == 1)
|
||||
self.assertTrue(
|
||||
(len(source_status.records) + len(source_status.updated_records))
|
||||
>= self.expected_tables()
|
||||
self.assertGreaterEqual(
|
||||
(len(source_status.records) + len(source_status.updated_records)),
|
||||
self.expected_tables(),
|
||||
)
|
||||
self.assertTrue(len(sink_status.failures) == 0)
|
||||
self.assertTrue(len(sink_status.warnings) == 0)
|
||||
self.assertTrue(
|
||||
(len(sink_status.records) + len(sink_status.updated_records))
|
||||
> self.expected_tables()
|
||||
self.assertGreater(
|
||||
(len(sink_status.records) + len(sink_status.updated_records)),
|
||||
self.expected_tables(),
|
||||
)
|
||||
|
||||
def create_table_and_view(self) -> None:
|
||||
@ -130,17 +158,22 @@ class SnowflakeCliTest(CliCommonDB.TestSuite, SQACommonMethods):
|
||||
# Otherwise the sampling here does not pick up rows
|
||||
extra_args={"profileSample": 100},
|
||||
)
|
||||
# wait for query log to be updated
|
||||
self.wait_for_query_log()
|
||||
# run profiler with new tables
|
||||
result = self.run_command("profile")
|
||||
sink_status, source_status = self.retrieve_statuses(result)
|
||||
self.assert_for_table_with_profiler(source_status, sink_status)
|
||||
self.custom_profiler_assertions()
|
||||
|
||||
@staticmethod
|
||||
def expected_tables() -> int:
|
||||
return 7
|
||||
|
||||
def inserted_rows_count(self) -> int:
|
||||
return len(self.insert_data_queries)
|
||||
return len(
|
||||
[q for q in self.insert_data_queries if "E2E_DB.e2e_test.persons" in q]
|
||||
)
|
||||
|
||||
def view_column_lineage_count(self) -> int:
|
||||
return 2
|
||||
@ -171,7 +204,7 @@ class SnowflakeCliTest(CliCommonDB.TestSuite, SQACommonMethods):
|
||||
|
||||
@staticmethod
|
||||
def expected_filtered_table_includes() -> int:
|
||||
return 5
|
||||
return 8
|
||||
|
||||
@staticmethod
|
||||
def expected_filtered_table_excludes() -> int:
|
||||
@ -179,7 +212,7 @@ class SnowflakeCliTest(CliCommonDB.TestSuite, SQACommonMethods):
|
||||
|
||||
@staticmethod
|
||||
def expected_filtered_mix() -> int:
|
||||
return 6
|
||||
return 7
|
||||
|
||||
@staticmethod
|
||||
def delete_queries() -> List[str]:
|
||||
@ -196,3 +229,90 @@ class SnowflakeCliTest(CliCommonDB.TestSuite, SQACommonMethods):
|
||||
UPDATE E2E_DB.E2E_TEST.PERSONS SET full_name = 'Bruce Wayne' WHERE full_name = 'Clark Kent'
|
||||
""",
|
||||
]
|
||||
|
||||
def custom_profiler_assertions(self):
|
||||
cases = [
|
||||
(
|
||||
"e2e_snowflake.E2E_DB.E2E_TEST.E2E_TABLE",
|
||||
[
|
||||
SystemProfile(
|
||||
timestamp=Timestamp(root=0),
|
||||
operation=DmlOperationType.INSERT,
|
||||
rowsAffected=1,
|
||||
),
|
||||
SystemProfile(
|
||||
timestamp=Timestamp(root=0),
|
||||
operation=DmlOperationType.INSERT,
|
||||
rowsAffected=1,
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
"e2e_snowflake.E2E_DB.PUBLIC.E2E_TABLE",
|
||||
[
|
||||
SystemProfile(
|
||||
timestamp=Timestamp(root=0),
|
||||
operation=DmlOperationType.INSERT,
|
||||
rowsAffected=1,
|
||||
)
|
||||
],
|
||||
),
|
||||
(
|
||||
"e2e_snowflake.E2E_DB.PUBLIC.PUBLIC_TABLE",
|
||||
[
|
||||
SystemProfile(
|
||||
timestamp=Timestamp(root=0),
|
||||
operation=DmlOperationType.INSERT,
|
||||
rowsAffected=1,
|
||||
),
|
||||
SystemProfile(
|
||||
timestamp=Timestamp(root=0),
|
||||
operation=DmlOperationType.INSERT,
|
||||
rowsAffected=1,
|
||||
),
|
||||
SystemProfile(
|
||||
timestamp=Timestamp(root=0),
|
||||
operation=DmlOperationType.UPDATE,
|
||||
rowsAffected=1,
|
||||
),
|
||||
SystemProfile(
|
||||
timestamp=Timestamp(root=0),
|
||||
operation=DmlOperationType.DELETE,
|
||||
rowsAffected=1,
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
for table_fqn, expected_profile in cases:
|
||||
actual_profiles = self.openmetadata.get_profile_data(
|
||||
table_fqn,
|
||||
start_ts=int((datetime.now().timestamp() - 600) * 1000),
|
||||
end_ts=int(datetime.now().timestamp() * 1000),
|
||||
profile_type=SystemProfile,
|
||||
).entities
|
||||
actual_profiles = sorted(actual_profiles, key=lambda x: x.timestamp.root)
|
||||
actual_profiles = actual_profiles[-len(expected_profile) :]
|
||||
actual_profiles = [
|
||||
p.copy(update={"timestamp": Timestamp(root=0)}) for p in actual_profiles
|
||||
]
|
||||
try:
|
||||
assert_equal_pydantic_objects(expected_profile, actual_profiles)
|
||||
except AssertionError as e:
|
||||
raise AssertionError(f"Table: {table_fqn}\n{e}")
|
||||
|
||||
@classmethod
|
||||
def wait_for_query_log(cls, timeout=600):
|
||||
start = datetime.now().timestamp()
|
||||
cls.engine.execute("SELECT 'e2e_query_log_wait'")
|
||||
latest = 0
|
||||
while latest < start:
|
||||
sleep(5)
|
||||
latest = (
|
||||
cls.engine.execute(
|
||||
'SELECT max(start_time) FROM "SNOWFLAKE"."ACCOUNT_USAGE"."QUERY_HISTORY"'
|
||||
)
|
||||
.scalar()
|
||||
.timestamp()
|
||||
)
|
||||
if (datetime.now().timestamp() - start) > timeout:
|
||||
raise TimeoutError(f"Query log not updated for {timeout} seconds")
|
||||
|
@ -0,0 +1,114 @@
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from metadata.ingestion.source.database.snowflake.profiler.system_metrics import (
|
||||
PUBLIC_SCHEMA,
|
||||
SnowflakeTableResovler,
|
||||
)
|
||||
from metadata.utils.profiler_utils import get_identifiers_from_string
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"schema_name",
|
||||
["test_schema", PUBLIC_SCHEMA, None],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"existing_tables",
|
||||
[
|
||||
["db.test_schema.test_table", "db.PUBLIC.test_table"],
|
||||
["db.test_schema.test_table"],
|
||||
["db.PUBLIC.test_table"],
|
||||
[],
|
||||
],
|
||||
)
|
||||
def test_resolve_snoflake_fqn(schema_name, existing_tables):
|
||||
def expected_result(schema_name, existing_tables):
|
||||
if len(existing_tables) == 0:
|
||||
return RuntimeError
|
||||
if schema_name == "test_schema":
|
||||
if "db.test_schema.test_table" in existing_tables:
|
||||
return "db", "test_schema", "test_table"
|
||||
if "db.PUBLIC.test_table" in existing_tables:
|
||||
return "db", PUBLIC_SCHEMA, "test_table"
|
||||
if (
|
||||
schema_name in [None, PUBLIC_SCHEMA]
|
||||
and "db.PUBLIC.test_table" in existing_tables
|
||||
):
|
||||
return "db", PUBLIC_SCHEMA, "test_table"
|
||||
return RuntimeError
|
||||
|
||||
resolver = SnowflakeTableResovler(Mock())
|
||||
|
||||
def mock_show_tables(_, schema, table):
|
||||
for t in existing_tables:
|
||||
if t == f"db.{schema}.{table}":
|
||||
return True
|
||||
|
||||
resolver.show_tables = mock_show_tables
|
||||
expected = expected_result(schema_name, existing_tables)
|
||||
if expected == RuntimeError:
|
||||
with pytest.raises(expected):
|
||||
resolver.resolve_implicit_fqn("db", schema_name, "test_table")
|
||||
else:
|
||||
result = resolver.resolve_implicit_fqn("db", schema_name, "test_table")
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("context_database", [None, "context_db"])
|
||||
@pytest.mark.parametrize("context_schema", [None, "context_schema", PUBLIC_SCHEMA])
|
||||
@pytest.mark.parametrize(
|
||||
"identifier",
|
||||
[
|
||||
"",
|
||||
"test_table",
|
||||
"id_schema.test_table",
|
||||
"PUBLIC.test_table",
|
||||
"id_db.test_schema.test_table",
|
||||
"id_db.PUBLIC.test_table",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"resolved_schema",
|
||||
[
|
||||
PUBLIC_SCHEMA,
|
||||
"context_schema",
|
||||
RuntimeError("could not resolve schema"),
|
||||
],
|
||||
)
|
||||
def test_get_identifiers(
|
||||
context_database,
|
||||
context_schema,
|
||||
identifier,
|
||||
resolved_schema,
|
||||
):
|
||||
def expected_result():
|
||||
if identifier == "":
|
||||
return RuntimeError("Could not extract the table name.")
|
||||
db, id_schema, table = get_identifiers_from_string(identifier)
|
||||
if db is None and context_database is None:
|
||||
return RuntimeError("Could not resolve database name.")
|
||||
if id_schema is None and isinstance(resolved_schema, RuntimeError):
|
||||
return RuntimeError("could not resolve schema")
|
||||
return (
|
||||
(db or context_database),
|
||||
(id_schema or resolved_schema or context_schema),
|
||||
table,
|
||||
)
|
||||
|
||||
resolver = SnowflakeTableResovler(Mock())
|
||||
if isinstance(resolved_schema, RuntimeError):
|
||||
resolver.resolve_implicit_fqn = MagicMock(side_effect=resolved_schema)
|
||||
else:
|
||||
resolver.resolve_implicit_fqn = MagicMock(
|
||||
return_value=(context_database, resolved_schema, identifier)
|
||||
)
|
||||
|
||||
expected_value = expected_result()
|
||||
if isinstance(expected_value, RuntimeError):
|
||||
with pytest.raises(type(expected_value), match=str(expected_value)) as e:
|
||||
resolver.resolve_snowflake_fqn(context_database, context_schema, identifier)
|
||||
else:
|
||||
assert expected_value == resolver.resolve_snowflake_fqn(
|
||||
context_database, context_schema, identifier
|
||||
)
|
@ -12,36 +12,21 @@
|
||||
"""
|
||||
Tests utils function for the profiler
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from sqlalchemy.sql.sqltypes import Integer, String
|
||||
|
||||
from metadata.generated.schema.entity.data.table import Column as OMetaColumn
|
||||
from metadata.generated.schema.entity.data.table import DataType, Table
|
||||
from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import (
|
||||
AuthProvider,
|
||||
OpenMetadataConnection,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.databaseService import (
|
||||
DatabaseService,
|
||||
DatabaseServiceType,
|
||||
)
|
||||
from metadata.generated.schema.security.client.openMetadataJWTClientConfig import (
|
||||
OpenMetadataJWTClientConfig,
|
||||
)
|
||||
from metadata.generated.schema.type.basic import EntityName, FullyQualifiedEntityName
|
||||
from metadata.generated.schema.type.entityReference import EntityReference
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
from metadata.profiler.metrics.hybrid.histogram import Histogram
|
||||
from metadata.profiler.metrics.system.queries.snowflake import (
|
||||
from metadata.ingestion.source.database.snowflake.models import SnowflakeQueryLogEntry
|
||||
from metadata.ingestion.source.database.snowflake.profiler.system_metrics import (
|
||||
SnowflakeTableResovler,
|
||||
get_snowflake_system_queries,
|
||||
)
|
||||
from metadata.profiler.metrics.hybrid.histogram import Histogram
|
||||
from metadata.profiler.metrics.system.system import recursive_dic
|
||||
from metadata.utils.profiler_utils import (
|
||||
get_identifiers_from_string,
|
||||
@ -50,8 +35,6 @@ from metadata.utils.profiler_utils import (
|
||||
)
|
||||
from metadata.utils.sqa_utils import is_array
|
||||
|
||||
from .conftest import Row
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
@ -125,7 +108,7 @@ def test_is_array():
|
||||
|
||||
def test_get_snowflake_system_queries():
|
||||
"""Test get snowflake system queries"""
|
||||
row = Row(
|
||||
row = SnowflakeQueryLogEntry(
|
||||
query_id="1",
|
||||
query_type="INSERT",
|
||||
start_time=datetime.now(),
|
||||
@ -133,8 +116,10 @@ def test_get_snowflake_system_queries():
|
||||
)
|
||||
|
||||
# We don't need the ometa_client nor the db_service if we have all the db.schema.table in the query
|
||||
resolver = SnowflakeTableResovler(Mock())
|
||||
query_result = get_snowflake_system_queries(
|
||||
row=row, database="DATABASE", schema="SCHEMA", ometa_client=..., db_service=...
|
||||
query_log_entry=row,
|
||||
resolver=resolver,
|
||||
) # type: ignore
|
||||
assert query_result
|
||||
assert query_result.query_id == "1"
|
||||
@ -143,15 +128,16 @@ def test_get_snowflake_system_queries():
|
||||
assert query_result.schema_name == "schema"
|
||||
assert query_result.table_name == "table1"
|
||||
|
||||
row = Row(
|
||||
query_id=1,
|
||||
row = SnowflakeQueryLogEntry(
|
||||
query_id="1",
|
||||
query_type="INSERT",
|
||||
start_time=datetime.now(),
|
||||
query_text="INSERT INTO SCHEMA.TABLE1 (col1, col2) VALUES (1, 'a'), (2, 'b')",
|
||||
)
|
||||
|
||||
query_result = get_snowflake_system_queries(
|
||||
row=row, database="DATABASE", schema="SCHEMA", ometa_client=..., db_service=...
|
||||
query_log_entry=row,
|
||||
resolver=resolver,
|
||||
) # type: ignore
|
||||
|
||||
assert not query_result
|
||||
@ -184,15 +170,17 @@ def test_get_snowflake_system_queries_all_dll(query, expected):
|
||||
"""test we ca get all ddl queries
|
||||
reference https://docs.snowflake.com/en/sql-reference/sql-dml
|
||||
"""
|
||||
row = Row(
|
||||
row = SnowflakeQueryLogEntry(
|
||||
query_id="1",
|
||||
query_type=expected,
|
||||
start_time=datetime.now(),
|
||||
query_text=query,
|
||||
)
|
||||
|
||||
resolver = Mock()
|
||||
resolver.resolve_snowflake_fqn = Mock(return_value=("database", "schema", "table1"))
|
||||
query_result = get_snowflake_system_queries(
|
||||
row=row, database="DATABASE", schema="SCHEMA", ometa_client=..., db_service=...
|
||||
query_log_entry=row,
|
||||
resolver=resolver,
|
||||
) # type: ignore
|
||||
|
||||
assert query_result
|
||||
@ -202,7 +190,8 @@ def test_get_snowflake_system_queries_all_dll(query, expected):
|
||||
assert query_result.table_name == "table1"
|
||||
|
||||
query_result = get_snowflake_system_queries(
|
||||
row=row, database="DATABASE", schema="SCHEMA", ometa_client=..., db_service=...
|
||||
query_log_entry=SnowflakeQueryLogEntry.model_validate(row),
|
||||
resolver=resolver,
|
||||
) # type: ignore
|
||||
|
||||
assert query_result
|
||||
@ -212,75 +201,6 @@ def test_get_snowflake_system_queries_all_dll(query, expected):
|
||||
assert query_result.table_name == "table1"
|
||||
|
||||
|
||||
def test_get_snowflake_system_queries_from_es():
|
||||
"""Test the ES integration"""
|
||||
|
||||
ometa_client = OpenMetadata(
|
||||
OpenMetadataConnection(
|
||||
hostPort="http://localhost:8585/api",
|
||||
authProvider=AuthProvider.openmetadata,
|
||||
enableVersionValidation=False,
|
||||
securityConfig=OpenMetadataJWTClientConfig(jwtToken="token"),
|
||||
)
|
||||
)
|
||||
|
||||
db_service = DatabaseService(
|
||||
id=uuid.uuid4(),
|
||||
name=EntityName("service"),
|
||||
fullyQualifiedName=FullyQualifiedEntityName("service"),
|
||||
serviceType=DatabaseServiceType.CustomDatabase,
|
||||
)
|
||||
|
||||
table = Table(
|
||||
id=uuid.uuid4(),
|
||||
name="TABLE",
|
||||
columns=[OMetaColumn(name="id", dataType=DataType.BIGINT)],
|
||||
database=EntityReference(id=uuid.uuid4(), type="database", name="database"),
|
||||
databaseSchema=EntityReference(
|
||||
id=uuid.uuid4(), type="databaseSchema", name="schema"
|
||||
),
|
||||
)
|
||||
|
||||
# With too many responses, we won't return anything since we don't want false results
|
||||
# that we cannot properly assign
|
||||
with patch.object(OpenMetadata, "es_search_from_fqn", return_value=[table] * 4):
|
||||
row = Row(
|
||||
query_id=1,
|
||||
query_type="INSERT",
|
||||
start_time=datetime.now(),
|
||||
query_text="INSERT INTO TABLE1 (col1, col2) VALUES (1, 'a'), (2, 'b')",
|
||||
)
|
||||
query_result = get_snowflake_system_queries(
|
||||
row=row,
|
||||
database="DATABASE",
|
||||
schema="SCHEMA",
|
||||
ometa_client=ometa_client,
|
||||
db_service=db_service,
|
||||
)
|
||||
assert not query_result
|
||||
|
||||
# Returning a single table should work fine
|
||||
with patch.object(OpenMetadata, "es_search_from_fqn", return_value=[table]):
|
||||
row = Row(
|
||||
query_id="1",
|
||||
query_type="INSERT",
|
||||
start_time=datetime.now(),
|
||||
query_text="INSERT INTO TABLE2 (col1, col2) VALUES (1, 'a'), (2, 'b')",
|
||||
)
|
||||
query_result = get_snowflake_system_queries(
|
||||
row=row,
|
||||
database="DATABASE",
|
||||
schema="SCHEMA",
|
||||
ometa_client=ometa_client,
|
||||
db_service=db_service,
|
||||
)
|
||||
assert query_result
|
||||
assert query_result.query_type == "INSERT"
|
||||
assert query_result.database_name == "database"
|
||||
assert query_result.schema_name == "schema"
|
||||
assert query_result.table_name == "table2"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"identifier, expected",
|
||||
[
|
||||
|
Loading…
x
Reference in New Issue
Block a user