mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2026-01-06 04:26:57 +00:00
This commit is contained in:
parent
64bbdf5533
commit
9f14ef7fab
@ -17,6 +17,7 @@ from typing import Optional, Union
|
||||
|
||||
from databricks.sdk import WorkspaceClient
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.exc import DatabaseError
|
||||
from sqlalchemy.inspection import inspect
|
||||
|
||||
from metadata.generated.schema.entity.automations.workflow import (
|
||||
@ -33,7 +34,6 @@ from metadata.ingestion.connections.builders import (
|
||||
from metadata.ingestion.connections.test_connections import (
|
||||
test_connection_engine_step,
|
||||
test_connection_steps,
|
||||
test_query,
|
||||
)
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
from metadata.ingestion.source.database.databricks.client import DatabricksClient
|
||||
@ -42,6 +42,9 @@ from metadata.ingestion.source.database.databricks.queries import (
|
||||
DATABRICKS_GET_CATALOGS,
|
||||
)
|
||||
from metadata.utils.db_utils import get_host_from_host_port
|
||||
from metadata.utils.logger import ingestion_logger
|
||||
|
||||
logger = ingestion_logger()
|
||||
|
||||
|
||||
def get_connection_url(connection: DatabricksConnection) -> str:
|
||||
@ -84,6 +87,18 @@ def test_connection(
|
||||
"""
|
||||
client = DatabricksClient(service_connection)
|
||||
|
||||
def test_database_query(engine: Engine, statement: str):
|
||||
"""
|
||||
Method used to execute the given query and fetch a result
|
||||
to test if user has access to the tables specified
|
||||
in the sql statement
|
||||
"""
|
||||
try:
|
||||
connection = engine.connect()
|
||||
connection.execute(statement).fetchone()
|
||||
except DatabaseError as soe:
|
||||
logger.debug(f"Failed to fetch catalogs due to: {soe}")
|
||||
|
||||
if service_connection.useUnityCatalog:
|
||||
table_obj = DatabricksTable()
|
||||
|
||||
@ -121,7 +136,7 @@ def test_connection(
|
||||
"GetTables": inspector.get_table_names,
|
||||
"GetViews": inspector.get_view_names,
|
||||
"GetDatabases": partial(
|
||||
test_query,
|
||||
test_database_query,
|
||||
engine=connection,
|
||||
statement=DATABRICKS_GET_CATALOGS,
|
||||
),
|
||||
|
||||
@ -13,11 +13,12 @@
|
||||
import re
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from typing import Iterable
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from pyhive.sqlalchemy_hive import _type_map
|
||||
from sqlalchemy import types, util
|
||||
from sqlalchemy.engine import reflection
|
||||
from sqlalchemy.exc import DatabaseError
|
||||
from sqlalchemy.inspection import inspect
|
||||
from sqlalchemy.sql.sqltypes import String
|
||||
from sqlalchemy_databricks._dialect import DatabricksDialect
|
||||
@ -35,10 +36,13 @@ from metadata.ingestion.source.connections import get_connection
|
||||
from metadata.ingestion.source.database.column_type_parser import create_sqlalchemy_type
|
||||
from metadata.ingestion.source.database.common_db_source import CommonDbSourceService
|
||||
from metadata.ingestion.source.database.databricks.queries import (
|
||||
DATABRICKS_GET_CATALOGS,
|
||||
DATABRICKS_GET_TABLE_COMMENTS,
|
||||
DATABRICKS_VIEW_DEFINITIONS,
|
||||
)
|
||||
from metadata.ingestion.source.database.multi_db_source import MultiDBSource
|
||||
from metadata.utils import fqn
|
||||
from metadata.utils.constants import DEFAULT_DATABASE
|
||||
from metadata.utils.filters import filter_by_database
|
||||
from metadata.utils.logger import ingestion_logger
|
||||
from metadata.utils.sqlalchemy_utils import (
|
||||
@ -158,7 +162,7 @@ def get_columns(self, connection, table_name, schema=None, **kw):
|
||||
@reflection.cache
|
||||
def get_schema_names(self, connection, **kw): # pylint: disable=unused-argument
|
||||
# Equivalent to SHOW DATABASES
|
||||
if kw.get("database"):
|
||||
if kw.get("database") and kw.get("is_old_version") is not True:
|
||||
connection.execute(f"USE CATALOG '{kw.get('database')}'")
|
||||
return [row[0] for row in connection.execute("SHOW SCHEMAS")]
|
||||
|
||||
@ -238,13 +242,26 @@ DatabricksDialect.get_all_view_definitions = get_all_view_definitions
|
||||
reflection.Inspector.get_schema_names = get_schema_names_reflection
|
||||
|
||||
|
||||
class DatabricksLegacySource(CommonDbSourceService):
|
||||
class DatabricksLegacySource(CommonDbSourceService, MultiDBSource):
|
||||
"""
|
||||
Implements the necessary methods to extract
|
||||
Database metadata from Databricks Source using
|
||||
the legacy hive metastore method
|
||||
"""
|
||||
|
||||
def __init__(self, config: WorkflowSource, metadata: OpenMetadata):
|
||||
super().__init__(config, metadata)
|
||||
self.is_older_version = False
|
||||
self._init_version()
|
||||
|
||||
def _init_version(self):
|
||||
try:
|
||||
self.connection.execute(DATABRICKS_GET_CATALOGS).fetchone()
|
||||
self.is_older_version = False
|
||||
except DatabaseError as soe:
|
||||
logger.debug(f"Failed to fetch catalogs due to: {soe}")
|
||||
self.is_older_version = True
|
||||
|
||||
@classmethod
|
||||
def create(cls, config_dict, metadata: OpenMetadata):
|
||||
config: WorkflowSource = WorkflowSource.parse_obj(config_dict)
|
||||
@ -268,44 +285,55 @@ class DatabricksLegacySource(CommonDbSourceService):
|
||||
self.engine = get_connection(new_service_connection)
|
||||
self.inspector = inspect(self.engine)
|
||||
|
||||
def get_configured_database(self) -> Optional[str]:
|
||||
return self.service_connection.catalog
|
||||
|
||||
def get_database_names_raw(self) -> Iterable[str]:
|
||||
if not self.is_older_version:
|
||||
results = self.connection.execute(DATABRICKS_GET_CATALOGS)
|
||||
for res in results:
|
||||
if res:
|
||||
row = list(res)
|
||||
yield row[0]
|
||||
else:
|
||||
yield DEFAULT_DATABASE
|
||||
|
||||
def get_database_names(self) -> Iterable[str]:
|
||||
configured_catalog = self.service_connection.__dict__.get("catalog")
|
||||
configured_catalog = self.service_connection.catalog
|
||||
if configured_catalog:
|
||||
self.set_inspector(database_name=configured_catalog)
|
||||
yield configured_catalog
|
||||
else:
|
||||
results = self.connection.execute("SHOW CATALOGS")
|
||||
for res in results:
|
||||
if res:
|
||||
new_catalog = res[0]
|
||||
database_fqn = fqn.build(
|
||||
self.metadata,
|
||||
entity_type=Database,
|
||||
service_name=self.context.database_service.name.__root__,
|
||||
database_name=new_catalog,
|
||||
for new_catalog in self.get_database_names_raw():
|
||||
database_fqn = fqn.build(
|
||||
self.metadata,
|
||||
entity_type=Database,
|
||||
service_name=self.context.database_service.name.__root__,
|
||||
database_name=new_catalog,
|
||||
)
|
||||
if filter_by_database(
|
||||
self.source_config.databaseFilterPattern,
|
||||
database_fqn
|
||||
if self.source_config.useFqnForFiltering
|
||||
else new_catalog,
|
||||
):
|
||||
self.status.filter(database_fqn, "Database Filtered Out")
|
||||
continue
|
||||
try:
|
||||
self.set_inspector(database_name=new_catalog)
|
||||
yield new_catalog
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"Error trying to process database {new_catalog}: {exc}"
|
||||
)
|
||||
if filter_by_database(
|
||||
self.source_config.databaseFilterPattern,
|
||||
database_fqn
|
||||
if self.source_config.useFqnForFiltering
|
||||
else new_catalog,
|
||||
):
|
||||
self.status.filter(database_fqn, "Database Filtered Out")
|
||||
continue
|
||||
try:
|
||||
self.set_inspector(database_name=new_catalog)
|
||||
yield new_catalog
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"Error trying to process database {new_catalog}: {exc}"
|
||||
)
|
||||
|
||||
def get_raw_database_schema_names(self) -> Iterable[str]:
|
||||
if self.service_connection.__dict__.get("databaseSchema"):
|
||||
yield self.service_connection.databaseSchema
|
||||
else:
|
||||
for schema_name in self.inspector.get_schema_names(
|
||||
database=self.context.database.name.__root__
|
||||
database=self.context.database.name.__root__,
|
||||
is_old_version=self.is_older_version,
|
||||
):
|
||||
yield schema_name
|
||||
|
||||
@ -61,6 +61,7 @@ from metadata.ingestion.source.database.databricks.models import (
|
||||
ForeignConstrains,
|
||||
Type,
|
||||
)
|
||||
from metadata.ingestion.source.database.multi_db_source import MultiDBSource
|
||||
from metadata.ingestion.source.database.stored_procedures_mixin import QueryByProcedure
|
||||
from metadata.ingestion.source.models import TableView
|
||||
from metadata.utils import fqn
|
||||
@ -84,7 +85,7 @@ def from_dict(cls, dct: Dict[str, Any]) -> "TableConstraintList":
|
||||
TableConstraintList.from_dict = from_dict
|
||||
|
||||
|
||||
class DatabricksUnityCatalogSource(DatabaseServiceSource):
|
||||
class DatabricksUnityCatalogSource(DatabaseServiceSource, MultiDBSource):
|
||||
"""
|
||||
Implements the necessary methods to extract
|
||||
Database metadata from Databricks Source using
|
||||
@ -107,6 +108,13 @@ class DatabricksUnityCatalogSource(DatabaseServiceSource):
|
||||
self.table_constraints = []
|
||||
self.test_connection()
|
||||
|
||||
def get_configured_database(self) -> Optional[str]:
|
||||
return self.service_connection.catalog
|
||||
|
||||
def get_database_names_raw(self) -> Iterable[str]:
|
||||
for catalog in self.client.catalogs.list():
|
||||
yield catalog.name
|
||||
|
||||
@classmethod
|
||||
def create(cls, config_dict, metadata: OpenMetadata):
|
||||
config: WorkflowSource = WorkflowSource.parse_obj(config_dict)
|
||||
@ -131,31 +139,31 @@ class DatabricksUnityCatalogSource(DatabaseServiceSource):
|
||||
if self.service_connection.catalog:
|
||||
yield self.service_connection.catalog
|
||||
else:
|
||||
for catalog in self.client.catalogs.list():
|
||||
for catalog_name in self.get_database_names_raw():
|
||||
try:
|
||||
database_fqn = fqn.build(
|
||||
self.metadata,
|
||||
entity_type=Database,
|
||||
service_name=self.context.database_service.name.__root__,
|
||||
database_name=catalog.name,
|
||||
database_name=catalog_name,
|
||||
)
|
||||
if filter_by_database(
|
||||
self.config.sourceConfig.config.databaseFilterPattern,
|
||||
database_fqn
|
||||
if self.config.sourceConfig.config.useFqnForFiltering
|
||||
else catalog.name,
|
||||
else catalog_name,
|
||||
):
|
||||
self.status.filter(
|
||||
database_fqn,
|
||||
"Database (Catalog ID) Filtered Out",
|
||||
)
|
||||
continue
|
||||
yield catalog.name
|
||||
yield catalog_name
|
||||
except Exception as exc:
|
||||
self.status.failed(
|
||||
StackTraceError(
|
||||
name=catalog.name,
|
||||
error=f"Unexpected exception to get database name [{catalog.name}]: {exc}",
|
||||
name=catalog_name,
|
||||
error=f"Unexpected exception to get database name [{catalog_name}]: {exc}",
|
||||
stack_trace=traceback.format_exc(),
|
||||
)
|
||||
)
|
||||
|
||||
@ -1,3 +1,18 @@
|
||||
# Copyright 2021 Collate
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Test databricks using the topology
|
||||
"""
|
||||
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -20,6 +35,7 @@ from metadata.generated.schema.type.basic import FullyQualifiedEntityName
|
||||
from metadata.generated.schema.type.entityReference import EntityReference
|
||||
from metadata.ingestion.source.database.databricks.metadata import DatabricksSource
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
mock_databricks_config = {
|
||||
"source": {
|
||||
"type": "databricks",
|
||||
@ -230,12 +246,20 @@ EXPTECTED_TABLE = [
|
||||
|
||||
|
||||
class DatabricksUnitTest(TestCase):
|
||||
"""
|
||||
Databricks unit tests
|
||||
"""
|
||||
|
||||
@patch(
|
||||
"metadata.ingestion.source.database.common_db_source.CommonDbSourceService.test_connection"
|
||||
)
|
||||
def __init__(self, methodName, test_connection) -> None:
|
||||
@patch(
|
||||
"metadata.ingestion.source.database.databricks.legacy.metadata.DatabricksLegacySource._init_version"
|
||||
)
|
||||
def __init__(self, methodName, test_connection, db_init_version) -> None:
|
||||
super().__init__(methodName)
|
||||
test_connection.return_value = False
|
||||
db_init_version.return_value = None
|
||||
|
||||
self.config = OpenMetadataWorkflowConfig.parse_obj(mock_databricks_config)
|
||||
self.databricks_source = DatabricksSource.create(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user