Fix Stored Procedures Lineage for multi-db processes (#13655)

This commit is contained in:
Pere Miquel Brull 2023-10-20 09:14:08 +02:00 committed by GitHub
parent c9c6c94ddf
commit 660bf01a5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 200 additions and 344 deletions

View File

@ -13,7 +13,7 @@ We require Taxonomy Admin permissions to fetch all Policy Tags
"""
import os
import traceback
from typing import Iterable, List, Optional, Tuple
from typing import Dict, Iterable, List, Optional, Tuple
from google import auth
from google.cloud.datacatalog_v1 import PolicyTagManagerClient
@ -622,26 +622,18 @@ class BigquerySource(StoredProcedureMixin, CommonDbSourceService):
)
)
def get_stored_procedure_queries(self) -> Iterable[QueryByProcedure]:
def get_stored_procedure_queries_dict(self) -> Dict[str, List[QueryByProcedure]]:
"""
Pick the stored procedure name from the context
and return the list of associated queries
"""
# Only process if we actually have yield a stored procedure
if self.context.stored_procedure:
start, _ = get_start_and_end(self.source_config.queryLogDuration)
query = BIGQUERY_GET_STORED_PROCEDURE_QUERIES.format(
start_date=start,
region=self.service_connection.usageLocation,
)
queries_dict = self.procedure_queries_dict(
query=query,
schema_name=self.context.database_schema.name.__root__,
database_name=self.context.database.name.__root__,
)
start, _ = get_start_and_end(self.source_config.queryLogDuration)
query = BIGQUERY_GET_STORED_PROCEDURE_QUERIES.format(
start_date=start,
region=self.service_connection.usageLocation,
)
queries_dict = self.procedure_queries_dict(
query=query,
)
for query_by_procedure in (
queries_dict.get(self.context.stored_procedure.name.__root__.lower())
or []
):
yield query_by_procedure
return queries_dict

View File

@ -113,6 +113,8 @@ SELECT
Q.query_type as query_type,
SP.query_text as procedure_text,
Q.query_text as query_text,
null as query_database_name,
null as query_schema_name,
SP.start_time as procedure_start_time,
SP.end_time as procedure_end_time,
Q.start_time as query_start_time,

View File

@ -14,7 +14,7 @@ Generic source to build SQL connectors.
import traceback
from abc import ABC
from copy import deepcopy
from typing import Any, Iterable, List, Optional, Tuple
from typing import Any, Iterable, List, Optional, Tuple, Union
from pydantic import BaseModel
from sqlalchemy.engine import Connection
@ -50,12 +50,10 @@ from metadata.ingestion.lineage.sql_lineage import get_column_fqn
from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.connections import get_connection
from metadata.ingestion.source.database.database_service import (
DatabaseServiceSource,
QueryByProcedure,
)
from metadata.ingestion.source.database.database_service import DatabaseServiceSource
from metadata.ingestion.source.database.sql_column_handler import SqlColumnHandlerMixin
from metadata.ingestion.source.database.sqlalchemy_source import SqlAlchemySource
from metadata.ingestion.source.database.stored_procedures_mixin import QueryByProcedure
from metadata.ingestion.source.models import TableView
from metadata.utils import fqn
from metadata.utils.db_utils import get_view_lineage
@ -370,15 +368,11 @@ class CommonDbSourceService(
def get_stored_procedure_queries(self) -> Iterable[QueryByProcedure]:
"""Not Implemented"""
def yield_procedure_query(
self, query_by_procedure: QueryByProcedure
) -> Iterable[Either[CreateQueryRequest]]:
"""Not implemented"""
def yield_procedure_lineage(
self, query_by_procedure: QueryByProcedure
) -> Iterable[Either[AddLineageRequest]]:
"""Not implemented"""
def yield_procedure_lineage_and_queries(
self,
) -> Iterable[Either[Union[AddLineageRequest, CreateQueryRequest]]]:
"""Not Implemented"""
yield from []
@calculate_execution_time_generator
def yield_table(

View File

@ -38,10 +38,8 @@ from metadata.ingestion.api.models import Either, StackTraceError
from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.connections import get_connection
from metadata.ingestion.source.database.database_service import (
DatabaseServiceSource,
QueryByProcedure,
)
from metadata.ingestion.source.database.database_service import DatabaseServiceSource
from metadata.ingestion.source.database.stored_procedures_mixin import QueryByProcedure
from metadata.utils import fqn
from metadata.utils.constants import DEFAULT_DATABASE
from metadata.utils.datalake.datalake_utils import get_columns
@ -260,15 +258,11 @@ class CommonNoSQLSource(DatabaseServiceSource, ABC):
def get_stored_procedure_queries(self) -> Iterable[QueryByProcedure]:
"""Not Implemented"""
def yield_procedure_query(
self, query_by_procedure: QueryByProcedure
) -> Iterable[Either[CreateQueryRequest]]:
"""Not implemented"""
def yield_procedure_lineage(
self, query_by_procedure: QueryByProcedure
) -> Iterable[Either[AddLineageRequest]]:
"""Not implemented"""
def yield_procedure_lineage_and_queries(
self,
) -> Iterable[Either[Union[AddLineageRequest, CreateQueryRequest]]]:
"""Not Implemented"""
yield from []
def get_source_url(
self,

View File

@ -12,7 +12,7 @@
Base class for ingesting database services
"""
from abc import ABC, abstractmethod
from typing import Any, Iterable, List, Optional, Set, Tuple
from typing import Any, Iterable, List, Optional, Set, Tuple, Union
from pydantic import BaseModel
from sqlalchemy.engine import Inspector
@ -32,7 +32,6 @@ from metadata.generated.schema.api.services.createDatabaseService import (
)
from metadata.generated.schema.entity.data.database import Database
from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema
from metadata.generated.schema.entity.data.query import Query
from metadata.generated.schema.entity.data.storedProcedure import StoredProcedure
from metadata.generated.schema.entity.data.table import (
Column,
@ -65,7 +64,6 @@ from metadata.ingestion.models.topology import (
create_source_context,
)
from metadata.ingestion.source.connections import get_test_connection_fn
from metadata.ingestion.source.database.stored_procedures_mixin import QueryByProcedure
from metadata.utils import fqn
from metadata.utils.filters import filter_by_schema
from metadata.utils.logger import ingestion_logger
@ -104,7 +102,10 @@ class DatabaseServiceTopology(ServiceTopology):
),
],
children=["database"],
post_process=["yield_view_lineage"],
# Note how we have `yield_view_lineage` and `yield_stored_procedure_lineage`
# as post_processed. This is because we cannot ensure proper lineage processing
# until we have finished ingesting all the metadata from the source.
post_process=["yield_view_lineage", "yield_procedure_lineage_and_queries"],
)
database = TopologyNode(
producer="get_database_names",
@ -166,26 +167,10 @@ class DatabaseServiceTopology(ServiceTopology):
stages=[
NodeStage(
type_=StoredProcedure,
context="stored_procedure",
context="stored_procedures",
processor="yield_stored_procedure",
consumer=["database_service", "database", "database_schema"],
),
],
children=["stored_procedure_queries"],
)
stored_procedure_queries = TopologyNode(
producer="get_stored_procedure_queries",
stages=[
NodeStage(
type_=AddLineageRequest,
processor="yield_procedure_lineage",
context="stored_procedure_query_lineage", # Used to flag if the query has had processed lineage
nullable=True,
),
NodeStage(
type_=Query,
processor="yield_procedure_query",
nullable=True,
cache_all=True,
),
],
)
@ -339,20 +324,10 @@ class DatabaseServiceSource(
"""Process the stored procedure information"""
@abstractmethod
def get_stored_procedure_queries(self) -> Iterable[QueryByProcedure]:
"""List the queries associated to a stored procedure"""
@abstractmethod
def yield_procedure_query(
self, query_by_procedure: QueryByProcedure
) -> Iterable[Either[CreateQueryRequest]]:
"""Process the stored procedure query"""
@abstractmethod
def yield_procedure_lineage(
self, query_by_procedure: QueryByProcedure
) -> Iterable[Either[AddLineageRequest]]:
"""Add procedure lineage from its query"""
def yield_procedure_lineage_and_queries(
self,
) -> Iterable[Either[Union[AddLineageRequest, CreateQueryRequest]]]:
"""Extracts the lineage information from Stored Procedures"""
def get_raw_database_schema_names(self) -> Iterable[str]:
"""

View File

@ -13,7 +13,7 @@ Databricks Unity Catalog Source source methods.
"""
import json
import traceback
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from databricks.sdk.service.catalog import ColumnInfo
from databricks.sdk.service.catalog import TableConstraint as DBTableConstraint
@ -53,10 +53,7 @@ from metadata.ingestion.lineage.sql_lineage import get_column_fqn
from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.database.column_type_parser import ColumnTypeParser
from metadata.ingestion.source.database.database_service import (
DatabaseServiceSource,
QueryByProcedure,
)
from metadata.ingestion.source.database.database_service import DatabaseServiceSource
from metadata.ingestion.source.database.databricks.connection import get_connection
from metadata.ingestion.source.database.databricks.models import (
ColumnJson,
@ -64,6 +61,7 @@ from metadata.ingestion.source.database.databricks.models import (
ForeignConstrains,
Type,
)
from metadata.ingestion.source.database.stored_procedures_mixin import QueryByProcedure
from metadata.ingestion.source.models import TableView
from metadata.utils import fqn
from metadata.utils.db_utils import get_view_lineage
@ -500,15 +498,11 @@ class DatabricksUnityCatalogSource(DatabaseServiceSource):
def get_stored_procedure_queries(self) -> Iterable[QueryByProcedure]:
"""Not Implemented"""
def yield_procedure_query(
self, query_by_procedure: QueryByProcedure
) -> Iterable[Either[CreateQueryRequest]]:
"""Not implemented"""
def yield_procedure_lineage(
self, query_by_procedure: QueryByProcedure
) -> Iterable[Either[AddLineageRequest]]:
"""Not implemented"""
def yield_procedure_lineage_and_queries(
self,
) -> Iterable[Either[Union[AddLineageRequest, CreateQueryRequest]]]:
"""Not Implemented"""
yield from []
def close(self):
"""Nothing to close"""

View File

@ -14,7 +14,7 @@ DataLake connector to fetch metadata from a files stored s3, gcs and Hdfs
"""
import json
import traceback
from typing import Any, Iterable, Tuple
from typing import Any, Iterable, Tuple, Union
from metadata.generated.schema.api.data.createDatabase import CreateDatabaseRequest
from metadata.generated.schema.api.data.createDatabaseSchema import (
@ -54,10 +54,8 @@ from metadata.ingestion.api.steps import InvalidSourceException
from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.connections import get_connection
from metadata.ingestion.source.database.database_service import (
DatabaseServiceSource,
QueryByProcedure,
)
from metadata.ingestion.source.database.database_service import DatabaseServiceSource
from metadata.ingestion.source.database.stored_procedures_mixin import QueryByProcedure
from metadata.ingestion.source.storage.storage_service import (
OPENMETADATA_TEMPLATE_FILE_NAME,
)
@ -382,15 +380,11 @@ class DatalakeSource(DatabaseServiceSource):
def get_stored_procedure_queries(self) -> Iterable[QueryByProcedure]:
"""Not Implemented"""
def yield_procedure_query(
self, query_by_procedure: QueryByProcedure
) -> Iterable[Either[CreateQueryRequest]]:
"""Not implemented"""
def yield_procedure_lineage(
self, query_by_procedure: QueryByProcedure
) -> Iterable[Either[AddLineageRequest]]:
"""Not implemented"""
def yield_procedure_lineage_and_queries(
self,
) -> Iterable[Either[Union[AddLineageRequest, CreateQueryRequest]]]:
"""Not Implemented"""
yield from []
def standardize_table_name(
self, schema: str, table: str # pylint: disable=unused-argument

View File

@ -14,7 +14,7 @@ Deltalake source methods.
import re
import traceback
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from pyspark.sql.utils import AnalysisException, ParseException
@ -45,10 +45,8 @@ from metadata.ingestion.models.ometa_classification import OMetaTagAndClassifica
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.connections import get_connection
from metadata.ingestion.source.database.column_type_parser import ColumnTypeParser
from metadata.ingestion.source.database.database_service import (
DatabaseServiceSource,
QueryByProcedure,
)
from metadata.ingestion.source.database.database_service import DatabaseServiceSource
from metadata.ingestion.source.database.stored_procedures_mixin import QueryByProcedure
from metadata.utils import fqn
from metadata.utils.constants import DEFAULT_DATABASE
from metadata.utils.filters import filter_by_schema, filter_by_table
@ -413,15 +411,11 @@ class DeltalakeSource(DatabaseServiceSource):
def get_stored_procedure_queries(self) -> Iterable[QueryByProcedure]:
"""Not Implemented"""
def yield_procedure_query(
self, query_by_procedure: QueryByProcedure
) -> Iterable[Either[CreateQueryRequest]]:
"""Not implemented"""
def yield_procedure_lineage(
self, query_by_procedure: QueryByProcedure
) -> Iterable[Either[AddLineageRequest]]:
"""Not implemented"""
def yield_procedure_lineage_and_queries(
self,
) -> Iterable[Either[Union[AddLineageRequest, CreateQueryRequest]]]:
"""Not Implemented"""
yield from []
def close(self):
"""No client to close"""

View File

@ -14,7 +14,7 @@ Domo Database source to extract metadata
"""
import traceback
from typing import Any, Iterable, List, Optional, Tuple
from typing import Any, Iterable, List, Optional, Tuple, Union
from metadata.generated.schema.api.data.createDatabase import CreateDatabaseRequest
from metadata.generated.schema.api.data.createDatabaseSchema import (
@ -42,16 +42,14 @@ from metadata.ingestion.api.steps import InvalidSourceException
from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.connections import get_connection
from metadata.ingestion.source.database.database_service import (
DatabaseServiceSource,
QueryByProcedure,
)
from metadata.ingestion.source.database.database_service import DatabaseServiceSource
from metadata.ingestion.source.database.domodatabase.models import (
OutputDataset,
Owner,
SchemaColumn,
User,
)
from metadata.ingestion.source.database.stored_procedures_mixin import QueryByProcedure
from metadata.utils import fqn
from metadata.utils.constants import DEFAULT_DATABASE
from metadata.utils.filters import filter_by_table
@ -232,15 +230,11 @@ class DomodatabaseSource(DatabaseServiceSource):
def get_stored_procedure_queries(self) -> Iterable[QueryByProcedure]:
"""Not Implemented"""
def yield_procedure_query(
self, query_by_procedure: QueryByProcedure
) -> Iterable[Either[CreateQueryRequest]]:
"""Not implemented"""
def yield_procedure_lineage(
self, query_by_procedure: QueryByProcedure
) -> Iterable[Either[AddLineageRequest]]:
"""Not implemented"""
def yield_procedure_lineage_and_queries(
self,
) -> Iterable[Either[Union[AddLineageRequest, CreateQueryRequest]]]:
"""Not Implemented"""
yield from []
def yield_view_lineage(self) -> Iterable[Either[AddLineageRequest]]:
yield from []

View File

@ -12,7 +12,7 @@
Glue source methods.
"""
import traceback
from typing import Any, Iterable, Optional, Tuple
from typing import Any, Iterable, Optional, Tuple, Union
from metadata.generated.schema.api.data.createDatabase import CreateDatabaseRequest
from metadata.generated.schema.api.data.createDatabaseSchema import (
@ -43,16 +43,14 @@ from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.connections import get_connection
from metadata.ingestion.source.database.column_helpers import truncate_column_name
from metadata.ingestion.source.database.column_type_parser import ColumnTypeParser
from metadata.ingestion.source.database.database_service import (
DatabaseServiceSource,
QueryByProcedure,
)
from metadata.ingestion.source.database.database_service import DatabaseServiceSource
from metadata.ingestion.source.database.glue.models import Column as GlueColumn
from metadata.ingestion.source.database.glue.models import (
DatabasePage,
StorageDetails,
TablePage,
)
from metadata.ingestion.source.database.stored_procedures_mixin import QueryByProcedure
from metadata.utils import fqn
from metadata.utils.filters import filter_by_database, filter_by_schema, filter_by_table
from metadata.utils.logger import ingestion_logger
@ -363,15 +361,11 @@ class GlueSource(DatabaseServiceSource):
def get_stored_procedure_queries(self) -> Iterable[QueryByProcedure]:
"""Not Implemented"""
def yield_procedure_query(
self, query_by_procedure: QueryByProcedure
) -> Iterable[Either[CreateQueryRequest]]:
"""Not implemented"""
def yield_procedure_lineage(
self, query_by_procedure: QueryByProcedure
) -> Iterable[Either[AddLineageRequest]]:
"""Not implemented"""
def yield_procedure_lineage_and_queries(
self,
) -> Iterable[Either[Union[AddLineageRequest, CreateQueryRequest]]]:
"""Not Implemented"""
yield from []
def get_source_url(
self,

View File

@ -14,7 +14,7 @@ Redshift source ingestion
import re
import traceback
from typing import Iterable, List, Optional, Tuple
from typing import Dict, Iterable, List, Optional, Tuple
from sqlalchemy import inspect, sql
from sqlalchemy.dialects.postgresql.base import PGDialect
@ -271,27 +271,19 @@ class RedshiftSource(StoredProcedureMixin, CommonDbSourceService):
)
)
def get_stored_procedure_queries(self) -> Iterable[QueryByProcedure]:
def get_stored_procedure_queries_dict(self) -> Dict[str, List[QueryByProcedure]]:
"""
Pick the stored procedure name from the context
and return the list of associated queries
Return the dictionary associating stored procedures to the
queries they triggered
"""
# Only process if we actually have yield a stored procedure
if self.context.stored_procedure:
start, _ = get_start_and_end(self.source_config.queryLogDuration)
query = REDSHIFT_GET_STORED_PROCEDURE_QUERIES.format(
start_date=start,
database_name=self.context.database.name.__root__,
)
start, _ = get_start_and_end(self.source_config.queryLogDuration)
query = REDSHIFT_GET_STORED_PROCEDURE_QUERIES.format(
start_date=start,
database_name=self.context.database.name.__root__,
)
queries_dict = self.procedure_queries_dict(
query=query,
schema_name=self.context.database_schema.name.__root__,
database_name=self.context.database.name.__root__,
)
queries_dict = self.procedure_queries_dict(
query=query,
)
for query_by_procedure in (
queries_dict.get(self.context.stored_procedure.name.__root__.lower())
or []
):
yield query_by_procedure
return queries_dict

View File

@ -306,8 +306,7 @@ with SP_HISTORY as (
endtime as procedure_end_time,
pid as procedure_session_id
from SVL_STORED_PROC_CALL
where database = '{database_name}'
and aborted = 0
where aborted = 0
and starttime >= '{start_date}'
),
Q_HISTORY as (
@ -320,6 +319,7 @@ Q_HISTORY as (
when querytxt ilike '%%CREATE%%AS%%' then 'CREATE_TABLE_AS_SELECT'
when querytxt ilike '%%INSERT%%' then 'INSERT'
else 'UNKNOWN' end query_type,
database as query_database_name,
pid as query_session_id,
starttime as query_start_time,
endtime as query_end_time,
@ -330,7 +330,6 @@ Q_HISTORY as (
where label not in ('maintenance', 'metrics', 'health')
and querytxt not like '/* {{"app": "OpenMetadata", %%}} */%%'
and querytxt not like '/* {{"app": "dbt", %%}} */%%'
and database = '{database_name}'
and starttime >= '{start_date}'
and userid <> 1
)
@ -342,6 +341,8 @@ select
q.query_id,
q.query_text,
q.query_type,
q.query_database_name,
null as query_schema_name,
q.query_start_time,
q.query_end_time,
q.query_user_name

View File

@ -12,7 +12,7 @@
Salesforce source ingestion
"""
import traceback
from typing import Any, Iterable, Optional, Tuple
from typing import Any, Iterable, Optional, Tuple, Union
from metadata.generated.schema.api.data.createDatabase import CreateDatabaseRequest
from metadata.generated.schema.api.data.createDatabaseSchema import (
@ -45,10 +45,8 @@ from metadata.ingestion.api.steps import InvalidSourceException
from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.connections import get_connection, get_test_connection_fn
from metadata.ingestion.source.database.database_service import (
DatabaseServiceSource,
QueryByProcedure,
)
from metadata.ingestion.source.database.database_service import DatabaseServiceSource
from metadata.ingestion.source.database.stored_procedures_mixin import QueryByProcedure
from metadata.utils import fqn
from metadata.utils.constants import DEFAULT_DATABASE
from metadata.utils.filters import filter_by_table
@ -282,15 +280,11 @@ class SalesforceSource(DatabaseServiceSource):
def get_stored_procedure_queries(self) -> Iterable[QueryByProcedure]:
"""Not Implemented"""
def yield_procedure_query(
self, query_by_procedure: QueryByProcedure
) -> Iterable[Either[CreateQueryRequest]]:
"""Not implemented"""
def yield_procedure_lineage(
self, query_by_procedure: QueryByProcedure
) -> Iterable[Either[AddLineageRequest]]:
"""Not implemented"""
def yield_procedure_lineage_and_queries(
self,
) -> Iterable[Either[Union[AddLineageRequest, CreateQueryRequest]]]:
"""Not Implemented"""
yield from []
def standardize_table_name( # pylint: disable=unused-argument
self, schema: str, table: str

View File

@ -1,57 +0,0 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Snowflake module to define constants
"""
# checkout the region reference map here:
# https://docs.snowflake.com/en/user-guide/admin-account-identifier#region-ids
SNOWFLAKE_REGION_ID_MAP = {
"aws_us_west_2": "us-west-2",
"aws_us_gov_west_1": "us-gov-west-1",
"aws_us_gov_west_1_fhplus": "fhplus.us-gov-west-1.aws",
"aws_us_east_2": "us-east-2",
"aws_us_east_1": "us-east-1",
"aws_us_east_1_gov": "us-east-1",
"aws_ca_central_1": "ca-central-1",
"aws_sa_east_1": "sa-east-1",
"aws_eu_west_2": "eu-west-2",
"aws_eu_west_3": "eu-west-3",
"aws_eu_central_1": "eu-central-1",
"aws_eu_north_1": "eu-north-1",
"aws_ap_northeast_1": "ap-northeast-1",
"aws_ap_northeast_2": "ap-northeast-2",
"aws_ap_northeast_3": "ap-northeast-3",
"aws_ap_south_1": "ap-south-1",
"aws_ap_southeast_1": "ap-southeast-1",
"aws_ap_southeast_2": "ap-southeast-2",
"aws_ap_southeast_3": "ap-southeast-3",
"gcp_us_central1": "us-central1",
"gcp_us_east4": "us-east4",
"gcp_europe_west2": "europe-west2",
"azure_westus2": "westus2",
"azure_centralus": "centralus",
"azure_southcentralus": "southcentralus",
"azure_eastus2": "eastus2",
"azure_usgovvirginia": "usgovvirginia",
"azure_canadacentral": "canadacentral",
"azure_uksouth": "uk-south",
"azure_northeurope": "northeurope",
"azure_westeurope": "westeurope",
"azure_switzerlandnorth": "switzerlandnorth",
"azure_uaenorth": "uaenorth",
"azure_centralindia": "central-india.azure",
"azure_japaneast": "japaneast",
"azure_southeastasia": "southeastasia",
"azure_australiaeast": "australiaeast",
}

View File

@ -13,10 +13,9 @@ Snowflake source module
"""
import json
import traceback
from typing import Iterable, List, Optional, Tuple
from typing import Dict, Iterable, List, Optional, Tuple
import sqlparse
from requests.utils import quote
from snowflake.sqlalchemy.custom_types import VARIANT
from snowflake.sqlalchemy.snowdialect import SnowflakeDialect, ischema_names
from sqlalchemy.engine.reflection import Inspector
@ -53,9 +52,6 @@ from metadata.ingestion.source.database.common_db_source import (
from metadata.ingestion.source.database.life_cycle_query_mixin import (
LifeCycleQueryMixin,
)
from metadata.ingestion.source.database.snowflake.constants import (
SNOWFLAKE_REGION_ID_MAP,
)
from metadata.ingestion.source.database.snowflake.models import (
STORED_PROC_LANGUAGE_MAP,
SnowflakeStoredProcedure,
@ -64,9 +60,9 @@ from metadata.ingestion.source.database.snowflake.queries import (
SNOWFLAKE_FETCH_ALL_TAGS,
SNOWFLAKE_GET_CLUSTER_KEY,
SNOWFLAKE_GET_CURRENT_ACCOUNT,
SNOWFLAKE_GET_CURRENT_REGION,
SNOWFLAKE_GET_DATABASE_COMMENTS,
SNOWFLAKE_GET_DATABASES,
SNOWFLAKE_GET_ORGANIZATION_NAME,
SNOWFLAKE_GET_SCHEMA_COMMENTS,
SNOWFLAKE_GET_STORED_PROCEDURE_QUERIES,
SNOWFLAKE_GET_STORED_PROCEDURES,
@ -139,7 +135,7 @@ class SnowflakeSource(LifeCycleQueryMixin, StoredProcedureMixin, CommonDbSourceS
self.database_desc_map = {}
self._account: Optional[str] = None
self._region: Optional[str] = None
self._org_name: Optional[str] = None
@classmethod
def create(cls, config_dict, metadata: OpenMetadata):
@ -153,32 +149,25 @@ class SnowflakeSource(LifeCycleQueryMixin, StoredProcedureMixin, CommonDbSourceS
@property
def account(self) -> Optional[str]:
"""Query the account information"""
"""
Query the account information
ref https://docs.snowflake.com/en/sql-reference/functions/current_account_name
"""
if self._account is None:
self._account = self._get_current_account()
return self._account
@property
def region(self) -> Optional[str]:
def org_name(self) -> Optional[str]:
"""
Query the region information
Region id can be a vanilla id like "AWS_US_WEST_2"
and in case of multi region group it can be like "PUBLIC.AWS_US_WEST_2"
in such cases this method will extract vanilla region id and return the
region name from constant map SNOWFLAKE_REGION_ID_MAP
for more info checkout this doc:
https://docs.snowflake.com/en/sql-reference/functions/current_region
Query the Organization information.
ref https://docs.snowflake.com/en/sql-reference/functions/current_organization_name
"""
if self._region is None:
raw_region = self._get_current_region()
if raw_region:
clean_region_id = raw_region.split(".")[-1]
self._region = SNOWFLAKE_REGION_ID_MAP.get(clean_region_id.lower())
if self._org_name is None:
self._org_name = self._get_org_name()
return self._region
return self._org_name
def set_session_query_tag(self) -> None:
"""
@ -418,14 +407,14 @@ class SnowflakeSource(LifeCycleQueryMixin, StoredProcedureMixin, CommonDbSourceS
return table_list
def _get_current_region(self) -> Optional[str]:
def _get_org_name(self) -> Optional[str]:
try:
res = self.engine.execute(SNOWFLAKE_GET_CURRENT_REGION).one()
res = self.engine.execute(SNOWFLAKE_GET_ORGANIZATION_NAME).one()
if res:
return res.REGION
return res.NAME
except Exception as exc:
logger.debug(traceback.format_exc())
logger.debug(f"Failed to fetch current region due to: {exc}")
logger.debug(f"Failed to fetch Organization name due to: {exc}")
return None
def _get_current_account(self) -> Optional[str]:
@ -442,7 +431,7 @@ class SnowflakeSource(LifeCycleQueryMixin, StoredProcedureMixin, CommonDbSourceS
self, database_name: Optional[str] = None, schema_name: Optional[str] = None
) -> str:
url = (
f"https://app.snowflake.com/{self.region.lower()}"
f"https://app.snowflake.com/{self.org_name.lower()}"
f"/{self.account.lower()}/#/data/databases/{database_name}"
)
if schema_name:
@ -461,7 +450,7 @@ class SnowflakeSource(LifeCycleQueryMixin, StoredProcedureMixin, CommonDbSourceS
Method to get the source url for snowflake
"""
try:
if self.account and self.region:
if self.account and self.org_name:
tab_type = "view" if table_type == TableType.View else "table"
url = self._get_source_url_root(
database_name=database_name, schema_name=schema_name
@ -542,7 +531,7 @@ class SnowflakeSource(LifeCycleQueryMixin, StoredProcedureMixin, CommonDbSourceS
schema_name=self.context.database_schema.name.__root__,
)
+ f"/procedure/{stored_procedure.name}"
+ f"{quote(stored_procedure.signature) if stored_procedure.signature else ''}"
+ f"{stored_procedure.signature if stored_procedure.signature else ''}"
),
)
)
@ -555,29 +544,19 @@ class SnowflakeSource(LifeCycleQueryMixin, StoredProcedureMixin, CommonDbSourceS
)
)
def get_stored_procedure_queries(self) -> Iterable[QueryByProcedure]:
def get_stored_procedure_queries_dict(self) -> Dict[str, List[QueryByProcedure]]:
"""
Pick the stored procedure name from the context
and return the list of associated queries
Return the dictionary associating stored procedures to the
queries they triggered
"""
# Only process if we actually have yield a stored procedure
if self.context.stored_procedure:
start, _ = get_start_and_end(self.source_config.queryLogDuration)
query = SNOWFLAKE_GET_STORED_PROCEDURE_QUERIES.format(
start_date=start,
warehouse=self.service_connection.warehouse,
schema_name=self.context.database_schema.name.__root__,
database_name=self.context.database.name.__root__,
)
start, _ = get_start_and_end(self.source_config.queryLogDuration)
query = SNOWFLAKE_GET_STORED_PROCEDURE_QUERIES.format(
start_date=start,
warehouse=self.service_connection.warehouse,
)
queries_dict = self.procedure_queries_dict(
query=query,
schema_name=self.context.database_schema.name.__root__,
database_name=self.context.database.name.__root__,
)
queries_dict = self.procedure_queries_dict(
query=query,
)
for query_by_procedure in (
queries_dict.get(self.context.stored_procedure.name.__root__.lower())
or []
):
yield query_by_procedure
return queries_dict

View File

@ -14,6 +14,7 @@ Snowflake models
from typing import Optional
from pydantic import BaseModel, Field, validator
from requests.utils import quote
from metadata.generated.schema.entity.data.storedProcedure import Language
from metadata.utils.logger import ingestion_logger
@ -51,7 +52,8 @@ class SnowflakeStoredProcedure(BaseModel):
A signature may look like `(TABLE_NAME VARCHAR, NAME VARCHAR)`
We want it to keep only `(VARCHAR, VARCHAR).
This is needed to build the source URL of the procedure
This is needed to build the source URL of the procedure, so we'll
directly parse the quoted signature
"""
try:
clean_signature = signature.replace("(", "").replace(")", "")
@ -61,7 +63,7 @@ class SnowflakeStoredProcedure(BaseModel):
signature_list = clean_signature.split(",")
clean_signature_list = [elem.split(" ")[-1] for elem in signature_list]
return f"({','.join(clean_signature_list)})"
return f"({quote(', '.join(clean_signature_list))})"
except Exception as exc:
logger.warning(f"Error cleaning up Stored Procedure signature - [{exc}]")
return signature

View File

@ -143,9 +143,9 @@ SELECT /* sqlalchemy:_get_schema_columns */
ORDER BY ic.ordinal_position
"""
SNOWFLAKE_GET_CURRENT_REGION = "SELECT CURRENT_REGION() AS region"
SNOWFLAKE_GET_ORGANIZATION_NAME = "SELECT CURRENT_ORGANIZATION_NAME() AS NAME"
SNOWFLAKE_GET_CURRENT_ACCOUNT = "SELECT CURRENT_ACCOUNT() AS account"
SNOWFLAKE_GET_CURRENT_ACCOUNT = "SELECT CURRENT_ACCOUNT_NAME() AS ACCOUNT"
SNOWFLAKE_LIFE_CYCLE_QUERY = textwrap.dedent(
"""
@ -186,8 +186,6 @@ WITH SP_HISTORY AS (
WHERE QUERY_TYPE = 'CALL'
AND START_TIME >= '{start_date}'
AND WAREHOUSE_NAME = '{warehouse}'
AND SCHEMA_NAME = '{schema_name}'
AND DATABASE_NAME = '{database_name}'
),
Q_HISTORY AS (
SELECT
@ -198,20 +196,22 @@ Q_HISTORY AS (
START_TIME,
END_TIME,
TOTAL_ELAPSED_TIME/1000 AS DURATION,
USER_NAME
USER_NAME,
SCHEMA_NAME,
DATABASE_NAME
FROM SNOWFLAKE.ACCOUNT_USAGE.QUERY_HISTORY SP
WHERE QUERY_TYPE <> 'CALL'
AND QUERY_TEXT NOT LIKE '/* {{"app": "OpenMetadata", %%}} */%%'
AND QUERY_TEXT NOT LIKE '/* {{"app": "dbt", %%}} */%%'
AND START_TIME >= '{start_date}'
AND WAREHOUSE_NAME = '{warehouse}'
AND SCHEMA_NAME = '{schema_name}'
AND DATABASE_NAME = '{database_name}'
)
SELECT
SP.QUERY_ID AS PROCEDURE_ID,
Q.QUERY_ID AS QUERY_ID,
Q.QUERY_TYPE AS QUERY_TYPE,
Q.DATABASE_NAME AS QUERY_DATABASE_NAME,
Q.SCHEMA_NAME AS QUERY_SCHEMA_NAME,
SP.QUERY_TEXT AS PROCEDURE_TEXT,
SP.START_TIME AS PROCEDURE_START_TIME,
SP.END_TIME AS PROCEDURE_END_TIME,

View File

@ -13,16 +13,17 @@ Mixin class with common Stored Procedures logic aimed at lineage.
"""
import re
import traceback
from abc import ABC, abstractmethod
from collections import defaultdict
from datetime import datetime
from functools import lru_cache
from typing import Dict, Iterable, List, Optional
from typing import Dict, Iterable, List, Optional, Union
from pydantic import BaseModel, Field
from sqlalchemy.engine import Engine
from metadata.generated.schema.api.data.createQuery import CreateQueryRequest
from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest
from metadata.generated.schema.entity.data.storedProcedure import StoredProcedure
from metadata.generated.schema.metadataIngestion.databaseServiceMetadataPipeline import (
DatabaseServiceMetadataPipeline,
)
@ -35,9 +36,12 @@ from metadata.ingestion.lineage.models import ConnectionTypeDialectMapper
from metadata.ingestion.lineage.sql_lineage import get_lineage_by_query
from metadata.ingestion.models.topology import TopologyContext
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.utils.logger import ingestion_logger
from metadata.utils.stored_procedures import get_procedure_name_from_call
from metadata.utils.time_utils import convert_timestamp_to_milliseconds
logger = ingestion_logger()
class QueryByProcedure(BaseModel):
"""
@ -47,6 +51,8 @@ class QueryByProcedure(BaseModel):
procedure_id: str = Field(..., alias="PROCEDURE_ID")
query_id: str = Field(..., alias="QUERY_ID")
query_type: str = Field(..., alias="QUERY_TYPE")
query_database_name: str = Field(None, alias="QUERY_DATABASE_NAME")
query_schema_name: str = Field(None, alias="QUERY_SCHEMA_NAME")
procedure_text: str = Field(..., alias="PROCEDURE_TEXT")
procedure_start_time: datetime = Field(..., alias="PROCEDURE_START_TIME")
procedure_end_time: datetime = Field(..., alias="PROCEDURE_END_TIME")
@ -59,7 +65,7 @@ class QueryByProcedure(BaseModel):
allow_population_by_field_name = True
class StoredProcedureMixin:
class StoredProcedureMixin(ABC):
"""
The full flow is:
1. List Stored Procedures
@ -79,12 +85,14 @@ class StoredProcedureMixin:
engine: Engine
metadata: OpenMetadata
@lru_cache(
maxsize=1
) # Limit the caching since it cannot be repeated due to the topology ordering
def procedure_queries_dict(
self, query: str, schema_name: str, database_name: str
) -> Dict[str, List[QueryByProcedure]]:
@abstractmethod
def get_stored_procedure_queries_dict(self) -> Dict[str, List[QueryByProcedure]]:
"""
Return the dictionary associating stored procedures to the
queries they triggered
"""
def procedure_queries_dict(self, query: str) -> Dict[str, List[QueryByProcedure]]:
"""
Cache the queries ran for the stored procedures in the last `queryLogDuration` days.
@ -100,8 +108,6 @@ class StoredProcedureMixin:
query_by_procedure = QueryByProcedure.parse_obj(dict(row))
procedure_name = get_procedure_name_from_call(
query_text=query_by_procedure.procedure_text,
schema_name=schema_name,
database_name=database_name,
)
queries_dict[procedure_name].append(query_by_procedure)
except Exception as exc:
@ -130,21 +136,23 @@ class StoredProcedureMixin:
return False
def yield_procedure_lineage(
self, query_by_procedure: QueryByProcedure
self, query_by_procedure: QueryByProcedure, procedure: StoredProcedure
) -> Iterable[Either[AddLineageRequest]]:
"""Add procedure lineage from its query"""
self.context.stored_procedure_query_lineage = False
if self.is_lineage_query(
query_type=query_by_procedure.query_type,
query_text=query_by_procedure.query_text,
):
self.context.stored_procedure_query_lineage = True
for either_lineage in get_lineage_by_query(
self.metadata,
query=query_by_procedure.query_text,
service_name=self.context.database_service.name.__root__,
database_name=self.context.database.name.__root__,
schema_name=self.context.database_schema.name.__root__,
database_name=query_by_procedure.query_database_name,
schema_name=query_by_procedure.query_schema_name,
dialect=ConnectionTypeDialectMapper.dialect_of(
self.context.database_service.serviceType.value
),
@ -153,14 +161,14 @@ class StoredProcedureMixin:
):
if either_lineage.right.edge.lineageDetails:
either_lineage.right.edge.lineageDetails.pipeline = EntityReference(
id=self.context.stored_procedure.id,
id=procedure.id,
type="storedProcedure",
)
yield either_lineage
def yield_procedure_query(
self, query_by_procedure: QueryByProcedure
self, query_by_procedure: QueryByProcedure, procedure: StoredProcedure
) -> Iterable[Either[CreateQueryRequest]]:
"""Check the queries triggered by the procedure and add their lineage, if any"""
@ -175,10 +183,30 @@ class StoredProcedureMixin:
)
),
triggeredBy=EntityReference(
id=self.context.stored_procedure.id,
id=procedure.id,
type="storedProcedure",
),
processedLineage=bool(self.context.stored_procedure_query_lineage),
service=self.context.database_service.name.__root__,
)
)
def yield_procedure_lineage_and_queries(
self,
) -> Iterable[Either[Union[AddLineageRequest, CreateQueryRequest]]]:
"""Get all the queries and procedures list and yield them"""
if self.context.stored_procedures:
logger.info("Processing Lineage for Stored Procedures")
# First, get all the query history
queries_dict = self.get_stored_procedure_queries_dict()
# Then for each procedure, iterate over all its queries
for procedure in self.context.stored_procedures:
for query_by_procedure in (
queries_dict.get(procedure.name.__root__.lower()) or []
):
yield from self.yield_procedure_lineage(
query_by_procedure=query_by_procedure, procedure=procedure
)
yield from self.yield_procedure_query(
query_by_procedure=query_by_procedure, procedure=procedure
)

View File

@ -23,7 +23,7 @@ NAME_PATTERN = r"(?<=call)(.*)(?=\()"
def get_procedure_name_from_call(
query_text: str, schema_name: str, database_name: str, sensitive_match: bool = False
query_text: str, sensitive_match: bool = False
) -> Optional[str]:
"""
In the query text we'll have:
@ -47,8 +47,8 @@ def get_procedure_name_from_call(
res.group(0) # Get the first match
.strip() # Remove whitespace
.lower() # Replace all the lowercase variants of the procedure name prefixes
.replace(f"{database_name.lower()}.", "")
.replace(f"{schema_name.lower()}.", "")
.replace("`", "") # Clean weird characters from escaping the SQL
.split(".")[-1]
)
except Exception as exc:
logger.warning(

View File

@ -74,8 +74,8 @@ MOCK_SCHEMA_NAME_1 = "INFORMATION_SCHEMA"
MOCK_SCHEMA_NAME_2 = "TPCDS_SF10TCL"
MOCK_VIEW_NAME = "COLUMNS"
MOCK_TABLE_NAME = "CALL_CENTER"
EXPECTED_SNOW_URL_VIEW = "https://app.snowflake.com/us-west-2/random_account/#/data/databases/SNOWFLAKE_SAMPLE_DATA/schemas/INFORMATION_SCHEMA/view/COLUMNS"
EXPECTED_SNOW_URL_TABLE = "https://app.snowflake.com/us-west-2/random_account/#/data/databases/SNOWFLAKE_SAMPLE_DATA/schemas/TPCDS_SF10TCL/table/CALL_CENTER"
EXPECTED_SNOW_URL_VIEW = "https://app.snowflake.com/random_org/random_account/#/data/databases/SNOWFLAKE_SAMPLE_DATA/schemas/INFORMATION_SCHEMA/view/COLUMNS"
EXPECTED_SNOW_URL_TABLE = "https://app.snowflake.com/random_org/random_account/#/data/databases/SNOWFLAKE_SAMPLE_DATA/schemas/TPCDS_SF10TCL/table/CALL_CENTER"
class SnowflakeUnitTest(TestCase):
@ -135,15 +135,15 @@ class SnowflakeUnitTest(TestCase):
):
with patch.object(
SnowflakeSource,
"region",
return_value="us-west-2",
"org_name",
return_value="random_org",
new_callable=PropertyMock,
):
self._assert_urls()
with patch.object(
SnowflakeSource,
"region",
"org_name",
new_callable=PropertyMock,
return_value=None,
):

View File

@ -24,8 +24,6 @@ class StoredProceduresTests(TestCase):
self.assertEquals(
get_procedure_name_from_call(
query_text="CALL db.schema.procedure_name(...)",
schema_name="schema",
database_name="db",
),
"procedure_name",
)
@ -33,8 +31,6 @@ class StoredProceduresTests(TestCase):
self.assertEquals(
get_procedure_name_from_call(
query_text="CALL schema.procedure_name(...)",
schema_name="schema",
database_name="db",
),
"procedure_name",
)
@ -42,8 +38,6 @@ class StoredProceduresTests(TestCase):
self.assertEquals(
get_procedure_name_from_call(
query_text="CALL procedure_name(...)",
schema_name="schema",
database_name="db",
),
"procedure_name",
)
@ -51,8 +45,6 @@ class StoredProceduresTests(TestCase):
self.assertEquals(
get_procedure_name_from_call(
query_text="CALL DB.SCHEMA.PROCEDURE_NAME(...)",
schema_name="SCHEMA",
database_name="DB",
),
"procedure_name",
)
@ -60,7 +52,5 @@ class StoredProceduresTests(TestCase):
self.assertIsNone(
get_procedure_name_from_call(
query_text="something very random",
schema_name="schema",
database_name="db",
)
)