mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-08-06 16:18:05 +00:00
parent
548a0ab722
commit
d6d9afa8be
@ -10,7 +10,7 @@
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
@ -18,16 +18,21 @@ from cryptography.hazmat.primitives.asymmetric import dsa, rsa
|
||||
from snowflake.sqlalchemy.custom_types import VARIANT
|
||||
from snowflake.sqlalchemy.snowdialect import SnowflakeDialect, ischema_names
|
||||
from sqlalchemy.engine import reflection
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.inspection import inspect
|
||||
from sqlalchemy.sql import text
|
||||
|
||||
from metadata.generated.schema.entity.data.database import Database
|
||||
from metadata.generated.schema.entity.data.table import TableData
|
||||
from metadata.generated.schema.entity.services.databaseService import (
|
||||
DatabaseServiceType,
|
||||
)
|
||||
from metadata.generated.schema.type.entityReference import EntityReference
|
||||
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
|
||||
from metadata.ingestion.source.sql_source import SQLSource
|
||||
from metadata.ingestion.source.sql_source_common import SQLConnectionConfig
|
||||
from metadata.utils.column_type_parser import create_sqlalchemy_type
|
||||
from metadata.utils.engines import get_engine
|
||||
|
||||
GEOGRAPHY = create_sqlalchemy_type("GEOGRAPHY")
|
||||
ischema_names["VARIANT"] = VARIANT
|
||||
@ -39,8 +44,8 @@ logger: logging.Logger = logging.getLogger(__name__)
|
||||
class SnowflakeConfig(SQLConnectionConfig):
|
||||
scheme = "snowflake"
|
||||
account: str
|
||||
database: str
|
||||
warehouse: str
|
||||
database: Optional[str]
|
||||
warehouse: Optional[str]
|
||||
result_limit: int = 1000
|
||||
role: Optional[str]
|
||||
duration: Optional[int]
|
||||
@ -76,6 +81,31 @@ class SnowflakeSource(SQLSource):
|
||||
config.connect_args["private_key"] = pkb
|
||||
super().__init__(config, metadata_config, ctx)
|
||||
|
||||
def get_databases(self) -> Iterable[Inspector]:
|
||||
if self.config.database != None:
|
||||
yield from super().get_databases()
|
||||
else:
|
||||
query = "SHOW DATABASES"
|
||||
results = self.connection.execute(query)
|
||||
for res in results:
|
||||
|
||||
row = list(res)
|
||||
use_db_query = f"USE DATABASE {row[1]}"
|
||||
self.connection.execute(use_db_query)
|
||||
logger.info(f"Ingesting from database: {row[1]}")
|
||||
self.config.database = row[1]
|
||||
self.engine = get_engine(self.config)
|
||||
yield inspect(self.engine)
|
||||
|
||||
def get_table_fqn(self, service_name, schema, table_name) -> str:
|
||||
return f"{service_name}.{self.config.database}_{schema}.{table_name}"
|
||||
|
||||
def _get_database(self, schema: str) -> Database:
|
||||
return Database(
|
||||
name=self.config.database + "_" + schema.replace(".", "_DOT_"),
|
||||
service=EntityReference(id=self.service.id, type=self.config.service_type),
|
||||
)
|
||||
|
||||
def fetch_sample_data(self, schema: str, table: str) -> Optional[TableData]:
|
||||
resp_sample_data = super().fetch_sample_data(schema, table)
|
||||
if not resp_sample_data:
|
||||
@ -116,4 +146,10 @@ def _get_table_comment(self, connection, table_name, schema=None, **kw):
|
||||
return cursor.fetchone() # pylint: disable=protected-access
|
||||
|
||||
|
||||
@reflection.cache
|
||||
def get_unique_constraints(self, connection, table_name, schema=None, **kw):
|
||||
return []
|
||||
|
||||
|
||||
SnowflakeDialect._get_table_comment = _get_table_comment
|
||||
SnowflakeDialect.get_unique_constraints = get_unique_constraints
|
||||
|
@ -196,23 +196,29 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
|
||||
logger.error(f"Failed to generate sample data for {table} - {err}")
|
||||
return None
|
||||
|
||||
def next_record(self) -> Iterable[Entity]:
|
||||
inspector = inspect(self.engine)
|
||||
schema_names = inspector.get_schema_names()
|
||||
def get_databases(self) -> Iterable[Inspector]:
|
||||
yield inspect(self.engine)
|
||||
|
||||
for schema in schema_names:
|
||||
# clear any previous source database state
|
||||
self.database_source_state.clear()
|
||||
if not self.sql_config.schema_filter_pattern.included(schema):
|
||||
self.status.filter(schema, "Schema pattern not allowed")
|
||||
continue
|
||||
if self.config.include_tables:
|
||||
yield from self.fetch_tables(inspector, schema)
|
||||
if self.config.include_views:
|
||||
yield from self.fetch_views(inspector, schema)
|
||||
if self.config.mark_deleted_tables_as_deleted:
|
||||
schema_fqdn = f"{self.config.service_name}.{schema}"
|
||||
yield from self.delete_tables(schema_fqdn)
|
||||
def get_table_fqn(self, service_name, schema, table_name) -> str:
|
||||
return f"{service_name}.{schema}.{table_name}"
|
||||
|
||||
def next_record(self) -> Iterable[Entity]:
|
||||
inspectors = self.get_databases()
|
||||
for inspector in inspectors:
|
||||
schema_names = inspector.get_schema_names()
|
||||
for schema in schema_names:
|
||||
# clear any previous source database state
|
||||
self.database_source_state.clear()
|
||||
if not self.sql_config.schema_filter_pattern.included(schema):
|
||||
self.status.filter(schema, "Schema pattern not allowed")
|
||||
continue
|
||||
if self.config.include_tables:
|
||||
yield from self.fetch_tables(inspector, schema)
|
||||
if self.config.include_views:
|
||||
yield from self.fetch_views(inspector, schema)
|
||||
if self.config.mark_deleted_tables_as_deleted:
|
||||
schema_fqdn = f"{self.config.service_name}.{schema}"
|
||||
yield from self.delete_tables(schema_fqdn)
|
||||
|
||||
def fetch_tables(
|
||||
self, inspector: Inspector, schema: str
|
||||
@ -240,7 +246,7 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
|
||||
)
|
||||
continue
|
||||
description = _get_table_description(schema, table_name, inspector)
|
||||
fqn = f"{self.config.service_name}.{schema}.{table_name}"
|
||||
fqn = self.get_table_fqn(self.config.service_name, schema, table_name)
|
||||
self.database_source_state.add(fqn)
|
||||
self.table_constraints = None
|
||||
table_columns = self._get_columns(schema, table_name, inspector)
|
||||
@ -322,7 +328,7 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
|
||||
)
|
||||
except NotImplementedError:
|
||||
view_definition = ""
|
||||
fqn = f"{self.config.service_name}.{schema}.{view_name}"
|
||||
fqn = self.get_table_fqn(self.config.service_name, schema, view_name)
|
||||
self.database_source_state.add(fqn)
|
||||
table = Table(
|
||||
id=uuid.uuid4(),
|
||||
@ -425,7 +431,9 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
|
||||
try:
|
||||
_, database, table = node.split(".", 2)
|
||||
table = table.replace(".", "_DOT_")
|
||||
table_fqn = f"{self.config.service_name}.{database}.{table}"
|
||||
table_fqn = self.get_table_fqn(
|
||||
self.config.service_name, database, table
|
||||
)
|
||||
upstream_nodes.append(table_fqn)
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
logger.error(
|
||||
|
Loading…
x
Reference in New Issue
Block a user