FIX: Trino&Presto catalogs as databases (#8189)

* FIX: Trino&Presto catalogs as databases

* FIX: Trino&Presto catalogs as databases

* Change based on comments
This commit is contained in:
Milan Bariya 2022-10-18 20:00:17 +05:30 committed by GitHub
parent 7af8c5418c
commit 1565aa7733
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 131 additions and 13 deletions

View File

@ -14,12 +14,14 @@ Presto source module
import re
import traceback
from copy import deepcopy
from typing import Iterable
from pyhive.sqlalchemy_presto import PrestoDialect, _type_map
from sqlalchemy import inspect, types, util
from sqlalchemy.engine import reflection
from metadata.generated.schema.entity.data.database import Database
from metadata.generated.schema.entity.services.connections.database.prestoConnection import (
PrestoConnection,
)
@ -31,6 +33,9 @@ from metadata.generated.schema.metadataIngestion.workflow import (
)
from metadata.ingestion.api.source import InvalidSourceException
from metadata.ingestion.source.database.common_db_source import CommonDbSourceService
from metadata.utils import fqn
from metadata.utils.connections import get_connection
from metadata.utils.filters import filter_by_database
from metadata.utils.logger import ometa_logger
logger = ometa_logger()
@ -106,6 +111,48 @@ class PrestoSource(CommonDbSourceService):
)
return cls(config, metadata_config)
def get_database_names(self) -> Iterable[str]:
def set_inspector(self, database_name: str) -> None:
"""
When sources override `get_database_names`, they will need
to setup multiple inspectors. They can use this function.
:param database_name: new database to set
"""
logger.info(f"Ingesting from catalog: {database_name}")
new_service_connection = deepcopy(self.service_connection)
new_service_connection.catalog = database_name
self.engine = get_connection(new_service_connection)
self.inspector = inspect(self.engine)
yield self.service_connection.catalog
def get_database_names(self) -> Iterable[str]:
configured_catalog = self.service_connection.catalog
if configured_catalog:
self.set_inspector(database_name=configured_catalog)
yield configured_catalog
else:
results = self.connection.execute("SHOW CATALOGS")
for res in results:
new_catalog = res[0]
database_fqn = fqn.build(
self.metadata,
entity_type=Database,
service_name=self.context.database_service.name.__root__,
database_name=new_catalog,
)
if filter_by_database(
self.source_config.databaseFilterPattern,
database_fqn
if self.source_config.useFqnForFiltering
else new_catalog,
):
self.status.filter(database_fqn, "Database Filtered Out")
continue
try:
self.set_inspector(database_name=new_catalog)
yield new_catalog
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"Error trying to connect to database {new_catalog}: {exc}"
)

View File

@ -14,6 +14,8 @@ Trino source implementation.
import logging
import re
import sys
import traceback
from copy import deepcopy
from typing import Any, Dict, Iterable, List, Optional, Tuple
import click
@ -23,6 +25,7 @@ from sqlalchemy.sql import sqltypes
from trino.sqlalchemy import datatype
from trino.sqlalchemy.dialect import TrinoDialect
from metadata.generated.schema.entity.data.database import Database
from metadata.generated.schema.entity.services.connections.database.trinoConnection import (
TrinoConnection,
)
@ -34,6 +37,9 @@ from metadata.generated.schema.metadataIngestion.workflow import (
)
from metadata.ingestion.api.source import InvalidSourceException
from metadata.ingestion.source.database.common_db_source import CommonDbSourceService
from metadata.utils import fqn
from metadata.utils.connections import get_connection
from metadata.utils.filters import filter_by_database
from metadata.utils.logger import ingestion_logger
logger = ingestion_logger()
@ -154,6 +160,48 @@ class TrinoSource(CommonDbSourceService):
)
return cls(config, metadata_config)
def get_database_names(self) -> Iterable[str]:
def set_inspector(self, database_name: str) -> None:
"""
When sources override `get_database_names`, they will need
to setup multiple inspectors. They can use this function.
:param database_name: new database to set
"""
logger.info(f"Ingesting from catalog: {database_name}")
new_service_connection = deepcopy(self.service_connection)
new_service_connection.catalog = database_name
self.engine = get_connection(new_service_connection)
self.inspector = inspect(self.engine)
yield self.trino_connection.catalog
def get_database_names(self) -> Iterable[str]:
configured_catalog = self.trino_connection.catalog
if configured_catalog:
self.set_inspector(database_name=configured_catalog)
yield configured_catalog
else:
results = self.connection.execute("SHOW CATALOGS")
for res in results:
new_catalog = res[0]
database_fqn = fqn.build(
self.metadata,
entity_type=Database,
service_name=self.context.database_service.name.__root__,
database_name=new_catalog,
)
if filter_by_database(
self.source_config.databaseFilterPattern,
database_fqn
if self.source_config.useFqnForFiltering
else new_catalog,
):
self.status.filter(database_fqn, "Database Filtered Out")
continue
try:
self.set_inspector(database_name=new_catalog)
yield new_catalog
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"Error trying to connect to database {new_catalog}: {exc}"
)

View File

@ -230,7 +230,8 @@ def _(connection: TrinoConnection):
url += f":{quote_plus(connection.password.get_secret_value())}"
url += "@"
url += f"{connection.hostPort}"
url += f"/{connection.catalog}"
if connection.catalog:
url += f"/{connection.catalog}"
if connection.params is not None:
params = "&".join(
f"{key}={quote_plus(value)}"
@ -256,7 +257,8 @@ def _(connection: PrestoConnection):
url += f":{quote_plus(connection.password.get_secret_value())}"
url += "@"
url += f"{connection.hostPort}"
url += f"/{connection.catalog}"
if connection.catalog:
url += f"/{connection.catalog}"
if connection.databaseSchema:
url += f"?schema={quote_plus(connection.databaseSchema)}"
return url

View File

@ -152,7 +152,7 @@ class SouceConnectionTest(TestCase):
hostPort="localhost:10000",
connectionArguments={"auth": "CUSTOM"},
)
print("get_connection_url(hive_conn_obj)...", get_connection_url(hive_conn_obj))
assert expected_result == get_connection_url(hive_conn_obj)
def test_hive_url_conn_options_with_db(self):
@ -314,6 +314,18 @@ class SouceConnectionTest(TestCase):
== get_connection_args(trino_conn_obj).get("http_session").proxies
)
def test_trino_without_catalog(self):
# Test trino url without catalog
expected_url = "trino://username:pass@localhost:443"
trino_conn_obj = TrinoConnection(
scheme=TrinoScheme.trino,
hostPort="localhost:443",
username="username",
password="pass",
)
assert expected_url == get_connection_url(trino_conn_obj)
def test_vertica_url(self):
expected_url = (
"vertica+vertica_python://username:password@localhost:5443/database"
@ -521,10 +533,7 @@ class SouceConnectionTest(TestCase):
warehouse="COMPUTE_WH",
account="ue18849.us-east-2.aws",
)
print(
"get_connection_url(snowflake_conn_obj),,,,,",
get_connection_url(snowflake_conn_obj),
)
assert expected_url == get_connection_url(snowflake_conn_obj)
# connection arguments with db
@ -819,6 +828,18 @@ class SouceConnectionTest(TestCase):
assert expected_url == get_connection_url(presto_conn_obj)
def test_presto_without_catalog(self):
# Test presto url without catalog
expected_url = "presto://username:pass@localhost:8080"
presto_conn_obj = PrestoConnection(
scheme=PrestoScheme.presto,
hostPort="localhost:8080",
username="username",
password="pass",
)
assert expected_url == get_connection_url(presto_conn_obj)
def test_oracle_url(self):
# oracle with db
expected_url = "oracle+cx_oracle://admin:password@localhost:1541/testdb"

View File

@ -84,5 +84,5 @@
}
},
"additionalProperties": false,
"required": ["hostPort", "username", "catalog"]
"required": ["hostPort", "username"]
}

View File

@ -100,5 +100,5 @@
}
},
"additionalProperties": false,
"required": ["hostPort", "username", "catalog"]
"required": ["hostPort", "username"]
}