mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-12-25 14:38:29 +00:00
FIX: Trino&Presto catalogs as databases (#8189)
* FIX: Trino&Presto catalogs as databases * FIX: Trino&Presto catalogs as databases * Change based on comments
This commit is contained in:
parent
7af8c5418c
commit
1565aa7733
@ -14,12 +14,14 @@ Presto source module
|
||||
|
||||
import re
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from typing import Iterable
|
||||
|
||||
from pyhive.sqlalchemy_presto import PrestoDialect, _type_map
|
||||
from sqlalchemy import inspect, types, util
|
||||
from sqlalchemy.engine import reflection
|
||||
|
||||
from metadata.generated.schema.entity.data.database import Database
|
||||
from metadata.generated.schema.entity.services.connections.database.prestoConnection import (
|
||||
PrestoConnection,
|
||||
)
|
||||
@ -31,6 +33,9 @@ from metadata.generated.schema.metadataIngestion.workflow import (
|
||||
)
|
||||
from metadata.ingestion.api.source import InvalidSourceException
|
||||
from metadata.ingestion.source.database.common_db_source import CommonDbSourceService
|
||||
from metadata.utils import fqn
|
||||
from metadata.utils.connections import get_connection
|
||||
from metadata.utils.filters import filter_by_database
|
||||
from metadata.utils.logger import ometa_logger
|
||||
|
||||
logger = ometa_logger()
|
||||
@ -106,6 +111,48 @@ class PrestoSource(CommonDbSourceService):
|
||||
)
|
||||
return cls(config, metadata_config)
|
||||
|
||||
def get_database_names(self) -> Iterable[str]:
|
||||
def set_inspector(self, database_name: str) -> None:
|
||||
"""
|
||||
When sources override `get_database_names`, they will need
|
||||
to setup multiple inspectors. They can use this function.
|
||||
:param database_name: new database to set
|
||||
"""
|
||||
logger.info(f"Ingesting from catalog: {database_name}")
|
||||
|
||||
new_service_connection = deepcopy(self.service_connection)
|
||||
new_service_connection.catalog = database_name
|
||||
self.engine = get_connection(new_service_connection)
|
||||
self.inspector = inspect(self.engine)
|
||||
yield self.service_connection.catalog
|
||||
|
||||
def get_database_names(self) -> Iterable[str]:
|
||||
configured_catalog = self.service_connection.catalog
|
||||
if configured_catalog:
|
||||
self.set_inspector(database_name=configured_catalog)
|
||||
yield configured_catalog
|
||||
else:
|
||||
results = self.connection.execute("SHOW CATALOGS")
|
||||
for res in results:
|
||||
new_catalog = res[0]
|
||||
database_fqn = fqn.build(
|
||||
self.metadata,
|
||||
entity_type=Database,
|
||||
service_name=self.context.database_service.name.__root__,
|
||||
database_name=new_catalog,
|
||||
)
|
||||
if filter_by_database(
|
||||
self.source_config.databaseFilterPattern,
|
||||
database_fqn
|
||||
if self.source_config.useFqnForFiltering
|
||||
else new_catalog,
|
||||
):
|
||||
self.status.filter(database_fqn, "Database Filtered Out")
|
||||
continue
|
||||
|
||||
try:
|
||||
self.set_inspector(database_name=new_catalog)
|
||||
yield new_catalog
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"Error trying to connect to database {new_catalog}: {exc}"
|
||||
)
|
||||
|
||||
@ -14,6 +14,8 @@ Trino source implementation.
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import click
|
||||
@ -23,6 +25,7 @@ from sqlalchemy.sql import sqltypes
|
||||
from trino.sqlalchemy import datatype
|
||||
from trino.sqlalchemy.dialect import TrinoDialect
|
||||
|
||||
from metadata.generated.schema.entity.data.database import Database
|
||||
from metadata.generated.schema.entity.services.connections.database.trinoConnection import (
|
||||
TrinoConnection,
|
||||
)
|
||||
@ -34,6 +37,9 @@ from metadata.generated.schema.metadataIngestion.workflow import (
|
||||
)
|
||||
from metadata.ingestion.api.source import InvalidSourceException
|
||||
from metadata.ingestion.source.database.common_db_source import CommonDbSourceService
|
||||
from metadata.utils import fqn
|
||||
from metadata.utils.connections import get_connection
|
||||
from metadata.utils.filters import filter_by_database
|
||||
from metadata.utils.logger import ingestion_logger
|
||||
|
||||
logger = ingestion_logger()
|
||||
@ -154,6 +160,48 @@ class TrinoSource(CommonDbSourceService):
|
||||
)
|
||||
return cls(config, metadata_config)
|
||||
|
||||
def get_database_names(self) -> Iterable[str]:
|
||||
def set_inspector(self, database_name: str) -> None:
|
||||
"""
|
||||
When sources override `get_database_names`, they will need
|
||||
to setup multiple inspectors. They can use this function.
|
||||
:param database_name: new database to set
|
||||
"""
|
||||
logger.info(f"Ingesting from catalog: {database_name}")
|
||||
|
||||
new_service_connection = deepcopy(self.service_connection)
|
||||
new_service_connection.catalog = database_name
|
||||
self.engine = get_connection(new_service_connection)
|
||||
self.inspector = inspect(self.engine)
|
||||
yield self.trino_connection.catalog
|
||||
|
||||
def get_database_names(self) -> Iterable[str]:
|
||||
configured_catalog = self.trino_connection.catalog
|
||||
if configured_catalog:
|
||||
self.set_inspector(database_name=configured_catalog)
|
||||
yield configured_catalog
|
||||
else:
|
||||
results = self.connection.execute("SHOW CATALOGS")
|
||||
for res in results:
|
||||
new_catalog = res[0]
|
||||
database_fqn = fqn.build(
|
||||
self.metadata,
|
||||
entity_type=Database,
|
||||
service_name=self.context.database_service.name.__root__,
|
||||
database_name=new_catalog,
|
||||
)
|
||||
if filter_by_database(
|
||||
self.source_config.databaseFilterPattern,
|
||||
database_fqn
|
||||
if self.source_config.useFqnForFiltering
|
||||
else new_catalog,
|
||||
):
|
||||
self.status.filter(database_fqn, "Database Filtered Out")
|
||||
continue
|
||||
|
||||
try:
|
||||
self.set_inspector(database_name=new_catalog)
|
||||
yield new_catalog
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"Error trying to connect to database {new_catalog}: {exc}"
|
||||
)
|
||||
|
||||
@ -230,7 +230,8 @@ def _(connection: TrinoConnection):
|
||||
url += f":{quote_plus(connection.password.get_secret_value())}"
|
||||
url += "@"
|
||||
url += f"{connection.hostPort}"
|
||||
url += f"/{connection.catalog}"
|
||||
if connection.catalog:
|
||||
url += f"/{connection.catalog}"
|
||||
if connection.params is not None:
|
||||
params = "&".join(
|
||||
f"{key}={quote_plus(value)}"
|
||||
@ -256,7 +257,8 @@ def _(connection: PrestoConnection):
|
||||
url += f":{quote_plus(connection.password.get_secret_value())}"
|
||||
url += "@"
|
||||
url += f"{connection.hostPort}"
|
||||
url += f"/{connection.catalog}"
|
||||
if connection.catalog:
|
||||
url += f"/{connection.catalog}"
|
||||
if connection.databaseSchema:
|
||||
url += f"?schema={quote_plus(connection.databaseSchema)}"
|
||||
return url
|
||||
|
||||
@ -152,7 +152,7 @@ class SouceConnectionTest(TestCase):
|
||||
hostPort="localhost:10000",
|
||||
connectionArguments={"auth": "CUSTOM"},
|
||||
)
|
||||
print("get_connection_url(hive_conn_obj)...", get_connection_url(hive_conn_obj))
|
||||
|
||||
assert expected_result == get_connection_url(hive_conn_obj)
|
||||
|
||||
def test_hive_url_conn_options_with_db(self):
|
||||
@ -314,6 +314,18 @@ class SouceConnectionTest(TestCase):
|
||||
== get_connection_args(trino_conn_obj).get("http_session").proxies
|
||||
)
|
||||
|
||||
def test_trino_without_catalog(self):
|
||||
# Test trino url without catalog
|
||||
expected_url = "trino://username:pass@localhost:443"
|
||||
trino_conn_obj = TrinoConnection(
|
||||
scheme=TrinoScheme.trino,
|
||||
hostPort="localhost:443",
|
||||
username="username",
|
||||
password="pass",
|
||||
)
|
||||
|
||||
assert expected_url == get_connection_url(trino_conn_obj)
|
||||
|
||||
def test_vertica_url(self):
|
||||
expected_url = (
|
||||
"vertica+vertica_python://username:password@localhost:5443/database"
|
||||
@ -521,10 +533,7 @@ class SouceConnectionTest(TestCase):
|
||||
warehouse="COMPUTE_WH",
|
||||
account="ue18849.us-east-2.aws",
|
||||
)
|
||||
print(
|
||||
"get_connection_url(snowflake_conn_obj),,,,,",
|
||||
get_connection_url(snowflake_conn_obj),
|
||||
)
|
||||
|
||||
assert expected_url == get_connection_url(snowflake_conn_obj)
|
||||
|
||||
# connection arguments with db
|
||||
@ -819,6 +828,18 @@ class SouceConnectionTest(TestCase):
|
||||
|
||||
assert expected_url == get_connection_url(presto_conn_obj)
|
||||
|
||||
def test_presto_without_catalog(self):
|
||||
# Test presto url without catalog
|
||||
expected_url = "presto://username:pass@localhost:8080"
|
||||
presto_conn_obj = PrestoConnection(
|
||||
scheme=PrestoScheme.presto,
|
||||
hostPort="localhost:8080",
|
||||
username="username",
|
||||
password="pass",
|
||||
)
|
||||
|
||||
assert expected_url == get_connection_url(presto_conn_obj)
|
||||
|
||||
def test_oracle_url(self):
|
||||
# oracle with db
|
||||
expected_url = "oracle+cx_oracle://admin:password@localhost:1541/testdb"
|
||||
|
||||
@ -84,5 +84,5 @@
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": ["hostPort", "username", "catalog"]
|
||||
"required": ["hostPort", "username"]
|
||||
}
|
||||
|
||||
@ -100,5 +100,5 @@
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": ["hostPort", "username", "catalog"]
|
||||
"required": ["hostPort", "username"]
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user