Fix #3118: Allowing multiple databases for snowflake (#3464)

This commit is contained in:
Mayur Singal 2022-03-21 23:01:17 +05:30 committed by GitHub
parent 548a0ab722
commit d6d9afa8be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 22 deletions

View File

@ -10,7 +10,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import os import os
from typing import Optional from typing import Iterable, Optional
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives import serialization
@ -18,16 +18,21 @@ from cryptography.hazmat.primitives.asymmetric import dsa, rsa
from snowflake.sqlalchemy.custom_types import VARIANT from snowflake.sqlalchemy.custom_types import VARIANT
from snowflake.sqlalchemy.snowdialect import SnowflakeDialect, ischema_names from snowflake.sqlalchemy.snowdialect import SnowflakeDialect, ischema_names
from sqlalchemy.engine import reflection from sqlalchemy.engine import reflection
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.inspection import inspect
from sqlalchemy.sql import text from sqlalchemy.sql import text
from metadata.generated.schema.entity.data.database import Database
from metadata.generated.schema.entity.data.table import TableData from metadata.generated.schema.entity.data.table import TableData
from metadata.generated.schema.entity.services.databaseService import ( from metadata.generated.schema.entity.services.databaseService import (
DatabaseServiceType, DatabaseServiceType,
) )
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
from metadata.ingestion.source.sql_source import SQLSource from metadata.ingestion.source.sql_source import SQLSource
from metadata.ingestion.source.sql_source_common import SQLConnectionConfig from metadata.ingestion.source.sql_source_common import SQLConnectionConfig
from metadata.utils.column_type_parser import create_sqlalchemy_type from metadata.utils.column_type_parser import create_sqlalchemy_type
from metadata.utils.engines import get_engine
GEOGRAPHY = create_sqlalchemy_type("GEOGRAPHY") GEOGRAPHY = create_sqlalchemy_type("GEOGRAPHY")
ischema_names["VARIANT"] = VARIANT ischema_names["VARIANT"] = VARIANT
@ -39,8 +44,8 @@ logger: logging.Logger = logging.getLogger(__name__)
class SnowflakeConfig(SQLConnectionConfig): class SnowflakeConfig(SQLConnectionConfig):
scheme = "snowflake" scheme = "snowflake"
account: str account: str
database: str database: Optional[str]
warehouse: str warehouse: Optional[str]
result_limit: int = 1000 result_limit: int = 1000
role: Optional[str] role: Optional[str]
duration: Optional[int] duration: Optional[int]
@ -76,6 +81,31 @@ class SnowflakeSource(SQLSource):
config.connect_args["private_key"] = pkb config.connect_args["private_key"] = pkb
super().__init__(config, metadata_config, ctx) super().__init__(config, metadata_config, ctx)
def get_databases(self) -> Iterable[Inspector]:
if self.config.database != None:
yield from super().get_databases()
else:
query = "SHOW DATABASES"
results = self.connection.execute(query)
for res in results:
row = list(res)
use_db_query = f"USE DATABASE {row[1]}"
self.connection.execute(use_db_query)
logger.info(f"Ingesting from database: {row[1]}")
self.config.database = row[1]
self.engine = get_engine(self.config)
yield inspect(self.engine)
def get_table_fqn(self, service_name, schema, table_name) -> str:
return f"{service_name}.{self.config.database}_{schema}.{table_name}"
def _get_database(self, schema: str) -> Database:
return Database(
name=self.config.database + "_" + schema.replace(".", "_DOT_"),
service=EntityReference(id=self.service.id, type=self.config.service_type),
)
def fetch_sample_data(self, schema: str, table: str) -> Optional[TableData]: def fetch_sample_data(self, schema: str, table: str) -> Optional[TableData]:
resp_sample_data = super().fetch_sample_data(schema, table) resp_sample_data = super().fetch_sample_data(schema, table)
if not resp_sample_data: if not resp_sample_data:
@ -116,4 +146,10 @@ def _get_table_comment(self, connection, table_name, schema=None, **kw):
return cursor.fetchone() # pylint: disable=protected-access return cursor.fetchone() # pylint: disable=protected-access
@reflection.cache
def get_unique_constraints(self, connection, table_name, schema=None, **kw):
return []
SnowflakeDialect._get_table_comment = _get_table_comment SnowflakeDialect._get_table_comment = _get_table_comment
SnowflakeDialect.get_unique_constraints = get_unique_constraints

View File

@ -196,10 +196,16 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
logger.error(f"Failed to generate sample data for {table} - {err}") logger.error(f"Failed to generate sample data for {table} - {err}")
return None return None
def next_record(self) -> Iterable[Entity]: def get_databases(self) -> Iterable[Inspector]:
inspector = inspect(self.engine) yield inspect(self.engine)
schema_names = inspector.get_schema_names()
def get_table_fqn(self, service_name, schema, table_name) -> str:
return f"{service_name}.{schema}.{table_name}"
def next_record(self) -> Iterable[Entity]:
inspectors = self.get_databases()
for inspector in inspectors:
schema_names = inspector.get_schema_names()
for schema in schema_names: for schema in schema_names:
# clear any previous source database state # clear any previous source database state
self.database_source_state.clear() self.database_source_state.clear()
@ -240,7 +246,7 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
) )
continue continue
description = _get_table_description(schema, table_name, inspector) description = _get_table_description(schema, table_name, inspector)
fqn = f"{self.config.service_name}.{schema}.{table_name}" fqn = self.get_table_fqn(self.config.service_name, schema, table_name)
self.database_source_state.add(fqn) self.database_source_state.add(fqn)
self.table_constraints = None self.table_constraints = None
table_columns = self._get_columns(schema, table_name, inspector) table_columns = self._get_columns(schema, table_name, inspector)
@ -322,7 +328,7 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
) )
except NotImplementedError: except NotImplementedError:
view_definition = "" view_definition = ""
fqn = f"{self.config.service_name}.{schema}.{view_name}" fqn = self.get_table_fqn(self.config.service_name, schema, view_name)
self.database_source_state.add(fqn) self.database_source_state.add(fqn)
table = Table( table = Table(
id=uuid.uuid4(), id=uuid.uuid4(),
@ -425,7 +431,9 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
try: try:
_, database, table = node.split(".", 2) _, database, table = node.split(".", 2)
table = table.replace(".", "_DOT_") table = table.replace(".", "_DOT_")
table_fqn = f"{self.config.service_name}.{database}.{table}" table_fqn = self.get_table_fqn(
self.config.service_name, database, table
)
upstream_nodes.append(table_fqn) upstream_nodes.append(table_fqn)
except Exception as err: # pylint: disable=broad-except except Exception as err: # pylint: disable=broad-except
logger.error( logger.error(