From 7b09571a22d88ce9e4c38d6e4c6d13a5d321ec1d Mon Sep 17 00:00:00 2001 From: Sriharsha Chintalapani Date: Tue, 26 Oct 2021 20:01:10 -0700 Subject: [PATCH] Profiler: Module to calculate measurements for a table (#933) * Profiler: Module to calculate measurements for a table * Profiler: refactor sql expressions * Profiler: fix supported types * Update README.md Co-authored-by: Ayush Shah --- .../ingestion/models/table_queries.py | 3 +- .../ingestion/processor/query_parser.py | 3 +- profiler/CHANGELOG | 0 profiler/README.md | 25 + profiler/configs/database.yaml | 11 + profiler/requirements.txt | 14 + profiler/setup.cfg | 49 ++ profiler/setup.py | 96 ++++ profiler/src/openmetadata/__init__.py | 0 profiler/src/openmetadata/cmd.py | 78 +++ profiler/src/openmetadata/common/__init__.py | 0 profiler/src/openmetadata/common/config.py | 123 +++++ profiler/src/openmetadata/common/database.py | 72 +++ .../openmetadata/common/database_common.py | 328 +++++++++++++ profiler/src/openmetadata/common/metric.py | 63 +++ .../src/openmetadata/databases/__init__.py | 0 .../src/openmetadata/databases/postgres.py | 64 +++ .../src/openmetadata/databases/redshift.py | 108 +++++ .../src/openmetadata/exceptions/exceptions.py | 18 + .../src/openmetadata/profiler/__init__.py | 0 .../src/openmetadata/profiler/profiler.py | 447 ++++++++++++++++++ .../profiler/profiler_metadata.py | 120 +++++ .../openmetadata/profiler/profiler_runner.py | 85 ++++ 23 files changed, 1703 insertions(+), 4 deletions(-) create mode 100644 profiler/CHANGELOG create mode 100644 profiler/README.md create mode 100644 profiler/configs/database.yaml create mode 100644 profiler/requirements.txt create mode 100644 profiler/setup.cfg create mode 100644 profiler/setup.py create mode 100644 profiler/src/openmetadata/__init__.py create mode 100644 profiler/src/openmetadata/cmd.py create mode 100644 profiler/src/openmetadata/common/__init__.py create mode 100644 profiler/src/openmetadata/common/config.py create mode 100644 profiler/src/openmetadata/common/database.py create mode 100644 profiler/src/openmetadata/common/database_common.py create mode 100644 profiler/src/openmetadata/common/metric.py create mode 100644 profiler/src/openmetadata/databases/__init__.py create mode 100644 profiler/src/openmetadata/databases/postgres.py create mode 100644 profiler/src/openmetadata/databases/redshift.py create mode 100644 profiler/src/openmetadata/exceptions/exceptions.py create mode 100644 profiler/src/openmetadata/profiler/__init__.py create mode 100644 profiler/src/openmetadata/profiler/profiler.py create mode 100644 profiler/src/openmetadata/profiler/profiler_metadata.py create mode 100644 profiler/src/openmetadata/profiler/profiler_runner.py diff --git a/ingestion/src/metadata/ingestion/models/table_queries.py b/ingestion/src/metadata/ingestion/models/table_queries.py index fb6875e92aa..116ca27c10a 100644 --- a/ingestion/src/metadata/ingestion/models/table_queries.py +++ b/ingestion/src/metadata/ingestion/models/table_queries.py @@ -15,10 +15,9 @@ from typing import Dict, List, Optional -from pydantic import BaseModel - from metadata.generated.schema.entity.data.table import ColumnJoins from metadata.ingestion.models.json_serializable import JsonSerializable +from pydantic import BaseModel class TableQuery(JsonSerializable): diff --git a/ingestion/src/metadata/ingestion/processor/query_parser.py b/ingestion/src/metadata/ingestion/processor/query_parser.py index c08819cd6d1..926c4cefd56 100644 --- a/ingestion/src/metadata/ingestion/processor/query_parser.py +++ b/ingestion/src/metadata/ingestion/processor/query_parser.py @@ -18,13 +18,12 @@ import logging import traceback from typing import Optional -from sql_metadata import Parser - from metadata.config.common import ConfigModel from metadata.ingestion.api.common import WorkflowContext from metadata.ingestion.api.processor import Processor, ProcessorStatus from metadata.ingestion.models.table_queries import QueryParserData, TableQuery from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig +from sql_metadata import Parser class QueryParserProcessorConfig(ConfigModel): diff --git a/profiler/CHANGELOG b/profiler/CHANGELOG new file mode 100644 index 00000000000..e69de29bb2d diff --git a/profiler/README.md b/profiler/README.md new file mode 100644 index 00000000000..ddde411fa99 --- /dev/null +++ b/profiler/README.md @@ -0,0 +1,25 @@ +--- +This guide will help you setup the Data Profiler +--- + +![Python version 3.8+](https://img.shields.io/badge/python-3.8%2B-blue) + +OpenMetadata Data profiler to collect measurements on various data sources +publishes them to OpenMetadata. This framework runs the tests and profiler + +**Prerequisites** + +- Python >= 3.8.x + +### Install From PyPI + +```text +python3 -m pip install --upgrade pip wheel setuptools openmetadata-dataprofiler +``` + +#### Generate Redshift Data + +```text +profiler test -c ./examples/workflows/redshift.json +``` + diff --git a/profiler/configs/database.yaml b/profiler/configs/database.yaml new file mode 100644 index 00000000000..03b3804b116 --- /dev/null +++ b/profiler/configs/database.yaml @@ -0,0 +1,11 @@ +profiler: + type: redshift + config: + sql_connection: + host_port: host:port + username: username + password: password + database: warehouse + db_schema: public + service_name: redshift + table_name: sales diff --git a/profiler/requirements.txt b/profiler/requirements.txt new file mode 100644 index 00000000000..0c82a7ac1f0 --- /dev/null +++ b/profiler/requirements.txt @@ -0,0 +1,14 @@ +pip~=21.3 +PyYAML~=6.0 +MarkupSafe~=2.0.1 +urllib3~=1.26.7 +six~=1.16.0 +python-dateutil~=2.8.2 +greenlet~=1.1.2 +psycopg2~=2.9.1 +setuptools~=58.2.0 +SQLAlchemy~=1.4.26 +click~=8.0.3 +Jinja2~=3.0.2 +expandvars~=0.7.0 +pydantic~=1.8.2 \ No newline at end of file diff --git a/profiler/setup.cfg b/profiler/setup.cfg new file mode 100644 index 00000000000..15c158b9056 --- /dev/null +++ b/profiler/setup.cfg @@ -0,0 +1,49 @@ +[flake8] +# We ignore the line length issues here, since black will take care of them. +max-line-length = 150 +max-complexity = 15 +ignore = + # Ignore: 1 blank line required before class docstring. + D203, + W503 +exclude = + .git, + __pycache__ +per-file-ignores = + # imported but unused + __init__.py: F401 +[metadata] +license_files = LICENSE +[mypy] +mypy_path = src +plugins = + sqlmypy, + pydantic.mypy +ignore_missing_imports = yes +namespace_packages = true +strict_optional = yes +check_untyped_defs = yes +# eventually we'd like to enable these +disallow_untyped_defs = no +disallow_incomplete_defs = no + +[isort] +profile = black +indent=' ' +sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER + +[tool:pytest] +addopts = --cov src --cov-report term --cov-config setup.cfg --strict-markers +markers = + slow: marks tests as slow (deselect with '-m "not slow"') +testpaths = + tests/unit + +[options] +packages = find: +package_dir = + =src + +[options.packages.find] +where = src +include = * \ No newline at end of file diff --git a/profiler/setup.py b/profiler/setup.py new file mode 100644 index 00000000000..b838272b676 --- /dev/null +++ b/profiler/setup.py @@ -0,0 +1,96 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Dict, Set + +from setuptools import find_namespace_packages, setup + + +def get_version(): + root = os.path.dirname(__file__) + changelog = os.path.join(root, "CHANGELOG") + with open(changelog) as f: + return f.readline().strip() + + +def get_long_description(): + root = os.path.dirname(__file__) + with open(os.path.join(root, "README.md")) as f: + description = f.read() + description += "\n\nChangelog\n=========\n\n" + with open(os.path.join(root, "CHANGELOG")) as f: + description += f.read() + return description + + +base_requirements = { + "sqlalchemy>=1.3.24", + "Jinja2>=2.11.3, <3.0", + "click>=7.1.2, <8.0", + "cryptography==3.3.2", + "pyyaml>=5.4.1, <6.0", + "requests>=2.23.0, <3.0" "idna<3,>=2.5", + "click<7.2.0,>=7.1.1", + "expandvars>=0.6.5" + "dataclasses>=0.8" + "typing_extensions>=3.7.4" + "mypy_extensions>=0.4.3", + "typing-inspect", + "pydantic==1.7.4", + "pymysql>=1.0.2", + "GeoAlchemy2", + "psycopg2-binary>=2.8.5, <3.0", + "openmetadata-sqlalchemy-redshift==0.2.1", +} + +plugins: Dict[str, Set[str]] = { + "redshift": { + "openmetadata-sqlalchemy-redshift==0.2.1", + "psycopg2-binary", + "GeoAlchemy2", + }, + "postgres": {"pymysql>=1.0.2", "psycopg2-binary", "GeoAlchemy2"}, +} + +build_options = {"includes": ["_cffi_backend"]} +setup( + name="openmetadata-data-profiler", + version="0.1", + url="https://open-metadata.org/", + author="OpenMetadata Committers", + license="Apache License 2.0", + description="Data Profiler and Testing Framework for OpenMetadata", + long_description=get_long_description(), + long_description_content_type="text/markdown", + python_requires=">=3.8", + options={"build_exe": build_options}, + package_dir={"": "src"}, + zip_safe=False, + dependency_links=[], + project_urls={ + "Documentation": "https://docs.open-metadata.org/", + "Source": "https://github.com/open-metadata/OpenMetadata", + }, + packages=find_namespace_packages(where="./src", exclude=["tests*"]), + entry_points={ + "console_scripts": ["openmetadata = openmetadata.cmd:openmetadata"], + }, + install_requires=list(base_requirements), + extras_require={ + "base": list(base_requirements), + **{plugin: list(dependencies) for (plugin, dependencies) in plugins.items()}, + "all": list( + base_requirements.union( + *[requirements for plugin, requirements in plugins.items()] + ) + ), + }, +) diff --git a/profiler/src/openmetadata/__init__.py b/profiler/src/openmetadata/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/profiler/src/openmetadata/cmd.py b/profiler/src/openmetadata/cmd.py new file mode 100644 index 00000000000..6bf12e45269 --- /dev/null +++ b/profiler/src/openmetadata/cmd.py @@ -0,0 +1,78 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import pathlib +import sys + +import click +from pydantic import ValidationError + +from openmetadata.common.config import load_config_file +from openmetadata.profiler.profiler_metadata import ProfileResult +from openmetadata.profiler.profiler_runner import ProfilerRunner + +logger = logging.getLogger(__name__) + +# Configure logger. +BASE_LOGGING_FORMAT = ( + "[%(asctime)s] %(levelname)-8s {%(name)s:%(lineno)d} - %(message)s" +) +logging.basicConfig(format=BASE_LOGGING_FORMAT) + + +@click.group() +def check() -> None: + pass + + +@click.group() +@click.option("--debug/--no-debug", default=False) +def openmetadata(debug: bool) -> None: + if debug: + logging.getLogger().setLevel(logging.INFO) + logging.getLogger("openmetadata").setLevel(logging.DEBUG) + else: + logging.getLogger().setLevel(logging.WARNING) + logging.getLogger("openmetadata").setLevel(logging.INFO) + + +@openmetadata.command() +@click.option( + "-c", + "--config", + type=click.Path(exists=True, dir_okay=False), + help="Profiler config", + required=True, +) +def profiler(config: str) -> None: + """Main command for running data openmetadata and tests""" + try: + config_file = pathlib.Path(config) + profiler_config = load_config_file(config_file) + try: + logger.info(f"Using config: {profiler_config}") + profiler_runner = ProfilerRunner.create(profiler_config) + except ValidationError as e: + click.echo(e, err=True) + sys.exit(1) + + logger.info(f"Running Profiler for {profiler_runner.table_name} ...") + profile_result: ProfileResult = profiler_runner.execute() + logger.info(f"Profiler Results") + logger.info(f"{profile_result.json()}") + + except Exception as e: + logger.exception(f"Scan failed: {str(e)}") + logger.info(f"Exiting with code 1") + sys.exit(1) + + +openmetadata.add_command(check) diff --git a/profiler/src/openmetadata/common/__init__.py b/profiler/src/openmetadata/common/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/profiler/src/openmetadata/common/config.py b/profiler/src/openmetadata/common/config.py new file mode 100644 index 00000000000..c3b2d520731 --- /dev/null +++ b/profiler/src/openmetadata/common/config.py @@ -0,0 +1,123 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import io +import json +import pathlib +from abc import ABC, abstractmethod +from typing import IO, Any, List, Optional + +import yaml +from expandvars import expandvars +from pydantic import BaseModel + + +class ConfigModel(BaseModel): + class Config: + extra = "forbid" + + +class DynamicTypedConfig(ConfigModel): + type: str + config: Optional[Any] + + +class WorkflowExecutionError(Exception): + """An error occurred when executing the workflow""" + + +class ConfigurationError(Exception): + """A configuration error has happened""" + + +class ConfigurationMechanism(ABC): + @abstractmethod + def load_config(self, config_fp: IO) -> dict: + pass + + +class YamlConfigurationMechanism(ConfigurationMechanism): + """load configuration from yaml files""" + + def load_config(self, config_fp: IO) -> dict: + config = yaml.safe_load(config_fp) + return config + + +class JsonConfigurationMechanism(ConfigurationMechanism): + """load configuration from yaml files""" + + def load_config(self, config_fp: IO) -> dict: + config = json.load(config_fp) + return config + + +def load_config_file(config_file: pathlib.Path) -> dict: + if not config_file.is_file(): + raise ConfigurationError(f"Cannot open config file {config_file}") + + config_mech: ConfigurationMechanism + if config_file.suffix in [".yaml", ".yml"]: + config_mech = YamlConfigurationMechanism() + elif config_file.suffix == ".json": + config_mech = JsonConfigurationMechanism() + else: + raise ConfigurationError( + "Only .json and .yml are supported. Cannot process file type {}".format( + config_file.suffix + ) + ) + + with config_file.open() as raw_config_file: + raw_config = raw_config_file.read() + + expanded_config_file = expandvars(raw_config, nounset=True) + config_fp = io.StringIO(expanded_config_file) + config = config_mech.load_config(config_fp) + + return config + + +class IncludeFilterPattern(ConfigModel): + """A class to store allow deny regexes""" + + includes: List[str] = [".*"] + excludes: List[str] = [] + alphabet: str = "[A-Za-z0-9 _.-]" + + @property + def alphabet_pattern(self): + return re.compile(f"^{self.alphabet}+$") + + @classmethod + def allow_all(cls): + return IncludeFilterPattern() + + def included(self, string: str) -> bool: + try: + for exclude in self.excludes: + if re.match(exclude, string): + return False + + for include in self.includes: + if re.match(include, string): + return True + return False + except Exception as err: + raise Exception("Regex Error: {}".format(err)) + + def is_fully_specified_include_list(self) -> bool: + for filter_pattern in self.includes: + if not self.alphabet_pattern.match(filter_pattern): + return False + return True + + def get_allowed_list(self): + assert self.is_fully_specified_include_list() + return [a for a in self.includes if self.included(a)] diff --git a/profiler/src/openmetadata/common/database.py b/profiler/src/openmetadata/common/database.py new file mode 100644 index 00000000000..b7ff945cd17 --- /dev/null +++ b/profiler/src/openmetadata/common/database.py @@ -0,0 +1,72 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass +from typing import List + + +class Closeable: + @abstractmethod + def close(self): + pass + + +@dataclass +class Database(Closeable, metaclass=ABCMeta): + @classmethod + @abstractmethod + def create(cls, config_dict: dict) -> "Database": + pass + + @property + @abstractmethod + def sql_exprs(self): + pass + + @abstractmethod + def table_metadata_query(self, table_name: str) -> str: + pass + + @abstractmethod + def qualify_table_name(self, table_name: str) -> str: + return table_name + + @abstractmethod + def qualify_column_name(self, column_name: str): + return column_name + + @abstractmethod + def is_text(self, column_type: str): + pass + + @abstractmethod + def is_number(self, column_type: str): + pass + + @abstractmethod + def is_time(self, column_type: str): + pass + + @abstractmethod + def sql_fetchone(self, sql) -> tuple: + pass + + @abstractmethod + def sql_fetchone_description(self, sql) -> tuple: + pass + + @abstractmethod + def sql_fetchall(self, sql) -> List[tuple]: + pass + + @abstractmethod + def sql_fetchall_description(self, sql) -> tuple: + pass diff --git a/profiler/src/openmetadata/common/database_common.py b/profiler/src/openmetadata/common/database_common.py new file mode 100644 index 00000000000..8c9ea2b4c98 --- /dev/null +++ b/profiler/src/openmetadata/common/database_common.py @@ -0,0 +1,328 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import logging +import re +from abc import abstractmethod +from datetime import date, datetime +from numbers import Number +from typing import List, Optional +from urllib.parse import quote_plus + +from pydantic import BaseModel +from sqlalchemy import create_engine + +from openmetadata.common.config import ConfigModel, IncludeFilterPattern +from openmetadata.common.database import Database +from openmetadata.profiler.profiler_metadata import Column, SupportedDataType + +logger = logging.getLogger(__name__) + + +class SQLConnectionConfig(ConfigModel): + username: Optional[str] = None + password: Optional[str] = None + host_port: str + database: Optional[str] = None + db_schema: Optional[str] = None + scheme: str + service_name: str + service_type: str + options: dict = {} + profiler_date: Optional[str] = datetime.now().strftime("%Y-%m-%d") + profiler_offset: Optional[int] = 0 + profiler_limit: Optional[int] = 50000 + filter_pattern: IncludeFilterPattern = IncludeFilterPattern.allow_all() + + @abstractmethod + def get_connection_url(self): + url = f"{self.scheme}://" + if self.username is not None: + url += f"{quote_plus(self.username)}" + if self.password is not None: + url += f":{quote_plus(self.password)}" + url += "@" + url += f"{self.host_port}" + if self.database: + url += f"/{self.database}" + logger.info(url) + return url + + +_numeric_types = [ + "SMALLINT", + "INTEGER", + "BIGINT", + "DECIMAL", + "NUMERIC", + "REAL", + "DOUBLE PRECISION", + "SMALLSERIAL", + "SERIAL", + "BIGSERIAL", +] + +_text_types = ["CHARACTER VARYING", "CHARACTER", "CHAR", "VARCHAR" "TEXT"] + +_time_types = [ + "TIMESTAMP", + "DATE", + "TIME", + "TIMESTAMP WITH TIME ZONE", + "TIMESTAMP WITHOUT TIME ZONE", + "TIME WITH TIME ZONE", + "TIME WITHOUT TIME ZONE", +] + + +def register_custom_type( + data_types: List[str], type_category: SupportedDataType +) -> None: + if type_category == SupportedDataType.TIME: + _time_types.extend(data_types) + elif type_category == SupportedDataType.TEXT: + _text_types.extend(data_types) + elif type_category == SupportedDataType.NUMERIC: + _numeric_types.extend(data_types) + else: + raise Exception(f"Unsupported {type_category}") + + +class SQLExpressions(BaseModel): + count_all_expr: str = "COUNT(*)" + count_expr: str = "COUNT({})" + distinct_expr: str = "DISTINCT({})" + min_expr: str = "MIN({})" + max_expr: str = "MAX({})" + length_expr: str = "LENGTH({})" + avg_expr: str = "AVG({})" + sum_expr: str = "SUM({})" + variance_expr: str = "VARIANCE({})" + stddev_expr: str = "STDDEV({})" + limit_expr: str = "LIMIT {}" + count_conditional_expr: str = "COUNT(CASE WHEN {} THEN 1 END)" + conditional_expr: str = "CASE WHEN {} THEN {} END" + equal_expr: str = "{} == {}" + less_than_expr: str = "{} < {}" + less_than_or_equal_expr: str = "{} <= {}" + greater_than_expr: str = "{} > {}" + greater_than_or_equal_expr: str = "{} >= {}" + var_in_expr: str = "{} in {}" + regex_like_pattern_expr: str = "REGEXP_LIKE({}, '{}')" + contains_expr: str = "{} LIKE '%{}%'" + starts_with_expr: str = "{} LIKE '%{}'" + ends_with_expr: str = "{} LIKE '{}%'" + + @staticmethod + def escape_metacharacters(value: str): + return re.sub(r"(\\.)", r"\\\1", value) + + def literal_number(self, value: Number): + if value is None: + return None + return str(value) + + def literal_string(self, value: str): + if value is None: + return None + return "'" + self.escape_metacharacters(value) + "'" + + def literal_list(self, l: list): + if l is None: + return None + return "(" + (",".join([self.literal(e) for e in l])) + ")" + + def count(self, expr: str): + return self.count_expr.format(expr) + + def distinct(self, expr: str): + return self.distinct_expr.format(expr) + + def min(self, expr: str): + return self.min_expr.format(expr) + + def max(self, expr: str): + return self.max_expr.format(expr) + + def length(self, expr: str): + return self.length_expr.format(expr) + + def avg(self, expr: str): + return self.avg_expr.format(expr) + + def sum(self, expr: str): + return self.sum_expr.format(expr) + + def variance(self, expr: str): + return self.variance_expr.format(expr) + + def stddev(self, expr: str): + return self.stddev_expr.format(expr) + + def limit(self, expr: str): + return self.limit_expr.format(expr) + + def regex_like(self, expr: str, pattern: str): + return self.regex_like_pattern_expr.format(expr, pattern) + + def equal(self, left: str, right: str): + if right == "null": + return f"{left} IS NULL" + else: + return self.equal_expr.format(right, left) + + def less_than(self, left, right): + return self.less_than_expr.format(left, right) + + def less_than_or_equal(self, left, right): + return self.less_than_or_equal_expr.format(left, right) + + def greater_than(self, left, right): + return self.greater_than_expr.format(left, right) + + def greater_than_or_equal(self, left, right): + return self.greater_than_or_equal_expr.format(left, right) + + def var_in(self, left, right): + return self.var_in_expr.format(left, right) + + def contains(self, value, substring): + return self.contains_expr.format(value, substring) + + def starts_with(self, value, substring): + return self.starts_with_expr.format(value, substring) + + def ends_with(self, value, substring): + return self.ends_with_expr.format(value, substring) + + def count_conditional(self, condition: str): + return self.count_conditional_expr.format(condition) + + def conditional(self, condition: str, expr: str): + return self.conditional_expr.format(condition, expr) + + def literal_date_expr(self, date_expr: date): + date_string = date_expr.strftime("%Y-%m-%d") + return f"DATE '{date_string}'" + + def literal(self, o: object): + if isinstance(o, Number): + return self.literal_number(o) + elif isinstance(o, str): + return self.literal_string(o) + elif isinstance(o, list) or isinstance(o, set) or isinstance(o, tuple): + return self.literal_list(o) + raise RuntimeError(f"Cannot convert type {type(o)} to a SQL literal: {o}") + + def list_expr(self, column: Column, values: List[str]) -> str: + if column.is_text(): + sql_values = [self.literal_string(value) for value in values] + elif column.is_number(): + sql_values = [self.literal_number(value) for value in values] + else: + raise RuntimeError( + f"Couldn't format list {str(values)} for column {str(column)}" + ) + return "(" + ",".join(sql_values) + ")" + + +class DatabaseCommon(Database): + data_type_varchar_255 = "VARCHAR(255)" + data_type_integer = "INTEGER" + data_type_bigint = "BIGINT" + data_type_decimal = "REAL" + data_type_date = "DATE" + config: SQLConnectionConfig = None + sql_exprs: SQLExpressions = SQLExpressions() + + def __init__(self, config: SQLConnectionConfig): + self.config = config + self.connection_string = self.config.get_connection_url() + self.engine = create_engine(self.connection_string, **self.config.options) + self.connection = self.engine.raw_connection() + + @classmethod + def create(cls, config_dict: dict): + pass + + def table_metadata_query(self, table_name: str) -> str: + pass + + def qualify_table_name(self, table_name: str) -> str: + return table_name + + def qualify_column_name(self, column_name: str): + return column_name + + def is_text(self, column_type: str): + return column_type.upper() in _text_types + + def is_number(self, column_type: str): + return column_type.upper() in _numeric_types + + def is_time(self, column_type: str): + return column_type.upper() in _time_types + + def sql_fetchone(self, sql: str) -> tuple: + """ + Only returns the tuple obtained by cursor.fetchone() + """ + return self.sql_fetchone_description(sql)[0] + + def sql_fetchone_description(self, sql: str) -> tuple: + """ + Returns a tuple with 2 elements: + 1) the tuple obtained by cursor.fetchone() + 2) the cursor.description + """ + cursor = self.connection.cursor() + try: + logger.debug(f"Executing SQL query: \n{sql}") + start = datetime.now() + cursor.execute(sql) + row_tuple = cursor.fetchone() + description = cursor.description + delta = datetime.now() - start + logger.debug(f"SQL took {str(delta)}") + return row_tuple, description + finally: + cursor.close() + + def sql_fetchall(self, sql: str) -> List[tuple]: + """ + Only returns the tuples obtained by cursor.fetchall() + """ + return self.sql_fetchall_description(sql)[0] + + def sql_fetchall_description(self, sql: str) -> tuple: + """ + Returns a tuple with 2 elements: + 1) the tuples obtained by cursor.fetchall() + 2) the cursor.description + """ + cursor = self.connection.cursor() + try: + logger.debug(f"Executing SQL query: \n{sql}") + start = datetime.now() + cursor.execute(sql) + rows = cursor.fetchall() + delta = datetime.now() - start + logger.debug(f"SQL took {str(delta)}") + return rows, cursor.description + finally: + cursor.close() + + def close(self): + if self.connection: + try: + self.connection.close() + except Exception as e: + logger.error(f"Closing connection failed: {str(e)}") diff --git a/profiler/src/openmetadata/common/metric.py b/profiler/src/openmetadata/common/metric.py new file mode 100644 index 00000000000..505ba1ff37a --- /dev/null +++ b/profiler/src/openmetadata/common/metric.py @@ -0,0 +1,63 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class Metric: + ROW_COUNT = "row_count" + AVG = "avg" + AVG_LENGTH = "avg_length" + DISTINCT = "distinct" + DUPLICATE_COUNT = "duplicate_count" + FREQUENT_VALUES = "frequent_values" + HISTOGRAM = "histogram" + INVALID_COUNT = "invalid_count" + INVALID_PERCENTAGE = "invalid_percentage" + MAX = "max" + MAX_LENGTH = "max_length" + MIN = "min" + MIN_LENGTH = "min_length" + MISSING_COUNT = "missing_count" + MISSING_PERCENTAGE = "missing_percentage" + STDDEV = "stddev" + SUM = "sum" + UNIQUENESS = "uniqueness" + UNIQUE_COUNT = "unique_count" + VALID_COUNT = "valid_count" + VALID_PERCENTAGE = "valid_percentage" + VALUES_COUNT = "values_count" + VALUES_PERCENTAGE = "values_percentage" + VARIANCE = "variance" + + METRIC_TYPES = [ + ROW_COUNT, + AVG, + AVG_LENGTH, + DISTINCT, + DUPLICATE_COUNT, + FREQUENT_VALUES, + HISTOGRAM, + INVALID_COUNT, + INVALID_PERCENTAGE, + MAX, + MAX_LENGTH, + MIN, + MIN_LENGTH, + MISSING_COUNT, + MISSING_PERCENTAGE, + STDDEV, + SUM, + UNIQUENESS, + UNIQUE_COUNT, + VALID_COUNT, + VALID_PERCENTAGE, + VALUES_COUNT, + VALUES_PERCENTAGE, + VARIANCE, + ] diff --git a/profiler/src/openmetadata/databases/__init__.py b/profiler/src/openmetadata/databases/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/profiler/src/openmetadata/databases/postgres.py b/profiler/src/openmetadata/databases/postgres.py new file mode 100644 index 00000000000..f4f4e63960f --- /dev/null +++ b/profiler/src/openmetadata/databases/postgres.py @@ -0,0 +1,64 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from openmetadata.common.database_common import ( + DatabaseCommon, + SQLConnectionConfig, + SQLExpressions, + register_custom_type, +) + +logger = logging.getLogger(__name__) + + +class PostgresConnectionConfig(SQLConnectionConfig): + scheme = "postgres+psycopg2" + + def get_connection_url(self): + return super().get_connection_url() + + +class PostgresSQLExpressions(SQLExpressions): + regex_like_pattern_expr: str = "{} ~* '{}'" + + +class Postgres(DatabaseCommon): + config: PostgresConnectionConfig = None + sql_exprs: PostgresSQLExpressions = PostgresSQLExpressions() + + def __init__(self, config): + super().__init__(config) + self.config = config + + @classmethod + def create(cls, config_dict): + config = PostgresConnectionConfig.parse_obj(config_dict) + return cls(config) + + def table_metadata_query(self, table_name: str) -> str: + sql = ( + f"SELECT column_name, data_type, is_nullable \n" + f"FROM information_schema.columns \n" + f"WHERE lower(table_name) = '{table_name}'" + ) + if self.config.database: + sql += f" \n AND table_catalog = '{self.config.database}'" + if self.config.db_schema: + sql += f" \n AND table_schema = '{self.config.db_schema}'" + return sql + + def qualify_table_name(self, table_name: str) -> str: + if self.config.db_schema: + return f'"{self.config.db_schema}"."{table_name}"' + return f'"{table_name}"' + + def qualify_column_name(self, column_name: str): + return f'"{column_name}"' diff --git a/profiler/src/openmetadata/databases/redshift.py b/profiler/src/openmetadata/databases/redshift.py new file mode 100644 index 00000000000..6e598ebfee9 --- /dev/null +++ b/profiler/src/openmetadata/databases/redshift.py @@ -0,0 +1,108 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from openmetadata.common.database_common import ( + DatabaseCommon, + SQLConnectionConfig, + SQLExpressions, + register_custom_type, +) +from openmetadata.profiler.profiler_metadata import SupportedDataType + +register_custom_type( + [ + "CHAR", + "CHARACTER", + "BPCHAR", + "VARCHAR", + "CHARACTER VARYING", + "NVARCHAR", + "TEXT", + ], + SupportedDataType.TEXT, +) + +register_custom_type( + [ + "SMALLINT", + "INT2", + "INTEGER", + "INT", + "INT4", + "BIGINT", + "INT8", + "DECIMAL", + "NUMERIC", + "REAL", + "FLOAT4", + "DOUBLE PRECISION", + "FLOAT8", + "FLOAT", + ], + SupportedDataType.NUMERIC, +) + +register_custom_type( + [ + "DATE", + "TIMESTAMP", + "TIMESTAMP WITHOUT TIME ZONE", + "TIMESTAMPTZ", + "TIMESTAMP WITH TIME ZONE", + "TIME", + "TIME WITHOUT TIME ZONE", + "TIMETZ", + "TIME WITH TIME ZONE", + ], + SupportedDataType.TIME, +) + + +class RedshiftConnectionConfig(SQLConnectionConfig): + scheme = "redshift+psycopg2" + where_clause: Optional[str] = None + duration: int = 1 + service_type = "Redshift" + + def get_connection_url(self): + return super().get_connection_url() + + +class RedshiftSQLExpressions(SQLExpressions): + avg_expr = "AVG({})" + sum_expr = "SUM({})" + regex_like_pattern_expr: str = "{} ~* '{}'" + + +class Redshift(DatabaseCommon): + config: RedshiftConnectionConfig = None + sql_exprs: RedshiftSQLExpressions = RedshiftSQLExpressions() + + def __init__(self, config): + super().__init__(config) + self.config = config + + @classmethod + def create(cls, config_dict): + config = RedshiftConnectionConfig.parse_obj(config_dict) + return cls(config) + + def table_metadata_query(self, table_name: str) -> str: + sql = ( + f"SELECT column_name, data_type, is_nullable \n" + f"FROM information_schema.columns \n" + f"WHERE lower(table_name) = '{table_name}'" + ) + if self.config.database: + sql += f" \n AND table_catalog = '{self.config.database}'" + if self.config.db_schema: + sql += f" \n AND table_schema = '{self.config.db_schema}'" + return sql diff --git a/profiler/src/openmetadata/exceptions/exceptions.py b/profiler/src/openmetadata/exceptions/exceptions.py new file mode 100644 index 00000000000..11fdcbf3848 --- /dev/null +++ b/profiler/src/openmetadata/exceptions/exceptions.py @@ -0,0 +1,18 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +ERROR_CODE_GENERIC = "generic_error" + + +class ProfilerSqlError(Exception): + def __init__(self, msg, original_exception): + super(ProfilerSqlError, self).__init__(f"{msg}: {str(original_exception)}") + self.error_code = ERROR_CODE_GENERIC + self.original_exception = original_exception diff --git a/profiler/src/openmetadata/profiler/__init__.py b/profiler/src/openmetadata/profiler/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/profiler/src/openmetadata/profiler/profiler.py b/profiler/src/openmetadata/profiler/profiler.py new file mode 100644 index 00000000000..ae9e8858405 --- /dev/null +++ b/profiler/src/openmetadata/profiler/profiler.py @@ -0,0 +1,447 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from datetime import datetime +from math import ceil, floor +from typing import List + +from openmetadata.common.database import Database +from openmetadata.common.metric import Metric +from openmetadata.profiler.profiler_metadata import ( + Column, + ColumnProfileResult, + MetricMeasurement, + ProfileResult, + SupportedDataType, + Table, + TableProfileResult, + get_group_by_cte, + get_group_by_cte_numeric_value_expression, +) + +logger = logging.getLogger(__name__) + + +class Profiler: + def __init__( + self, + database: Database, + table_name: str, + excluded_columns: List[str] = [], + profile_time: str = None, + ): + self.database = database + self.table = Table(name=table_name) + self.excluded_columns = excluded_columns + self.time = profile_time + self.qualified_table_name = self.database.qualify_table_name(table_name) + self.scan_reference = None + self.columns: List[Column] = [] + self.start_time = None + self.queries_executed = 0 + + self.profiler_result = ProfileResult( + profile_date=self.time, + table_result=TableProfileResult(name=self.table.name), + ) + + def execute(self) -> ProfileResult: + self.start_time = datetime.now() + + try: + self._table_metadata() + self._profile_aggregations() + self._query_group_by_value() + self._query_histograms() + logger.debug( + f"Executed {self.queries_executed} queries in {(datetime.now() - self.start_time)}" + ) + except Exception as e: + logger.exception("Exception during scan") + + finally: + self.database.close() + + return self.profiler_result + + def _table_metadata(self): + sql = self.database.table_metadata_query(self.table.name) + columns = self.database.sql_fetchall(sql) + self.queries_executed += 1 + self.table_columns = [] + for column in columns: + name = column[0] + data_type = column[1] + nullable = "YES" == column[2].upper() + if self.database.is_number(data_type): + logical_type = SupportedDataType.NUMERIC + elif self.database.is_time(data_type): + logical_type = SupportedDataType.TIME + elif self.database.is_text(data_type): + logical_type = SupportedDataType.TEXT + else: + logger.info(f" {name} ({data_type}) not supported.") + continue + self.columns.append( + Column( + name=name, + data_type=data_type, + nullable=nullable, + logical_type=logical_type, + ) + ) + + self.column_names: List[str] = [column.name for column in self.columns] + logger.debug(str(len(self.columns)) + " columns:") + self.profiler_result.table_result.col_count = len(self.columns) + + def _profile_aggregations(self): + measurements: List[MetricMeasurement] = [] + fields: List[str] = [] + + # Compute Row Count + fields.append(self.database.sql_exprs.count_all_expr) + measurements.append(MetricMeasurement(name=Metric.ROW_COUNT, col_name="table")) + column_metric_indices = {} + + try: + for column in self.columns: + metric_indices = {} + column_metric_indices[column.name.lower()] = metric_indices + column_name = column.name + qualified_column_name = self.database.qualify_column_name(column_name) + + ## values_count + metric_indices["non_missing"] = len(measurements) + fields.append(self.database.sql_exprs.count(qualified_column_name)) + measurements.append( + MetricMeasurement(name=Metric.VALUES_COUNT, col_name=column_name) + ) + + # Valid Count + fields.append(self.database.sql_exprs.count(qualified_column_name)) + measurements.append( + MetricMeasurement(name=Metric.VALID_COUNT, col_name=column_name) + ) + if column.logical_type == SupportedDataType.TEXT: + length_expr = self.database.sql_exprs.length(qualified_column_name) + fields.append(self.database.sql_exprs.avg(length_expr)) + measurements.append( + MetricMeasurement(name=Metric.AVG_LENGTH, col_name=column_name) + ) + + # Min Length + fields.append(self.database.sql_exprs.min(length_expr)) + measurements.append( + MetricMeasurement(name=Metric.MIN_LENGTH, col_name=column_name) + ) + + # Max Length + fields.append(self.database.sql_exprs.max(length_expr)) + measurements.append( + MetricMeasurement(name=Metric.MAX_LENGTH, col_name=column_name) + ) + + if column.logical_type == SupportedDataType.NUMERIC: + # Min + fields.append(self.database.sql_exprs.min(qualified_column_name)) + measurements.append( + MetricMeasurement(name=Metric.MIN, col_name=column_name) + ) + + # Max + fields.append(self.database.sql_exprs.max(qualified_column_name)) + measurements.append( + MetricMeasurement(name=Metric.MAX, col_name=column_name) + ) + + # AVG + fields.append(self.database.sql_exprs.avg(qualified_column_name)) + measurements.append( + MetricMeasurement(name=Metric.AVG, col_name=column_name) + ) + + # SUM + fields.append(self.database.sql_exprs.sum(qualified_column_name)) + measurements.append( + MetricMeasurement(name=Metric.SUM, col_name=column_name) + ) + + # VARIANCE + fields.append( + self.database.sql_exprs.variance(qualified_column_name) + ) + measurements.append( + MetricMeasurement(name=Metric.VARIANCE, col_name=column_name) + ) + + # STDDEV + fields.append(self.database.sql_exprs.stddev(qualified_column_name)) + measurements.append( + MetricMeasurement(name=Metric.STDDEV, col_name=column_name) + ) + + if len(fields) > 0: + sql = ( + "SELECT \n " + ",\n ".join(fields) + " \n" + "FROM " + self.qualified_table_name + ) + query_result_tuple = self.database.sql_fetchone(sql) + self.queries_executed += 1 + + for i in range(0, len(measurements)): + measurement = measurements[i] + measurement.value = query_result_tuple[i] + self._add_measurement(measurement) + + # Calculating derived measurements + row_count_measurement = next( + (m for m in measurements if m.name == Metric.ROW_COUNT), None + ) + + if row_count_measurement: + row_count = row_count_measurement.value + self.profiler_result.table_result.row_count = row_count + for column in self.columns: + column_name = column.name + metric_indices = column_metric_indices[column_name.lower()] + non_missing_index = metric_indices.get("non_missing") + if non_missing_index is not None: + values_count = measurements[non_missing_index].value + missing_count = row_count - values_count + missing_percentage = ( + missing_count * 100 / row_count + if row_count > 0 + else None + ) + values_percentage = ( + values_count * 100 / row_count + if row_count > 0 + else None + ) + self._add_measurement( + MetricMeasurement( + name=Metric.MISSING_PERCENTAGE, + col_name=column_name, + value=missing_percentage, + ) + ) + self._add_measurement( + MetricMeasurement( + name=Metric.MISSING_COUNT, + col_name=column_name, + value=missing_count, + ) + ) + self._add_measurement( + MetricMeasurement( + name=Metric.VALUES_PERCENTAGE, + col_name=column_name, + value=values_percentage, + ) + ) + + valid_index = metric_indices.get("valid") + if valid_index is not None: + valid_count = measurements[valid_index].value + invalid_count = row_count - missing_count - valid_count + invalid_percentage = ( + invalid_count * 100 / row_count + if row_count > 0 + else None + ) + valid_percentage = ( + valid_count * 100 / row_count + if row_count > 0 + else None + ) + self._add_measurement( + MetricMeasurement( + name=Metric.INVALID_PERCENTAGE, + col_name=column_name, + value=invalid_percentage, + ) + ) + self._add_measurement( + MetricMeasurement( + name=Metric.INVALID_COUNT, + col_name=column_name, + value=invalid_count, + ) + ) + self._add_measurement( + MetricMeasurement( + name=Metric.VALID_PERCENTAGE, + col_name=column_name, + value=valid_percentage, + ) + ) + except Exception as e: + logger.error(f"Exception during aggregation query", exc_info=e) + + def _query_group_by_value(self): + for column in self.columns: + try: + measurements = [] + column_name = column.name + group_by_cte = get_group_by_cte( + self.database.qualify_column_name(column.name), + self.database.qualify_table_name(self.table.name), + ) + ## Compute Distinct, Unique, Unique_Count, Duplicate_count + sql = ( + f"{group_by_cte} \n" + f"SELECT COUNT(*), \n" + f" COUNT(CASE WHEN frequency = 1 THEN 1 END), \n" + f" SUM(frequency) \n" + f"FROM group_by_value" + ) + + query_result_tuple = self.database.sql_fetchone(sql) + self.queries_executed += 1 + + distinct_count = query_result_tuple[0] + unique_count = query_result_tuple[1] + valid_count = query_result_tuple[2] if query_result_tuple[2] else 0 + duplicate_count = distinct_count - unique_count + self._add_measurement( + MetricMeasurement( + name=Metric.DISTINCT, col_name=column_name, value=distinct_count + ) + ) + self._add_measurement( + MetricMeasurement( + name=Metric.UNIQUE_COUNT, + col_name=column_name, + value=unique_count, + ) + ) + self._add_measurement( + MetricMeasurement( + name=Metric.DUPLICATE_COUNT, + col_name=column_name, + value=duplicate_count, + ) + ) + if valid_count > 1: + uniqueness = (distinct_count - 1) * 100 / (valid_count - 1) + self._add_measurement( + MetricMeasurement( + name=Metric.UNIQUENESS, + col_name=column_name, + value=uniqueness, + ) + ) + except Exception as e: + logger.error( + f"Exception during column group by value queries", exc_info=e + ) + + def _query_histograms(self): + for column in self.columns: + column_name = column.name + try: + if column.is_number(): + measurements = [] + buckets: int = 20 + column_results = self.profiler_result.columns_result[column_name] + min_value = column_results.measurements.get(Metric.MIN).value + max_value = column_results.measurements.get(Metric.MAX).value + + if ( + column.is_number() + and min_value + and max_value + and min_value < max_value + ): + min_value = floor(min_value * 1000) / 1000 + max_value = ceil(max_value * 1000) / 1000 + bucket_width = (max_value - min_value) / buckets + + boundary = min_value + boundaries = [min_value] + for i in range(0, buckets): + boundary += bucket_width + boundaries.append(round(boundary, 3)) + + group_by_cte = get_group_by_cte(column_name, self.table.name) + numeric_value_expr = get_group_by_cte_numeric_value_expression( + column, self.database, None + ) + + field_clauses = [] + for i in range(0, buckets): + lower_bound = ( + "" + if i == 0 + else f"{boundaries[i]} <= {numeric_value_expr}" + ) + upper_bound = ( + "" + if i == buckets - 1 + else f"{numeric_value_expr} < {boundaries[i + 1]}" + ) + optional_and = ( + "" + if lower_bound == "" or upper_bound == "" + else " and " + ) + field_clauses.append( + f"SUM(CASE WHEN {lower_bound}{optional_and}{upper_bound} THEN frequency END)" + ) + + fields = ",\n ".join(field_clauses) + + sql = ( + f"{group_by_cte} \n" + f"SELECT \n" + f" {fields} \n" + f"FROM group_by_value" + ) + + row = self.database.sql_fetchone(sql) + self.queries_executed += 1 + + # Process the histogram query + frequencies = [] + for i in range(0, buckets): + frequency = row[i] + frequencies.append(0 if not frequency else int(frequency)) + histogram = { + "boundaries": boundaries, + "frequencies": frequencies, + } + self._add_measurement( + MetricMeasurement( + name=Metric.HISTOGRAM, + col_name=column_name, + value=histogram, + ) + ) + except Exception as e: + logger.error(f"Exception during aggregation query", exc_info=e) + + def _add_measurement(self, measurement): + logger.debug(f"measurement: {measurement}") + if measurement.col_name in self.profiler_result.columns_result.keys(): + col_result = self.profiler_result.columns_result[measurement.col_name] + col_measurements = col_result.measurements + col_measurements[measurement.name] = measurement + self.profiler_result.columns_result[ + measurement.col_name + ].measurements = col_measurements + else: + col_measurements = {measurement.name: measurement} + self.profiler_result.columns_result[ + measurement.col_name + ] = ColumnProfileResult( + name=measurement.col_name, measurements=col_measurements + ) diff --git a/profiler/src/openmetadata/profiler/profiler_metadata.py b/profiler/src/openmetadata/profiler/profiler_metadata.py new file mode 100644 index 00000000000..a295868f545 --- /dev/null +++ b/profiler/src/openmetadata/profiler/profiler_metadata.py @@ -0,0 +1,120 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from enum import Enum +from typing import Dict, List, Optional + +from pydantic import BaseModel + + +def get_group_by_cte(column_name, table_name): + where = f"{column_name} IS NOT NULL" + + return ( + f"WITH group_by_value AS ( \n" + f" SELECT \n" + f" {column_name} AS value, \n" + f" COUNT(*) AS frequency \n" + f" FROM {table_name} \n" + f" WHERE {where} \n" + f" GROUP BY {column_name} \n" + f")" + ) + + +def get_group_by_cte_numeric_value_expression(column, database, validity_format): + if column.is_number or column.is_time: + return "value" + if column.is_column_numeric_text_format: + return database.sql_expr_cast_text_to_number("value", validity_format) + + +def get_order_by_cte_value_expression( + column, database, validity_format, numeric_value_expr: str +): + if column.is_number or column.is_time: + return "value" + if column.is_column_numeric_text_format: + return database.sql_expr_cast_text_to_number("value", validity_format) + elif column.is_text: + return "value" + return None + + +class SupportedDataType(Enum): + NUMERIC = 1 + TEXT = 2 + TIME = 3 + + +class Column(BaseModel): + """Column Metadata""" + + name: str + nullable: bool = None + data_type: str + logical_type: SupportedDataType + + def is_text(self) -> bool: + return self.logical_type == SupportedDataType.TEXT + + def is_number(self) -> bool: + return self.logical_type == SupportedDataType.NUMERIC + + def is_time(self) -> bool: + return self.logical_type == SupportedDataType.TIME + + +class Table(BaseModel): + """Table Metadata""" + + name: str + columns: List[Column] = [] + + +class GroupValue(BaseModel): + """Metrinc Group Values""" + + group: Dict = {} + value: object + + class Config: + arbitrary_types_allowed = True + + +class MetricMeasurement(BaseModel): + """Metric Measurement""" + + name: str + col_name: str + value: object = None + + class Config: + arbitrary_types_allowed = True + + +class TableProfileResult(BaseModel): + """Table Profile Result""" + + name: str + row_count: int = None + col_count: int = None + + +class ColumnProfileResult(BaseModel): + name: str + measurements: Dict[str, MetricMeasurement] = {} + + +class ProfileResult(BaseModel): + """Profile Run Result""" + + profile_date: str + table_result: TableProfileResult + columns_result: Dict[str, ColumnProfileResult] = {} diff --git a/profiler/src/openmetadata/profiler/profiler_runner.py b/profiler/src/openmetadata/profiler/profiler_runner.py new file mode 100644 index 00000000000..cc17c7dc719 --- /dev/null +++ b/profiler/src/openmetadata/profiler/profiler_runner.py @@ -0,0 +1,85 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import logging +import sys +from datetime import datetime, timezone +from typing import Type, TypeVar + +from openmetadata.common.config import ConfigModel, DynamicTypedConfig +from openmetadata.common.database import Database +from openmetadata.common.database_common import SQLConnectionConfig +from openmetadata.profiler.profiler import Profiler +from openmetadata.profiler.profiler_metadata import ProfileResult + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +class ProfilerConfig(ConfigModel): + profiler: DynamicTypedConfig + + +def type_class_fetch(clazz_type: str, is_file: bool): + if is_file: + return clazz_type.replace("-", "_") + else: + return "".join([i.title() for i in clazz_type.replace("-", "_").split("_")]) + + +def get_clazz(key: str) -> Type[T]: + if key.find(".") >= 0: + # If the key contains a dot, we treat it as a import path and attempt + # to load it dynamically. + module_name, class_name = key.rsplit(".", 1) + clazz = getattr(importlib.import_module(module_name), class_name) + return clazz + + +class ProfilerRunner: + config: ProfilerConfig + database: Database + + def __init__(self, config: ProfilerConfig): + self.config = config + database_type = self.config.profiler.type + database_class = get_clazz( + "openmetadata.databases.{}.{}".format( + type_class_fetch(database_type, True), + type_class_fetch(database_type, False), + ) + ) + self.profiler_config = self.config.profiler.dict().get("config", {}) + self.database: Database = database_class.create( + self.profiler_config.get("sql_connection", {}) + ) + self.table_name = self.profiler_config.get("table_name") + self.variables: dict = {} + self.time = datetime.now(tz=timezone.utc).isoformat(timespec="seconds") + + @classmethod + def create(cls, config_dict: dict) -> "ProfilerRunner": + config = ProfilerConfig.parse_obj(config_dict) + return cls(config) + + def execute(self): + try: + profiler = Profiler( + database=self.database, + table_name=self.table_name, + profile_time=self.time, + ) + profile_result: ProfileResult = profiler.execute() + return profile_result + except Exception as e: + logger.exception(f"Profiler failed: {str(e)}") + logger.info(f"Exiting with code 1") + sys.exit(1)