diff --git a/ingestion/src/metadata/ingestion/source/trino.py b/ingestion/src/metadata/ingestion/source/trino.py index e7d2c144bbc..ab7dbf005d6 100644 --- a/ingestion/src/metadata/ingestion/source/trino.py +++ b/ingestion/src/metadata/ingestion/source/trino.py @@ -8,11 +8,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Trino source module""" import logging import sys -from typing import Iterable +from typing import Iterable, Optional from urllib.parse import quote_plus import click @@ -30,41 +29,26 @@ logger = logging.getLogger(__name__) class TrinoConfig(SQLConnectionConfig): - """Trinio config class -- extends SQLConnectionConfig class - - Attributes: - host_port: - scheme: - service_type: - catalog: - database: - """ - host_port = "localhost:8080" scheme = "trino" service_type = DatabaseServiceType.Trino.value catalog: str - database: str + include_views = False + params: Optional[dict] = None def get_connection_url(self): url = f"{self.scheme}://" - if self.username is not None: - url += f"{self.username}" - if self.password is not None: + if self.username: + url += f"{quote_plus(self.username)}" + if self.password: url += f":{quote_plus(self.password.get_secret_value())}" url += "@" url += f"{self.host_port}" - if self.catalog is not None: - url += f"/{self.catalog}" - if self.database is not None: - url += f"/{self.database}" - - if self.options is not None: - if self.database is None: - url += "/" + url += f"/{self.catalog}" + if self.params is not None: params = "&".join( f"{key}={quote_plus(value)}" - for (key, value) in self.options.items() + for (key, value) in self.params.items() if value ) url = f"{url}?{params}" @@ -72,15 +56,9 @@ class TrinoConfig(SQLConnectionConfig): class TrinoSource(SQLSource): - """Trino source -- extends SQLSource - - Args: - config: - metadata_config: - ctx - """ - def __init__(self, config, metadata_config, ctx): + self.schema_names = None + self.inspector = None try: from sqlalchemy_trino import ( dbapi, # pylint: disable=import-outside-toplevel,unused-import @@ -102,9 +80,25 @@ class TrinoSource(SQLSource): metadata_config = MetadataServerConfig.parse_obj(metadata_config_dict) return cls(config, metadata_config, ctx) + def prepare(self): + self.inspector = inspect(self.engine) + self.schema_names = ( + self.inspector.get_schema_names() + if not self.config.database + else [self.config.database] + ) + return super().prepare() + def next_record(self) -> Iterable[OMetaDatabaseAndTable]: - inspector = inspect(self.engine) - if self.config.include_tables: - yield from self.fetch_tables(inspector, self.config.database) - if self.config.include_views: - yield from self.fetch_views(inspector, self.config.database) + for schema in self.schema_names: + self.database_source_state.clear() + if not self.sql_config.schema_filter_pattern.included(schema): + self.status.filter(schema, "Schema pattern not allowed") + continue + if self.config.include_tables: + yield from self.fetch_tables(self.inspector, schema) + if self.config.include_views: + yield from self.fetch_views(self.inspector, schema) + if self.config.mark_deleted_tables_as_deleted: + schema_fqdn = f"{self.config.service_name}.{schema}" + yield from self.delete_tables(schema_fqdn)