diff --git a/ingestion/src/metadata/ingestion/source/snowflake.py b/ingestion/src/metadata/ingestion/source/snowflake.py index 5e8fcd0ac69..285c7fa28a5 100644 --- a/ingestion/src/metadata/ingestion/source/snowflake.py +++ b/ingestion/src/metadata/ingestion/source/snowflake.py @@ -10,7 +10,7 @@ # limitations under the License. import logging import os -from typing import Optional +from typing import Iterable, Optional from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization @@ -18,16 +18,21 @@ from cryptography.hazmat.primitives.asymmetric import dsa, rsa from snowflake.sqlalchemy.custom_types import VARIANT from snowflake.sqlalchemy.snowdialect import SnowflakeDialect, ischema_names from sqlalchemy.engine import reflection +from sqlalchemy.engine.reflection import Inspector +from sqlalchemy.inspection import inspect from sqlalchemy.sql import text +from metadata.generated.schema.entity.data.database import Database from metadata.generated.schema.entity.data.table import TableData from metadata.generated.schema.entity.services.databaseService import ( DatabaseServiceType, ) +from metadata.generated.schema.type.entityReference import EntityReference from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig from metadata.ingestion.source.sql_source import SQLSource from metadata.ingestion.source.sql_source_common import SQLConnectionConfig from metadata.utils.column_type_parser import create_sqlalchemy_type +from metadata.utils.engines import get_engine GEOGRAPHY = create_sqlalchemy_type("GEOGRAPHY") ischema_names["VARIANT"] = VARIANT @@ -39,8 +44,8 @@ logger: logging.Logger = logging.getLogger(__name__) class SnowflakeConfig(SQLConnectionConfig): scheme = "snowflake" account: str - database: str - warehouse: str + database: Optional[str] + warehouse: Optional[str] result_limit: int = 1000 role: Optional[str] duration: Optional[int] @@ -76,6 +81,31 @@ class SnowflakeSource(SQLSource): config.connect_args["private_key"] = pkb super().__init__(config, metadata_config, ctx) + def get_databases(self) -> Iterable[Inspector]: + if self.config.database != None: + yield from super().get_databases() + else: + query = "SHOW DATABASES" + results = self.connection.execute(query) + for res in results: + + row = list(res) + use_db_query = f"USE DATABASE {row[1]}" + self.connection.execute(use_db_query) + logger.info(f"Ingesting from database: {row[1]}") + self.config.database = row[1] + self.engine = get_engine(self.config) + yield inspect(self.engine) + + def get_table_fqn(self, service_name, schema, table_name) -> str: + return f"{service_name}.{self.config.database}_{schema}.{table_name}" + + def _get_database(self, schema: str) -> Database: + return Database( + name=self.config.database + "_" + schema.replace(".", "_DOT_"), + service=EntityReference(id=self.service.id, type=self.config.service_type), + ) + def fetch_sample_data(self, schema: str, table: str) -> Optional[TableData]: resp_sample_data = super().fetch_sample_data(schema, table) if not resp_sample_data: @@ -116,4 +146,10 @@ def _get_table_comment(self, connection, table_name, schema=None, **kw): return cursor.fetchone() # pylint: disable=protected-access +@reflection.cache +def get_unique_constraints(self, connection, table_name, schema=None, **kw): + return [] + + SnowflakeDialect._get_table_comment = _get_table_comment +SnowflakeDialect.get_unique_constraints = get_unique_constraints diff --git a/ingestion/src/metadata/ingestion/source/sql_source.py b/ingestion/src/metadata/ingestion/source/sql_source.py index 584e7d58c4c..084267b470a 100644 --- a/ingestion/src/metadata/ingestion/source/sql_source.py +++ b/ingestion/src/metadata/ingestion/source/sql_source.py @@ -196,23 +196,29 @@ class SQLSource(Source[OMetaDatabaseAndTable]): logger.error(f"Failed to generate sample data for {table} - {err}") return None - def next_record(self) -> Iterable[Entity]: - inspector = inspect(self.engine) - schema_names = inspector.get_schema_names() + def get_databases(self) -> Iterable[Inspector]: + yield inspect(self.engine) - for schema in schema_names: - # clear any previous source database state - self.database_source_state.clear() - if not self.sql_config.schema_filter_pattern.included(schema): - self.status.filter(schema, "Schema pattern not allowed") - continue - if self.config.include_tables: - yield from self.fetch_tables(inspector, schema) - if self.config.include_views: - yield from self.fetch_views(inspector, schema) - if self.config.mark_deleted_tables_as_deleted: - schema_fqdn = f"{self.config.service_name}.{schema}" - yield from self.delete_tables(schema_fqdn) + def get_table_fqn(self, service_name, schema, table_name) -> str: + return f"{service_name}.{schema}.{table_name}" + + def next_record(self) -> Iterable[Entity]: + inspectors = self.get_databases() + for inspector in inspectors: + schema_names = inspector.get_schema_names() + for schema in schema_names: + # clear any previous source database state + self.database_source_state.clear() + if not self.sql_config.schema_filter_pattern.included(schema): + self.status.filter(schema, "Schema pattern not allowed") + continue + if self.config.include_tables: + yield from self.fetch_tables(inspector, schema) + if self.config.include_views: + yield from self.fetch_views(inspector, schema) + if self.config.mark_deleted_tables_as_deleted: + schema_fqdn = f"{self.config.service_name}.{schema}" + yield from self.delete_tables(schema_fqdn) def fetch_tables( self, inspector: Inspector, schema: str @@ -240,7 +246,7 @@ class SQLSource(Source[OMetaDatabaseAndTable]): ) continue description = _get_table_description(schema, table_name, inspector) - fqn = f"{self.config.service_name}.{schema}.{table_name}" + fqn = self.get_table_fqn(self.config.service_name, schema, table_name) self.database_source_state.add(fqn) self.table_constraints = None table_columns = self._get_columns(schema, table_name, inspector) @@ -322,7 +328,7 @@ class SQLSource(Source[OMetaDatabaseAndTable]): ) except NotImplementedError: view_definition = "" - fqn = f"{self.config.service_name}.{schema}.{view_name}" + fqn = self.get_table_fqn(self.config.service_name, schema, view_name) self.database_source_state.add(fqn) table = Table( id=uuid.uuid4(), @@ -425,7 +431,9 @@ class SQLSource(Source[OMetaDatabaseAndTable]): try: _, database, table = node.split(".", 2) table = table.replace(".", "_DOT_") - table_fqn = f"{self.config.service_name}.{database}.{table}" + table_fqn = self.get_table_fqn( + self.config.service_name, database, table + ) upstream_nodes.append(table_fqn) except Exception as err: # pylint: disable=broad-except logger.error(