Issue 997-01 (#2476)

* Fix linting

* Fixed pyformat errors

* Address comments from PR review

* Added back phony in makefile

* Added back comment
This commit is contained in:
Teddy 2022-01-28 03:45:45 +01:00 committed by GitHub
parent 10a94b265c
commit 4f3e330dd8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 886 additions and 137 deletions

View File

@ -22,3 +22,6 @@ disable=no-name-in-module
[TYPECHECK] [TYPECHECK]
ignored-classes=optparse.Values,thread._local,_thread._local,SQLAlchemyHelper,FieldInfo ignored-classes=optparse.Values,thread._local,_thread._local,SQLAlchemyHelper,FieldInfo
[FORMAT]
max-line-length=88

View File

@ -31,10 +31,9 @@ precommit_install: ## Install the project's precommit hooks from .pre-commit-co
@echo "Make sure to first run install_test first" @echo "Make sure to first run install_test first"
pre-commit install pre-commit install
## Python Checkstyle
.PHONY: lint .PHONY: lint
lint: ## Run pylint on the Python sources to analyze the codebase lint: ## Run pylint on the Python sources to analyze the codebase
find $(PY_SOURCE) -path $(PY_SOURCE)/metadata/generated -prune -false -o -type f -name "*.py" | xargs pylint find $(PY_SOURCE) -path $(PY_SOURCE)/metadata/generated -prune -false -o -type f -name "*.py" | xargs pylint --ignore-paths=$(PY_SOURCE)/metadata_server/
.PHONY: py_format .PHONY: py_format
py_format: ## Run black and isort to format the Python codebase py_format: ## Run black and isort to format the Python codebase

View File

@ -12,10 +12,14 @@
Airflow backend lineage module Airflow backend lineage module
""" """
import metadata
def get_provider_config(): def get_provider_config():
"""
Get provider configuration
Returns
dict:
"""
return { return {
"name": "OpenMetadata", "name": "OpenMetadata",
"description": "OpenMetadata <https://open-metadata.org/>", "description": "OpenMetadata <https://open-metadata.org/>",

View File

@ -52,4 +52,4 @@ def lineage_callback(context: Dict[str, str]) -> None:
) )
except Exception as exc: # pylint: disable=broad-except except Exception as exc: # pylint: disable=broad-except
logging.error(f"Lineage Callback exception {exc}") logging.error("Lineage Callback exception %s", exc)

View File

@ -23,6 +23,16 @@ from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
class OpenMetadataLineageConfig(ConfigModel): class OpenMetadataLineageConfig(ConfigModel):
"""
Base class for OpenMetada lineage config
Attributes
airflow_service_name (str): name of the service
api_endpoint (str): the endpoint for the API
auth_provider_type (str):
secret_key (str):
"""
airflow_service_name: str = "airflow" airflow_service_name: str = "airflow"
api_endpoint: str = "http://localhost:8585" api_endpoint: str = "http://localhost:8585"
auth_provider_type: str = "no-auth" auth_provider_type: str = "no-auth"

View File

@ -13,55 +13,75 @@
OpenMetadata Airflow Lineage Backend OpenMetadata Airflow Lineage Backend
""" """
import ast
import json
import os
import traceback import traceback
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Union from typing import TYPE_CHECKING, Dict, List, Optional
from airflow.configuration import conf
from airflow.lineage.backend import LineageBackend from airflow.lineage.backend import LineageBackend
from airflow_provider_openmetadata.lineage.config import ( from airflow_provider_openmetadata.lineage.config import (
OpenMetadataLineageConfig,
get_lineage_config, get_lineage_config,
get_metadata_config, get_metadata_config,
) )
from airflow_provider_openmetadata.lineage.utils import ( from airflow_provider_openmetadata.lineage.utils import (
ALLOWED_FLOW_KEYS,
ALLOWED_TASK_KEYS,
create_pipeline_entity,
get_or_create_pipeline_service,
get_properties,
get_xlets, get_xlets,
is_airflow_version_1,
parse_lineage_to_openmetadata, parse_lineage_to_openmetadata,
) )
from metadata.config.common import ConfigModel
from metadata.generated.schema.api.data.createPipeline import (
CreatePipelineEntityRequest,
)
from metadata.generated.schema.api.lineage.addLineage import AddLineage
from metadata.generated.schema.api.services.createPipelineService import (
CreatePipelineServiceEntityRequest,
)
from metadata.generated.schema.entity.data.pipeline import Pipeline, Task
from metadata.generated.schema.entity.data.table import Table
from metadata.generated.schema.entity.services.pipelineService import (
PipelineService,
PipelineServiceType,
)
from metadata.generated.schema.type.entityLineage import EntitiesEdge
from metadata.generated.schema.type.entityReference import EntityReference
from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
from metadata.utils.helpers import convert_epoch_to_iso
if TYPE_CHECKING: if TYPE_CHECKING:
from airflow import DAG
from airflow.models.baseoperator import BaseOperator from airflow.models.baseoperator import BaseOperator
allowed_task_keys = [
"_downstream_task_ids",
"_inlets",
"_outlets",
"_task_type",
"_task_module",
"depends_on_past",
"email",
"label",
"execution_timeout",
"end_date",
"start_date",
"sla",
"sql",
"task_id",
"trigger_rule",
"wait_for_downstream",
]
allowed_flow_keys = [
"_access_control",
"_concurrency",
"_default_view",
"catchup",
"fileloc",
"is_paused_upon_creation",
"start_date",
"tags",
"timezone",
]
# pylint: disable=import-outside-toplevel, unused-import
def is_airflow_version_1() -> bool:
"""
Manage airflow submodule import based airflow version
Returns
bool
"""
try:
from airflow.hooks.base import BaseHook
return False
except ModuleNotFoundError:
from airflow.hooks.base_hook import BaseHook
return True
# pylint: disable=too-few-public-methods
class OpenMetadataLineageBackend(LineageBackend): class OpenMetadataLineageBackend(LineageBackend):
""" """
Sends lineage data from tasks to OpenMetadata. Sends lineage data from tasks to OpenMetadata.
@ -75,13 +95,21 @@ class OpenMetadataLineageBackend(LineageBackend):
auth_provider_type = no-auth # use google here if you are auth_provider_type = no-auth # use google here if you are
configuring google as SSO configuring google as SSO
secret_key = google-client-secret-key # it needs to be configured secret_key = google-client-secret-key # it needs to be configured
only if you are using google as SSO only if you are using google as SSO the one configured in openMetadata
openmetadata_api_endpoint = http://localhost:8585
auth_provider_type = no-auth # use google here if you are configuring google as SSO
secret_key = google-client-secret-key # it needs to be configured
only if you are using google as SSO
""" """
def __init__(self) -> None: def __init__(self) -> None:
"""
Instantiate a superclass object and run lineage config function
"""
super().__init__() super().__init__()
_ = get_lineage_config() _ = get_lineage_config()
# pylint: disable=protected-access
@staticmethod @staticmethod
def send_lineage( def send_lineage(
operator: "BaseOperator", operator: "BaseOperator",
@ -89,6 +117,17 @@ class OpenMetadataLineageBackend(LineageBackend):
outlets: Optional[List] = None, outlets: Optional[List] = None,
context: Dict = None, context: Dict = None,
) -> None: ) -> None:
"""
Send lineage to OpenMetadata
Args
operator (BaseOperator):
inlets (Optional[List]):
outlets (Optional[List]):
context (Dict):
Returns
None
"""
try: try:
config = get_lineage_config() config = get_lineage_config()

View File

@ -133,6 +133,7 @@ def get_xlets(
return None return None
# pylint: disable=too-many-arguments
def iso_dag_start_date(props: Dict[str, Any]) -> Optional[str]: def iso_dag_start_date(props: Dict[str, Any]) -> Optional[str]:
""" """
Given a properties dict, return the start_date Given a properties dict, return the start_date
@ -229,6 +230,7 @@ def create_pipeline_entity(
return client.create_or_update(create_pipeline) return client.create_or_update(create_pipeline)
# pylint: disable=too-many-arguments,too-many-locals
def parse_lineage_to_openmetadata( def parse_lineage_to_openmetadata(
config: OpenMetadataLineageConfig, config: OpenMetadataLineageConfig,
context: Dict, context: Dict,

View File

@ -18,13 +18,29 @@ from dataclasses import dataclass
from metadata.config.common import ConfigModel from metadata.config.common import ConfigModel
@dataclass # type: ignore[misc] @dataclass(init=False) # type: ignore[misc]
class AuthenticationProvider(metaclass=ABCMeta): class AuthenticationProvider(metaclass=ABCMeta):
"""
Interface definition for an Authentification provider
"""
@classmethod @classmethod
@abstractmethod @abstractmethod
def create(cls, config: ConfigModel) -> "AuthenticationProvider": def create(cls, config: ConfigModel) -> "AuthenticationProvider":
pass """
Create authentication
Arguments:
config (ConfigModel): configuration
Returns:
AuthenticationProvider
"""
@abstractmethod @abstractmethod
def auth_token(self) -> str: def auth_token(self) -> str:
pass """
Authentication token
Returns:
str
"""

View File

@ -25,7 +25,9 @@ logger = logging.getLogger(__name__)
class RetryException(Exception): class RetryException(Exception):
pass """
API Client retry exception
"""
class APIError(Exception): class APIError(Exception):
@ -41,10 +43,19 @@ class APIError(Exception):
@property @property
def code(self): def code(self):
"""
Return error code
"""
return self._error["code"] return self._error["code"]
@property @property
def status_code(self): def status_code(self):
"""
Return response status code
Returns:
int
"""
http_error = self._http_error http_error = self._http_error
if http_error is not None and hasattr(http_error, "response"): if http_error is not None and hasattr(http_error, "response"):
return http_error.response.status_code return http_error.response.status_code
@ -53,6 +64,9 @@ class APIError(Exception):
@property @property
def request(self): def request(self):
"""
Handle requests error
"""
if self._http_error is not None: if self._http_error is not None:
return self._http_error.request return self._http_error.request
@ -60,6 +74,10 @@ class APIError(Exception):
@property @property
def response(self): def response(self):
"""
Handle response error
:return:
"""
if self._http_error is not None: if self._http_error is not None:
return self._http_error.response return self._http_error.response
@ -83,6 +101,7 @@ class ClientConfig(ConfigModel):
allow_redirects: Optional[bool] = False allow_redirects: Optional[bool] = False
# pylint: disable=too-many-instance-attributes
class REST: class REST:
""" """
REST client wrapper to manage requests with REST client wrapper to manage requests with
@ -100,6 +119,7 @@ class REST:
self._retry_codes = self.config.retry_codes self._retry_codes = self.config.retry_codes
self._auth_token = self.config.auth_token self._auth_token = self.config.auth_token
# pylint: disable=too-many-arguments
def _request( def _request(
self, method, path, data=None, base_url: URL = None, api_version: str = None self, method, path, data=None, base_url: URL = None, api_version: str = None
): ):
@ -126,14 +146,16 @@ class REST:
retry = total_retries retry = total_retries
while retry >= 0: while retry >= 0:
try: try:
logger.debug("URL {}, method {}".format(url, method)) logger.debug("URL %s, method %s", url, method)
logger.debug("Data {}".format(opts)) logger.debug("Data %s", opts)
return self._one_request(method, url, opts, retry) return self._one_request(method, url, opts, retry)
except RetryException: except RetryException:
retry_wait = self._retry_wait * (total_retries - retry + 1) retry_wait = self._retry_wait * (total_retries - retry + 1)
logger.warning( logger.warning(
"sleep {} seconds and retrying {} " "sleep %s seconds and retrying %s " "%s more time(s)...",
"{} more time(s)...".format(retry_wait, url, retry) retry_wait,
url,
retry,
) )
time.sleep(retry_wait) time.sleep(retry_wait)
retry -= 1 retry -= 1
@ -152,11 +174,11 @@ class REST:
except HTTPError as http_error: except HTTPError as http_error:
# retry if we hit Rate Limit # retry if we hit Rate Limit
if resp.status_code in retry_codes and retry > 0: if resp.status_code in retry_codes and retry > 0:
raise RetryException() raise RetryException() from http_error
if "code" in resp.text: if "code" in resp.text:
error = resp.json() error = resp.json()
if "code" in error: if "code" in error:
raise APIError(error, http_error) raise APIError(error, http_error) from http_error
else: else:
raise raise
if resp.text != "": if resp.text != "":
@ -164,24 +186,77 @@ class REST:
return None return None
def get(self, path, data=None): def get(self, path, data=None):
"""
GET method
Parameters:
path (str):
data ():
Returns:
Response
"""
return self._request("GET", path, data) return self._request("GET", path, data)
def post(self, path, data=None): def post(self, path, data=None):
"""
POST method
Parameters:
path (str):
data ():
Returns:
Response
"""
return self._request("POST", path, data) return self._request("POST", path, data)
def put(self, path, data=None): def put(self, path, data=None):
"""
PUT method
Parameters:
path (str):
data ():
Returns:
Response
"""
return self._request("PUT", path, data) return self._request("PUT", path, data)
def patch(self, path, data=None): def patch(self, path, data=None):
"""
PATCH method
Parameters:
path (str):
data ():
Returns:
Response
"""
return self._request("PATCH", path, data) return self._request("PATCH", path, data)
def delete(self, path, data=None): def delete(self, path, data=None):
"""
DELETE method
Parameters:
path (str):
data ():
Returns:
Response
"""
return self._request("DELETE", path, data) return self._request("DELETE", path, data)
def __enter__(self): def __enter__(self):
return self return self
def close(self): def close(self):
"""
Close requests session
"""
self._session.close() self._session.close()
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):

View File

@ -22,6 +22,16 @@ Credentials = Tuple[str, str, str]
class URL(str): class URL(str):
"""
Handle URL for creds retrieval
Args:
value (tuple):
Attributes:
value (value):
"""
def __new__(cls, *value): def __new__(cls, *value):
""" """
note: we use *value and v0 to allow an empty URL string note: we use *value and v0 to allow an empty URL string
@ -60,8 +70,8 @@ class DATE(str):
try: try:
dateutil.parser.parse(value) dateutil.parser.parse(value)
except Exception as exc: except Exception as exc:
msg = f"{value} is not a valid date string: {exc}" msg = f"{value} is not a valid date string"
raise Exception(msg) raise Exception(msg) from exc
return str.__new__(cls, value) return str.__new__(cls, value)
@ -85,6 +95,16 @@ class FLOAT(str):
def get_credentials( def get_credentials(
key_id: str = None, secret_key: str = None, oauth: str = None key_id: str = None, secret_key: str = None, oauth: str = None
) -> Credentials: ) -> Credentials:
"""
Get credentials
Args:
key_id (str):
secret_key (str):
oauth (oauth):
Returns:
Credentials
"""
oauth = oauth or os.environ.get("OMETA_API_OAUTH_TOKEN") oauth = oauth or os.environ.get("OMETA_API_OAUTH_TOKEN")
key_id = key_id or os.environ.get("OMETA_API_KEY_ID") key_id = key_id or os.environ.get("OMETA_API_KEY_ID")
@ -105,6 +125,14 @@ def get_credentials(
def get_api_version(api_version: str) -> str: def get_api_version(api_version: str) -> str:
"""
Get version API
Args:
api_version (str):
Returns:
str
"""
api_version = api_version or os.environ.get("APCA_API_VERSION") api_version = api_version or os.environ.get("APCA_API_VERSION")
if api_version is None: if api_version is None:
api_version = "v1" api_version = "v1"

View File

@ -14,7 +14,7 @@ from metadata.ingestion.ometa.utils import get_entity_type
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
T = TypeVar("T", bound=BaseModel) T = TypeVar("T", bound=BaseModel) # pylint: disable=invalid-name
class OMetaLineageMixin(Generic[T]): class OMetaLineageMixin(Generic[T]):
@ -35,7 +35,7 @@ class OMetaLineageMixin(Generic[T]):
self.client.put(self.get_suffix(AddLineage), data=data.json()) self.client.put(self.get_suffix(AddLineage), data=data.json())
except APIError as err: except APIError as err:
logger.error( logger.error(
f"Error {err.status_code} trying to PUT lineage for {data.json()}" "Error %s trying to PUT lineage for %s", err.status_code, data.json()
) )
raise err raise err

View File

@ -87,7 +87,8 @@ class OMetaMlModelMixin(OMetaLineageMixin):
except ModuleNotFoundError as exc: except ModuleNotFoundError as exc:
logger.error( logger.error(
"Cannot import BaseEstimator, please install sklearn plugin: " "Cannot import BaseEstimator, please install sklearn plugin: "
+ f"pip install openmetadata-ingestion[sklearn], {exc}" "pip install openmetadata-ingestion[sklearn], %s",
exc,
) )
raise exc raise exc

View File

@ -97,7 +97,7 @@ class OMetaTableMixin:
resp = self.client.post( resp = self.client.post(
f"/usage/table/{table.id.__root__}", data=table_usage_request.json() f"/usage/table/{table.id.__root__}", data=table_usage_request.json()
) )
logger.debug("published table usage {}".format(resp)) logger.debug("published table usage %s", resp)
def publish_frequently_joined_with( def publish_frequently_joined_with(
self, table: Table, table_join_request: TableJoins self, table: Table, table_join_request: TableJoins
@ -108,9 +108,9 @@ class OMetaTableMixin:
:param table: Table Entity to update :param table: Table Entity to update
:param table_join_request: Join data to add :param table_join_request: Join data to add
""" """
logger.info("table join request {}".format(table_join_request.json())) logger.info("table join request %s", table_join_request.json())
resp = self.client.put( resp = self.client.put(
f"{self.get_suffix(Table)}/{table.id.__root__}/joins", f"{self.get_suffix(Table)}/{table.id.__root__}/joins",
data=table_join_request.json(), data=table_join_request.json(),
) )
logger.debug("published frequently joined with {}".format(resp)) logger.debug("published frequently joined with %s", resp)

View File

@ -15,7 +15,7 @@ from metadata.generated.schema.type.entityHistory import EntityVersionHistory
from metadata.ingestion.ometa.client import REST from metadata.ingestion.ometa.client import REST
from metadata.ingestion.ometa.utils import uuid_to_str from metadata.ingestion.ometa.utils import uuid_to_str
T = TypeVar("T", bound=BaseModel) T = TypeVar("T", bound=BaseModel) # pylint: disable=invalid-name
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -81,6 +81,15 @@ class InvalidEntityException(Exception):
class EntityList(Generic[T], BaseModel): class EntityList(Generic[T], BaseModel):
"""
Pydantic Entity list model
Attributes
entities (List): list of entities
total (int):
after (str):
"""
entities: List[T] entities: List[T]
total: int total: int
after: str = None after: str = None
@ -380,8 +389,11 @@ class OpenMetadata(
return entity(**resp) return entity(**resp)
except APIError as err: except APIError as err:
logger.error( logger.error(
f"GET {entity.__name__} for {path}. " "GET %s for %s." "Error %s - %s",
+ f"Error {err.status_code} - {err}" entity.__name__,
path,
err.status_code,
err,
) )
return None return None
@ -405,9 +417,10 @@ class OpenMetadata(
href=instance.href, href=instance.href,
) )
logger.error(f"Cannot find the Entity {fqdn}") logger.error("Cannot find the Entity %s", fqdn)
return None return None
# pylint: disable=too-many-arguments,dangerous-default-value
def list_entities( def list_entities(
self, self,
entity: Type[T], entity: Type[T],
@ -450,8 +463,7 @@ class OpenMetadata(
if self._use_raw_data: if self._use_raw_data:
return resp return resp
else: return EntityVersionHistory(**resp)
return EntityVersionHistory(**resp)
def list_services(self, entity: Type[T]) -> List[EntityList[T]]: def list_services(self, entity: Type[T]) -> List[EntityList[T]]:
""" """
@ -465,6 +477,15 @@ class OpenMetadata(
return [entity(**p) for p in resp["data"]] return [entity(**p) for p in resp["data"]]
def delete(self, entity: Type[T], entity_id: Union[str, basic.Uuid]) -> None: def delete(self, entity: Type[T], entity_id: Union[str, basic.Uuid]) -> None:
"""
API call to delete an entity from entity ID
Args
entity (T): entity Type
entity_id (basic.Uuid): entity ID
Returns
None
"""
self.client.delete(f"{self.get_suffix(entity)}/{uuid_to_str(entity_id)}") self.client.delete(f"{self.get_suffix(entity)}/{uuid_to_str(entity_id)}")
def compute_percentile(self, entity: Union[Type[T], str], date: str) -> None: def compute_percentile(self, entity: Union[Type[T], str], date: str) -> None:
@ -473,7 +494,7 @@ class OpenMetadata(
""" """
entity_name = get_entity_type(entity) entity_name = get_entity_type(entity)
resp = self.client.post(f"/usage/compute.percentile/{entity_name}/{date}") resp = self.client.post(f"/usage/compute.percentile/{entity_name}/{date}")
logger.debug("published compute percentile {}".format(resp)) logger.debug("published compute percentile %s", resp)
def list_tags_by_category(self, category: str) -> List[Tag]: def list_tags_by_category(self, category: str) -> List[Tag]:
""" """
@ -489,4 +510,10 @@ class OpenMetadata(
return self.client.get("/health-check")["status"] == "healthy" return self.client.get("/health-check")["status"] == "healthy"
def close(self): def close(self):
"""
Closing connection
Returns
None
"""
self.client.close() self.client.close()

View File

@ -44,30 +44,50 @@ TableProfiles = List[TableProfile]
class TableEntities(BaseModel): class TableEntities(BaseModel):
"""
Table entity pydantic model
"""
tables: List[Table] tables: List[Table]
total: int total: int
after: str = None after: str = None
class TopicEntities(BaseModel): class TopicEntities(BaseModel):
"""
Topic entity pydantic model
"""
topics: List[Topic] topics: List[Topic]
total: int total: int
after: str = None after: str = None
class DashboardEntities(BaseModel): class DashboardEntities(BaseModel):
"""
Dashboard entity pydantic model
"""
dashboards: List[Dashboard] dashboards: List[Dashboard]
total: int total: int
after: str = None after: str = None
class PipelineEntities(BaseModel): class PipelineEntities(BaseModel):
"""
Pipeline entity pydantic model
"""
pipelines: List[Pipeline] pipelines: List[Pipeline]
total: int total: int
after: str = None after: str = None
class MetadataServerConfig(ConfigModel): class MetadataServerConfig(ConfigModel):
"""
Metadata Server pydantic config model
"""
api_endpoint: str api_endpoint: str
api_version: str = "v1" api_version: str = "v1"
retry: int = 3 retry: int = 3
@ -84,6 +104,16 @@ class MetadataServerConfig(ConfigModel):
class NoOpAuthenticationProvider(AuthenticationProvider): class NoOpAuthenticationProvider(AuthenticationProvider):
"""
Extends AuthenticationProvider class
Args:
config (MetadataServerConfig):
Attributes:
config (MetadataServerConfig)
"""
def __init__(self, config: MetadataServerConfig): def __init__(self, config: MetadataServerConfig):
self.config = config self.config = config
@ -96,6 +126,16 @@ class NoOpAuthenticationProvider(AuthenticationProvider):
class GoogleAuthenticationProvider(AuthenticationProvider): class GoogleAuthenticationProvider(AuthenticationProvider):
"""
Google authentication implementation
Args:
config (MetadataServerConfig):
Attributes:
config (MetadataServerConfig)
"""
def __init__(self, config: MetadataServerConfig): def __init__(self, config: MetadataServerConfig):
self.config = config self.config = config
@ -142,6 +182,16 @@ class OktaAuthenticationProvider(AuthenticationProvider):
class Auth0AuthenticationProvider(AuthenticationProvider): class Auth0AuthenticationProvider(AuthenticationProvider):
"""
OAuth authentication implementation
Args:
config (MetadataServerConfig):
Attributes:
config (MetadataServerConfig)
"""
def __init__(self, config: MetadataServerConfig): def __init__(self, config: MetadataServerConfig):
self.config = config self.config = config
@ -153,7 +203,8 @@ class Auth0AuthenticationProvider(AuthenticationProvider):
conn = http.client.HTTPSConnection(self.config.domain) conn = http.client.HTTPSConnection(self.config.domain)
payload = ( payload = (
f"grant_type=client_credentials&client_id={self.config.client_id}" f"grant_type=client_credentials&client_id={self.config.client_id}"
f"&client_secret={self.config.secret_key}&audience=https://{self.config.domain}/api/v2/" f"&client_secret={self.config.secret_key}"
f"&audience=https://{self.config.domain}/api/v2/"
) )
headers = {"content-type": "application/x-www-form-urlencoded"} headers = {"content-type": "application/x-www-form-urlencoded"}
conn.request("POST", f"/{self.config.domain}/oauth/token", payload, headers) conn.request("POST", f"/{self.config.domain}/oauth/token", payload, headers)

View File

@ -25,6 +25,19 @@ logger = logging.getLogger(__name__)
class SupersetConfig(ConfigModel): class SupersetConfig(ConfigModel):
"""
Superset Configuration class
Attributes:
url (str):
username (Optional[str]):
password (Optional[str]):
service_name (str):
service_type (str):
provider (str):
options (dict):
"""
url: str = "localhost:8088" url: str = "localhost:8088"
username: Optional[str] = None username: Optional[str] = None
password: Optional[SecretStr] = None password: Optional[SecretStr] = None
@ -43,6 +56,7 @@ class SupersetAuthenticationProvider(AuthenticationProvider):
self.config = config self.config = config
client_config = ClientConfig(base_url=config.url, api_version="api/v1") client_config = ClientConfig(base_url=config.url, api_version="api/v1")
self.client = REST(client_config) self.client = REST(client_config)
super().__init__()
@classmethod @classmethod
def create(cls, config: SupersetConfig): def create(cls, config: SupersetConfig):
@ -84,29 +98,77 @@ class SupersetAPIClient:
self.client = REST(client_config) self.client = REST(client_config)
def fetch_total_dashboards(self) -> int: def fetch_total_dashboards(self) -> int:
"""
Fetch total dahsboard
Returns:
int
"""
response = self.client.get("/dashboard?q=(page:0,page_size:1)") response = self.client.get("/dashboard?q=(page:0,page_size:1)")
return response.get("count") or 0 return response.get("count") or 0
def fetch_dashboards(self, current_page: int, page_size: int): def fetch_dashboards(self, current_page: int, page_size: int):
"""
Fetch dashboards
Args:
current_page (int): current page number
page_size (int): total number of pages
Returns:
requests.Response
"""
response = self.client.get( response = self.client.get(
f"/dashboard?q=(page:{current_page},page_size:{page_size})" f"/dashboard?q=(page:{current_page},page_size:{page_size})"
) )
return response return response
def fetch_total_charts(self) -> int: def fetch_total_charts(self) -> int:
"""
Fetch the total number of charts
Returns:
int
"""
response = self.client.get("/chart?q=(page:0,page_size:1)") response = self.client.get("/chart?q=(page:0,page_size:1)")
return response.get("count") or 0 return response.get("count") or 0
def fetch_charts(self, current_page: int, page_size: int): def fetch_charts(self, current_page: int, page_size: int):
"""
Fetch charts
Args:
current_page (str):
page_size (str):
Returns:
requests.Response
"""
response = self.client.get( response = self.client.get(
f"/chart?q=(page:{current_page},page_size:{page_size})" f"/chart?q=(page:{current_page},page_size:{page_size})"
) )
return response return response
def fetch_datasource(self, datasource_id: str): def fetch_datasource(self, datasource_id: str):
"""
Fetch data source
Args:
datasource_id (str):
Returns:
requests.Response
"""
response = self.client.get(f"/dataset/{datasource_id}") response = self.client.get(f"/dataset/{datasource_id}")
return response return response
def fetch_database(self, database_id: str): def fetch_database(self, database_id: str):
"""
Fetch database
Args:
database_id (str):
Returns:
requests.Response
"""
response = self.client.get(f"/database/{database_id}") response = self.client.get(f"/database/{database_id}")
return response return response

View File

@ -20,7 +20,7 @@ from pydantic import BaseModel
from metadata.generated.schema.type import basic from metadata.generated.schema.type import basic
T = TypeVar("T", bound=BaseModel) T = TypeVar("T", bound=BaseModel) # pylint: disable=invalid-name
def format_name(name: str) -> str: def format_name(name: str) -> str:

View File

@ -8,10 +8,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Query parser implementation
"""
import datetime import datetime
import logging import logging
import traceback
from typing import Optional from typing import Optional
from sql_metadata import Parser from sql_metadata import Parser
@ -24,6 +26,10 @@ from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
class QueryParserProcessorConfig(ConfigModel): class QueryParserProcessorConfig(ConfigModel):
"""
Query parser pydantic configuration model
"""
filter: Optional[str] = None filter: Optional[str] = None
@ -31,6 +37,20 @@ logger = logging.getLogger(__name__)
class QueryParserProcessor(Processor): class QueryParserProcessor(Processor):
"""
Extension of the `Processor` class
Args:
ctx (WorkflowContext):
config (QueryParserProcessorConfig):
metadata_config (MetadataServerConfig):
Attributes:
config (QueryParserProcessorConfig):
metadata_config (MetadataServerConfig):
status (ProcessorStatus):
"""
config: QueryParserProcessorConfig config: QueryParserProcessorConfig
status: ProcessorStatus status: ProcessorStatus
@ -69,11 +89,11 @@ class QueryParserProcessor(Processor):
date=start_date.strftime("%Y-%m-%d"), date=start_date.strftime("%Y-%m-%d"),
service_name=record.service_name, service_name=record.service_name,
) )
# pylint: disable=broad-except
except Exception as err: except Exception as err:
logger.debug(record.sql) logger.debug(record.sql)
logger.error(err) logger.error(err)
query_parser_data = None query_parser_data = None
pass
return query_parser_data return query_parser_data

View File

@ -8,6 +8,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""gc source module"""
import logging import logging
import uuid import uuid
@ -37,16 +38,33 @@ logger: logging.Logger = logging.getLogger(__name__)
class GcsSourceConfig(ConfigModel): class GcsSourceConfig(ConfigModel):
"""GCS source pydantic config module"""
service_name: str service_name: str
class GcsSource(Source[Entity]): class GcsSource(Source[Entity]):
"""GCS source entity
Args:
config:
GcsSourceConfig:
metadata_config:
ctx:
Attributes:
config:
status:
service:
gcs:
"""
config: GcsSourceConfig config: GcsSourceConfig
status: SourceStatus status: SourceStatus
def __init__( def __init__(
self, config: GcsSourceConfig, metadata_config: MetadataServerConfig, ctx self, config: GcsSourceConfig, metadata_config: MetadataServerConfig, ctx
): ):
super().__init__(ctx)
self.config = config self.config = config
self.status = SourceStatus() self.status = SourceStatus()
self.service = get_storage_service_or_create( self.service = get_storage_service_or_create(
@ -105,8 +123,8 @@ class GcsSource(Source[Entity]):
location=location, location=location,
policy=policy, policy=policy,
) )
except Exception as e: except Exception as err: # pylint: disable=broad-except
self.status.failure("error", str(e)) self.status.failure("error", str(err))
def get_status(self) -> SourceStatus: def get_status(self) -> SourceStatus:
return self.status return self.status
@ -128,7 +146,7 @@ class GcsSource(Source[Entity]):
actions: List[Union[LifecycleDeleteAction, LifecycleMoveAction]] = [] actions: List[Union[LifecycleDeleteAction, LifecycleMoveAction]] = []
if "action" not in rule or "type" not in rule["action"]: if "action" not in rule or "type" not in rule["action"]:
return return None
name = policy_name name = policy_name
@ -156,6 +174,7 @@ class GcsSource(Source[Entity]):
return LifecycleRule( return LifecycleRule(
actions=actions, actions=actions,
enabled=True, # gcs bucket lifecycle policies do not have an enabled field, hence True. # gcs bucket lifecycle policies do not have an enabled field, hence True.
enabled=True,
name=name, name=name,
) )

View File

@ -8,6 +8,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Metabase source module"""
import json import json
import logging import logging
@ -37,7 +38,6 @@ from metadata.ingestion.api.source import Source, SourceStatus
from metadata.ingestion.models.table_metadata import Chart, Dashboard from metadata.ingestion.models.table_metadata import Chart, Dashboard
from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
from metadata.ingestion.source.sql_alchemy_helper import SQLSourceStatus
from metadata.ingestion.source.sql_source import SQLSourceStatus from metadata.ingestion.source.sql_source import SQLSourceStatus
from metadata.utils.helpers import get_dashboard_service_or_create from metadata.utils.helpers import get_dashboard_service_or_create
@ -48,6 +48,8 @@ logger: logging.Logger = logging.getLogger(__name__)
class MetabaseSourceConfig(ConfigModel): class MetabaseSourceConfig(ConfigModel):
"""Metabase pydantic config model"""
username: str username: str
password: SecretStr password: SecretStr
host_port: str host_port: str
@ -58,10 +60,26 @@ class MetabaseSourceConfig(ConfigModel):
database_service_name: str = None database_service_name: str = None
def get_connection_url(self): def get_connection_url(self):
pass """get connection url (not implemented)"""
class MetabaseSource(Source[Entity]): class MetabaseSource(Source[Entity]):
"""Metabase entity class
Args:
config:
metadata_config:
ctx:
Attributes:
config:
metadata_config:
status:
metabase_session:
dashboard_service:
charts:
metric_charts:
"""
config: MetabaseSourceConfig config: MetabaseSourceConfig
metadata_config: MetadataServerConfig metadata_config: MetadataServerConfig
status: SQLSourceStatus status: SQLSourceStatus
@ -86,7 +104,7 @@ class MetabaseSource(Source[Entity]):
headers=HEADERS, headers=HEADERS,
) )
except Exception as err: except Exception as err:
raise ConnectionError(f"{err}") raise ConnectionError() from err
session_id = resp.json()["id"] session_id = resp.json()["id"]
self.metabase_session = {"X-Metabase-Session": session_id} self.metabase_session = {"X-Metabase-Session": session_id}
self.dashboard_service = get_dashboard_service_or_create( self.dashboard_service = get_dashboard_service_or_create(
@ -102,6 +120,15 @@ class MetabaseSource(Source[Entity]):
@classmethod @classmethod
def create(cls, config_dict, metadata_config_dict, ctx): def create(cls, config_dict, metadata_config_dict, ctx):
"""Instantiate object
Args:
config_dict:
metadata_config_dict:
ctx:
Returns:
MetabaseSource
"""
config = MetabaseSourceConfig.parse_obj(config_dict) config = MetabaseSourceConfig.parse_obj(config_dict)
metadata_config = MetadataServerConfig.parse_obj(metadata_config_dict) metadata_config = MetadataServerConfig.parse_obj(metadata_config_dict)
return cls(config, metadata_config, ctx) return cls(config, metadata_config, ctx)
@ -110,11 +137,18 @@ class MetabaseSource(Source[Entity]):
yield from self.get_dashboards() yield from self.get_dashboards()
def get_charts(self, charts) -> Iterable[Chart]: def get_charts(self, charts) -> Iterable[Chart]:
"""Get chart method
Args:
charts:
Returns:
Iterable[Chart]
"""
for chart in charts: for chart in charts:
try: try:
chart_details = chart["card"] chart_details = chart["card"]
if not self.config.chart_pattern.included(chart_details["name"]): if not self.config.chart_pattern.included(chart_details["name"]):
self.status.filter(chart_details["name"]) self.status.filter(chart_details["name"], None)
continue continue
yield Chart( yield Chart(
id=uuid.uuid4(), id=uuid.uuid4(),
@ -131,12 +165,13 @@ class MetabaseSource(Source[Entity]):
) )
self.charts.append(chart_details["name"]) self.charts.append(chart_details["name"])
self.status.scanned(chart_details["name"]) self.status.scanned(chart_details["name"])
except Exception as err: except Exception as err: # pylint: disable=broad-except
logger.error(repr(err)) logger.error(repr(err))
traceback.print_exc() traceback.print_exc()
continue continue
def get_dashboards(self): def get_dashboards(self):
"""Get dashboard method"""
resp_dashboards = self.req_get("/api/dashboard") resp_dashboards = self.req_get("/api/dashboard")
if resp_dashboards.status_code == 200: if resp_dashboards.status_code == 200:
for dashboard in resp_dashboards.json(): for dashboard in resp_dashboards.json():
@ -146,7 +181,7 @@ class MetabaseSource(Source[Entity]):
if not self.config.dashboard_pattern.included( if not self.config.dashboard_pattern.included(
dashboard_details["name"] dashboard_details["name"]
): ):
self.status.filter(dashboard_details["name"]) self.status.filter(dashboard_details["name"], None)
continue continue
yield from self.get_charts(dashboard_details["ordered_cards"]) yield from self.get_charts(dashboard_details["ordered_cards"])
yield Dashboard( yield Dashboard(
@ -167,6 +202,12 @@ class MetabaseSource(Source[Entity]):
) )
def get_lineage(self, chart_list, dashboard_name): def get_lineage(self, chart_list, dashboard_name):
"""Get lineage method
Args:
chart_list:
dashboard_name
"""
metadata = OpenMetadata(self.metadata_config) metadata = OpenMetadata(self.metadata_config)
for chart in chart_list: for chart in chart_list:
try: try:
@ -174,7 +215,8 @@ class MetabaseSource(Source[Entity]):
resp_tables = self.req_get(f"/api/table/{chart_details['table_id']}") resp_tables = self.req_get(f"/api/table/{chart_details['table_id']}")
if resp_tables.status_code == 200: if resp_tables.status_code == 200:
table = resp_tables.json() table = resp_tables.json()
table_fqdn = f"{self.config.database_service_name}.{table['schema']}.{table['name']}" table_fqdn = f"{self.config.database_service_name}.\
{table['schema']}.{table['name']}"
dashboard_fqdn = ( dashboard_fqdn = (
f"{self.dashboard_service.name}.{quote(dashboard_name)}" f"{self.dashboard_service.name}.{quote(dashboard_name)}"
) )
@ -182,7 +224,7 @@ class MetabaseSource(Source[Entity]):
chart_entity = metadata.get_by_name( chart_entity = metadata.get_by_name(
entity=Model_Dashboard, fqdn=dashboard_fqdn entity=Model_Dashboard, fqdn=dashboard_fqdn
) )
logger.debug("from entity {}".format(table_entity)) logger.debug("from entity %s", table_entity)
lineage = AddLineage( lineage = AddLineage(
edge=EntitiesEdge( edge=EntitiesEdge(
fromEntity=EntityReference( fromEntity=EntityReference(
@ -194,10 +236,15 @@ class MetabaseSource(Source[Entity]):
) )
) )
yield lineage yield lineage
except Exception as err: except Exception as err: # pylint: disable=broad-except,unused-variable
logger.debug(traceback.print_exc()) logger.error(traceback.print_exc())
def req_get(self, path): def req_get(self, path):
"""Send get request method
Args:
path:
"""
return requests.get(self.config.host_port + path, headers=self.metabase_session) return requests.get(self.config.host_port + path, headers=self.metabase_session)
def get_status(self) -> SourceStatus: def get_status(self) -> SourceStatus:

View File

@ -8,6 +8,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Metadata source module"""
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -27,6 +28,8 @@ logger = logging.getLogger(__name__)
class MetadataTablesRestSourceConfig(ConfigModel): class MetadataTablesRestSourceConfig(ConfigModel):
"""Metadata Table Rest pydantic config model"""
include_tables: Optional[bool] = True include_tables: Optional[bool] = True
include_topics: Optional[bool] = True include_topics: Optional[bool] = True
include_dashboards: Optional[bool] = True include_dashboards: Optional[bool] = True
@ -36,31 +39,78 @@ class MetadataTablesRestSourceConfig(ConfigModel):
@dataclass @dataclass
class MetadataSourceStatus(SourceStatus): class MetadataSourceStatus(SourceStatus):
"""Metadata Source class -- extends SourceStatus class
Attributes:
success:
failures:
warnings:
"""
success: List[str] = field(default_factory=list) success: List[str] = field(default_factory=list)
failures: List[str] = field(default_factory=list) failures: List[str] = field(default_factory=list)
warnings: List[str] = field(default_factory=list) warnings: List[str] = field(default_factory=list)
def scanned_table(self, table_name: str) -> None: def scanned_table(self, table_name: str) -> None:
"""scanned table method
Args:
table_name (str):
"""
self.success.append(table_name) self.success.append(table_name)
logger.info("Table Scanned: {}".format(table_name)) logger.info("Table Scanned: %s", table_name)
def scanned_topic(self, topic_name: str) -> None: def scanned_topic(self, topic_name: str) -> None:
"""scanned topic method
Args:
topic_name (str):
"""
self.success.append(topic_name) self.success.append(topic_name)
logger.info("Topic Scanned: {}".format(topic_name)) logger.info("Topic Scanned: %s", topic_name)
def scanned_dashboard(self, dashboard_name: str) -> None: def scanned_dashboard(self, dashboard_name: str) -> None:
self.success.append(dashboard_name) """scanned dashboard method
logger.info("Dashboard Scanned: {}".format(dashboard_name))
Args:
dashboard_name (str)
"""
self.success.append(dashboard_name)
logger.info("Dashboard Scanned: %s", dashboard_name)
# pylint: disable=unused-argument
def filtered( def filtered(
self, table_name: str, err: str, dataset_name: str = None, col_type: str = None self, table_name: str, err: str, dataset_name: str = None, col_type: str = None
) -> None: ) -> None:
"""filtered methods
Args:
table_name (str):
err (str):
"""
self.warnings.append(table_name) self.warnings.append(table_name)
logger.warning("Dropped Entity {} due to {}".format(table_name, err)) logger.warning("Dropped Entity %s due to %s", table_name, err)
class MetadataSource(Source[Entity]): class MetadataSource(Source[Entity]):
"""Metadata source class
Args:
config:
metadata_config:
ctx:
Attributes:
config:
report:
metadata_config:
status:
wrote_something:
metadata:
tables:
topics:
"""
config: MetadataTablesRestSourceConfig config: MetadataTablesRestSourceConfig
report: SourceStatus report: SourceStatus
@ -97,6 +147,11 @@ class MetadataSource(Source[Entity]):
yield from self.fetch_pipeline() yield from self.fetch_pipeline()
def fetch_table(self) -> Table: def fetch_table(self) -> Table:
"""Fetch table method
Returns:
Table
"""
if self.config.include_tables: if self.config.include_tables:
after = None after = None
while True: while True:
@ -121,6 +176,11 @@ class MetadataSource(Source[Entity]):
after = table_entities.after after = table_entities.after
def fetch_topic(self) -> Topic: def fetch_topic(self) -> Topic:
"""fetch topic method
Returns:
Topic
"""
if self.config.include_topics: if self.config.include_topics:
after = None after = None
while True: while True:
@ -138,6 +198,11 @@ class MetadataSource(Source[Entity]):
after = topic_entities.after after = topic_entities.after
def fetch_dashboard(self) -> Dashboard: def fetch_dashboard(self) -> Dashboard:
"""fetch dashboard method
Returns:
Dashboard:
"""
if self.config.include_dashboards: if self.config.include_dashboards:
after = None after = None
while True: while True:
@ -161,6 +226,11 @@ class MetadataSource(Source[Entity]):
after = dashboard_entities.after after = dashboard_entities.after
def fetch_pipeline(self) -> Pipeline: def fetch_pipeline(self) -> Pipeline:
"""fetch pipeline method
Returns:
Pipeline:
"""
if self.config.include_pipelines: if self.config.include_pipelines:
after = None after = None
while True: while True:

View File

@ -8,6 +8,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""ml flow source module"""
import ast import ast
import logging import logging
@ -41,26 +42,26 @@ class MlFlowStatus(SourceStatus):
failures: List[str] = field(default_factory=list) failures: List[str] = field(default_factory=list)
warnings: List[str] = field(default_factory=list) warnings: List[str] = field(default_factory=list)
def scanned(self, model_name: str) -> None: def scanned(self, record: str) -> None:
""" """
Log successful ML Model scans Log successful ML Model scans
""" """
self.success.append(model_name) self.success.append(record)
logger.info(f"ML Model scanned: {model_name}") logger.info("ML Model scanned: %s", record)
def failed(self, model_name: str, reason: str) -> None: def failed(self, model_name: str, reason: str) -> None:
""" """
Log failed ML Model scans Log failed ML Model scans
""" """
self.failures.append(model_name) self.failures.append(model_name)
logger.error(f"ML Model failed: {model_name} - {reason}") logger.error("ML Model failed: %s - %s", model_name, reason)
def warned(self, model_name: str, reason: str) -> None: def warned(self, model_name: str, reason: str) -> None:
""" """
Log Ml Model with warnings Log Ml Model with warnings
""" """
self.warnings.append(model_name) self.warnings.append(model_name)
logger.warning(f"ML Model warning: {model_name} - {reason}") logger.warning("ML Model warning: %s - %s", model_name, reason)
class MlFlowConnectionConfig(ConfigModel): class MlFlowConnectionConfig(ConfigModel):
@ -195,6 +196,7 @@ class MlflowSource(Source[CreateMlModelEntityRequest]):
for feature in features for feature in features
] ]
# pylint: disable=broad-except)
except Exception as exc: except Exception as exc:
reason = f"Cannot extract properties from RunData {exc}" reason = f"Cannot extract properties from RunData {exc}"
logging.warning(reason) logging.warning(reason)
@ -209,4 +211,3 @@ class MlflowSource(Source[CreateMlModelEntityRequest]):
""" """
Don't need to close the client Don't need to close the client
""" """
pass

View File

@ -8,8 +8,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""MSSQL source module"""
import sqlalchemy_pytds # noqa: F401
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
from metadata.ingestion.source.sql_source import SQLSource from metadata.ingestion.source.sql_source import SQLSource
@ -17,6 +16,8 @@ from metadata.ingestion.source.sql_source_common import SQLConnectionConfig
class MssqlConfig(SQLConnectionConfig): class MssqlConfig(SQLConnectionConfig):
"""MSSQL config -- extends SQLConnectionConfig class"""
host_port = "localhost:1433" host_port = "localhost:1433"
scheme = "mssql+pytds" scheme = "mssql+pytds"
service_type = "MSSQL" service_type = "MSSQL"
@ -28,17 +29,23 @@ class MssqlConfig(SQLConnectionConfig):
if self.use_pyodbc: if self.use_pyodbc:
self.scheme = "mssql+pyodbc" self.scheme = "mssql+pyodbc"
return f"{self.scheme}://{self.uri_string}" return f"{self.scheme}://{self.uri_string}"
elif self.use_pymssql: if self.use_pymssql:
self.scheme = "mssql+pymssql" self.scheme = "mssql+pymssql"
return super().get_connection_url() return super().get_connection_url()
class MssqlSource(SQLSource): class MssqlSource(SQLSource):
def __init__(self, config, metadata_config, ctx): """MSSQL Source class
super().__init__(config, metadata_config, ctx)
Args:
config:
metadata_config:
ctx
"""
@classmethod @classmethod
def create(cls, config_dict, metadata_config_dict, ctx): def create(cls, config_dict, metadata_config_dict, ctx):
"""Create class instance"""
config = MssqlConfig.parse_obj(config_dict) config = MssqlConfig.parse_obj(config_dict)
metadata_config = MetadataServerConfig.parse_obj(metadata_config_dict) metadata_config = MetadataServerConfig.parse_obj(metadata_config_dict)
return cls(config, metadata_config, ctx) return cls(config, metadata_config, ctx)

View File

@ -8,6 +8,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Redshift source ingestion
"""
import logging import logging
import re import re
@ -16,11 +20,6 @@ from typing import Optional
import sqlalchemy as sa import sqlalchemy as sa
from packaging.version import Version from packaging.version import Version
from metadata.ingestion.source.sql_source_common import SQLConnectionConfig
sa_version = Version(sa.__version__)
from sqlalchemy import inspect from sqlalchemy import inspect
from sqlalchemy.engine import reflection from sqlalchemy.engine import reflection
from sqlalchemy.types import CHAR, VARCHAR, NullType from sqlalchemy.types import CHAR, VARCHAR, NullType
@ -29,31 +28,68 @@ from sqlalchemy_redshift.dialect import RedshiftDialectMixin, RelationKey
from metadata.ingestion.api.source import SourceStatus from metadata.ingestion.api.source import SourceStatus
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
from metadata.ingestion.source.sql_source import SQLSource from metadata.ingestion.source.sql_source import SQLSource
from metadata.ingestion.source.sql_source_common import SQLConnectionConfig
from metadata.utils.sql_queries import ( from metadata.utils.sql_queries import (
REDSHIFT_GET_ALL_RELATION_INFO, REDSHIFT_GET_ALL_RELATION_INFO,
REDSHIFT_GET_SCHEMA_COLUMN_INFO, REDSHIFT_GET_SCHEMA_COLUMN_INFO,
) )
sa_version = Version(sa.__version__)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@reflection.cache @reflection.cache
def get_table_names(self, connection, schema=None, **kw): def get_table_names(self, connection, schema=None, **kw):
return self._get_table_or_view_names(["r", "e"], connection, schema, **kw) """
Get table names
Args:
connection ():
schema ():
**kw:
Returns:
"""
return self._get_table_or_view_names(
["r", "e"], connection, schema, **kw
) # pylint: disable=protected-access
@reflection.cache @reflection.cache
def get_view_names(self, connection, schema=None, **kw): def get_view_names(self, connection, schema=None, **kw):
return self._get_table_or_view_names(["v"], connection, schema, **kw) """
Get view name
Args:
connection ():
schema ():
**kw:
Returns:
"""
return self._get_table_or_view_names(
["v"], connection, schema, **kw
) # pylint: disable=protected-access
@reflection.cache @reflection.cache
def _get_table_or_view_names(self, relkinds, connection, schema=None, **kw): def _get_table_or_view_names(self, relkinds, connection, schema=None, **kw):
"""
Get table or view name
Args:
relkinds:
connection:
schema:
**kw:
Returns
"""
default_schema = inspect(connection).default_schema_name default_schema = inspect(connection).default_schema_name
if not schema: if not schema:
schema = default_schema schema = default_schema
info_cache = kw.get("info_cache") info_cache = kw.get("info_cache")
all_relations = self._get_all_relation_info(connection, info_cache=info_cache) all_relations = self._get_all_relation_info(
connection, info_cache=info_cache
) # pylint: disable=protected-access
relation_names = [] relation_names = []
for key, relation in all_relations.items(): for key, relation in all_relations.items():
if key.schema == schema and relation.relkind in relkinds: if key.schema == schema and relation.relkind in relkinds:
@ -62,16 +98,26 @@ def _get_table_or_view_names(self, relkinds, connection, schema=None, **kw):
def _get_column_info(self, *args, **kwargs): def _get_column_info(self, *args, **kwargs):
kw = kwargs.copy() """
encode = kw.pop("encode", None) Get column info
Args:
*args:
**kwargs:
Returns
"""
kwdrs = kwargs.copy()
encode = kwdrs.pop("encode", None)
if sa_version >= Version("1.3.16"): if sa_version >= Version("1.3.16"):
kw["generated"] = "" kwdrs["generated"] = ""
if sa_version < Version("1.4.0") and "identity" in kw: if sa_version < Version("1.4.0") and "identity" in kwdrs:
del kw["identity"] del kwdrs["identity"]
elif sa_version >= Version("1.4.0") and "identity" not in kw: elif sa_version >= Version("1.4.0") and "identity" not in kwdrs:
kw["identity"] = None kwdrs["identity"] = None
column_info = super(RedshiftDialectMixin, self)._get_column_info(*args, **kw) column_info = super(RedshiftDialectMixin, self)._get_column_info(
column_info["raw_data_type"] = kw["format_type"] *args, **kwdrs
) # pylint: disable=protected-access
column_info["raw_data_type"] = kwdrs["format_type"]
if isinstance(column_info["type"], VARCHAR): if isinstance(column_info["type"], VARCHAR):
if column_info["type"].length is None: if column_info["type"].length is None:
@ -86,8 +132,17 @@ def _get_column_info(self, *args, **kwargs):
return column_info return column_info
# pylint: disable=unused-argument
@reflection.cache @reflection.cache
def _get_all_relation_info(self, connection, **kw): def _get_all_relation_info(self, connection, **kw):
"""
Get all relation info
Args:
connection:
**kw:
Returns
"""
result = connection.execute(REDSHIFT_GET_ALL_RELATION_INFO) result = connection.execute(REDSHIFT_GET_ALL_RELATION_INFO)
relations = {} relations = {}
for rel in result: for rel in result:
@ -111,10 +166,19 @@ def _get_all_relation_info(self, connection, **kw):
@reflection.cache @reflection.cache
def _get_schema_column_info(self, connection, schema=None, **kw): def _get_schema_column_info(self, connection, schema=None, **kw):
schema_clause = "AND schema = '{schema}'".format(schema=schema) if schema else "" """
Get schema column info
Args:
connection:
schema:
**kw:
Returns:
"""
schema_clause = f"AND schema = '{schema if schema else ''}'"
all_columns = defaultdict(list) all_columns = defaultdict(list)
with connection.connect() as cc: with connection.connect() as cnct:
result = cc.execute( result = cnct.execute(
REDSHIFT_GET_SCHEMA_COLUMN_INFO.format(schema_clause=schema_clause) REDSHIFT_GET_SCHEMA_COLUMN_INFO.format(schema_clause=schema_clause)
) )
for col in result: for col in result:
@ -123,39 +187,98 @@ def _get_schema_column_info(self, connection, schema=None, **kw):
return dict(all_columns) return dict(all_columns)
RedshiftDialectMixin._get_table_or_view_names = _get_table_or_view_names RedshiftDialectMixin._get_table_or_view_names = (
RedshiftDialectMixin.get_view_names = get_view_names _get_table_or_view_names # pylint: disable=protected-access
RedshiftDialectMixin.get_table_names = get_table_names )
RedshiftDialectMixin._get_column_info = _get_column_info RedshiftDialectMixin.get_view_names = get_view_names # pylint: disable=protected-access
RedshiftDialectMixin._get_all_relation_info = _get_all_relation_info RedshiftDialectMixin.get_table_names = (
RedshiftDialectMixin._get_schema_column_info = _get_schema_column_info get_table_names # pylint: disable=protected-access
)
RedshiftDialectMixin._get_column_info = (
_get_column_info # pylint: disable=protected-access
)
RedshiftDialectMixin._get_all_relation_info = (
_get_all_relation_info # pylint: disable=protected-access
)
RedshiftDialectMixin._get_schema_column_info = (
_get_schema_column_info # pylint: disable=protected-access
)
# pylint: disable=useless-super-delegation
class RedshiftConfig(SQLConnectionConfig): class RedshiftConfig(SQLConnectionConfig):
"""
Redshift config class
Attributes:
scheme:
where_clause:
duration:
service_type:
"""
scheme = "redshift+psycopg2" scheme = "redshift+psycopg2"
where_clause: Optional[str] = None where_clause: Optional[str] = None
duration: int = 1 duration: int = 1
service_type = "Redshift" service_type = "Redshift"
def get_identifier(self, schema: str, table: str) -> str: def get_identifier(self, schema: str, table: str) -> str:
"""
Get identifier
Args:
schema:
table:
Returns:
str
"""
regular = f"{schema}.{table}" regular = f"{schema}.{table}"
if self.database: if self.database:
return f"{self.database}.{regular}" return f"{self.database}.{regular}"
return regular return regular
def get_connection_url(self): def get_connection_url(self):
"""
Get connection url
Returns:
"""
return super().get_connection_url() return super().get_connection_url()
# pylint: disable=useless-super-delegation
class RedshiftSource(SQLSource): class RedshiftSource(SQLSource):
"""
Redshift source cloass
Args:
confi:
metadata_config:
ctx:
"""
def __init__(self, config, metadata_config, ctx): def __init__(self, config, metadata_config, ctx):
super().__init__(config, metadata_config, ctx) super().__init__(config, metadata_config, ctx)
@classmethod @classmethod
def create(cls, config_dict, metadata_config_dict, ctx): def create(cls, config_dict, metadata_config_dict, ctx):
"""
Create source
Args:
config_dict:
metadata_config_dict:
ctx:
Returns:
"""
config = RedshiftConfig.parse_obj(config_dict) config = RedshiftConfig.parse_obj(config_dict)
metadata_config = MetadataServerConfig.parse_obj(metadata_config_dict) metadata_config = MetadataServerConfig.parse_obj(metadata_config_dict)
return cls(config, metadata_config, ctx) return cls(config, metadata_config, ctx)
def get_status(self) -> SourceStatus: def get_status(self) -> SourceStatus:
"""
Get status
Returns
"""
return self.status return self.status

View File

@ -8,6 +8,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Snowflake usage module
"""
from typing import Any, Dict, Iterable, Iterator, Union from typing import Any, Dict, Iterable, Iterator, Union
@ -26,7 +29,33 @@ from metadata.utils.sql_queries import SNOWFLAKE_SQL_STATEMENT
class SnowflakeUsageSource(Source[TableQuery]): class SnowflakeUsageSource(Source[TableQuery]):
# SELECT statement from mysql information_schema to extract table and column metadata """
Snowflake Usage source
Args:
config:
metadata_config:
ctx:
Attributes:
config:
analysis_date:
sql_stmt:
alchemy_helper:
_extract_iter:
_database:
report:
SQL_STATEMENT (str):
WHERE_CLAUSE_SUFFIX_KEY (str):
CLUSTER_SOURCE (str):
USE_CATALOG_AS_CLUSTER_NAME (str):
DATABASE_KEY (str):
SERVICE_TYPE (str):
DEFAULT_CLUSTER_SOURCE (str):
"""
# SELECT statement from mysql information_schema
# to extract table and column metadata
SQL_STATEMENT = SNOWFLAKE_SQL_STATEMENT SQL_STATEMENT = SNOWFLAKE_SQL_STATEMENT
# CONFIG KEYS # CONFIG KEYS
@ -73,17 +102,18 @@ class SnowflakeUsageSource(Source[TableQuery]):
def next_record(self) -> Iterable[TableQuery]: def next_record(self) -> Iterable[TableQuery]:
""" """
Using itertools.groupby and raw level iterator, it groups to table and yields TableMetadata Using itertools.groupby and raw level iterator,
it groups to table and yields TableMetadata
:return: :return:
""" """
for row in self._get_raw_extract_iter(): for row in self._get_raw_extract_iter():
tq = TableQuery( table_query = TableQuery(
query=row["query_type"], query=row["query_type"],
user_name=row["user_name"], user_name=row["user_name"],
starttime=str(row["start_time"]), starttime=str(row["start_time"]),
endtime=str(row["end_time"]), endtime=str(row["end_time"]),
analysis_date=self.analysis_date, analysis_date=self.analysis_date,
aborted=True if "1969" in str(row["end_time"]) else False, aborted="1969" in str(row["end_time"]),
database=row["database_name"], database=row["database_name"],
sql=row["query_text"], sql=row["query_text"],
service_name=self.config.service_name, service_name=self.config.service_name,
@ -92,9 +122,14 @@ class SnowflakeUsageSource(Source[TableQuery]):
self.report.scanned(f"{row['database_name']}.{row['schema_name']}") self.report.scanned(f"{row['database_name']}.{row['schema_name']}")
else: else:
self.report.scanned(f"{row['database_name']}") self.report.scanned(f"{row['database_name']}")
yield tq yield table_query
def get_report(self): def get_report(self):
"""
get report
Returns:
"""
return self.report return self.report
def close(self): def close(self):

View File

@ -8,6 +8,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Superset source module
"""
import json import json
from typing import Iterable from typing import Iterable
@ -27,16 +31,32 @@ from metadata.utils.helpers import get_dashboard_service_or_create
def get_metric_name(metric): def get_metric_name(metric):
"""
Get metric name
Args:
metric:
Returns:
"""
if not metric: if not metric:
return "" return ""
if isinstance(metric, str): if isinstance(metric, str):
return metric return metric
label = metric.get("label") label = metric.get("label")
if label:
return label return label or None
def get_filter_name(filter_obj): def get_filter_name(filter_obj):
"""
Get filter name
Args:
filter_obj:
Returns:
str
"""
sql_expression = filter_obj.get("sqlExpression") sql_expression = filter_obj.get("sqlExpression")
if sql_expression: if sql_expression:
return sql_expression return sql_expression
@ -49,6 +69,14 @@ def get_filter_name(filter_obj):
def get_owners(owners_obj): def get_owners(owners_obj):
"""
Get owner
Args:
owners_obj:
Returns:
list
"""
owners = [] owners = []
for owner in owners_obj: for owner in owners_obj:
dashboard_owner = DashboardOwner( dashboard_owner = DashboardOwner(
@ -60,7 +88,17 @@ def get_owners(owners_obj):
return owners return owners
# pylint: disable=too-many-return-statements, too-many-branches
def get_service_type_from_database_uri(uri: str) -> str: def get_service_type_from_database_uri(uri: str) -> str:
"""
Get service type from database URI
Args:
uri (str):
Returns:
str
"""
if uri.startswith("bigquery"): if uri.startswith("bigquery"):
return "bigquery" return "bigquery"
if uri.startswith("druid"): if uri.startswith("druid"):
@ -91,6 +129,24 @@ def get_service_type_from_database_uri(uri: str) -> str:
class SupersetSource(Source[Entity]): class SupersetSource(Source[Entity]):
"""
Superset source class
Args:
config:
metadata_config:
ctx:
Attributes:
config:
metadata_config:
status:
platform:
service_type:
service:
"""
config: SupersetConfig config: SupersetConfig
metadata_config: MetadataServerConfig metadata_config: MetadataServerConfig
status: SourceStatus status: SourceStatus
@ -197,6 +253,7 @@ class SupersetSource(Source[Entity]):
return dataset_fqn return dataset_fqn
return None return None
# pylint: disable=too-many-locals
def _build_chart(self, chart_json) -> Chart: def _build_chart(self, chart_json) -> Chart:
chart_id = chart_json["id"] chart_id = chart_json["id"]
name = chart_json["slice_name"] name = chart_json["slice_name"]

View File

@ -8,10 +8,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Tableau source module
"""
import logging import logging
import uuid import uuid
from typing import Iterable, Optional from typing import Iterable, List, Optional
import dateutil.parser as dateparser import dateutil.parser as dateparser
from pydantic import SecretStr from pydantic import SecretStr
@ -37,6 +40,8 @@ logger = logging.getLogger(__name__)
class TableauSourceConfig(ConfigModel): class TableauSourceConfig(ConfigModel):
"""Tableau pydantic source model"""
username: str username: str
password: SecretStr password: SecretStr
server: str server: str
@ -51,6 +56,22 @@ class TableauSourceConfig(ConfigModel):
class TableauSource(Source[Entity]): class TableauSource(Source[Entity]):
"""Tableau source entity class
Args:
config:
metadata_config:
ctx:
Attributes:
config:
metadata_config:
status:
service:
dashboard:
all_dashboard_details:
"""
config: TableauSourceConfig config: TableauSourceConfig
metadata_config: MetadataServerConfig metadata_config: MetadataServerConfig
status: SourceStatus status: SourceStatus
@ -78,6 +99,10 @@ class TableauSource(Source[Entity]):
self.all_dashboard_details = get_views_dataframe(self.client).to_dict() self.all_dashboard_details = get_views_dataframe(self.client).to_dict()
def tableau_client(self): def tableau_client(self):
"""Tableau client method
Returns:
"""
tableau_server_config = { tableau_server_config = {
f"{self.config.env}": { f"{self.config.env}": {
"server": self.config.server, "server": self.config.server,
@ -93,8 +118,8 @@ class TableauSource(Source[Entity]):
config_json=tableau_server_config, env="tableau_prod" config_json=tableau_server_config, env="tableau_prod"
) )
conn.sign_in().json() conn.sign_in().json()
except Exception as err: except Exception as err: # pylint: disable=broad-except
logger.error(f"{repr(err)}: {err}") logger.error("%s: %s", repr(err), err)
return conn return conn
@classmethod @classmethod
@ -113,7 +138,14 @@ class TableauSource(Source[Entity]):
yield from self._get_tableau_dashboard() yield from self._get_tableau_dashboard()
@staticmethod @staticmethod
def get_owner(owner) -> DashboardOwner: def get_owner(owner) -> List[DashboardOwner]:
"""Get dashboard owner
Args:
owner:
Returns:
List[DashboardOwner]
"""
return [ return [
DashboardOwner( DashboardOwner(
first_name=owner["fullName"].split(" ")[0], first_name=owner["fullName"].split(" ")[0],

View File

@ -8,6 +8,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Trino source module"""
import logging import logging
import sys import sys
from typing import Iterable from typing import Iterable
@ -25,6 +27,16 @@ logger = logging.getLogger(__name__)
class TrinoConfig(SQLConnectionConfig): class TrinoConfig(SQLConnectionConfig):
"""Trinio config class -- extends SQLConnectionConfig class
Attributes:
host_port:
scheme:
service_type:
catalog:
database:
"""
host_port = "localhost:8080" host_port = "localhost:8080"
scheme = "trino" scheme = "trino"
service_type = "Trino" service_type = "Trino"
@ -57,9 +69,19 @@ class TrinoConfig(SQLConnectionConfig):
class TrinoSource(SQLSource): class TrinoSource(SQLSource):
"""Trino source -- extends SQLSource
Args:
config:
metadata_config:
ctx
"""
def __init__(self, config, metadata_config, ctx): def __init__(self, config, metadata_config, ctx):
try: try:
from sqlalchemy_trino import dbapi from sqlalchemy_trino import (
dbapi, # pylint: disable=import-outside-toplevel,unused-import
)
except ModuleNotFoundError: except ModuleNotFoundError:
click.secho( click.secho(
"Trino source dependencies are missing. Please run\n" "Trino source dependencies are missing. Please run\n"
@ -68,8 +90,7 @@ class TrinoSource(SQLSource):
) )
if logger.isEnabledFor(logging.DEBUG): if logger.isEnabledFor(logging.DEBUG):
raise raise
else: sys.exit(1)
sys.exit(1)
super().__init__(config, metadata_config, ctx) super().__init__(config, metadata_config, ctx)
@classmethod @classmethod