Merge pull request #96 from open-metadata/ingestion_refactor

Ingestion refactor
This commit is contained in:
Suresh Srinivas 2021-08-11 20:55:14 -07:00 committed by GitHub
commit 3857f72c7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 199 additions and 557 deletions

View File

@ -9,7 +9,7 @@
"username": "sa",
"password": "test!Password",
"include_pattern": {
"include": ["catalog_test.*"]
"excludes": ["catalog_test.*"]
}
}
},

View File

@ -7,13 +7,7 @@
"host_port": "localhost:5432",
"database": "pagila",
"service_name": "local_postgres",
"service_type": "Postgres",
"include_pattern": {
"filter": [
"pg_catalog.*[a-zA-Z0-9]*",
"information_schema.*[a-zA-Z0-9]*"
]
}
"service_type": "Postgres"
}
},
"processor": {

View File

@ -1,37 +0,0 @@
{
"source": {
"type": "redshift-sql",
"config": {
"host_port": "cluster.name.region.redshift.amazonaws.com:5439",
"username": "username",
"password": "strong_password",
"database": "dev",
"service_name": "aws_redshift",
"service_type": "Redshift"
}
},
"processor": {
"type": "pii-tags",
"config": {
}
},
"sink": {
"type": "metadata-rest-tables",
"config": {
}
},
"metadata_server": {
"type": "metadata-server",
"config": {
"api_endpoint": "http://localhost:8585/api",
"auth_provider_type": "no-auth"
}
},
"cron": {
"minute": "*/5",
"hour": null,
"day": null,
"month": null,
"day_of_week": null
}
}

View File

@ -10,7 +10,7 @@
"service_name": "snowflake",
"service_type": "Snowflake",
"include_pattern": {
"include": [
"includes": [
"(\\w)*.tpcds_sf100tcl.catalog_page",
"(\\w)*.tpcds_sf100tcl.time_dim",
"(\\w)*.tpcds_sf10tcl.catalog_page"

View File

@ -0,0 +1,14 @@
click~=7.1.2
pydantic~=1.7.4
expandvars~=0.6.5
requests~=2.25.1
python-dateutil~=2.8.1
SQLAlchemy~=1.4.5
pandas~=1.2.4
Faker~=8.1.1
elasticsearch~=7.12.0
spacy~=3.0.5
commonregex~=1.5.4
setuptools~=57.0.0
PyHive~=0.6.4
ldap3~=2.9.1

View File

@ -101,19 +101,17 @@ build_options = {"includes": ["_cffi_backend"]}
setup(
name="metadata",
version=get_version(),
url="https://github.com/streamlinedata/metadata",
version="0.2.0",
url="https://github.com/open-metadata/OpenMetadata",
author="Metadata Committers",
license="Apache License 2.0",
description="Ingestion Framework for OpenMetadata",
long_description="Ingestion Framework for OpenMetadata",
long_description="Ingestion Framework for OpenMetadata",
long_description_content_type="text/markdown",
python_requires=">=3.8",
options={"build_exe": build_options},
package_dir={"": "src"},
packages=find_namespace_packages(where='src', exclude=['tests*']),
dependency_links=['git+git://github.com/djacobs/PyAPNs.git#egg=apns',
'git+https://github.com/StreamlineData/sdscheduler.git#egg=simplescheduler'],
entry_points={
"console_scripts": ["metadata = metadata.cmd:metadata"],
"metadata.ingestion.source.plugins": [

View File

@ -1,20 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import click
@click.group()
def check() -> None:
pass

View File

@ -21,7 +21,6 @@ import sys
import click
from pydantic import ValidationError
from metadata.check.check_cli import check
from metadata.config.config_loader import load_config_file
from metadata.ingestion.workflow.workflow import Workflow
@ -35,11 +34,14 @@ BASE_LOGGING_FORMAT = (
)
logging.basicConfig(format=BASE_LOGGING_FORMAT)
@click.group()
def check() -> None:
pass
@click.group()
@click.option("--debug/--no-debug", default=False)
def metadata(debug: bool) -> None:
if debug or os.getenv("METADATA_DEBUG", False):
if os.getenv("METADATA_DEBUG", False):
logging.getLogger().setLevel(logging.INFO)
logging.getLogger("metadata").setLevel(logging.DEBUG)
else:
@ -52,12 +54,11 @@ def metadata(debug: bool) -> None:
"-c",
"--config",
type=click.Path(exists=True, dir_okay=False),
help="Config file in .toml or .yaml format",
help="Workflow config",
required=True,
)
def ingest(config: str) -> None:
"""Main command for ingesting metadata into Metadata"""
config_file = pathlib.Path(config)
workflow_config = load_config_file(config_file)
@ -71,6 +72,7 @@ def ingest(config: str) -> None:
workflow.execute()
ret = workflow.print_status()
workflow.stop()
sys.exit(ret)

View File

@ -42,48 +42,18 @@ class ConfigModel(BaseModel):
class DynamicTypedConfig(ConfigModel):
type: str
# This config type is declared Optional[Any] here. The eventual parser for the
# specified type is responsible for further validation.
config: Optional[Any]
class MetaError(Exception):
"""A base class for all meta exceptions"""
class WorkflowExecutionError(MetaError):
class WorkflowExecutionError(Exception):
"""An error occurred when executing the workflow"""
class OperationalError(WorkflowExecutionError):
"""An error occurred because of client-provided metadata"""
message: str
info: dict
def __init__(self, message: str, info: dict = None):
self.message = message
if info:
self.info = info
else:
self.info = {}
class ConfigurationError(MetaError):
"""A configuration error has happened"""
class ConfigurationMechanism(ABC):
@abstractmethod
def load_config(self, config_fp: IO) -> dict:
pass
class IncludeFilterPattern(ConfigModel):
"""A class to store allow deny regexes"""
include: List[str] = [".*"]
filter: List[str] = []
includes: List[str] = [".*"]
excludes: List[str] = []
alphabet: str = "[A-Za-z0-9 _.-]"
@property
@ -96,11 +66,11 @@ class IncludeFilterPattern(ConfigModel):
def included(self, string: str) -> bool:
try:
for filter in self.filter:
if re.match(filter, string):
for exclude in self.excludes:
if re.match(exclude, string):
return False
for include in self.include:
for include in self.includes:
if re.match(include, string):
return True
return False
@ -108,17 +78,11 @@ class IncludeFilterPattern(ConfigModel):
raise Exception("Regex Error: {}".format(err))
def is_fully_specified_include_list(self) -> bool:
"""
If the allow patterns are literals and not full regexes, then it is considered
fully specified. This is useful if you want to convert a 'list + filter'
pattern into a 'search for the ones that are allowed' pattern, which can be
much more efficient in some cases.
"""
for include_pattern in self.include:
for include_pattern in self.includes:
if not self.alphabet_pattern.match(include_pattern):
return False
return True
def get_allowed_list(self):
assert self.is_fully_specified_include_list()
return [a for a in self.include if self.included(a)]
return [a for a in self.includes if self.included(a)]

View File

@ -28,7 +28,7 @@ class ProcessorStatus(Status):
warnings: List[Any] = field(default_factory=list)
failures: List[Any] = field(default_factory=list)
def records_processed(self, record: Record):
def processed(self, record: Record):
self.records += 1
def warning(self, info: Any) -> None:

View File

@ -28,7 +28,7 @@ class SourceStatus(Status):
warnings: Dict[str, List[str]] = field(default_factory=dict)
failures: Dict[str, List[str]] = field(default_factory=dict)
def records_produced(self, record: Record) -> None:
def scanned(self, record: Record) -> None:
self.records += 1
def warning(self, key: str, reason: str) -> None:

View File

@ -44,6 +44,7 @@ class MetadataUsageBulkSink(BulkSink):
self.client = REST(self.metadata_config)
self.status = BulkSinkStatus()
self.tables_dict = {}
self.table_join_dict = {}
self.__map_tables()
def __map_tables(self):
@ -74,7 +75,8 @@ class MetadataUsageBulkSink(BulkSink):
try:
self.client.publish_usage_for_a_table(table_entity, table_usage_request)
except APIError as err:
logger.error("Failed to update usage and query join {}".format(err))
self.status.failures.append(table_usage_request)
logger.error("Failed to update usage for {} {}".format(table_usage.table, err))
table_join_request = self.__get_table_joins(table_usage)
logger.debug("table join request {}".format(table_join_request))
@ -82,7 +84,8 @@ class MetadataUsageBulkSink(BulkSink):
if table_join_request is not None and len(table_join_request.columnJoins) > 0:
self.client.publish_frequently_joined_with(table_entity, table_join_request)
except APIError as err:
logger.error("Failed to update usage and query join {}".format(err))
self.status.failures.append(table_join_request)
logger.error("Failed to update query join for {}, {}".format(table_usage.table, err))
else:
logger.warning("Table does not exist, skipping usage publish {}, {}".format(table_usage.table,
@ -90,21 +93,32 @@ class MetadataUsageBulkSink(BulkSink):
def __get_table_joins(self, table_usage):
table_joins: TableJoins = TableJoins(columnJoins=[], startDate=table_usage.date)
column_joins_dict = {}
joined_with = {}
for column_join in table_usage.joins:
if column_join.table_column is None or len(column_join.joined_with) == 0:
continue
logger.debug("main column join {}".format(column_join.table_column))
if column_join.table_column.column in column_joins_dict.keys():
joined_with = column_joins_dict[column_join.table_column.column]
else:
column_joins_dict[column_join.table_column.column] = {}
main_column_fqdn = self.__get_column_fqdn(column_join.table_column)
logger.debug("main column fqdn join {}".format(main_column_fqdn))
joined_with = []
for column in column_join.joined_with:
logger.debug("joined column {}".format(column))
joined_column_fqdn = self.__get_column_fqdn(column)
logger.debug("joined column fqdn {}".format(joined_column_fqdn))
if joined_column_fqdn is not None:
joined_with.append(ColumnJoinedWith(fullyQualifiedName=joined_column_fqdn, joinCount=1))
table_joins.columnJoins.append(ColumnJoins(columnName=column_join.table_column.column,
joinedWith=joined_with))
if joined_column_fqdn in joined_with.keys():
column_joined_with = joined_with[joined_column_fqdn]
column_joined_with.joinCount += 1
joined_with[joined_column_fqdn] = column_joined_with
else:
joined_with[joined_column_fqdn] = ColumnJoinedWith(fullyQualifiedName=joined_column_fqdn,
joinCount=1)
column_joins_dict[column_join.table_column.column] = joined_with
for key, value in column_joins_dict.items():
table_joins.columnJoins.append(ColumnJoins(columnName=key,
joinedWith=list(value.values())))
return table_joins
def __get_column_fqdn(self, table_column: TableColumn):

View File

@ -31,7 +31,7 @@ from metadata.ingestion.models.table_queries import TableUsageRequest, ColumnJoi
from metadata.ingestion.ometa.auth_provider import MetadataServerConfig, AuthenticationProvider, \
GoogleAuthenticationProvider, NoOpAuthenticationProvider, OktaAuthenticationProvider
from metadata.ingestion.ometa.credentials import URL, get_api_version
from metadata.generated.schema.entity.data.table import TableEntity
from metadata.generated.schema.entity.data.table import TableEntity, TableJoins
from metadata.generated.schema.entity.data.database import DatabaseEntity
logger = logging.getLogger(__name__)
@ -296,11 +296,11 @@ class REST(object):
def publish_usage_for_a_table(self, table: TableEntity, table_usage_request: TableUsageRequest) -> None:
"""publish usage details for a table"""
resp = self.post('/usage/table/{}'.format(table.id.__root__), data=table_usage_request.json())
# self.post('/usage/compute.percentile/table/{}'.format(table.id.__root__), table_usage_request.date)
logger.debug("published table usage {}".format(resp))
def publish_frequently_joined_with(self, table: TableEntity, table_join_request: ColumnJoinsList) -> None:
def publish_frequently_joined_with(self, table: TableEntity, table_join_request: TableJoins) -> None:
"""publish frequently joined with for a table"""
print(table_join_request.json())
logger.debug(table_join_request.json())
logger.info("table join request {}".format(table_join_request.json()))
resp = self.put('/tables/{}/joins'.format(table.id.__root__), data=table_join_request.json())
logger.debug("published frequently joined with {}".format(resp))

View File

@ -52,7 +52,7 @@ class QueryParserProcessor(Processor):
try:
start_date = datetime.datetime.strptime(record.analysis_date, '%Y-%m-%d %H:%M:%S').date()
parser = Parser(record.sql)
columns_dict = {} if parser.columns_dict == None else parser.columns_dict
columns_dict = {} if parser.columns_dict is None else parser.columns_dict
query_parser_data = QueryParserData(tables=parser.tables,
tables_aliases=parser.tables_aliases,
columns=columns_dict,
@ -60,8 +60,8 @@ class QueryParserProcessor(Processor):
sql=record.sql,
date=start_date.strftime('%Y-%m-%d'))
except Exception as err:
logger.error(record.sql)
logger.error(err)
logger.debug(record.sql)
logger.debug(err)
query_parser_data = None
pass

View File

@ -16,11 +16,11 @@
from typing import Optional
from urllib.parse import quote_plus
from .sql_source import SQLAlchemyConfig, SQLAlchemySource
from .sql_source import SQLConnectionConfig, SQLSource
from ..ometa.auth_provider import MetadataServerConfig
class AthenaConfig(SQLAlchemyConfig):
class AthenaConfig(SQLConnectionConfig):
scheme: str = "awsathena+rest"
username: Optional[str] = None
password: Optional[str] = None
@ -29,7 +29,7 @@ class AthenaConfig(SQLAlchemyConfig):
s3_staging_dir: str
work_group: str
def get_sql_alchemy_url(self):
def get_connection_url(self):
url = f"{self.scheme}://"
if self.username:
url += f"{quote_plus(self.username)}"
@ -46,9 +46,9 @@ class AthenaConfig(SQLAlchemyConfig):
return url
class AthenaSource(SQLAlchemySource):
class AthenaSource(SQLSource):
def __init__(self, config, metadata_config, ctx):
super().__init__(config, metadata_config, ctx, "athena")
super().__init__(config, metadata_config, ctx)
@classmethod
def create(cls, config_dict, metadata_config_dict, ctx):

View File

@ -17,26 +17,32 @@ from typing import Optional, Tuple
# This import verifies that the dependencies are available.
from .sql_source import BasicSQLAlchemyConfig, SQLAlchemySource
from .sql_source import SQLConnectionConfig, SQLSource
from ..ometa.auth_provider import MetadataServerConfig
class BigQueryConfig(BasicSQLAlchemyConfig):
class BigQueryConfig(SQLConnectionConfig, SQLSource):
scheme = "bigquery"
project_id: Optional[str] = None
def get_sql_alchemy_url(self):
def get_connection_url(self):
if self.project_id:
return f"{self.scheme}://{self.project_id}"
return f"{self.scheme}://"
def get_identifier(self, schema: str, table: str) -> str:
if self.project_id:
return f"{self.project_id}.{schema}.{table}"
return f"{schema}.{table}"
class BigQuerySource(SQLSource):
def __init__(self, config, metadata_config, ctx):
super().__init__(config, metadata_config, ctx)
@classmethod
def create(cls, config_dict, metadata_config_dict, ctx):
config = BigQueryConfig.parse_obj(config_dict)
metadata_config = MetadataServerConfig.parse_obj(metadata_config_dict)
return cls(config, metadata_config, ctx)
def standardize_schema_table_names(
self, schema: str, table: str
self, schema: str, table: str
) -> Tuple[str, str]:
segments = table.split(".")
if len(segments) != 2:
@ -44,14 +50,3 @@ class BigQueryConfig(BasicSQLAlchemyConfig):
if segments[0] != schema:
raise ValueError(f"schema {schema} does not match table {table}")
return segments[0], segments[1]
class BigQuerySource(SQLAlchemySource):
def __init__(self, config, metadata_config, ctx):
super().__init__(config, metadata_config, ctx, "bigquery")
@classmethod
def create(cls, config_dict, metadata_config_dict, ctx):
config = BigQueryConfig.parse_obj(config_dict)
metadata_config = MetadataServerConfig.parse_obj(metadata_config_dict)
return cls(config, metadata_config, ctx)

View File

@ -17,8 +17,8 @@ from pyhive import hive # noqa: F401
from pyhive.sqlalchemy_hive import HiveDate, HiveDecimal, HiveTimestamp
from .sql_source import (
BasicSQLAlchemyConfig,
SQLAlchemySource,
SQLConnectionConfig,
SQLSource,
register_custom_type,
)
from ..ometa.auth_provider import MetadataServerConfig
@ -28,13 +28,16 @@ register_custom_type(HiveTimestamp, "TIME")
register_custom_type(HiveDecimal, "NUMBER")
class HiveConfig(BasicSQLAlchemyConfig):
class HiveConfig(SQLConnectionConfig):
scheme = "hive"
def get_connection_url(self):
return super().get_connection_url()
class HiveSource(SQLAlchemySource):
class HiveSource(SQLSource):
def __init__(self, config, metadata_config, ctx):
super().__init__(config, metadata_config, ctx, "hive")
super().__init__(config, metadata_config, ctx)
@classmethod
def create(cls, config_dict, metadata_config_dict, ctx):

View File

@ -16,24 +16,21 @@
# This import verifies that the dependencies are available.
import sqlalchemy_pytds # noqa: F401
from .sql_source import BasicSQLAlchemyConfig, SQLAlchemySource
from .sql_source import SQLConnectionConfig, SQLSource
from ..ometa.auth_provider import MetadataServerConfig
class SQLServerConfig(BasicSQLAlchemyConfig):
class SQLServerConfig(SQLConnectionConfig):
host_port = "localhost:1433"
scheme = "mssql+pytds"
def get_identifier(self, schema: str, table: str) -> str:
regular = f"{schema}.{table}"
if self.database:
return f"{self.database}.{regular}"
return regular
def get_connection_url(self):
return super().get_connection_url()
class SQLServerSource(SQLAlchemySource):
class SQLServerSource(SQLSource):
def __init__(self, config, metadata_config, ctx):
super().__init__(config, metadata_config, ctx, "mssql")
super().__init__(config, metadata_config, ctx)
@classmethod
def create(cls, config_dict, metadata_config_dict, ctx):

View File

@ -13,19 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pymysql # noqa: F401
from .sql_source import BasicSQLAlchemyConfig, SQLAlchemySource
from .sql_source import SQLSource, SQLConnectionConfig
from ..ometa.auth_provider import MetadataServerConfig
class MySQLConfig(BasicSQLAlchemyConfig):
# defaults
class MySQLConfig(SQLConnectionConfig):
host_port = "localhost:3306"
scheme = "mysql+pymysql"
def get_connection_url(self):
return super().get_connection_url()
class MySQLSource(SQLAlchemySource):
class MySQLSource(SQLSource):
def __init__(self, config, metadata_config, ctx):
super().__init__(config, metadata_config, ctx)

View File

@ -16,18 +16,18 @@
# This import verifies that the dependencies are available.
import cx_Oracle # noqa: F401
from .sql_source import BasicSQLAlchemyConfig, SQLAlchemySource
from .sql_source import SQLSource, SQLConnectionConfig
from ..ometa.auth_provider import MetadataServerConfig
class OracleConfig(BasicSQLAlchemyConfig):
class OracleConfig(SQLConnectionConfig):
# defaults
scheme = "oracle+cx_oracle"
class OracleSource(SQLAlchemySource):
class OracleSource(SQLSource):
def __init__(self, config, metadata_config, ctx):
super().__init__(config, metadata_config, ctx, "oracle")
super().__init__(config, metadata_config, ctx)
@classmethod
def create(cls, config_dict, metadata_config_dict, ctx):

View File

@ -24,8 +24,8 @@ from metadata.ingestion.models.ometa_table_db import OMetaDatabaseAndTable
import pymysql # noqa: F401
from metadata.generated.schema.entity.data.table import TableEntity, Column
from metadata.ingestion.source.sql_source_common import SQLAlchemyHelper, SQLSourceStatus
from .sql_source import BasicSQLAlchemyConfig
from metadata.ingestion.source.sql_alchemy_helper import SQLAlchemyHelper, SQLSourceStatus
from .sql_source import SQLConnectionConfig
from metadata.ingestion.api.source import Source, SourceStatus
from metadata.ingestion.models.table_metadata import DatabaseMetadata
from itertools import groupby
@ -38,27 +38,18 @@ from ...utils.helpers import get_service_or_create
TableKey = namedtuple('TableKey', ['schema', 'table_name'])
class PostgresSourceConfig(BasicSQLAlchemyConfig):
class PostgresSourceConfig(SQLConnectionConfig):
# defaults
scheme = "postgresql+psycopg2"
service_name = "postgres"
service_type = "POSTGRES"
def get_sql_alchemy_url(self):
url = f"{self.scheme}://"
if self.username:
url += f"{self.username}"
if self.password:
url += f":{self.password}"
url += "@"
url += f"{self.host_port}"
if self.database:
url += f"/{self.database}"
return url
def get_service_type(self) -> DatabaseServiceType:
return DatabaseServiceType[self.service_type]
def get_connection_url(self):
return super().get_connection_url()
def get_table_key(row: Dict[str, Any]) -> Union[TableKey, None]:
"""
@ -73,7 +64,6 @@ def get_table_key(row: Dict[str, Any]) -> Union[TableKey, None]:
class PostgresSource(Source):
# SELECT statement from mysql information_schema to extract table and column metadata
SQL_STATEMENT = """
SELECT
c.table_catalog as cluster, c.table_schema as schema, c.table_name as name, pgtd.description as description
@ -106,7 +96,7 @@ class PostgresSource(Source):
self.status = SQLSourceStatus()
self.service = get_service_or_create(config, metadata_config)
self.include_pattern = IncludeFilterPattern
self.pattern = config.include_pattern
self.pattern = config
@classmethod
def create(cls, config_dict, metadata_config_dict, ctx):
@ -131,7 +121,6 @@ class PostgresSource(Source):
Using itertools.groupby and raw level iterator, it groups to table and yields TableMetadata
:return:
"""
counter = 0
for key, group in groupby(self._get_raw_extract_iter(), get_table_key):
columns = []
for row in group:
@ -139,7 +128,7 @@ class PostgresSource(Source):
col_type = ''
if row['col_type'].upper() == 'CHARACTER VARYING':
col_type = 'VARCHAR'
elif row['col_type'].upper() == 'CHARACTER':
elif row['col_type'].upper() == 'CHARACTER' or row['col_type'].upper() == 'NAME':
col_type = 'CHAR'
elif row['col_type'].upper() == 'INTEGER':
col_type = 'INT'
@ -149,28 +138,29 @@ class PostgresSource(Source):
col_type = 'DOUBLE'
elif row['col_type'].upper() == 'OID':
col_type = 'NUMBER'
elif row['col_type'].upper() == 'NAME':
col_type = 'CHAR'
elif row['col_type'].upper() == 'ARRAY':
col_type = 'ARRAY'
elif row['col_type'].upper() == 'BOOLEAN':
col_type = 'BOOLEAN'
else:
col_type = row['col_type'].upper()
if not self.include_pattern.included(self.pattern, last_row[1]):
self.status.report_dropped(last_row['name'])
col_type = None
if not self.pattern.include_pattern.included(f'{last_row[1]}.{last_row[2]}'):
self.status.filtered(f'{last_row[1]}.{last_row[2]}', "pattern not allowed", last_row[2])
continue
columns.append(Column(name=row['col_name'], description=row['col_description'],
columnDataType=col_type, ordinalPosition=int(row['col_sort_order'])))
if col_type is not None:
columns.append(Column(name=row['col_name'], description=row['col_description'],
columnDataType=col_type, ordinalPosition=int(row['col_sort_order'])))
table_metadata = TableEntity(name=last_row['name'],
description=last_row['description'],
columns=columns)
self.status.report_table_scanned(table_metadata.name)
self.status.scanned(table_metadata.name.__root__)
dm = DatabaseEntity(id=uuid.uuid4(),
name=row['schema'],
description=row['description'] if row['description'] is not None else ' ',
service=EntityReference(id=self.service.id, type=self.SERVICE_TYPE))
table_and_db = OMetaDatabaseAndTable(table=table_metadata, database=dm)
self.status.records_produced(dm)
yield table_and_db
def close(self):

View File

@ -17,13 +17,13 @@ import logging
from typing import Optional
from metadata.ingestion.ometa.auth_provider import MetadataServerConfig
from metadata.ingestion.source.sql_source import SQLAlchemySource, BasicSQLAlchemyConfig
from metadata.ingestion.source.sql_source import SQLSource, SQLConnectionConfig
from metadata.ingestion.api.source import SourceStatus
logger = logging.getLogger(__name__)
class RedshiftConfig(BasicSQLAlchemyConfig):
class RedshiftConfig(SQLConnectionConfig):
scheme = "postgresql+psycopg2"
where_clause: Optional[str] = None
duration: int = 1
@ -34,8 +34,11 @@ class RedshiftConfig(BasicSQLAlchemyConfig):
return f"{self.database}.{regular}"
return regular
def get_connection_url(self):
return super().get_connection_url()
class RedshiftSource(SQLAlchemySource):
class RedshiftSource(SQLSource):
def __init__(self, config, metadata_config, ctx):
super().__init__(config, metadata_config, ctx)

View File

@ -1,200 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This import verifies that the dependencies are available.
import logging
import uuid
import pymysql # noqa: F401
from pydantic import ValidationError
from metadata.generated.schema.entity.data.table import Column, TableEntity
from metadata.generated.schema.entity.data.database import DatabaseEntity
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.models.ometa_table_db import OMetaDatabaseAndTable
from metadata.ingestion.ometa.auth_provider import MetadataServerConfig
from metadata.ingestion.source.sql_source_common import BasicSQLQueryConfig, SQLAlchemyHelper, SQLSourceStatus
from metadata.ingestion.api.source import Source, SourceStatus
from itertools import groupby
from typing import Iterator, Union, Dict, Any, Iterable
from collections import namedtuple
from metadata.utils.helpers import get_service_or_create
TableKey = namedtuple('TableKey', ['schema', 'table_name'])
class RedshiftConfig(BasicSQLQueryConfig):
scheme = "redshift"
where_clause: str = None
cluster_source: str = "CURRENT_DATABASE()"
api_endpoint: str = None
service_type: str = "REDSHIFT"
service_name: str = "aws_redshift"
def get_table_key(row: Dict[str, Any]) -> Union[TableKey, None]:
"""
Table key consists of schema and table name
:param row:
:return:
"""
if row:
return TableKey(schema=row['schema'], table_name=row['name'])
return None
logger = logging.getLogger(__name__)
class RedshiftSQLSource(Source):
# SELECT statement from mysql information_schema to extract table and column metadata
SQL_STATEMENT = """
SELECT
*
FROM (
SELECT
{cluster_source} as cluster,
c.table_schema as schema,
c.table_name as name,
pgtd.description as description,
c.column_name as col_name,
c.data_type as col_type,
pgcd.description as col_description,
ordinal_position as col_sort_order
FROM INFORMATION_SCHEMA.COLUMNS c
INNER JOIN
pg_catalog.pg_statio_all_tables as st on c.table_schema=st.schemaname and c.table_name=st.relname
LEFT JOIN
pg_catalog.pg_description pgcd on pgcd.objoid=st.relid and pgcd.objsubid=c.ordinal_position
LEFT JOIN
pg_catalog.pg_description pgtd on pgtd.objoid=st.relid and pgtd.objsubid=0
UNION
SELECT
{cluster_source} as cluster,
view_schema as schema,
view_name as name,
NULL as description,
column_name as col_name,
data_type as col_type,
NULL as col_description,
ordinal_position as col_sort_order
FROM
PG_GET_LATE_BINDING_VIEW_COLS()
COLS(view_schema NAME, view_name NAME, column_name NAME, data_type VARCHAR, ordinal_position INT)
)
{where_clause_suffix}
ORDER by cluster, schema, name, col_sort_order ;
"""
# CONFIG KEYS
WHERE_CLAUSE_SUFFIX_KEY = 'where_clause'
CLUSTER_SOURCE = 'cluster_source'
CLUSTER_KEY = 'cluster_key'
USE_CATALOG_AS_CLUSTER_NAME = 'use_catalog_as_cluster_name'
DATABASE_KEY = 'database_key'
SERVICE_TYPE = 'REDSHIFT'
DEFAULT_CLUSTER_SOURCE = 'CURRENT_DATABASE()'
def __init__(self, config, metadata_config, ctx):
super().__init__(ctx)
self.sql_stmt = RedshiftSQLSource.SQL_STATEMENT.format(
where_clause_suffix=config.where_clause,
cluster_source=config.cluster_source,
database=config.database
)
self.alchemy_helper = SQLAlchemyHelper(config, metadata_config, ctx, "Redshift", self.sql_stmt)
self.config = config
self.metadata_config = metadata_config
self._extract_iter: Union[None, Iterator] = None
self._database = 'redshift'
self.report = SQLSourceStatus()
self.service = get_service_or_create(config, metadata_config)
@classmethod
def create(cls, config_dict, metadata_config_dict, ctx):
config = RedshiftConfig.parse_obj(config_dict)
metadata_config = MetadataServerConfig.parse_obj(metadata_config_dict)
return cls(config, metadata_config, ctx)
def prepare(self):
pass
def _get_raw_extract_iter(self) -> Iterable[Dict[str, Any]]:
"""
Provides iterator of result row from SQLAlchemy helper
:return:
"""
rows = self.alchemy_helper.execute_query()
for row in rows:
yield row
def next_record(self) -> Iterable[OMetaDatabaseAndTable]:
"""
Using itertools.groupby and raw level iterator, it groups to table and yields TableMetadata
:return:
"""
for key, group in groupby(self._get_raw_extract_iter(), get_table_key):
try:
columns = []
for row in group:
last_row = row
col_type = ''
if row['col_type'].upper() == 'CHARACTER VARYING':
col_type = 'VARCHAR'
elif row['col_type'].upper() == 'CHARACTER':
col_type = 'CHAR'
elif row['col_type'].upper() == 'INTEGER':
col_type = 'INT'
elif row['col_type'].upper() == 'TIMESTAMP WITHOUT TIME ZONE':
col_type = 'TIMESTAMP'
elif row['col_type'].upper() == 'DOUBLE PRECISION':
col_type = 'DOUBLE'
elif row['col_type'].upper() == 'OID':
col_type = 'NUMBER'
elif row['col_type'].upper() == 'NAME':
col_type = 'CHAR'
else:
col_type = row['col_type'].upper()
columns.append(Column(name=row['col_name'], description=row['col_description'],
columnDataType=col_type,
ordinalPosition=int(row['col_sort_order'])))
db = DatabaseEntity(id=uuid.uuid4(),
name=last_row['schema'],
description=last_row['description'] if last_row['description'] is not None else ' ',
service=EntityReference(id=self.service.id, type=self.config.service_type))
table = TableEntity(name=last_row['name'],
columns=columns)
table_and_db = OMetaDatabaseAndTable(table=table, database=db)
self.report.report_table_scanned(table.name)
self.report.records_produced(table.name)
yield table_and_db
except ValidationError as err:
logger.info("Dropped Table {} due to {}".format(row['name'], err))
self.report.report_dropped(row['name'])
continue
def get_report(self):
return self.report
def close(self):
self.alchemy_helper.close()
def get_status(self) -> SourceStatus:
return self.report

View File

@ -17,7 +17,7 @@
import logging
from metadata.ingestion.models.table_queries import TableQuery
from metadata.ingestion.ometa.auth_provider import MetadataServerConfig
from metadata.ingestion.source.sql_source_common import SQLAlchemyHelper, SQLSourceStatus
from metadata.ingestion.source.sql_alchemy_helper import SQLAlchemyHelper, SQLSourceStatus
from metadata.ingestion.api.source import Source, SourceStatus
from typing import Iterator, Union, Dict, Any, Iterable
from metadata.utils.helpers import get_start_and_end
@ -99,8 +99,8 @@ class RedshiftUsageSource(Source):
"""
for row in self._get_raw_extract_iter():
tq = TableQuery(row['query'], row['label'], row['userid'], row['xid'], row['pid'], str(row['starttime']),
str(row['endtime']), str(row['analysis_date']), row['duration'], row['database'], row['aborted'], row['sql'])
self.status.records_produced(tq)
str(row['endtime']), str(row['analysis_date']), row['duration'], row['database'],
row['aborted'], row['sql'])
yield tq
def close(self):

View File

@ -293,8 +293,7 @@ class SampleTableSource(Source):
for table in self.tables['tables']:
table_metadata = TableEntity(**table)
table_and_db = OMetaDatabaseAndTable(table=table_metadata, database=db)
self.status.report_table_scanned(table_metadata.name.__root__)
self.status.records_produced(table_metadata.name.__root__)
self.status.scanned(table_metadata.name.__root__)
yield table_and_db
def close(self):

View File

@ -15,12 +15,11 @@
from typing import Optional
import snowflake.sqlalchemy
from snowflake.sqlalchemy import custom_types
from .sql_source import (
BasicSQLAlchemyConfig,
SQLAlchemySource,
SQLConnectionConfig,
SQLSource,
register_custom_type,
)
from ..ometa.auth_provider import MetadataServerConfig
@ -30,7 +29,7 @@ register_custom_type(custom_types.TIMESTAMP_LTZ, "TIME")
register_custom_type(custom_types.TIMESTAMP_NTZ, "TIME")
class SnowflakeConfig(BasicSQLAlchemyConfig):
class SnowflakeConfig(SQLConnectionConfig):
scheme = "snowflake"
account: str
database: str # database is required
@ -38,8 +37,8 @@ class SnowflakeConfig(BasicSQLAlchemyConfig):
role: Optional[str]
duration: Optional[int]
def get_sql_alchemy_url(self):
connect_string = super().get_sql_alchemy_url()
def get_connection_url(self):
connect_string = super().get_connection_url()
options = {
"account": self.account,
"warehouse": self.warehouse,
@ -50,14 +49,10 @@ class SnowflakeConfig(BasicSQLAlchemyConfig):
connect_string = f"{connect_string}?{params}"
return connect_string
def get_identifier(self, schema: str, table: str) -> str:
regular = super().get_identifier(schema, table)
return f"{self.database}.{regular}"
class SnowflakeSource(SQLAlchemySource):
class SnowflakeSource(SQLSource):
def __init__(self, config, metadata_config, ctx):
super().__init__(config, metadata_config, ctx, "snowflake")
super().__init__(config, metadata_config, ctx)
@classmethod
def create(cls, config_dict, metadata_config_dict, ctx):

View File

@ -16,7 +16,7 @@
# This import verifies that the dependencies are available.
from metadata.ingestion.models.table_queries import TableQuery
from metadata.ingestion.ometa.auth_provider import MetadataServerConfig
from metadata.ingestion.source.sql_source_common import SQLAlchemyHelper, SQLSourceStatus
from metadata.ingestion.source.sql_alchemy_helper import SQLAlchemyHelper, SQLSourceStatus
from metadata.ingestion.api.source import Source, SourceStatus
from typing import Iterator, Union, Dict, Any, Iterable
@ -83,7 +83,7 @@ class SnowflakeUsageSource(Source):
for row in self._get_raw_extract_iter():
tq = TableQuery(row['query'], row['label'], 0, 0, 0, str(row['starttime']),
str(row['endtime']), str(row['starttime'])[0:19], 2, row['database'], 0, row['sql'])
self.report.records_produced(tq)
self.report.scanned(tq)
yield tq
def get_report(self):

View File

@ -13,71 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import abstractmethod
from metadata.config.common import ConfigModel
from typing import Any, Iterable, List, Optional, Tuple
from dataclasses import dataclass, field
from typing import Any, Iterable
from metadata.ingestion.api.common import WorkflowContext
from metadata.ingestion.api.source import SourceStatus
from sqlalchemy import create_engine
from .sql_source import SQLConnectionConfig, SQLSourceStatus
from metadata.ingestion.ometa.auth_provider import MetadataServerConfig
@dataclass
class SQLSourceStatus(SourceStatus):
tables_scanned = 0
filtered: List[str] = field(default_factory=list)
def report_table_scanned(self, table_name: str) -> None:
self.tables_scanned += 1
def report_dropped(self, table_name: str) -> None:
self.filtered.append(table_name)
class SQLAlchemyConfig(ConfigModel):
options: dict = {}
@abstractmethod
def get_sql_alchemy_url(self):
pass
def get_identifier(self, schema: str, table: str) -> str:
return f"{schema}.{table}"
def standardize_schema_table_names(
self, schema: str, table: str
) -> Tuple[str, str]:
# Some SQLAlchemy dialects need a standardization step to clean the schema
# and table names. See BigQuery for an example of when this is useful.
return schema, table
class BasicSQLQueryConfig(SQLAlchemyConfig):
username: Optional[str] = None
password: Optional[str] = None
host_port: str
database: Optional[str] = None
scheme: str
def get_sql_alchemy_url(self):
url = f"{self.scheme}://"
if self.username:
url += f"{self.username}"
if self.password:
url += f":{self.password}"
url += "@"
url += f"{self.host_port}"
if self.database:
url += f"/{self.database}"
return url
class SQLAlchemyHelper:
"""A helper class for all SQL Sources that use SQLAlchemy to extend"""
def __init__(self, config: SQLAlchemyConfig, metadata_config: MetadataServerConfig,
def __init__(self, config: SQLConnectionConfig, metadata_config: MetadataServerConfig,
ctx: WorkflowContext, platform: str, query: str):
self.config = config
self.platform = platform
@ -89,7 +36,7 @@ class SQLAlchemyHelper:
"""
Create a SQLAlchemy connection to Database
"""
engine = create_engine(self.config.get_sql_alchemy_url())
engine = create_engine(self.config.get_connection_url())
conn = engine.connect()
return conn

View File

@ -44,40 +44,20 @@ logger: logging.Logger = logging.getLogger(__name__)
@dataclass
class SQLSourceStatus(SourceStatus):
tables_scanned: List[str] = field(default_factory=list)
filtered: List[str] = field(default_factory=list)
success: List[str] = field(default_factory=list)
failures: List[str] = field(default_factory=list)
warnings: List[str] = field(default_factory=list)
def report_table_scanned(self, table_name: str) -> None:
self.tables_scanned.append(table_name)
def scanned(self, table_name: str) -> None:
self.success.append(table_name)
logger.info('Table Scanned: {}'.format(table_name))
def report_dropped(self, table_name: str, err: str, dataset_name: str = None, col_type: str = None) -> None:
self.filtered.append(table_name)
logger.error("Dropped Table {} due to {}".format(dataset_name, err))
logger.error("column type {}".format(col_type))
def filtered(self, table_name: str, err: str, dataset_name: str = None, col_type: str = None) -> None:
self.warnings.append(table_name)
logger.warning("Dropped Table {} due to {}".format(dataset_name, err))
class SQLAlchemyConfig(ConfigModel):
env: str = "PROD"
options: dict = {}
include_pattern: IncludeFilterPattern
@abstractmethod
def get_sql_alchemy_url(self):
pass
def get_identifier(self, schema: str, table: str) -> str:
return f"{schema}.{table}"
def standardize_schema_table_names(
self, schema: str, table: str
) -> Tuple[str, str]:
# Some SQLAlchemy dialects need a standardization step to clean the schema
# and table names. See BigQuery for an example of when this is useful.
return schema, table
class BasicSQLAlchemyConfig(SQLAlchemyConfig):
class SQLConnectionConfig(ConfigModel):
username: Optional[str] = None
password: Optional[str] = None
host_port: str
@ -85,8 +65,11 @@ class BasicSQLAlchemyConfig(SQLAlchemyConfig):
scheme: str
service_name: str
service_type: str
options: dict = {}
include_pattern: IncludeFilterPattern = IncludeFilterPattern.allow_all()
def get_sql_alchemy_url(self):
@abstractmethod
def get_connection_url(self):
url = f"{self.scheme}://"
if self.username:
url += f"{self.username}"
@ -101,8 +84,11 @@ class BasicSQLAlchemyConfig(SQLAlchemyConfig):
def get_service_type(self) -> DatabaseServiceType:
return DatabaseServiceType[self.service_type]
def get_service_name(self) -> str:
return self.service_name
_field_type_mapping: Dict[Type[types.TypeEngine], str] = {
_column_type_mapping: Dict[Type[types.TypeEngine], str] = {
types.Integer: "INT",
types.Numeric: "INT",
types.Boolean: "BOOLEAN",
@ -123,7 +109,7 @@ _field_type_mapping: Dict[Type[types.TypeEngine], str] = {
types.JSON: "JSON"
}
_known_unknown_field_types: Set[Type[types.TypeEngine]] = {
_known_unknown_column_types: Set[Type[types.TypeEngine]] = {
types.Interval,
types.CLOB,
}
@ -133,25 +119,25 @@ def register_custom_type(
tp: Type[types.TypeEngine], output: str = None
) -> None:
if output:
_field_type_mapping[tp] = output
_column_type_mapping[tp] = output
else:
_known_unknown_field_types.add(tp)
_known_unknown_column_types.add(tp)
def get_column_type(sql_report: SQLSourceStatus, dataset_name: str, column_type: Any) -> str:
def get_column_type(status: SQLSourceStatus, dataset_name: str, column_type: Any) -> str:
type_class: Optional[str] = None
for sql_type in _field_type_mapping.keys():
for sql_type in _column_type_mapping.keys():
if isinstance(column_type, sql_type):
type_class = _field_type_mapping[sql_type]
type_class = _column_type_mapping[sql_type]
break
if type_class is None:
for sql_type in _known_unknown_field_types:
for sql_type in _known_unknown_column_types:
if isinstance(column_type, sql_type):
type_class = "NULL"
break
if type_class is None:
sql_report.warning(
status.warning(
dataset_name, f"unable to map type {column_type!r} to metadata schema"
)
type_class = "NULL"
@ -159,10 +145,10 @@ def get_column_type(sql_report: SQLSourceStatus, dataset_name: str, column_type:
return type_class
class SQLAlchemySource(Source):
class SQLSource(Source):
def __init__(self, config: SQLAlchemyConfig, metadata_config: MetadataServerConfig,
ctx: WorkflowContext, connector: str = None):
def __init__(self, config: SQLConnectionConfig, metadata_config: MetadataServerConfig,
ctx: WorkflowContext):
super().__init__(ctx)
self.config = config
self.metadata_config = metadata_config
@ -176,20 +162,25 @@ class SQLAlchemySource(Source):
def create(cls, config_dict: dict, metadata_config_dict: dict, ctx: WorkflowContext):
pass
def standardize_schema_table_names(
self, schema: str, table: str
) -> Tuple[str, str]:
return schema, table
def next_record(self) -> Iterable[OMetaDatabaseAndTable]:
sql_config = self.config
url = sql_config.get_sql_alchemy_url()
url = sql_config.get_connection_url()
logger.debug(f"sql_alchemy_url={url}")
engine = create_engine(url, **sql_config.options)
inspector = inspect(engine)
for schema in inspector.get_schema_names():
if not sql_config.include_pattern.included(schema):
self.status.report_dropped(schema, "Schema pattern not allowed")
self.status.filtered(schema, "Schema pattern not allowed")
continue
logger.debug("total tables {}".format(inspector.get_table_names(schema)))
for table in inspector.get_table_names(schema):
try:
schema, table = sql_config.standardize_schema_table_names(schema, table)
schema, table = self.standardize_schema_table_names(schema, table)
pk_constraints = inspector.get_pk_constraint(table, schema)
pk_columns = pk_constraints['column_constraints'] if len(
pk_constraints) > 0 and "column_constraints" in pk_constraints.keys() else {}
@ -203,11 +194,11 @@ class SQLAlchemySource(Source):
if 'column_names' in constraint.keys():
unique_columns = constraint['column_names']
dataset_name = sql_config.get_identifier(schema, table)
self.status.report_table_scanned('{}.{}'.format(self.config.service_name, dataset_name))
dataset_name = f"{schema}.{table}"
self.status.scanned('{}.{}'.format(self.config.get_service_name(), dataset_name))
if not sql_config.include_pattern.included(dataset_name):
self.status.report_dropped('{}.{}'.format(self.config.service_name, dataset_name),
"Table pattern not allowed")
self.status.filtered('{}.{}'.format(self.config.get_service_name(), dataset_name),
"Table pattern not allowed")
continue
columns = inspector.get_columns(table, schema)
@ -216,15 +207,9 @@ class SQLAlchemySource(Source):
table_info: dict = inspector.get_table_comment(table, schema)
except NotImplementedError:
description: Optional[str] = None
properties: Dict[str, str] = {}
else:
description = table_info["text"]
# The "properties" field is a non-standard addition to SQLAlchemy's interface.
properties = table_info.get("properties", {})
# TODO: capture inspector.get_pk_constraint
# TODO: capture inspector.get_sorted_table_and_fkc_names
table_columns = []
row_order = 1
for column in columns:
@ -255,12 +240,11 @@ class SQLAlchemySource(Source):
columns=table_columns)
table_and_db = OMetaDatabaseAndTable(table=table, database=db)
self.status.records_produced(table.name)
yield table_and_db
except ValidationError as err:
logger.error(err)
self.status.report_dropped('{}.{}'.format(self.config.service_name, dataset_name),
"Validation error")
self.status.filtered('{}.{}'.format(self.config.service_name, dataset_name),
"Validation error")
continue
def close(self):

View File

@ -42,7 +42,6 @@ def get_table_column_join(table, table_aliases, joins):
except ValueError as err:
logger.error("Error in parsing sql query joins {}".format(err))
pass
return TableColumnJoin(table_column=table_column, joined_with=joined_with)

View File

@ -118,6 +118,8 @@ class Workflow:
if hasattr(self, 'sink'):
self.sink.write_record(processed_record)
self.report['sink'] = self.sink.get_status().as_obj()
def stop(self):
if hasattr(self, 'processor'):
self.processor.close()
if hasattr(self, 'stage'):

View File

@ -41,7 +41,7 @@ def get_service_or_create(config, metadata_config) -> DatabaseServiceEntity:
if service is not None:
return service
else:
service = {'jdbc': {'connectionUrl': config.get_sql_alchemy_url(), 'driverClass': 'jdbc'},
service = {'jdbc': {'connectionUrl': config.get_connection_url(), 'driverClass': 'jdbc'},
'name': config.service_name, 'description': '', 'serviceType': config.get_service_type()}
created_service = client.create_database_service(CreateDatabaseServiceEntityRequest(**service))
return created_service

View File

@ -24,4 +24,4 @@ services:
volumes:
- ./setup:/setup
ports:
- 51433:1433
- 1433:1433