Profiler: Module to calculate measurements for a table (#933)

* Profiler: Module to calculate measurements for a table

* Profiler: refactor sql expressions

* Profiler: fix supported types

* Update README.md

Co-authored-by: Ayush Shah <ayush@getcollate.io>
This commit is contained in:
Sriharsha Chintalapani 2021-10-26 20:01:10 -07:00 committed by GitHub
parent 710675d51a
commit 7b09571a22
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 1703 additions and 4 deletions

View File

@ -15,10 +15,9 @@
from typing import Dict, List, Optional
from pydantic import BaseModel
from metadata.generated.schema.entity.data.table import ColumnJoins
from metadata.ingestion.models.json_serializable import JsonSerializable
from pydantic import BaseModel
class TableQuery(JsonSerializable):

View File

@ -18,13 +18,12 @@ import logging
import traceback
from typing import Optional
from sql_metadata import Parser
from metadata.config.common import ConfigModel
from metadata.ingestion.api.common import WorkflowContext
from metadata.ingestion.api.processor import Processor, ProcessorStatus
from metadata.ingestion.models.table_queries import QueryParserData, TableQuery
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
from sql_metadata import Parser
class QueryParserProcessorConfig(ConfigModel):

0
profiler/CHANGELOG Normal file
View File

25
profiler/README.md Normal file
View File

@ -0,0 +1,25 @@
---
This guide will help you setup the Data Profiler
---
![Python version 3.8+](https://img.shields.io/badge/python-3.8%2B-blue)
OpenMetadata Data profiler to collect measurements on various data sources
publishes them to OpenMetadata. This framework runs the tests and profiler
**Prerequisites**
- Python &gt;= 3.8.x
### Install From PyPI
```text
python3 -m pip install --upgrade pip wheel setuptools openmetadata-dataprofiler
```
#### Generate Redshift Data
```text
profiler test -c ./examples/workflows/redshift.json
```

View File

@ -0,0 +1,11 @@
profiler:
type: redshift
config:
sql_connection:
host_port: host:port
username: username
password: password
database: warehouse
db_schema: public
service_name: redshift
table_name: sales

14
profiler/requirements.txt Normal file
View File

@ -0,0 +1,14 @@
pip~=21.3
PyYAML~=6.0
MarkupSafe~=2.0.1
urllib3~=1.26.7
six~=1.16.0
python-dateutil~=2.8.2
greenlet~=1.1.2
psycopg2~=2.9.1
setuptools~=58.2.0
SQLAlchemy~=1.4.26
click~=8.0.3
Jinja2~=3.0.2
expandvars~=0.7.0
pydantic~=1.8.2

49
profiler/setup.cfg Normal file
View File

@ -0,0 +1,49 @@
[flake8]
# We ignore the line length issues here, since black will take care of them.
max-line-length = 150
max-complexity = 15
ignore =
# Ignore: 1 blank line required before class docstring.
D203,
W503
exclude =
.git,
__pycache__
per-file-ignores =
# imported but unused
__init__.py: F401
[metadata]
license_files = LICENSE
[mypy]
mypy_path = src
plugins =
sqlmypy,
pydantic.mypy
ignore_missing_imports = yes
namespace_packages = true
strict_optional = yes
check_untyped_defs = yes
# eventually we'd like to enable these
disallow_untyped_defs = no
disallow_incomplete_defs = no
[isort]
profile = black
indent=' '
sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
[tool:pytest]
addopts = --cov src --cov-report term --cov-config setup.cfg --strict-markers
markers =
slow: marks tests as slow (deselect with '-m "not slow"')
testpaths =
tests/unit
[options]
packages = find:
package_dir =
=src
[options.packages.find]
where = src
include = *

96
profiler/setup.py Normal file
View File

@ -0,0 +1,96 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Dict, Set
from setuptools import find_namespace_packages, setup
def get_version():
root = os.path.dirname(__file__)
changelog = os.path.join(root, "CHANGELOG")
with open(changelog) as f:
return f.readline().strip()
def get_long_description():
root = os.path.dirname(__file__)
with open(os.path.join(root, "README.md")) as f:
description = f.read()
description += "\n\nChangelog\n=========\n\n"
with open(os.path.join(root, "CHANGELOG")) as f:
description += f.read()
return description
base_requirements = {
"sqlalchemy>=1.3.24",
"Jinja2>=2.11.3, <3.0",
"click>=7.1.2, <8.0",
"cryptography==3.3.2",
"pyyaml>=5.4.1, <6.0",
"requests>=2.23.0, <3.0" "idna<3,>=2.5",
"click<7.2.0,>=7.1.1",
"expandvars>=0.6.5"
"dataclasses>=0.8"
"typing_extensions>=3.7.4"
"mypy_extensions>=0.4.3",
"typing-inspect",
"pydantic==1.7.4",
"pymysql>=1.0.2",
"GeoAlchemy2",
"psycopg2-binary>=2.8.5, <3.0",
"openmetadata-sqlalchemy-redshift==0.2.1",
}
plugins: Dict[str, Set[str]] = {
"redshift": {
"openmetadata-sqlalchemy-redshift==0.2.1",
"psycopg2-binary",
"GeoAlchemy2",
},
"postgres": {"pymysql>=1.0.2", "psycopg2-binary", "GeoAlchemy2"},
}
build_options = {"includes": ["_cffi_backend"]}
setup(
name="openmetadata-data-profiler",
version="0.1",
url="https://open-metadata.org/",
author="OpenMetadata Committers",
license="Apache License 2.0",
description="Data Profiler and Testing Framework for OpenMetadata",
long_description=get_long_description(),
long_description_content_type="text/markdown",
python_requires=">=3.8",
options={"build_exe": build_options},
package_dir={"": "src"},
zip_safe=False,
dependency_links=[],
project_urls={
"Documentation": "https://docs.open-metadata.org/",
"Source": "https://github.com/open-metadata/OpenMetadata",
},
packages=find_namespace_packages(where="./src", exclude=["tests*"]),
entry_points={
"console_scripts": ["openmetadata = openmetadata.cmd:openmetadata"],
},
install_requires=list(base_requirements),
extras_require={
"base": list(base_requirements),
**{plugin: list(dependencies) for (plugin, dependencies) in plugins.items()},
"all": list(
base_requirements.union(
*[requirements for plugin, requirements in plugins.items()]
)
),
},
)

View File

View File

@ -0,0 +1,78 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import pathlib
import sys
import click
from pydantic import ValidationError
from openmetadata.common.config import load_config_file
from openmetadata.profiler.profiler_metadata import ProfileResult
from openmetadata.profiler.profiler_runner import ProfilerRunner
logger = logging.getLogger(__name__)
# Configure logger.
BASE_LOGGING_FORMAT = (
"[%(asctime)s] %(levelname)-8s {%(name)s:%(lineno)d} - %(message)s"
)
logging.basicConfig(format=BASE_LOGGING_FORMAT)
@click.group()
def check() -> None:
pass
@click.group()
@click.option("--debug/--no-debug", default=False)
def openmetadata(debug: bool) -> None:
if debug:
logging.getLogger().setLevel(logging.INFO)
logging.getLogger("openmetadata").setLevel(logging.DEBUG)
else:
logging.getLogger().setLevel(logging.WARNING)
logging.getLogger("openmetadata").setLevel(logging.INFO)
@openmetadata.command()
@click.option(
"-c",
"--config",
type=click.Path(exists=True, dir_okay=False),
help="Profiler config",
required=True,
)
def profiler(config: str) -> None:
"""Main command for running data openmetadata and tests"""
try:
config_file = pathlib.Path(config)
profiler_config = load_config_file(config_file)
try:
logger.info(f"Using config: {profiler_config}")
profiler_runner = ProfilerRunner.create(profiler_config)
except ValidationError as e:
click.echo(e, err=True)
sys.exit(1)
logger.info(f"Running Profiler for {profiler_runner.table_name} ...")
profile_result: ProfileResult = profiler_runner.execute()
logger.info(f"Profiler Results")
logger.info(f"{profile_result.json()}")
except Exception as e:
logger.exception(f"Scan failed: {str(e)}")
logger.info(f"Exiting with code 1")
sys.exit(1)
openmetadata.add_command(check)

View File

@ -0,0 +1,123 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import json
import pathlib
from abc import ABC, abstractmethod
from typing import IO, Any, List, Optional
import yaml
from expandvars import expandvars
from pydantic import BaseModel
class ConfigModel(BaseModel):
class Config:
extra = "forbid"
class DynamicTypedConfig(ConfigModel):
type: str
config: Optional[Any]
class WorkflowExecutionError(Exception):
"""An error occurred when executing the workflow"""
class ConfigurationError(Exception):
"""A configuration error has happened"""
class ConfigurationMechanism(ABC):
@abstractmethod
def load_config(self, config_fp: IO) -> dict:
pass
class YamlConfigurationMechanism(ConfigurationMechanism):
"""load configuration from yaml files"""
def load_config(self, config_fp: IO) -> dict:
config = yaml.safe_load(config_fp)
return config
class JsonConfigurationMechanism(ConfigurationMechanism):
"""load configuration from yaml files"""
def load_config(self, config_fp: IO) -> dict:
config = json.load(config_fp)
return config
def load_config_file(config_file: pathlib.Path) -> dict:
if not config_file.is_file():
raise ConfigurationError(f"Cannot open config file {config_file}")
config_mech: ConfigurationMechanism
if config_file.suffix in [".yaml", ".yml"]:
config_mech = YamlConfigurationMechanism()
elif config_file.suffix == ".json":
config_mech = JsonConfigurationMechanism()
else:
raise ConfigurationError(
"Only .json and .yml are supported. Cannot process file type {}".format(
config_file.suffix
)
)
with config_file.open() as raw_config_file:
raw_config = raw_config_file.read()
expanded_config_file = expandvars(raw_config, nounset=True)
config_fp = io.StringIO(expanded_config_file)
config = config_mech.load_config(config_fp)
return config
class IncludeFilterPattern(ConfigModel):
"""A class to store allow deny regexes"""
includes: List[str] = [".*"]
excludes: List[str] = []
alphabet: str = "[A-Za-z0-9 _.-]"
@property
def alphabet_pattern(self):
return re.compile(f"^{self.alphabet}+$")
@classmethod
def allow_all(cls):
return IncludeFilterPattern()
def included(self, string: str) -> bool:
try:
for exclude in self.excludes:
if re.match(exclude, string):
return False
for include in self.includes:
if re.match(include, string):
return True
return False
except Exception as err:
raise Exception("Regex Error: {}".format(err))
def is_fully_specified_include_list(self) -> bool:
for filter_pattern in self.includes:
if not self.alphabet_pattern.match(filter_pattern):
return False
return True
def get_allowed_list(self):
assert self.is_fully_specified_include_list()
return [a for a in self.includes if self.included(a)]

View File

@ -0,0 +1,72 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import List
class Closeable:
@abstractmethod
def close(self):
pass
@dataclass
class Database(Closeable, metaclass=ABCMeta):
@classmethod
@abstractmethod
def create(cls, config_dict: dict) -> "Database":
pass
@property
@abstractmethod
def sql_exprs(self):
pass
@abstractmethod
def table_metadata_query(self, table_name: str) -> str:
pass
@abstractmethod
def qualify_table_name(self, table_name: str) -> str:
return table_name
@abstractmethod
def qualify_column_name(self, column_name: str):
return column_name
@abstractmethod
def is_text(self, column_type: str):
pass
@abstractmethod
def is_number(self, column_type: str):
pass
@abstractmethod
def is_time(self, column_type: str):
pass
@abstractmethod
def sql_fetchone(self, sql) -> tuple:
pass
@abstractmethod
def sql_fetchone_description(self, sql) -> tuple:
pass
@abstractmethod
def sql_fetchall(self, sql) -> List[tuple]:
pass
@abstractmethod
def sql_fetchall_description(self, sql) -> tuple:
pass

View File

@ -0,0 +1,328 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
import re
from abc import abstractmethod
from datetime import date, datetime
from numbers import Number
from typing import List, Optional
from urllib.parse import quote_plus
from pydantic import BaseModel
from sqlalchemy import create_engine
from openmetadata.common.config import ConfigModel, IncludeFilterPattern
from openmetadata.common.database import Database
from openmetadata.profiler.profiler_metadata import Column, SupportedDataType
logger = logging.getLogger(__name__)
class SQLConnectionConfig(ConfigModel):
username: Optional[str] = None
password: Optional[str] = None
host_port: str
database: Optional[str] = None
db_schema: Optional[str] = None
scheme: str
service_name: str
service_type: str
options: dict = {}
profiler_date: Optional[str] = datetime.now().strftime("%Y-%m-%d")
profiler_offset: Optional[int] = 0
profiler_limit: Optional[int] = 50000
filter_pattern: IncludeFilterPattern = IncludeFilterPattern.allow_all()
@abstractmethod
def get_connection_url(self):
url = f"{self.scheme}://"
if self.username is not None:
url += f"{quote_plus(self.username)}"
if self.password is not None:
url += f":{quote_plus(self.password)}"
url += "@"
url += f"{self.host_port}"
if self.database:
url += f"/{self.database}"
logger.info(url)
return url
_numeric_types = [
"SMALLINT",
"INTEGER",
"BIGINT",
"DECIMAL",
"NUMERIC",
"REAL",
"DOUBLE PRECISION",
"SMALLSERIAL",
"SERIAL",
"BIGSERIAL",
]
_text_types = ["CHARACTER VARYING", "CHARACTER", "CHAR", "VARCHAR" "TEXT"]
_time_types = [
"TIMESTAMP",
"DATE",
"TIME",
"TIMESTAMP WITH TIME ZONE",
"TIMESTAMP WITHOUT TIME ZONE",
"TIME WITH TIME ZONE",
"TIME WITHOUT TIME ZONE",
]
def register_custom_type(
data_types: List[str], type_category: SupportedDataType
) -> None:
if type_category == SupportedDataType.TIME:
_time_types.extend(data_types)
elif type_category == SupportedDataType.TEXT:
_text_types.extend(data_types)
elif type_category == SupportedDataType.NUMERIC:
_numeric_types.extend(data_types)
else:
raise Exception(f"Unsupported {type_category}")
class SQLExpressions(BaseModel):
count_all_expr: str = "COUNT(*)"
count_expr: str = "COUNT({})"
distinct_expr: str = "DISTINCT({})"
min_expr: str = "MIN({})"
max_expr: str = "MAX({})"
length_expr: str = "LENGTH({})"
avg_expr: str = "AVG({})"
sum_expr: str = "SUM({})"
variance_expr: str = "VARIANCE({})"
stddev_expr: str = "STDDEV({})"
limit_expr: str = "LIMIT {}"
count_conditional_expr: str = "COUNT(CASE WHEN {} THEN 1 END)"
conditional_expr: str = "CASE WHEN {} THEN {} END"
equal_expr: str = "{} == {}"
less_than_expr: str = "{} < {}"
less_than_or_equal_expr: str = "{} <= {}"
greater_than_expr: str = "{} > {}"
greater_than_or_equal_expr: str = "{} >= {}"
var_in_expr: str = "{} in {}"
regex_like_pattern_expr: str = "REGEXP_LIKE({}, '{}')"
contains_expr: str = "{} LIKE '%{}%'"
starts_with_expr: str = "{} LIKE '%{}'"
ends_with_expr: str = "{} LIKE '{}%'"
@staticmethod
def escape_metacharacters(value: str):
return re.sub(r"(\\.)", r"\\\1", value)
def literal_number(self, value: Number):
if value is None:
return None
return str(value)
def literal_string(self, value: str):
if value is None:
return None
return "'" + self.escape_metacharacters(value) + "'"
def literal_list(self, l: list):
if l is None:
return None
return "(" + (",".join([self.literal(e) for e in l])) + ")"
def count(self, expr: str):
return self.count_expr.format(expr)
def distinct(self, expr: str):
return self.distinct_expr.format(expr)
def min(self, expr: str):
return self.min_expr.format(expr)
def max(self, expr: str):
return self.max_expr.format(expr)
def length(self, expr: str):
return self.length_expr.format(expr)
def avg(self, expr: str):
return self.avg_expr.format(expr)
def sum(self, expr: str):
return self.sum_expr.format(expr)
def variance(self, expr: str):
return self.variance_expr.format(expr)
def stddev(self, expr: str):
return self.stddev_expr.format(expr)
def limit(self, expr: str):
return self.limit_expr.format(expr)
def regex_like(self, expr: str, pattern: str):
return self.regex_like_pattern_expr.format(expr, pattern)
def equal(self, left: str, right: str):
if right == "null":
return f"{left} IS NULL"
else:
return self.equal_expr.format(right, left)
def less_than(self, left, right):
return self.less_than_expr.format(left, right)
def less_than_or_equal(self, left, right):
return self.less_than_or_equal_expr.format(left, right)
def greater_than(self, left, right):
return self.greater_than_expr.format(left, right)
def greater_than_or_equal(self, left, right):
return self.greater_than_or_equal_expr.format(left, right)
def var_in(self, left, right):
return self.var_in_expr.format(left, right)
def contains(self, value, substring):
return self.contains_expr.format(value, substring)
def starts_with(self, value, substring):
return self.starts_with_expr.format(value, substring)
def ends_with(self, value, substring):
return self.ends_with_expr.format(value, substring)
def count_conditional(self, condition: str):
return self.count_conditional_expr.format(condition)
def conditional(self, condition: str, expr: str):
return self.conditional_expr.format(condition, expr)
def literal_date_expr(self, date_expr: date):
date_string = date_expr.strftime("%Y-%m-%d")
return f"DATE '{date_string}'"
def literal(self, o: object):
if isinstance(o, Number):
return self.literal_number(o)
elif isinstance(o, str):
return self.literal_string(o)
elif isinstance(o, list) or isinstance(o, set) or isinstance(o, tuple):
return self.literal_list(o)
raise RuntimeError(f"Cannot convert type {type(o)} to a SQL literal: {o}")
def list_expr(self, column: Column, values: List[str]) -> str:
if column.is_text():
sql_values = [self.literal_string(value) for value in values]
elif column.is_number():
sql_values = [self.literal_number(value) for value in values]
else:
raise RuntimeError(
f"Couldn't format list {str(values)} for column {str(column)}"
)
return "(" + ",".join(sql_values) + ")"
class DatabaseCommon(Database):
data_type_varchar_255 = "VARCHAR(255)"
data_type_integer = "INTEGER"
data_type_bigint = "BIGINT"
data_type_decimal = "REAL"
data_type_date = "DATE"
config: SQLConnectionConfig = None
sql_exprs: SQLExpressions = SQLExpressions()
def __init__(self, config: SQLConnectionConfig):
self.config = config
self.connection_string = self.config.get_connection_url()
self.engine = create_engine(self.connection_string, **self.config.options)
self.connection = self.engine.raw_connection()
@classmethod
def create(cls, config_dict: dict):
pass
def table_metadata_query(self, table_name: str) -> str:
pass
def qualify_table_name(self, table_name: str) -> str:
return table_name
def qualify_column_name(self, column_name: str):
return column_name
def is_text(self, column_type: str):
return column_type.upper() in _text_types
def is_number(self, column_type: str):
return column_type.upper() in _numeric_types
def is_time(self, column_type: str):
return column_type.upper() in _time_types
def sql_fetchone(self, sql: str) -> tuple:
"""
Only returns the tuple obtained by cursor.fetchone()
"""
return self.sql_fetchone_description(sql)[0]
def sql_fetchone_description(self, sql: str) -> tuple:
"""
Returns a tuple with 2 elements:
1) the tuple obtained by cursor.fetchone()
2) the cursor.description
"""
cursor = self.connection.cursor()
try:
logger.debug(f"Executing SQL query: \n{sql}")
start = datetime.now()
cursor.execute(sql)
row_tuple = cursor.fetchone()
description = cursor.description
delta = datetime.now() - start
logger.debug(f"SQL took {str(delta)}")
return row_tuple, description
finally:
cursor.close()
def sql_fetchall(self, sql: str) -> List[tuple]:
"""
Only returns the tuples obtained by cursor.fetchall()
"""
return self.sql_fetchall_description(sql)[0]
def sql_fetchall_description(self, sql: str) -> tuple:
"""
Returns a tuple with 2 elements:
1) the tuples obtained by cursor.fetchall()
2) the cursor.description
"""
cursor = self.connection.cursor()
try:
logger.debug(f"Executing SQL query: \n{sql}")
start = datetime.now()
cursor.execute(sql)
rows = cursor.fetchall()
delta = datetime.now() - start
logger.debug(f"SQL took {str(delta)}")
return rows, cursor.description
finally:
cursor.close()
def close(self):
if self.connection:
try:
self.connection.close()
except Exception as e:
logger.error(f"Closing connection failed: {str(e)}")

View File

@ -0,0 +1,63 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class Metric:
ROW_COUNT = "row_count"
AVG = "avg"
AVG_LENGTH = "avg_length"
DISTINCT = "distinct"
DUPLICATE_COUNT = "duplicate_count"
FREQUENT_VALUES = "frequent_values"
HISTOGRAM = "histogram"
INVALID_COUNT = "invalid_count"
INVALID_PERCENTAGE = "invalid_percentage"
MAX = "max"
MAX_LENGTH = "max_length"
MIN = "min"
MIN_LENGTH = "min_length"
MISSING_COUNT = "missing_count"
MISSING_PERCENTAGE = "missing_percentage"
STDDEV = "stddev"
SUM = "sum"
UNIQUENESS = "uniqueness"
UNIQUE_COUNT = "unique_count"
VALID_COUNT = "valid_count"
VALID_PERCENTAGE = "valid_percentage"
VALUES_COUNT = "values_count"
VALUES_PERCENTAGE = "values_percentage"
VARIANCE = "variance"
METRIC_TYPES = [
ROW_COUNT,
AVG,
AVG_LENGTH,
DISTINCT,
DUPLICATE_COUNT,
FREQUENT_VALUES,
HISTOGRAM,
INVALID_COUNT,
INVALID_PERCENTAGE,
MAX,
MAX_LENGTH,
MIN,
MIN_LENGTH,
MISSING_COUNT,
MISSING_PERCENTAGE,
STDDEV,
SUM,
UNIQUENESS,
UNIQUE_COUNT,
VALID_COUNT,
VALID_PERCENTAGE,
VALUES_COUNT,
VALUES_PERCENTAGE,
VARIANCE,
]

View File

@ -0,0 +1,64 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from openmetadata.common.database_common import (
DatabaseCommon,
SQLConnectionConfig,
SQLExpressions,
register_custom_type,
)
logger = logging.getLogger(__name__)
class PostgresConnectionConfig(SQLConnectionConfig):
scheme = "postgres+psycopg2"
def get_connection_url(self):
return super().get_connection_url()
class PostgresSQLExpressions(SQLExpressions):
regex_like_pattern_expr: str = "{} ~* '{}'"
class Postgres(DatabaseCommon):
config: PostgresConnectionConfig = None
sql_exprs: PostgresSQLExpressions = PostgresSQLExpressions()
def __init__(self, config):
super().__init__(config)
self.config = config
@classmethod
def create(cls, config_dict):
config = PostgresConnectionConfig.parse_obj(config_dict)
return cls(config)
def table_metadata_query(self, table_name: str) -> str:
sql = (
f"SELECT column_name, data_type, is_nullable \n"
f"FROM information_schema.columns \n"
f"WHERE lower(table_name) = '{table_name}'"
)
if self.config.database:
sql += f" \n AND table_catalog = '{self.config.database}'"
if self.config.db_schema:
sql += f" \n AND table_schema = '{self.config.db_schema}'"
return sql
def qualify_table_name(self, table_name: str) -> str:
if self.config.db_schema:
return f'"{self.config.db_schema}"."{table_name}"'
return f'"{table_name}"'
def qualify_column_name(self, column_name: str):
return f'"{column_name}"'

View File

@ -0,0 +1,108 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from openmetadata.common.database_common import (
DatabaseCommon,
SQLConnectionConfig,
SQLExpressions,
register_custom_type,
)
from openmetadata.profiler.profiler_metadata import SupportedDataType
register_custom_type(
[
"CHAR",
"CHARACTER",
"BPCHAR",
"VARCHAR",
"CHARACTER VARYING",
"NVARCHAR",
"TEXT",
],
SupportedDataType.TEXT,
)
register_custom_type(
[
"SMALLINT",
"INT2",
"INTEGER",
"INT",
"INT4",
"BIGINT",
"INT8",
"DECIMAL",
"NUMERIC",
"REAL",
"FLOAT4",
"DOUBLE PRECISION",
"FLOAT8",
"FLOAT",
],
SupportedDataType.NUMERIC,
)
register_custom_type(
[
"DATE",
"TIMESTAMP",
"TIMESTAMP WITHOUT TIME ZONE",
"TIMESTAMPTZ",
"TIMESTAMP WITH TIME ZONE",
"TIME",
"TIME WITHOUT TIME ZONE",
"TIMETZ",
"TIME WITH TIME ZONE",
],
SupportedDataType.TIME,
)
class RedshiftConnectionConfig(SQLConnectionConfig):
scheme = "redshift+psycopg2"
where_clause: Optional[str] = None
duration: int = 1
service_type = "Redshift"
def get_connection_url(self):
return super().get_connection_url()
class RedshiftSQLExpressions(SQLExpressions):
avg_expr = "AVG({})"
sum_expr = "SUM({})"
regex_like_pattern_expr: str = "{} ~* '{}'"
class Redshift(DatabaseCommon):
config: RedshiftConnectionConfig = None
sql_exprs: RedshiftSQLExpressions = RedshiftSQLExpressions()
def __init__(self, config):
super().__init__(config)
self.config = config
@classmethod
def create(cls, config_dict):
config = RedshiftConnectionConfig.parse_obj(config_dict)
return cls(config)
def table_metadata_query(self, table_name: str) -> str:
sql = (
f"SELECT column_name, data_type, is_nullable \n"
f"FROM information_schema.columns \n"
f"WHERE lower(table_name) = '{table_name}'"
)
if self.config.database:
sql += f" \n AND table_catalog = '{self.config.database}'"
if self.config.db_schema:
sql += f" \n AND table_schema = '{self.config.db_schema}'"
return sql

View File

@ -0,0 +1,18 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
ERROR_CODE_GENERIC = "generic_error"
class ProfilerSqlError(Exception):
def __init__(self, msg, original_exception):
super(ProfilerSqlError, self).__init__(f"{msg}: {str(original_exception)}")
self.error_code = ERROR_CODE_GENERIC
self.original_exception = original_exception

View File

@ -0,0 +1,447 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from datetime import datetime
from math import ceil, floor
from typing import List
from openmetadata.common.database import Database
from openmetadata.common.metric import Metric
from openmetadata.profiler.profiler_metadata import (
Column,
ColumnProfileResult,
MetricMeasurement,
ProfileResult,
SupportedDataType,
Table,
TableProfileResult,
get_group_by_cte,
get_group_by_cte_numeric_value_expression,
)
logger = logging.getLogger(__name__)
class Profiler:
def __init__(
self,
database: Database,
table_name: str,
excluded_columns: List[str] = [],
profile_time: str = None,
):
self.database = database
self.table = Table(name=table_name)
self.excluded_columns = excluded_columns
self.time = profile_time
self.qualified_table_name = self.database.qualify_table_name(table_name)
self.scan_reference = None
self.columns: List[Column] = []
self.start_time = None
self.queries_executed = 0
self.profiler_result = ProfileResult(
profile_date=self.time,
table_result=TableProfileResult(name=self.table.name),
)
def execute(self) -> ProfileResult:
self.start_time = datetime.now()
try:
self._table_metadata()
self._profile_aggregations()
self._query_group_by_value()
self._query_histograms()
logger.debug(
f"Executed {self.queries_executed} queries in {(datetime.now() - self.start_time)}"
)
except Exception as e:
logger.exception("Exception during scan")
finally:
self.database.close()
return self.profiler_result
def _table_metadata(self):
sql = self.database.table_metadata_query(self.table.name)
columns = self.database.sql_fetchall(sql)
self.queries_executed += 1
self.table_columns = []
for column in columns:
name = column[0]
data_type = column[1]
nullable = "YES" == column[2].upper()
if self.database.is_number(data_type):
logical_type = SupportedDataType.NUMERIC
elif self.database.is_time(data_type):
logical_type = SupportedDataType.TIME
elif self.database.is_text(data_type):
logical_type = SupportedDataType.TEXT
else:
logger.info(f" {name} ({data_type}) not supported.")
continue
self.columns.append(
Column(
name=name,
data_type=data_type,
nullable=nullable,
logical_type=logical_type,
)
)
self.column_names: List[str] = [column.name for column in self.columns]
logger.debug(str(len(self.columns)) + " columns:")
self.profiler_result.table_result.col_count = len(self.columns)
def _profile_aggregations(self):
measurements: List[MetricMeasurement] = []
fields: List[str] = []
# Compute Row Count
fields.append(self.database.sql_exprs.count_all_expr)
measurements.append(MetricMeasurement(name=Metric.ROW_COUNT, col_name="table"))
column_metric_indices = {}
try:
for column in self.columns:
metric_indices = {}
column_metric_indices[column.name.lower()] = metric_indices
column_name = column.name
qualified_column_name = self.database.qualify_column_name(column_name)
## values_count
metric_indices["non_missing"] = len(measurements)
fields.append(self.database.sql_exprs.count(qualified_column_name))
measurements.append(
MetricMeasurement(name=Metric.VALUES_COUNT, col_name=column_name)
)
# Valid Count
fields.append(self.database.sql_exprs.count(qualified_column_name))
measurements.append(
MetricMeasurement(name=Metric.VALID_COUNT, col_name=column_name)
)
if column.logical_type == SupportedDataType.TEXT:
length_expr = self.database.sql_exprs.length(qualified_column_name)
fields.append(self.database.sql_exprs.avg(length_expr))
measurements.append(
MetricMeasurement(name=Metric.AVG_LENGTH, col_name=column_name)
)
# Min Length
fields.append(self.database.sql_exprs.min(length_expr))
measurements.append(
MetricMeasurement(name=Metric.MIN_LENGTH, col_name=column_name)
)
# Max Length
fields.append(self.database.sql_exprs.max(length_expr))
measurements.append(
MetricMeasurement(name=Metric.MAX_LENGTH, col_name=column_name)
)
if column.logical_type == SupportedDataType.NUMERIC:
# Min
fields.append(self.database.sql_exprs.min(qualified_column_name))
measurements.append(
MetricMeasurement(name=Metric.MIN, col_name=column_name)
)
# Max
fields.append(self.database.sql_exprs.max(qualified_column_name))
measurements.append(
MetricMeasurement(name=Metric.MAX, col_name=column_name)
)
# AVG
fields.append(self.database.sql_exprs.avg(qualified_column_name))
measurements.append(
MetricMeasurement(name=Metric.AVG, col_name=column_name)
)
# SUM
fields.append(self.database.sql_exprs.sum(qualified_column_name))
measurements.append(
MetricMeasurement(name=Metric.SUM, col_name=column_name)
)
# VARIANCE
fields.append(
self.database.sql_exprs.variance(qualified_column_name)
)
measurements.append(
MetricMeasurement(name=Metric.VARIANCE, col_name=column_name)
)
# STDDEV
fields.append(self.database.sql_exprs.stddev(qualified_column_name))
measurements.append(
MetricMeasurement(name=Metric.STDDEV, col_name=column_name)
)
if len(fields) > 0:
sql = (
"SELECT \n " + ",\n ".join(fields) + " \n"
"FROM " + self.qualified_table_name
)
query_result_tuple = self.database.sql_fetchone(sql)
self.queries_executed += 1
for i in range(0, len(measurements)):
measurement = measurements[i]
measurement.value = query_result_tuple[i]
self._add_measurement(measurement)
# Calculating derived measurements
row_count_measurement = next(
(m for m in measurements if m.name == Metric.ROW_COUNT), None
)
if row_count_measurement:
row_count = row_count_measurement.value
self.profiler_result.table_result.row_count = row_count
for column in self.columns:
column_name = column.name
metric_indices = column_metric_indices[column_name.lower()]
non_missing_index = metric_indices.get("non_missing")
if non_missing_index is not None:
values_count = measurements[non_missing_index].value
missing_count = row_count - values_count
missing_percentage = (
missing_count * 100 / row_count
if row_count > 0
else None
)
values_percentage = (
values_count * 100 / row_count
if row_count > 0
else None
)
self._add_measurement(
MetricMeasurement(
name=Metric.MISSING_PERCENTAGE,
col_name=column_name,
value=missing_percentage,
)
)
self._add_measurement(
MetricMeasurement(
name=Metric.MISSING_COUNT,
col_name=column_name,
value=missing_count,
)
)
self._add_measurement(
MetricMeasurement(
name=Metric.VALUES_PERCENTAGE,
col_name=column_name,
value=values_percentage,
)
)
valid_index = metric_indices.get("valid")
if valid_index is not None:
valid_count = measurements[valid_index].value
invalid_count = row_count - missing_count - valid_count
invalid_percentage = (
invalid_count * 100 / row_count
if row_count > 0
else None
)
valid_percentage = (
valid_count * 100 / row_count
if row_count > 0
else None
)
self._add_measurement(
MetricMeasurement(
name=Metric.INVALID_PERCENTAGE,
col_name=column_name,
value=invalid_percentage,
)
)
self._add_measurement(
MetricMeasurement(
name=Metric.INVALID_COUNT,
col_name=column_name,
value=invalid_count,
)
)
self._add_measurement(
MetricMeasurement(
name=Metric.VALID_PERCENTAGE,
col_name=column_name,
value=valid_percentage,
)
)
except Exception as e:
logger.error(f"Exception during aggregation query", exc_info=e)
def _query_group_by_value(self):
for column in self.columns:
try:
measurements = []
column_name = column.name
group_by_cte = get_group_by_cte(
self.database.qualify_column_name(column.name),
self.database.qualify_table_name(self.table.name),
)
## Compute Distinct, Unique, Unique_Count, Duplicate_count
sql = (
f"{group_by_cte} \n"
f"SELECT COUNT(*), \n"
f" COUNT(CASE WHEN frequency = 1 THEN 1 END), \n"
f" SUM(frequency) \n"
f"FROM group_by_value"
)
query_result_tuple = self.database.sql_fetchone(sql)
self.queries_executed += 1
distinct_count = query_result_tuple[0]
unique_count = query_result_tuple[1]
valid_count = query_result_tuple[2] if query_result_tuple[2] else 0
duplicate_count = distinct_count - unique_count
self._add_measurement(
MetricMeasurement(
name=Metric.DISTINCT, col_name=column_name, value=distinct_count
)
)
self._add_measurement(
MetricMeasurement(
name=Metric.UNIQUE_COUNT,
col_name=column_name,
value=unique_count,
)
)
self._add_measurement(
MetricMeasurement(
name=Metric.DUPLICATE_COUNT,
col_name=column_name,
value=duplicate_count,
)
)
if valid_count > 1:
uniqueness = (distinct_count - 1) * 100 / (valid_count - 1)
self._add_measurement(
MetricMeasurement(
name=Metric.UNIQUENESS,
col_name=column_name,
value=uniqueness,
)
)
except Exception as e:
logger.error(
f"Exception during column group by value queries", exc_info=e
)
def _query_histograms(self):
for column in self.columns:
column_name = column.name
try:
if column.is_number():
measurements = []
buckets: int = 20
column_results = self.profiler_result.columns_result[column_name]
min_value = column_results.measurements.get(Metric.MIN).value
max_value = column_results.measurements.get(Metric.MAX).value
if (
column.is_number()
and min_value
and max_value
and min_value < max_value
):
min_value = floor(min_value * 1000) / 1000
max_value = ceil(max_value * 1000) / 1000
bucket_width = (max_value - min_value) / buckets
boundary = min_value
boundaries = [min_value]
for i in range(0, buckets):
boundary += bucket_width
boundaries.append(round(boundary, 3))
group_by_cte = get_group_by_cte(column_name, self.table.name)
numeric_value_expr = get_group_by_cte_numeric_value_expression(
column, self.database, None
)
field_clauses = []
for i in range(0, buckets):
lower_bound = (
""
if i == 0
else f"{boundaries[i]} <= {numeric_value_expr}"
)
upper_bound = (
""
if i == buckets - 1
else f"{numeric_value_expr} < {boundaries[i + 1]}"
)
optional_and = (
""
if lower_bound == "" or upper_bound == ""
else " and "
)
field_clauses.append(
f"SUM(CASE WHEN {lower_bound}{optional_and}{upper_bound} THEN frequency END)"
)
fields = ",\n ".join(field_clauses)
sql = (
f"{group_by_cte} \n"
f"SELECT \n"
f" {fields} \n"
f"FROM group_by_value"
)
row = self.database.sql_fetchone(sql)
self.queries_executed += 1
# Process the histogram query
frequencies = []
for i in range(0, buckets):
frequency = row[i]
frequencies.append(0 if not frequency else int(frequency))
histogram = {
"boundaries": boundaries,
"frequencies": frequencies,
}
self._add_measurement(
MetricMeasurement(
name=Metric.HISTOGRAM,
col_name=column_name,
value=histogram,
)
)
except Exception as e:
logger.error(f"Exception during aggregation query", exc_info=e)
def _add_measurement(self, measurement):
logger.debug(f"measurement: {measurement}")
if measurement.col_name in self.profiler_result.columns_result.keys():
col_result = self.profiler_result.columns_result[measurement.col_name]
col_measurements = col_result.measurements
col_measurements[measurement.name] = measurement
self.profiler_result.columns_result[
measurement.col_name
].measurements = col_measurements
else:
col_measurements = {measurement.name: measurement}
self.profiler_result.columns_result[
measurement.col_name
] = ColumnProfileResult(
name=measurement.col_name, measurements=col_measurements
)

View File

@ -0,0 +1,120 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from typing import Dict, List, Optional
from pydantic import BaseModel
def get_group_by_cte(column_name, table_name):
where = f"{column_name} IS NOT NULL"
return (
f"WITH group_by_value AS ( \n"
f" SELECT \n"
f" {column_name} AS value, \n"
f" COUNT(*) AS frequency \n"
f" FROM {table_name} \n"
f" WHERE {where} \n"
f" GROUP BY {column_name} \n"
f")"
)
def get_group_by_cte_numeric_value_expression(column, database, validity_format):
if column.is_number or column.is_time:
return "value"
if column.is_column_numeric_text_format:
return database.sql_expr_cast_text_to_number("value", validity_format)
def get_order_by_cte_value_expression(
column, database, validity_format, numeric_value_expr: str
):
if column.is_number or column.is_time:
return "value"
if column.is_column_numeric_text_format:
return database.sql_expr_cast_text_to_number("value", validity_format)
elif column.is_text:
return "value"
return None
class SupportedDataType(Enum):
NUMERIC = 1
TEXT = 2
TIME = 3
class Column(BaseModel):
"""Column Metadata"""
name: str
nullable: bool = None
data_type: str
logical_type: SupportedDataType
def is_text(self) -> bool:
return self.logical_type == SupportedDataType.TEXT
def is_number(self) -> bool:
return self.logical_type == SupportedDataType.NUMERIC
def is_time(self) -> bool:
return self.logical_type == SupportedDataType.TIME
class Table(BaseModel):
"""Table Metadata"""
name: str
columns: List[Column] = []
class GroupValue(BaseModel):
"""Metrinc Group Values"""
group: Dict = {}
value: object
class Config:
arbitrary_types_allowed = True
class MetricMeasurement(BaseModel):
"""Metric Measurement"""
name: str
col_name: str
value: object = None
class Config:
arbitrary_types_allowed = True
class TableProfileResult(BaseModel):
"""Table Profile Result"""
name: str
row_count: int = None
col_count: int = None
class ColumnProfileResult(BaseModel):
name: str
measurements: Dict[str, MetricMeasurement] = {}
class ProfileResult(BaseModel):
"""Profile Run Result"""
profile_date: str
table_result: TableProfileResult
columns_result: Dict[str, ColumnProfileResult] = {}

View File

@ -0,0 +1,85 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import logging
import sys
from datetime import datetime, timezone
from typing import Type, TypeVar
from openmetadata.common.config import ConfigModel, DynamicTypedConfig
from openmetadata.common.database import Database
from openmetadata.common.database_common import SQLConnectionConfig
from openmetadata.profiler.profiler import Profiler
from openmetadata.profiler.profiler_metadata import ProfileResult
logger = logging.getLogger(__name__)
T = TypeVar("T")
class ProfilerConfig(ConfigModel):
profiler: DynamicTypedConfig
def type_class_fetch(clazz_type: str, is_file: bool):
if is_file:
return clazz_type.replace("-", "_")
else:
return "".join([i.title() for i in clazz_type.replace("-", "_").split("_")])
def get_clazz(key: str) -> Type[T]:
if key.find(".") >= 0:
# If the key contains a dot, we treat it as a import path and attempt
# to load it dynamically.
module_name, class_name = key.rsplit(".", 1)
clazz = getattr(importlib.import_module(module_name), class_name)
return clazz
class ProfilerRunner:
config: ProfilerConfig
database: Database
def __init__(self, config: ProfilerConfig):
self.config = config
database_type = self.config.profiler.type
database_class = get_clazz(
"openmetadata.databases.{}.{}".format(
type_class_fetch(database_type, True),
type_class_fetch(database_type, False),
)
)
self.profiler_config = self.config.profiler.dict().get("config", {})
self.database: Database = database_class.create(
self.profiler_config.get("sql_connection", {})
)
self.table_name = self.profiler_config.get("table_name")
self.variables: dict = {}
self.time = datetime.now(tz=timezone.utc).isoformat(timespec="seconds")
@classmethod
def create(cls, config_dict: dict) -> "ProfilerRunner":
config = ProfilerConfig.parse_obj(config_dict)
return cls(config)
def execute(self):
try:
profiler = Profiler(
database=self.database,
table_name=self.table_name,
profile_time=self.time,
)
profile_result: ProfileResult = profiler.execute()
return profile_result
except Exception as e:
logger.exception(f"Profiler failed: {str(e)}")
logger.info(f"Exiting with code 1")
sys.exit(1)