[WIP] Issue #285: Add support for Dashboard Entities; Superset connector

This commit is contained in:
Suresh Srinivas 2021-08-24 13:47:41 -07:00
parent 723b9ba5d6
commit 6a28ae988f
10 changed files with 477 additions and 149 deletions

View File

@ -0,0 +1,24 @@
{
"source": {
"type": "superset",
"config": {
"url": "http://localhost:8088",
"username": "admin",
"password": "admin"
}
},
"metadata_server": {
"type": "metadata-server",
"config": {
"api_endpoint": "http://localhost:8585/api",
"auth_provider_type": "no-auth"
}
},
"cron": {
"minute": "*/5",
"hour": null,
"day": null,
"month": null,
"day_of_week": null
}
}

View File

@ -13,4 +13,7 @@ setuptools~=57.0.0
PyHive~=0.6.4
ldap3~=2.9.1
confluent_kafka>=1.5.0
fastavro>=1.2.0
fastavro>=1.2.0
google~=3.0.0
okta~=2.0.0
PyMySQL~=1.0.2

View File

@ -93,7 +93,8 @@ plugins: Dict[str, Set[str]] = {
"snowflake": {"snowflake-sqlalchemy<=1.2.4"},
"snowflake-usage": {"snowflake-sqlalchemy<=1.2.4"},
"sample-tables": {"faker~=8.1.1", },
"sample-topics": {}
"sample-topics": {},
"superset": {}
}
build_options = {"includes": ["_cffi_backend"]}

View File

@ -32,11 +32,10 @@ class SourceStatus(Status):
self.records += 1
def warning(self, key: str, reason: str) -> None:
self.warnings.append({key:reason})
self.warnings.append({key: reason})
def failure(self, key: str, reason: str) -> None:
self.failures.append({key:reason})
self.failures.append({key: reason})
@dataclass # type: ignore[misc]

View File

@ -20,8 +20,7 @@ from typing import Any, Dict, Iterable, List, Optional, Union
from pydantic import BaseModel
from metadata.generated.schema.entity.data.table import Table
from metadata.ingestion.models.json_serializable import JsonSerializable, NODE_KEY, NODE_LABEL
from metadata.ingestion.models.json_serializable import JsonSerializable
DESCRIPTION_NODE_LABEL_VAL = 'Description'
DESCRIPTION_NODE_LABEL = DESCRIPTION_NODE_LABEL_VAL
@ -207,3 +206,32 @@ class TopicESDocument(BaseModel):
schema_description: Optional[str] = None
owner: str
followers: List[str]
class DashboardOwner(BaseModel):
"""Dashboard owner"""
username: str
first_name: str
last_name: str
class Chart(BaseModel):
"""Chart"""
name: str
description: str
chart_type: str
url: str
owners: List[DashboardOwner]
lastModified: int
datasource_fqn: str
custom_props: Dict[Any, Any]
class Dashboard(BaseModel):
"""Dashboard"""
name: str
description: str
url: str
owners: List
charts: List
lastModified: int

View File

@ -13,37 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from google.oauth2 import service_account
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from metadata.config.common import ConfigModel
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass, field
import google.auth
import google.auth.transport.requests
from google.oauth2 import service_account
import time
import uuid
from jose import jwt
from okta.client import Client as OktaClient
import asyncio
from okta.jwt import JWT
class MetadataServerConfig(ConfigModel):
api_endpoint: str
api_version: str = 'v1'
retry: int = 3
retry_wait: int = 3
auth_provider_type: str = None
secret_key: str = None
org_url: str = None
client_id: str = None
private_key: str = None
email: str = None
audience: str = 'https://www.googleapis.com/oauth2/v4/token'
auth_header: str = 'X-Catalog-Source'
@dataclass # type: ignore[misc]
@ -51,7 +24,7 @@ class AuthenticationProvider(metaclass=ABCMeta):
@classmethod
@abstractmethod
def create(cls, config: MetadataServerConfig) -> "AuthenticationProvider":
def create(cls, config: ConfigModel) -> "AuthenticationProvider":
pass
@abstractmethod
@ -59,53 +32,5 @@ class AuthenticationProvider(metaclass=ABCMeta):
pass
class NoOpAuthenticationProvider(AuthenticationProvider):
def __init__(self, config: MetadataServerConfig):
self.config = config
@classmethod
def create(cls, config: MetadataServerConfig):
return cls(config)
def auth_token(self) -> str:
return "no_token"
class GoogleAuthenticationProvider(AuthenticationProvider):
def __init__(self, config: MetadataServerConfig):
self.config = config
@classmethod
def create(cls, config: MetadataServerConfig):
return cls(config)
def auth_token(self) -> str:
credentials = service_account.IDTokenCredentials.from_service_account_file(
self.config.secret_key,
target_audience=self.config.audience)
request = google.auth.transport.requests.Request()
credentials.refresh(request)
return credentials.token
class OktaAuthenticationProvider(AuthenticationProvider):
def __init__(self, config: MetadataServerConfig):
self.config = config
@classmethod
def create(cls, config: MetadataServerConfig):
return cls(config)
def auth_token(self) -> str:
my_pem, my_jwk = JWT.get_PEM_JWK(self.config.private_key)
claims = {
'sub': self.config.client_id,
'iat': time.time(),
'exp': time.time() + JWT.ONE_HOUR,
'iss': self.config.client_id,
'aud': self.config.org_url + JWT.OAUTH_ENDPOINT,
'jti': uuid.uuid4(),
'email': self.config.email
}
token = jwt.encode(claims, my_jwk.to_dict(), JWT.HASH_ALGORITHM)
return token

View File

@ -14,15 +14,14 @@
# limitations under the License.
import logging
import os
from typing import Optional
from typing import Optional, List
import requests
from requests.exceptions import HTTPError
import time
from enum import Enum
from metadata.ingestion.ometa.auth_provider import MetadataServerConfig
from metadata.config.common import ConfigModel
from metadata.ingestion.ometa.credentials import URL, get_api_version
logger = logging.getLogger(__name__)
@ -71,26 +70,33 @@ class TimeFrame(Enum):
Sec = "1Sec"
class REST(object):
def __init__(self,
config: MetadataServerConfig,
raw_data: bool = False,
auth_token: Optional[str] = None
):
"""
class ClientConfig(ConfigModel):
"""
:param raw_data: should we return api response raw or wrap it with
Entity objects.
"""
Entity objects.
"""
base_url: str
api_version: Optional[str] = "v1"
retry: Optional[int] = 3
retry_wait: Optional[int] = 30
retry_codes: List[int] = [429, 504]
auth_token: Optional[str] = None
auth_header: Optional[str] = None
raw_data: Optional[bool] = False
allow_redirects: Optional[bool] = False
class REST(object):
def __init__(self, config: ClientConfig):
self.config = config
self._base_url: URL = URL(self.config.api_endpoint)
self._base_url: URL = URL(self.config.base_url)
self._api_version = get_api_version(self.config.api_version)
self._session = requests.Session()
self._use_raw_data = raw_data
self._use_raw_data = self.config.raw_data
self._retry = self.config.retry
self._retry_wait = self.config.retry_wait
self._retry_codes = [int(o) for o in os.environ.get(
'OMETA_RETRY_CODES', '429,504').split(',')]
self._auth_token = auth_token
self._retry_codes = self.config.retry_codes
self._auth_token = self.config.auth_token
def _request(self,
method,
@ -105,14 +111,13 @@ class REST(object):
headers = {'Content-type': 'application/json'}
if self._auth_token is not None:
headers[self.config.auth_header] = self._auth_token
opts = {
'headers': headers,
# Since we allow users to set endpoint URL via env var,
# human error to put non-SSL endpoint could exploit
# uncanny issues in non-GET request redirecting http->https.
# It's better to fail early if the URL isn't right.
'allow_redirects': False,
'allow_redirects': self.config.allow_redirects,
}
if method.upper() == 'GET':
opts['params'] = data

View File

@ -1,6 +1,22 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import List
from metadata.config.common import ConfigModel
from metadata.generated.schema.api.data.createDatabase import CreateDatabaseEntityRequest
from metadata.generated.schema.api.data.createTable import CreateTableEntityRequest
from metadata.generated.schema.api.data.createTopic import CreateTopic
@ -13,9 +29,17 @@ from metadata.generated.schema.entity.services.databaseService import DatabaseSe
from metadata.generated.schema.entity.services.messagingService import MessagingService
from metadata.generated.schema.entity.tags.tagCategory import Tag
from metadata.ingestion.models.table_queries import TableUsageRequest
from metadata.ingestion.ometa.auth_provider import MetadataServerConfig, AuthenticationProvider, \
GoogleAuthenticationProvider, OktaAuthenticationProvider, NoOpAuthenticationProvider
from metadata.ingestion.ometa.client import REST
from metadata.ingestion.ometa.auth_provider import AuthenticationProvider
from metadata.ingestion.ometa.client import REST, ClientConfig
import google.auth
import google.auth.transport.requests
from google.oauth2 import service_account
import time
import uuid
from jose import jwt
from okta.jwt import JWT
logger = logging.getLogger(__name__)
DatabaseServiceEntities = List[DatabaseService]
@ -25,8 +49,76 @@ Tags = List[Tag]
Topics = List[Topic]
class MetadataServerConfig(ConfigModel):
api_endpoint: str
api_version: str = 'v1'
retry: int = 3
retry_wait: int = 3
auth_provider_type: str = None
secret_key: str = None
org_url: str = None
client_id: str = None
private_key: str = None
email: str = None
audience: str = 'https://www.googleapis.com/oauth2/v4/token'
auth_header: str = 'X-Catalog-Source'
class NoOpAuthenticationProvider(AuthenticationProvider):
def __init__(self, config: MetadataServerConfig):
self.config = config
@classmethod
def create(cls, config: MetadataServerConfig):
return cls(config)
def auth_token(self) -> str:
return "no_token"
class GoogleAuthenticationProvider(AuthenticationProvider):
def __init__(self, config: MetadataServerConfig):
self.config = config
@classmethod
def create(cls, config: MetadataServerConfig):
return cls(config)
def auth_token(self) -> str:
credentials = service_account.IDTokenCredentials.from_service_account_file(
self.config.secret_key,
target_audience=self.config.audience)
request = google.auth.transport.requests.Request()
credentials.refresh(request)
return credentials.token
class OktaAuthenticationProvider(AuthenticationProvider):
def __init__(self, config: MetadataServerConfig):
self.config = config
@classmethod
def create(cls, config: MetadataServerConfig):
return cls(config)
def auth_token(self) -> str:
my_pem, my_jwk = JWT.get_PEM_JWK(self.config.private_key)
claims = {
'sub': self.config.client_id,
'iat': time.time(),
'exp': time.time() + JWT.ONE_HOUR,
'iss': self.config.client_id,
'aud': self.config.org_url + JWT.OAUTH_ENDPOINT,
'jti': uuid.uuid4(),
'email': self.config.email
}
token = jwt.encode(claims, my_jwk.to_dict(), JWT.HASH_ALGORITHM)
return token
class OpenMetadataAPIClient(object):
client: REST
_auth_provider: AuthenticationProvider
def __init__(self,
config: MetadataServerConfig,
@ -39,7 +131,10 @@ class OpenMetadataAPIClient(object):
self._auth_provider: AuthenticationProvider = OktaAuthenticationProvider.create(self.config)
else:
self._auth_provider: AuthenticationProvider = NoOpAuthenticationProvider.create(self.config)
self.client = REST(config, raw_data, self._auth_provider.auth_token())
client_config: ClientConfig = ClientConfig(base_url=self.config.api_endpoint,
api_version=self.config.api_version,
auth_token=self._auth_provider.auth_token())
self.client = REST(client_config)
self._use_raw_data = raw_data
def get_database_service(self, service_name: str) -> DatabaseService:

View File

@ -0,0 +1,97 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
from typing import Optional
from metadata.config.common import ConfigModel
from metadata.ingestion.ometa.auth_provider import AuthenticationProvider
from metadata.ingestion.ometa.client import REST, ClientConfig
logger = logging.getLogger(__name__)
class SupersetConfig(ConfigModel):
url: str = "localhost:8088"
username: Optional[str] = None
password: Optional[str] = None
provider: str = "db"
options: dict = {}
class SupersetAuthenticationProvider(AuthenticationProvider):
def __init__(self, config: SupersetConfig):
self.config = config
client_config = ClientConfig(base_url=config.url, api_version="api/v1")
self.client = REST(client_config)
@classmethod
def create(cls, config: SupersetConfig):
return cls(config)
def auth_token(self) -> str:
login_request = self._login_request()
login_response = self.client.post("/security/login", login_request)
return login_response['access_token']
def _login_request(self) -> str:
auth_request = {'username': self.config.username,
'password': self.config.password,
'refresh': True,
'provider': self.config.provider}
return json.dumps(auth_request)
class SupersetAPIClient(object):
client: REST
_auth_provider: AuthenticationProvider
def __init__(self, config: SupersetConfig):
self.config = config
self._auth_provider = SupersetAuthenticationProvider.create(config)
client_config = ClientConfig(base_url=config.url, api_version="api/v1",
auth_token=f"Bearer {self._auth_provider.auth_token()}",
auth_header="Authorization", allow_redirects=True)
self.client = REST(client_config)
def fetch_total_dashboards(self) -> int:
params = "q=(page:0,page_size:1)"
response = self.client.get(f"/dashboard", data=params)
return response.get("count") or 0
def fetch_dashboards(self, current_page: int, page_size: int):
params = f"'q=(page:{current_page},page_size:{page_size})'"
response = self.client.get(f"/dashboard", data=params)
return response
def fetch_total_charts(self) -> int:
params = "q=(page:0,page_size:1)"
response = self.client.get(f"/chart", data=params)
return response.get("count") or 0
def fetch_charts(self, current_page: int, page_size: int):
params = f"'q=(page:{current_page},page_size:{page_size})'"
response = self.client.get(f"/chart", data=params)
return response
def fetch_datasource(self, datasource_id: str):
response = self.client.get(f"/dataset/{datasource_id}")
return response
def fetch_database(self, database_id: str):
response = self.client.get(f"/database/{database_id}")
return response

View File

@ -1,62 +1,213 @@
from typing import Optional
import json
from typing import Iterable, Tuple
import dateutil.parser as dateparser
from metadata.config.common import ConfigModel
from metadata.ingestion.api.common import WorkflowContext
from metadata.ingestion.api.common import WorkflowContext, Record
from metadata.ingestion.api.source import Source, SourceStatus
from metadata.ingestion.ometa.auth_provider import MetadataServerConfig
from metadata.ingestion.models.table_metadata import DashboardOwner, Dashboard, Chart
from metadata.ingestion.ometa.openmetadata_rest import MetadataServerConfig
from metadata.ingestion.ometa.superset_rest import SupersetConfig, SupersetAPIClient
class SupersetSourceConfig(ConfigModel):
url: str = "localhost:8088"
username: Optional[str] = None
password: Optional[str] = None
provider: str = "db"
options: dict = {}
def get_metric_name(metric):
if not metric:
return ""
if isinstance(metric, str):
return metric
label = metric.get("label")
if label:
return label
def get_filter_name(filter_obj):
sql_expression = filter_obj.get("sqlExpression")
if sql_expression:
return sql_expression
clause = filter_obj.get("clause")
column = filter_obj.get("subject")
operator = filter_obj.get("operator")
comparator = filter_obj.get("comparator")
return f"{clause} {column} {operator} {comparator}"
def get_owners(owners_obj):
owners = []
for owner in owners_obj:
dashboard_owner = DashboardOwner(first_name=owner['first_name'],
last_name=owner['last_name'],
username=owner['username'])
owners.append(dashboard_owner)
return owners
def get_service_type_from_database_uri(uri: str) -> str:
if uri.startswith("bigquery"):
return "bigquery"
if uri.startswith("druid"):
return "druid"
if uri.startswith("mssql"):
return "mssql"
if (
uri.startswith("jdbc:postgres:")
and uri.index("redshift.amazonaws") > 0
):
return "redshift"
if uri.startswith("snowflake"):
return "snowflake"
if uri.startswith("presto"):
return "presto"
if uri.startswith("postgresql"):
return "postgres"
if uri.startswith("pinot"):
return "pinot"
if uri.startswith("oracle"):
return "oracle"
if uri.startswith("mysql"):
return "mysql"
if uri.startswith("mongodb"):
return "mongodb"
if uri.startswith("hive"):
return "hive"
return "external"
class SupersetSource(Source):
config: SupersetSourceConfig
config: SupersetConfig
metadata_config: MetadataServerConfig
status: SourceStatus
platform = "superset"
def __init__(self, config: SupersetSourceConfig, metadata_config: MetadataServerConfig, ctx: WorkflowContext):
def __init__(self, config: SupersetConfig, metadata_config: MetadataServerConfig, ctx: WorkflowContext):
super().__init__(ctx)
self.config = config
self.metadata_config = metadata_config
self.status = SourceStatus()
self.client = SupersetAPIClient(self.config)
self.charts_dict = {}
login_response = requests.post(
f"{self.config.connect_uri}/api/v1/security/login",
None,
{
"username": self.config.username,
"password": self.config.password,
"refresh": True,
"provider": self.config.provider,
},
)
self.access_token = login_response.json()["access_token"]
self.session = requests.Session()
self.session.headers.update(
{
"Authorization": f"Bearer {self.access_token}",
"Content-Type": "application/json",
"Accept": "*/*",
}
)
# Test the connection
test_response = self.session.get(f"{self.config.connect_uri}/api/v1/database")
if test_response.status_code == 200:
pass
# TODO(Gabe): how should we message about this error?
@classmethod
@classmethod
def create(cls, config_dict: dict, metadata_config_dict: dict, ctx: WorkflowContext):
config = SupersetSourceConfig.parse_obj(config_dict)
config = SupersetConfig.parse_obj(config_dict)
metadata_config = MetadataServerConfig.parse_obj(metadata_config_dict)
return cls(config, metadata_config, ctx)
def prepare(self):
self._fetch_charts()
def next_record(self) -> Iterable[Record]:
yield from self._fetch_dashboards()
def _build_dashboard(self, dashboard_json) -> Dashboard:
name = dashboard_json['dashboard_title']
dashboard_url = f"{self.config.url[:-1]}{dashboard_json['url']}"
last_modified = dateparser.parse(dashboard_json.get("changed_on_utc", "now")).timestamp() * 1000
owners = get_owners(dashboard_json['owners'])
raw_position_data = dashboard_json.get("position_json", "{}")
charts = []
if raw_position_data is not None:
position_data = json.loads(raw_position_data)
for key, value in position_data.items():
if not key.startswith("CHART-"):
continue
chart_id = value.get('meta', {}).get('chartId', 'unknown')
if chart_id in self.charts_dict.keys():
charts.append(self.charts_dict[chart_id])
return Dashboard(name=name,
description="",
url=dashboard_url,
owners=owners,
charts=charts,
lastModified=last_modified)
def _fetch_dashboards(self) -> Iterable[Record]:
current_page = 0
page_size = 10
total_dashboards = self.client.fetch_total_dashboards()
while current_page * page_size <= total_dashboards:
dashboards = self.client.fetch_dashboards(current_page, page_size)
current_page += 1
for dashboard_json in dashboards['result']:
dashboard = self._build_dashboard(dashboard_json)
print(dashboard.json())
yield dashboard
def _get_service_type_from_database_id(self, database_id):
database_json = self.client.fetch_database(database_id)
sqlalchemy_uri = database_json.get("result", {}).get("sqlalchemy_uri")
return get_service_type_from_database_uri(sqlalchemy_uri)
def _get_datasource_from_id(self, datasource_id):
datasource_json = self.client.fetch_datasource(datasource_id)
schema_name = datasource_json.get("result", {}).get("schema")
table_name = datasource_json.get("result", {}).get("table_name")
database_id = datasource_json.get("result", {}).get("database", {}).get("id")
database_name = (
datasource_json.get("result", {}).get("database", {}).get("database_name")
)
if database_id and table_name:
platform = self._get_service_type_from_database_id(database_id)
dataset_fqn = (
f"{platform}.{database_name + '.' if database_name else ''}"
f"{schema_name + '.' if schema_name else ''}"
f"{table_name}"
)
return dataset_fqn
return None
def _build_chart(self, chart_json) -> Tuple[int, Chart]:
chart_id = chart_json['id']
name = chart_json['slice_name']
last_modified = dateparser.parse(chart_json.get("changed_on_utc", "now")).timestamp() * 1000
chart_type = chart_json["viz_type"]
chart_url = f"{self.config.url}{chart_json['url']}"
datasource_id = chart_json["datasource_id"]
datasource_fqn = self._get_datasource_from_id(datasource_id)
owners = get_owners(chart_json['owners'])
params = json.loads(chart_json["params"])
metrics = [
get_metric_name(metric)
for metric in (params.get("metrics", []) or [params.get("metric")])
]
filters = [
get_filter_name(filter_obj)
for filter_obj in params.get("adhoc_filters", [])
]
group_bys = params.get("groupby", []) or []
if isinstance(group_bys, str):
group_bys = [group_bys]
custom_properties = {
"Metrics": ", ".join(metrics),
"Filters": ", ".join(filters),
"Dimensions": ", ".join(group_bys),
}
chart = Chart(name=name,
description="",
chart_type=chart_type,
url=chart_url,
owners=owners,
datasource_fqn=datasource_fqn,
lastModified=last_modified,
custom_props=custom_properties)
return chart_id, chart
def _fetch_charts(self):
current_page = 0
page_size = 10
total_charts = self.client.fetch_total_charts()
while current_page * page_size <= total_charts:
charts = self.client.fetch_charts(current_page, page_size)
current_page += 1
for chart_json in charts['result']:
chart_id, chart = self._build_chart(chart_json)
self.charts_dict[chart_id] = chart
def get_status(self):
return self.status
def close(self):
pass