182 lines
5.8 KiB
Python
Raw Normal View History

import logging
import time
2021-02-09 15:58:26 -08:00
from dataclasses import dataclass, field
2021-02-11 23:14:20 -08:00
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from sqlalchemy import create_engine, types
from sqlalchemy.engine import reflection
2021-02-11 23:14:20 -08:00
from gometa.configuration.common import AllowDenyPattern
from gometa.ingestion.api.source import Source, SourceReport, WorkUnit
from gometa.metadata.com.linkedin.pegasus2avro.common import AuditStamp
from gometa.metadata.com.linkedin.pegasus2avro.metadata.snapshot import DatasetSnapshot
from gometa.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeEvent
2021-02-11 12:24:20 -08:00
from gometa.metadata.com.linkedin.pegasus2avro.schema import (
2021-02-11 23:14:20 -08:00
ArrayTypeClass,
2021-02-11 21:34:36 -08:00
BooleanTypeClass,
BytesTypeClass,
EnumTypeClass,
2021-02-11 23:14:20 -08:00
MySqlDDL,
2021-02-11 21:34:36 -08:00
NullTypeClass,
2021-02-11 23:14:20 -08:00
NumberTypeClass,
SchemaField,
SchemaFieldDataType,
SchemaMetadata,
StringTypeClass,
2021-02-11 12:24:20 -08:00
)
logger = logging.getLogger(__name__)
2021-02-09 15:58:26 -08:00
@dataclass
class SQLSourceReport(SourceReport):
tables_scanned = 0
filtered: List[str] = field(default_factory=list)
2021-02-11 12:24:20 -08:00
warnings: Dict[str, List[str]] = field(default_factory=dict)
def report_warning(self, table_name: str, reason: str) -> None:
if table_name not in self.warnings:
self.warnings[table_name] = []
self.warnings[table_name].append(reason)
2021-02-09 15:58:26 -08:00
def report_table_scanned(self, table_name: str) -> None:
self.tables_scanned += 1
2021-02-11 21:34:36 -08:00
2021-02-09 15:58:26 -08:00
def report_dropped(self, table_name: str) -> None:
self.filtered.append(table_name)
class SQLAlchemyConfig(BaseModel):
username: str
password: str
host_port: str
database: str = ""
scheme: str
options: Optional[dict] = {}
table_pattern: AllowDenyPattern = AllowDenyPattern.allow_all()
def get_sql_alchemy_url(self):
2021-02-11 22:48:08 -08:00
url = f"{self.scheme}://{self.username}:{self.password}@{self.host_port}/{self.database}"
logger.debug("sql_alchemy_url={url}")
return url
2021-01-31 22:40:30 -08:00
@dataclass
class SqlWorkUnit(WorkUnit):
2021-02-11 21:34:36 -08:00
mce: MetadataChangeEvent
def get_metadata(self):
2021-02-11 22:48:08 -08:00
return {"mce": self.mce}
2021-02-11 21:34:36 -08:00
2021-02-11 12:24:20 -08:00
_field_type_mapping = {
types.Integer: NumberTypeClass,
types.Numeric: NumberTypeClass,
types.Boolean: BooleanTypeClass,
types.Enum: EnumTypeClass,
types._Binary: BytesTypeClass,
types.PickleType: BytesTypeClass,
types.ARRAY: ArrayTypeClass,
types.String: StringTypeClass,
}
2021-02-11 21:34:36 -08:00
2021-02-11 22:48:08 -08:00
def get_column_type(
sql_report: SQLSourceReport, dataset_name: str, column_type
) -> SchemaFieldDataType:
2021-02-11 12:24:20 -08:00
"""
Maps SQLAlchemy types (https://docs.sqlalchemy.org/en/13/core/type_basics.html) to corresponding schema types
"""
TypeClass: Any = None
for sql_type in _field_type_mapping.keys():
if isinstance(column_type, sql_type):
TypeClass = _field_type_mapping[sql_type]
break
2021-02-11 21:34:36 -08:00
2021-02-11 12:24:20 -08:00
if TypeClass is None:
2021-02-11 22:48:08 -08:00
sql_report.report_warning(
dataset_name, f"unable to map type {column_type} to metadata schema"
)
2021-02-11 12:24:20 -08:00
TypeClass = NullTypeClass
2021-02-11 12:24:20 -08:00
return SchemaFieldDataType(type=TypeClass())
2021-01-31 22:40:30 -08:00
2021-02-09 01:02:05 -08:00
2021-02-11 22:48:08 -08:00
def get_schema_metadata(
sql_report: SQLSourceReport, dataset_name: str, platform: str, columns
) -> SchemaMetadata:
2021-02-11 12:24:20 -08:00
canonical_schema: List[SchemaField] = []
for column in columns:
field = SchemaField(
2021-02-11 22:48:08 -08:00
fieldPath=column["name"],
nativeDataType=repr(column["type"]),
type=get_column_type(sql_report, dataset_name, column["type"]),
2021-02-11 12:24:20 -08:00
description=column.get("comment", None),
)
canonical_schema.append(field)
actor, sys_time = "urn:li:corpuser:etl", int(time.time()) * 1000
schema_metadata = SchemaMetadata(
schemaName=dataset_name,
2021-02-11 22:48:08 -08:00
platform=f"urn:li:dataPlatform:{platform}",
version=0,
hash="",
2021-02-11 22:48:08 -08:00
platformSchema=MySqlDDL(tableSchema=""),
2021-02-11 21:34:36 -08:00
created=AuditStamp(time=sys_time, actor=actor),
lastModified=AuditStamp(time=sys_time, actor=actor),
fields=canonical_schema,
2021-02-11 12:24:20 -08:00
)
return schema_metadata
2021-01-31 22:40:30 -08:00
2021-02-09 01:02:05 -08:00
class SQLAlchemySource(Source):
"""A Base class for all SQL Sources that use SQLAlchemy to extend"""
def __init__(self, config, ctx, platform: str):
super().__init__(ctx)
self.config = config
self.platform = platform
2021-02-09 15:58:26 -08:00
self.report = SQLSourceReport()
2021-02-09 01:02:05 -08:00
def get_workunits(self):
2021-02-11 21:34:36 -08:00
env: str = "PROD"
2021-02-09 01:02:05 -08:00
sql_config = self.config
platform = self.platform
url = sql_config.get_sql_alchemy_url()
engine = create_engine(url, **sql_config.options)
inspector = reflection.Inspector.from_engine(engine)
database = sql_config.database
for schema in inspector.get_schema_names():
for table in inspector.get_table_names(schema):
2021-02-09 15:58:26 -08:00
if database != "":
2021-02-11 22:48:08 -08:00
dataset_name = f"{database}.{schema}.{table}"
2021-02-09 15:58:26 -08:00
else:
2021-02-11 22:48:08 -08:00
dataset_name = f"{schema}.{table}"
2021-02-09 15:58:26 -08:00
self.report.report_table_scanned(dataset_name)
if sql_config.table_pattern.allowed(dataset_name):
2021-02-09 01:02:05 -08:00
columns = inspector.get_columns(table, schema)
mce = MetadataChangeEvent()
dataset_snapshot = DatasetSnapshot()
2021-02-11 21:34:36 -08:00
dataset_snapshot.urn = f"urn:li:dataset:(urn:li:dataPlatform:{platform},{dataset_name},{env})"
2021-02-11 22:48:08 -08:00
schema_metadata = get_schema_metadata(
self.report, dataset_name, platform, columns
)
2021-02-09 01:02:05 -08:00
dataset_snapshot.aspects.append(schema_metadata)
mce.proposedSnapshot = dataset_snapshot
2021-02-11 21:34:36 -08:00
wu = SqlWorkUnit(id=dataset_name, mce=mce)
2021-02-09 15:58:26 -08:00
self.report.report_workunit(wu)
2021-02-11 21:34:36 -08:00
yield wu
else:
2021-02-09 15:58:26 -08:00
self.report.report_dropped(dataset_name)
2021-02-11 21:34:36 -08:00
2021-02-09 15:58:26 -08:00
def get_report(self):
return self.report
2021-02-11 21:34:36 -08:00
2021-02-09 01:02:05 -08:00
def close(self):
pass