Reflection Cache Implementation (#2016)

* Reflection Cache for Bigquery and Redshift

* Overrided few sqlalchemy packages

* Added Geography Support

* Reformatted files

* DBT models error handling implemented

* Geography type added as a custom sqlalchemy datatype

* GEOGRAPHY and VARIANT added as custom sql types

* Implemented file formatting using black

* Implemented file formatting using black
This commit is contained in:
Ayush Shah 2022-01-11 14:58:03 +05:30 committed by GitHub
parent cf6f438531
commit f379b35279
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 430 additions and 111 deletions

View File

@ -61,7 +61,7 @@ base_plugins = {
plugins: Dict[str, Set[str]] = {
"amundsen": {"neo4j~=4.4.0"},
"athena": {"PyAthena[SQLAlchemy]"},
"bigquery": {"openmetadata-sqlalchemy-bigquery==0.2.2"},
"bigquery": {"sqlalchemy-bigquery==1.2.2"},
"bigquery-usage": {"google-cloud-logging", "cachetools"},
"docker": {"docker==5.0.3"},
"dbt": {},
@ -69,10 +69,11 @@ plugins: Dict[str, Set[str]] = {
"elasticsearch": {"elasticsearch~=7.13.1"},
"glue": {"boto3~=1.19.12"},
"hive": {
"openmetadata-sqlalchemy-hive==0.2.0",
"pyhive~=0.6.3",
"thrift~=0.13.0",
"sasl==0.3.1",
"thrift-sasl==0.4.3",
"presto-types-parser==0.0.2"
},
"kafka": {"confluent_kafka>=1.5.0", "fastavro>=1.2.0"},
"ldap-users": {"ldap3==2.9.1"},
@ -86,12 +87,12 @@ plugins: Dict[str, Set[str]] = {
"postgres": {"pymysql>=1.0.2", "psycopg2-binary", "GeoAlchemy2"},
"redash": {"redash-toolbelt==0.1.4"},
"redshift": {
"openmetadata-sqlalchemy-redshift==0.2.1",
"sqlalchemy-redshift==0.8.9",
"psycopg2-binary",
"GeoAlchemy2",
},
"redshift-usage": {
"openmetadata-sqlalchemy-redshift==0.2.1",
"sqlalchemy-redshift==0.8.9",
"psycopg2-binary",
"GeoAlchemy2",
},

View File

@ -14,6 +14,40 @@ from typing import Optional, Tuple
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
from metadata.ingestion.source.sql_source import SQLConnectionConfig, SQLSource
from metadata.utils.column_helpers import create_sqlalchemy_type
from sqlalchemy_bigquery import _types
from sqlalchemy_bigquery._struct import STRUCT
from sqlalchemy_bigquery._types import (
_get_sqla_column_type,
_get_transitive_schema_fields,
)
GEOGRAPHY = create_sqlalchemy_type("GEOGRAPHY")
_types._type_map["GEOGRAPHY"] = GEOGRAPHY
def get_columns(bq_schema):
fields = _get_transitive_schema_fields(bq_schema)
col_list = []
for field in fields:
col_obj = {
"name": field.name,
"type": _get_sqla_column_type(field)
if "STRUCT" or "RECORD" not in field
else STRUCT,
"nullable": field.mode == "NULLABLE" or field.mode == "REPEATED",
"comment": field.description,
"default": None,
"precision": field.precision,
"scale": field.scale,
"max_length": field.max_length,
"raw_data_type": repr(_get_sqla_column_type(field)),
}
col_list.append(col_obj)
return col_list
_types.get_columns = get_columns
class BigQueryConfig(SQLConnectionConfig, SQLSource):

View File

@ -97,6 +97,9 @@ class BigqueryUsageSource(Source[TableQuery]):
jobStats["startTime"][0:19], "%Y-%m-%dT%H:%M:%S"
).strftime("%Y-%m-%d %H:%M:%S")
)
logger.debug(
f"Query :{statementType}:{queryConfig['query']}"
)
tq = TableQuery(
query=statementType,
user_name=entry.resource.labels["project_id"],

View File

@ -9,18 +9,50 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Optional
from pyhive import hive # noqa: F401
from pyhive.sqlalchemy_hive import HiveDate, HiveDecimal, HiveTimestamp
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
from metadata.ingestion.source.sql_source import SQLConnectionConfig, SQLSource
from metadata.utils.column_helpers import register_custom_type
from pyhive.sqlalchemy_hive import HiveDialect, _type_map
from sqlalchemy import types, util
register_custom_type(HiveDate, "DATE")
register_custom_type(HiveTimestamp, "TIME")
register_custom_type(HiveDecimal, "NUMBER")
complex_data_types = ["struct", "map", "array", "union"]
def get_columns(self, connection, table_name, schema=None, **kw):
rows = self._get_table_columns(connection, table_name, schema)
rows = [[col.strip() if col else None for col in row] for row in rows]
rows = [row for row in rows if row[0] and row[0] != "# col_name"]
result = []
for (col_name, col_type, _comment) in rows:
if col_name == "# Partition Information":
break
col_raw_type = col_type
col_type = re.search(r"^\w+", col_type).group(0)
try:
coltype = _type_map[col_type]
except KeyError:
util.warn(
"Did not recognize type '%s' of column '%s'" % (col_type, col_name)
)
coltype = types.NullType
result.append(
{
"name": col_name,
"type": coltype,
"nullable": True,
"default": None,
"raw_data_type": col_raw_type
if col_type in complex_data_types
else None,
}
)
return result
HiveDialect.get_columns = get_columns
class HiveConfig(SQLConnectionConfig):

View File

@ -10,8 +10,20 @@
# limitations under the License.
import logging
import re
from collections import defaultdict
from typing import Optional
import sqlalchemy as sa
from packaging.version import Version
sa_version = Version(sa.__version__)
from sqlalchemy import inspect
from sqlalchemy.engine import reflection
from sqlalchemy.types import CHAR, VARCHAR, NullType
from sqlalchemy_redshift.dialect import RedshiftDialectMixin, RelationKey
from metadata.ingestion.api.source import SourceStatus
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
from metadata.ingestion.source.sql_source import SQLConnectionConfig, SQLSource
@ -19,6 +31,212 @@ from metadata.ingestion.source.sql_source import SQLConnectionConfig, SQLSource
logger = logging.getLogger(__name__)
@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
return self._get_table_or_view_names(["r", "e"], connection, schema, **kw)
@reflection.cache
def get_view_names(self, connection, schema=None, **kw):
return self._get_table_or_view_names(["v"], connection, schema, **kw)
@reflection.cache
def _get_table_or_view_names(self, relkinds, connection, schema=None, **kw):
default_schema = inspect(connection).default_schema_name
if not schema:
schema = default_schema
info_cache = kw.get("info_cache")
all_relations = self._get_all_relation_info(connection, info_cache=info_cache)
relation_names = []
for key, relation in all_relations.items():
if key.schema == schema and relation.relkind in relkinds:
relation_names.append(key.name)
return relation_names
def _get_column_info(self, *args, **kwargs):
kw = kwargs.copy()
encode = kw.pop("encode", None)
if sa_version >= Version("1.3.16"):
kw["generated"] = ""
if sa_version < Version("1.4.0") and "identity" in kw:
del kw["identity"]
elif sa_version >= Version("1.4.0") and "identity" not in kw:
kw["identity"] = None
column_info = super(RedshiftDialectMixin, self)._get_column_info(*args, **kw)
column_info["raw_data_type"] = kw["format_type"]
if isinstance(column_info["type"], VARCHAR):
if column_info["type"].length is None:
column_info["type"] = NullType()
if re.match("char", column_info["raw_data_type"]):
column_info["type"] = CHAR
if "info" not in column_info:
column_info["info"] = {}
if encode and encode != "none":
column_info["info"]["encode"] = encode
return column_info
@reflection.cache
def _get_all_relation_info(self, connection, **kw):
result = connection.execute(
"""
SELECT
c.relkind,
n.oid as "schema_oid",
n.nspname as "schema",
c.oid as "rel_oid",
c.relname,
CASE c.reldiststyle
WHEN 0 THEN 'EVEN' WHEN 1 THEN 'KEY' WHEN 8 THEN 'ALL' END
AS "diststyle",
c.relowner AS "owner_id",
u.usename AS "owner_name",
TRIM(TRAILING ';' FROM pg_catalog.pg_get_viewdef(c.oid, true))
AS "view_definition",
pg_catalog.array_to_string(c.relacl, '\n') AS "privileges"
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
JOIN pg_catalog.pg_user u ON u.usesysid = c.relowner
WHERE c.relkind IN ('r', 'v', 'm', 'S', 'f')
AND n.nspname !~ '^pg_'
ORDER BY c.relkind, n.oid, n.nspname;
"""
)
relations = {}
for rel in result:
key = RelationKey(rel.relname, rel.schema, connection)
relations[key] = rel
result = connection.execute(
"""
SELECT
schemaname as "schema",
tablename as "relname",
'e' as relkind
FROM svv_external_tables;
"""
)
for rel in result:
key = RelationKey(rel.relname, rel.schema, connection)
relations[key] = rel
return relations
@reflection.cache
def _get_schema_column_info(self, connection, schema=None, **kw):
schema_clause = "AND schema = '{schema}'".format(schema=schema) if schema else ""
all_columns = defaultdict(list)
with connection.connect() as cc:
result = cc.execute(
"""
SELECT
n.nspname as "schema",
c.relname as "table_name",
att.attname as "name",
format_encoding(att.attencodingtype::integer) as "encode",
format_type(att.atttypid, att.atttypmod) as "type",
att.attisdistkey as "distkey",
att.attsortkeyord as "sortkey",
att.attnotnull as "notnull",
pg_catalog.col_description(att.attrelid, att.attnum)
as "comment",
adsrc,
attnum,
pg_catalog.format_type(att.atttypid, att.atttypmod),
pg_catalog.pg_get_expr(ad.adbin, ad.adrelid) AS DEFAULT,
n.oid as "schema_oid",
c.oid as "table_oid"
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n
ON n.oid = c.relnamespace
JOIN pg_catalog.pg_attribute att
ON att.attrelid = c.oid
LEFT JOIN pg_catalog.pg_attrdef ad
ON (att.attrelid, att.attnum) = (ad.adrelid, ad.adnum)
WHERE n.nspname !~ '^pg_'
AND att.attnum > 0
AND NOT att.attisdropped
{schema_clause}
UNION
SELECT
view_schema as "schema",
view_name as "table_name",
col_name as "name",
null as "encode",
col_type as "type",
null as "distkey",
0 as "sortkey",
null as "notnull",
null as "comment",
null as "adsrc",
null as "attnum",
col_type as "format_type",
null as "default",
null as "schema_oid",
null as "table_oid"
FROM pg_get_late_binding_view_cols() cols(
view_schema name,
view_name name,
col_name name,
col_type varchar,
col_num int)
WHERE 1 {schema_clause}
UNION
SELECT schemaname AS "schema",
tablename AS "table_name",
columnname AS "name",
null AS "encode",
-- Spectrum represents data types differently.
-- Standardize, so we can infer types.
CASE
WHEN external_type = 'int' THEN 'integer'
ELSE
replace(
replace(external_type, 'decimal', 'numeric'),
'varchar', 'character varying')
END
AS "type",
null AS "distkey",
0 AS "sortkey",
null AS "notnull",
null AS "comment",
null AS "adsrc",
null AS "attnum",
CASE
WHEN external_type = 'int' THEN 'integer'
ELSE
replace(
replace(external_type, 'decimal', 'numeric'),
'varchar', 'character varying')
END
AS "format_type",
null AS "default",
null AS "schema_oid",
null AS "table_oid"
FROM svv_external_columns
ORDER BY "schema", "table_name", "attnum";
""".format(
schema_clause=schema_clause
)
)
for col in result:
key = RelationKey(col.table_name, col.schema, connection)
all_columns[key].append(col)
return dict(all_columns)
RedshiftDialectMixin._get_table_or_view_names = _get_table_or_view_names
RedshiftDialectMixin.get_view_names = get_view_names
RedshiftDialectMixin.get_table_names = get_table_names
RedshiftDialectMixin._get_column_info = _get_column_info
RedshiftDialectMixin._get_all_relation_info = _get_all_relation_info
RedshiftDialectMixin._get_schema_column_info = _get_schema_column_info
class RedshiftConfig(SQLConnectionConfig):
scheme = "redshift+psycopg2"
where_clause: Optional[str] = None

View File

@ -11,15 +11,15 @@
from typing import Optional
from snowflake.sqlalchemy import custom_types
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
from metadata.ingestion.source.sql_source import SQLConnectionConfig, SQLSource
from metadata.utils.column_helpers import register_custom_type
from metadata.utils.column_helpers import create_sqlalchemy_type
from snowflake.sqlalchemy.custom_types import VARIANT
from snowflake.sqlalchemy.snowdialect import ischema_names
register_custom_type(custom_types.TIMESTAMP_TZ, "TIME")
register_custom_type(custom_types.TIMESTAMP_LTZ, "TIME")
register_custom_type(custom_types.TIMESTAMP_NTZ, "TIME")
GEOGRAPHY = create_sqlalchemy_type("GEOGRAPHY")
ischema_names["VARIANT"] = VARIANT
ischema_names["GEOGRAPHY"] = GEOGRAPHY
class SnowflakeConfig(SQLConnectionConfig):

View File

@ -22,11 +22,6 @@ from datetime import datetime
from typing import Dict, Iterable, List, Optional, Tuple
from urllib.parse import quote_plus
from pydantic import SecretStr
from sqlalchemy import create_engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.inspection import inspect
from metadata.generated.schema.entity.data.database import Database
from metadata.generated.schema.entity.data.table import (
Column,
@ -54,6 +49,10 @@ from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
from metadata.utils.column_helpers import check_column_complex_type, get_column_type
from metadata.utils.helpers import get_database_service_or_create
from pydantic import SecretStr
from sqlalchemy import create_engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.inspection import inspect
logger: logging.Logger = logging.getLogger(__name__)
@ -271,7 +270,8 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
def next_record(self) -> Iterable[Entity]:
inspector = inspect(self.engine)
for schema in inspector.get_schema_names():
schema_names = inspector.get_schema_names()
for schema in schema_names:
# clear any previous source database state
self.database_source_state.clear()
if not self.sql_config.schema_filter_pattern.included(schema):
@ -292,7 +292,8 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
Scrape an SQL schema and prepare Database and Table
OpenMetadata Entities
"""
for table_name in inspector.get_table_names(schema):
tables = inspector.get_table_names(schema)
for table_name in tables:
try:
schema, table_name = self.standardize_schema_table_names(
schema, table_name
@ -303,8 +304,6 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
"Table pattern not allowed",
)
continue
self.status.scanned(f"{self.config.get_service_name()}.{table_name}")
description = _get_table_description(schema, table_name, inspector)
fqn = f"{self.config.service_name}.{schema}.{table_name}"
self.database_source_state.add(fqn)
@ -338,10 +337,15 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
table=table_entity, database=self._get_database(schema)
)
yield table_and_db
# Catch any errors during the ingestion and continue
except Exception as err: # pylint: disable=broad-except
self.status.scanned(
"{}.{}".format(self.config.get_service_name(), table_name)
)
except Exception as err:
traceback.print_exc()
logger.error(err)
self.status.warnings.append(f"{self.config.service_name}.{table_name}")
self.status.failures.append(
"{}.{}".format(self.config.service_name, table_name)
)
continue
def fetch_views(
@ -426,31 +430,36 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
}
for key, mnode in manifest_entities.items():
name = mnode["alias"] if "alias" in mnode.keys() else mnode["name"]
cnode = catalog_entities.get(key)
columns = (
self._parse_data_model_columns(name, mnode, cnode) if cnode else []
)
try:
name = mnode["alias"] if "alias" in mnode.keys() else mnode["name"]
cnode = catalog_entities.get(key)
columns = (
self._parse_data_model_columns(name, mnode, cnode)
if cnode
else []
)
if mnode["resource_type"] == "test":
continue
upstream_nodes = self._parse_data_model_upstream(mnode)
model_name = (
mnode["alias"] if "alias" in mnode.keys() else mnode["name"]
)
model_name = model_name.replace(".", "_DOT_")
schema = mnode["schema"]
raw_sql = mnode.get("raw_sql", "")
model = DataModel(
modelType=ModelType.DBT,
description=mnode.get("description", ""),
path=f"{mnode['root_path']}/{mnode['original_file_path']}",
rawSql=raw_sql,
sql=mnode.get("compiled_sql", raw_sql),
columns=columns,
upstream=upstream_nodes,
)
model_fqdn = f"{schema}.{model_name}"
if mnode["resource_type"] == "test":
continue
upstream_nodes = self._parse_data_model_upstream(mnode)
model_name = (
mnode["alias"] if "alias" in mnode.keys() else mnode["name"]
)
model_name = model_name.replace(".", "_DOT_")
schema = mnode["schema"]
raw_sql = mnode.get("raw_sql", "")
model = DataModel(
modelType=ModelType.DBT,
description=mnode.get("description", ""),
path=f"{mnode['root_path']}/{mnode['original_file_path']}",
rawSql=raw_sql,
sql=mnode.get("compiled_sql", raw_sql),
columns=columns,
upstream=upstream_nodes,
)
model_fqdn = f"{schema}.{model_name}"
except Exception as err:
print(err)
self.data_models[model_fqdn] = model
def _parse_data_model_upstream(self, mnode):
@ -507,7 +516,7 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
def _get_database(self, schema: str) -> Database:
return Database(
name=schema,
name=schema.replace(".", "_DOT_"),
service=EntityReference(id=self.service.id, type=self.config.service_type),
)
@ -561,48 +570,61 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
dataset_name = f"{schema}.{table}"
table_columns = []
columns = inspector.get_columns(table, schema)
try:
for row_order, column in enumerate(inspector.get_columns(table, schema)):
if "." in column["name"]:
logger.info(f"Found '.' in {column['name']}")
column["name"] = column["name"].replace(".", "_DOT_")
children = None
data_type_display = None
col_data_length = None
arr_data_type = None
if "raw_data_type" in column and column["raw_data_type"] is not None:
(
col_type,
data_type_display,
arr_data_type,
children,
) = check_column_complex_type(
self.status,
dataset_name,
column["raw_data_type"],
column["name"],
)
else:
col_type = get_column_type(
self.status, dataset_name, column["type"]
)
if col_type == "ARRAY" and re.match(
r"(?:\w*)(?:\()(\w*)(?:.*)", str(column["type"])
):
arr_data_type = re.match(
r"(?:\w*)(?:[(]*)(\w*)(?:.*)", str(column["type"])
).groups()
data_type_display = column["type"]
col_constraint = self._get_column_constraints(
column, pk_columns, unique_columns
)
if col_type.upper() in {"CHAR", "VARCHAR", "BINARY", "VARBINARY"}:
col_data_length = column["type"].length
if col_data_length is None:
col_data_length = 1
for row_order, column in enumerate(columns):
try:
if "." in column["name"]:
logger.info(
f"Found '.' in {column['name']}, changing '.' to '_DOT_'"
)
column["name"] = column["name"].replace(".", "_DOT_")
children = None
data_type_display = None
col_data_length = None
arr_data_type = None
if (
"raw_data_type" in column
and column["raw_data_type"] is not None
):
column["raw_data_type"] = self.parse_raw_data_type(
column["raw_data_type"]
)
(
col_type,
data_type_display,
arr_data_type,
children,
) = check_column_complex_type(
self.status,
dataset_name,
column["raw_data_type"],
column["name"],
)
else:
col_type = get_column_type(
self.status, dataset_name, column["type"]
)
if col_type == "ARRAY" and re.match(
r"(?:\w*)(?:\()(\w*)(?:.*)", str(column["type"])
):
arr_data_type = re.match(
r"(?:\w*)(?:[(]*)(\w*)(?:.*)", str(column["type"])
).groups()
data_type_display = column["type"]
if repr(column["type"]).upper().startswith("ARRAY("):
arr_data_type = "STRUCT"
data_type_display = (
repr(column["type"])
.replace("(", "<")
.replace(")", ">")
.lower()
)
col_constraint = self._get_column_constraints(
column, pk_columns, unique_columns
)
if col_type.upper() in {"CHAR", "VARCHAR", "BINARY", "VARBINARY"}:
col_data_length = column["type"].length
if col_type == "NULL":
col_type = "VARCHAR"
data_type_display = "varchar"
@ -613,23 +635,24 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
name=column["name"],
description=column.get("comment", None),
dataType=col_type,
dataTypeDisplay=f"{col_type}({col_data_length})"
dataTypeDisplay="{}({})".format(
col_type, 1 if col_data_length is None else col_data_length
)
if data_type_display is None
else f"{data_type_display}",
dataLength=col_data_length,
dataLength=1 if col_data_length is None else col_data_length,
constraint=col_constraint,
ordinalPosition=row_order + 1, # enumerate starts at 0
children=children,
ordinalPosition=row_order,
children=children if children is not None else None,
arrayDataType=arr_data_type,
)
except Exception as err: # pylint: disable=broad-except
logger.error(traceback.format_exc())
except Exception as err:
logger.error(traceback.print_exc())
logger.error(f"{err} : {column}")
continue
table_columns.append(om_column)
return table_columns
except Exception as err: # pylint: disable=broad-except
except Exception as err:
logger.error(f"{repr(err)}: {table} {err}")
return None
@ -661,6 +684,9 @@ class SQLSource(Source[OMetaDatabaseAndTable]):
logger.debug(f"Finished profiling {dataset_name}")
return profile
def parse_raw_data_type(self, raw_data_type):
return raw_data_type
def _build_database_state(self, schema_fqdn: str) -> [EntityReference]:
after = None
tables = []

View File

@ -1,22 +1,26 @@
import re
from typing import Any, Dict, Optional, Set, Type
from sqlalchemy.sql import sqltypes as types
from metadata.ingestion.api.source import SourceStatus
def register_custom_type(tp: Type[types.TypeEngine], output: str = None) -> None:
if output:
_column_type_mapping[tp] = output
else:
_known_unknown_column_types.add(tp)
from sqlalchemy.sql import sqltypes as types
from sqlalchemy.types import TypeEngine
def register_custom_str_type(tp: str, output: str) -> None:
_column_string_mapping[tp] = output
def create_sqlalchemy_type(name: str):
sqlalchemy_type = type(
name,
(TypeEngine,),
{
"__repr__": lambda self: f"{name}()",
},
)
return sqlalchemy_type
_column_type_mapping: Dict[Type[types.TypeEngine], str] = {
types.Integer: "INT",
types.Numeric: "INT",
@ -123,6 +127,8 @@ _column_string_mapping = {
"XML": "BINARY",
"XMLTYPE": "BINARY",
"CURSOR": "BINARY",
"TIMESTAMP_LTZ": "TIMESTAMP",
"TIMESTAMP_TZ": "TIMESTAMP",
}
_known_unknown_column_types: Set[Type[types.TypeEngine]] = {

View File

@ -14,7 +14,6 @@ from openmetadata.common.database_common import (
DatabaseCommon,
SQLConnectionConfig,
SQLExpressions,
register_custom_type,
)
logger = logging.getLogger(__name__)