Fix #3118: Allowing multiple databases for snowflake (#3464)

This commit is contained in:
Mayur Singal 2022-03-21 23:01:17 +05:30 committed by GitHub
parent 548a0ab722
commit d6d9afa8be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 22 deletions

View File

@ -10,7 +10,7 @@
# limitations under the License.
import logging
import os
from typing import Optional
from typing import Iterable, Optional
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
@ -18,16 +18,21 @@ from cryptography.hazmat.primitives.asymmetric import dsa, rsa
from snowflake.sqlalchemy.custom_types import VARIANT
from snowflake.sqlalchemy.snowdialect import SnowflakeDialect, ischema_names
from sqlalchemy.engine import reflection
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.inspection import inspect
from sqlalchemy.sql import text
from metadata.generated.schema.entity.data.database import Database
from metadata.generated.schema.entity.data.table import TableData
from metadata.generated.schema.entity.services.databaseService import (
DatabaseServiceType,
)
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
from metadata.ingestion.source.sql_source import SQLSource
from metadata.ingestion.source.sql_source_common import SQLConnectionConfig
from metadata.utils.column_type_parser import create_sqlalchemy_type
from metadata.utils.engines import get_engine
GEOGRAPHY = create_sqlalchemy_type("GEOGRAPHY")
ischema_names["VARIANT"] = VARIANT
@ -39,8 +44,8 @@ logger: logging.Logger = logging.getLogger(__name__)
class SnowflakeConfig(SQLConnectionConfig):
scheme = "snowflake"
account: str
database: str
warehouse: str
database: Optional[str]
warehouse: Optional[str]
result_limit: int = 1000
role: Optional[str]
duration: Optional[int]
@ -76,6 +81,31 @@ class SnowflakeSource(SQLSource):
config.connect_args["private_key"] = pkb
super().__init__(config, metadata_config, ctx)
def get_databases(self) -> Iterable[Inspector]:
if self.config.database != None:
yield from super().get_databases()
else:
query = "SHOW DATABASES"
results = self.connection.execute(query)
for res in results:
row = list(res)
use_db_query = f"USE DATABASE {row[1]}"
self.connection.execute(use_db_query)
logger.info(f"Ingesting from database: {row[1]}")
self.config.database = row[1]
self.engine = get_engine(self.config)
yield inspect(self.engine)
def get_table_fqn(self, service_name, schema, table_name) -> str:
return f"{service_name}.{self.config.database}_{schema}.{table_name}"
def _get_database(self, schema: str) -> Database:
return Database(
name=self.config.database + "_" + schema.replace(".", "_DOT_"),
service=EntityReference(id=self.service.id, type=self.config.service_type),
)
def fetch_sample_data(self, schema: str, table: str) -> Optional[TableData]:
resp_sample_data = super().fetch_sample_data(schema, table)
if not resp_sample_data:
@ -116,4 +146,10 @@ def _get_table_comment(self, connection, table_name, schema=None, **kw):
return cursor.fetchone() # pylint: disable=protected-access
@reflection.cache
def get_unique_constraints(self, connection, table_name, schema=None, **kw):
return []
SnowflakeDialect._get_table_comment = _get_table_comment
SnowflakeDialect.get_unique_constraints = get_unique_constraints

View File

@ -196,23 +196,29 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
logger.error(f"Failed to generate sample data for {table} - {err}")
return None
def next_record(self) -> Iterable[Entity]:
inspector = inspect(self.engine)
schema_names = inspector.get_schema_names()
def get_databases(self) -> Iterable[Inspector]:
yield inspect(self.engine)
for schema in schema_names:
# clear any previous source database state
self.database_source_state.clear()
if not self.sql_config.schema_filter_pattern.included(schema):
self.status.filter(schema, "Schema pattern not allowed")
continue
if self.config.include_tables:
yield from self.fetch_tables(inspector, schema)
if self.config.include_views:
yield from self.fetch_views(inspector, schema)
if self.config.mark_deleted_tables_as_deleted:
schema_fqdn = f"{self.config.service_name}.{schema}"
yield from self.delete_tables(schema_fqdn)
def get_table_fqn(self, service_name, schema, table_name) -> str:
return f"{service_name}.{schema}.{table_name}"
def next_record(self) -> Iterable[Entity]:
inspectors = self.get_databases()
for inspector in inspectors:
schema_names = inspector.get_schema_names()
for schema in schema_names:
# clear any previous source database state
self.database_source_state.clear()
if not self.sql_config.schema_filter_pattern.included(schema):
self.status.filter(schema, "Schema pattern not allowed")
continue
if self.config.include_tables:
yield from self.fetch_tables(inspector, schema)
if self.config.include_views:
yield from self.fetch_views(inspector, schema)
if self.config.mark_deleted_tables_as_deleted:
schema_fqdn = f"{self.config.service_name}.{schema}"
yield from self.delete_tables(schema_fqdn)
def fetch_tables(
self, inspector: Inspector, schema: str
@ -240,7 +246,7 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
)
continue
description = _get_table_description(schema, table_name, inspector)
fqn = f"{self.config.service_name}.{schema}.{table_name}"
fqn = self.get_table_fqn(self.config.service_name, schema, table_name)
self.database_source_state.add(fqn)
self.table_constraints = None
table_columns = self._get_columns(schema, table_name, inspector)
@ -322,7 +328,7 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
)
except NotImplementedError:
view_definition = ""
fqn = f"{self.config.service_name}.{schema}.{view_name}"
fqn = self.get_table_fqn(self.config.service_name, schema, view_name)
self.database_source_state.add(fqn)
table = Table(
id=uuid.uuid4(),
@ -425,7 +431,9 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
try:
_, database, table = node.split(".", 2)
table = table.replace(".", "_DOT_")
table_fqn = f"{self.config.service_name}.{database}.{table}"
table_fqn = self.get_table_fqn(
self.config.service_name, database, table
)
upstream_nodes.append(table_fqn)
except Exception as err: # pylint: disable=broad-except
logger.error(