Fix #3507: Trino connector, support optional database field (#3551)

* Fix #3507: Trino connector, support optional database field

* override prepare and next_record

* fix formatting

* fix return type
This commit is contained in:
Alberto Miorin 2022-03-22 16:03:00 +01:00 committed by GitHub
parent eb906589fd
commit 5309dae08d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -8,11 +8,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Trino source module"""
import logging import logging
import sys import sys
from typing import Iterable from typing import Iterable, Optional
from urllib.parse import quote_plus from urllib.parse import quote_plus
import click import click
@ -30,41 +29,26 @@ logger = logging.getLogger(__name__)
class TrinoConfig(SQLConnectionConfig): class TrinoConfig(SQLConnectionConfig):
"""Trinio config class -- extends SQLConnectionConfig class
Attributes:
host_port:
scheme:
service_type:
catalog:
database:
"""
host_port = "localhost:8080" host_port = "localhost:8080"
scheme = "trino" scheme = "trino"
service_type = DatabaseServiceType.Trino.value service_type = DatabaseServiceType.Trino.value
catalog: str catalog: str
database: str include_views = False
params: Optional[dict] = None
def get_connection_url(self): def get_connection_url(self):
url = f"{self.scheme}://" url = f"{self.scheme}://"
if self.username is not None: if self.username:
url += f"{self.username}" url += f"{quote_plus(self.username)}"
if self.password is not None: if self.password:
url += f":{quote_plus(self.password.get_secret_value())}" url += f":{quote_plus(self.password.get_secret_value())}"
url += "@" url += "@"
url += f"{self.host_port}" url += f"{self.host_port}"
if self.catalog is not None:
url += f"/{self.catalog}" url += f"/{self.catalog}"
if self.database is not None: if self.params is not None:
url += f"/{self.database}"
if self.options is not None:
if self.database is None:
url += "/"
params = "&".join( params = "&".join(
f"{key}={quote_plus(value)}" f"{key}={quote_plus(value)}"
for (key, value) in self.options.items() for (key, value) in self.params.items()
if value if value
) )
url = f"{url}?{params}" url = f"{url}?{params}"
@ -72,15 +56,9 @@ class TrinoConfig(SQLConnectionConfig):
class TrinoSource(SQLSource): class TrinoSource(SQLSource):
"""Trino source -- extends SQLSource
Args:
config:
metadata_config:
ctx
"""
def __init__(self, config, metadata_config, ctx): def __init__(self, config, metadata_config, ctx):
self.schema_names = None
self.inspector = None
try: try:
from sqlalchemy_trino import ( from sqlalchemy_trino import (
dbapi, # pylint: disable=import-outside-toplevel,unused-import dbapi, # pylint: disable=import-outside-toplevel,unused-import
@ -102,9 +80,25 @@ class TrinoSource(SQLSource):
metadata_config = MetadataServerConfig.parse_obj(metadata_config_dict) metadata_config = MetadataServerConfig.parse_obj(metadata_config_dict)
return cls(config, metadata_config, ctx) return cls(config, metadata_config, ctx)
def prepare(self):
self.inspector = inspect(self.engine)
self.schema_names = (
self.inspector.get_schema_names()
if not self.config.database
else [self.config.database]
)
return super().prepare()
def next_record(self) -> Iterable[OMetaDatabaseAndTable]: def next_record(self) -> Iterable[OMetaDatabaseAndTable]:
inspector = inspect(self.engine) for schema in self.schema_names:
self.database_source_state.clear()
if not self.sql_config.schema_filter_pattern.included(schema):
self.status.filter(schema, "Schema pattern not allowed")
continue
if self.config.include_tables: if self.config.include_tables:
yield from self.fetch_tables(inspector, self.config.database) yield from self.fetch_tables(self.inspector, schema)
if self.config.include_views: if self.config.include_views:
yield from self.fetch_views(inspector, self.config.database) yield from self.fetch_views(self.inspector, schema)
if self.config.mark_deleted_tables_as_deleted:
schema_fqdn = f"{self.config.service_name}.{schema}"
yield from self.delete_tables(schema_fqdn)