Fix #5885 - Provider loading fails for airflow <2.3 (#5927)

Fix #5885 - Provider loading fails for airflow <2.3 (#5927)
This commit is contained in:
Pere Miquel Brull 2022-07-14 15:07:39 +02:00 committed by GitHub
parent 81b8710c2a
commit 479a8de486
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 360 additions and 17 deletions

View File

@ -156,6 +156,12 @@ test = {
"pandas==1.3.5",
# great_expectations tests
"great-expectations~=0.15.0",
# Airflow tests
"apache-airflow==2.1.4",
"marshmallow-sqlalchemy>=0.26.0",
"SQLAlchemy-Utils>=0.38.0",
"pymysql>=1.0.2",
"requests==2.26.0",
}
build_options = {"includes": ["_cffi_backend"]}

View File

@ -15,7 +15,7 @@ OpenMetadata Airflow Lineage Backend
import json
import os
from airflow.configuration import conf
from airflow.configuration import AirflowConfigParser
from pydantic import BaseModel
from airflow_provider_openmetadata.lineage.config.commons import LINEAGE
@ -34,7 +34,9 @@ class AirflowLineageConfig(BaseModel):
metadata_config: OpenMetadataConnection
def parse_airflow_config(airflow_service_name: str) -> AirflowLineageConfig:
def parse_airflow_config(
airflow_service_name: str, conf: AirflowConfigParser
) -> AirflowLineageConfig:
"""
Get airflow config from airflow.cfg and parse it
to the config model
@ -53,7 +55,7 @@ def parse_airflow_config(airflow_service_name: str) -> AirflowLineageConfig:
raise InvalidAirflowProviderException(
f"Cannot find {auth_provider_type} in airflow providers registry."
)
security_config = load_security_config_fn()
security_config = load_security_config_fn(conf)
return AirflowLineageConfig(
airflow_service_name=airflow_service_name,
@ -75,9 +77,11 @@ def get_lineage_config() -> AirflowLineageConfig:
a JSON file path configures as env in OPENMETADATA_LINEAGE_CONFIG
or return a default config.
"""
from airflow.configuration import conf
airflow_service_name = conf.get(LINEAGE, "airflow_service_name", fallback=None)
if airflow_service_name:
return parse_airflow_config(airflow_service_name)
return parse_airflow_config(airflow_service_name, conf=conf)
openmetadata_config_file = os.getenv("OPENMETADATA_LINEAGE_CONFIG")

View File

@ -12,8 +12,9 @@
"""
OpenMetadata Airflow Lineage Backend security providers config
"""
import json
from airflow.configuration import conf
from airflow.configuration import AirflowConfigParser
from airflow_provider_openmetadata.lineage.config.commons import LINEAGE
from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import (
@ -50,7 +51,7 @@ class InvalidAirflowProviderException(Exception):
@provider_config_registry.add(AuthProvider.google.value)
def load_google_auth() -> GoogleSSOClientConfig:
def load_google_auth(conf: AirflowConfigParser) -> GoogleSSOClientConfig:
"""
Load config for Google Auth
"""
@ -63,7 +64,7 @@ def load_google_auth() -> GoogleSSOClientConfig:
@provider_config_registry.add(AuthProvider.okta.value)
def load_okta_auth() -> OktaSSOClientConfig:
def load_okta_auth(conf: AirflowConfigParser) -> OktaSSOClientConfig:
"""
Load config for Google Auth
"""
@ -72,12 +73,13 @@ def load_okta_auth() -> OktaSSOClientConfig:
orgURL=conf.get(LINEAGE, "org_url"),
privateKey=conf.get(LINEAGE, "private_key"),
email=conf.get(LINEAGE, "email"),
scopes=conf.getjson(LINEAGE, "scopes", fallback=[]),
# conf.getjson only available for airflow +2.3. Manually casting for lower versions
scopes=json.loads(conf.get(LINEAGE, "scopes", fallback="[]")),
)
@provider_config_registry.add(AuthProvider.auth0.value)
def load_auth0_auth() -> Auth0SSOClientConfig:
def load_auth0_auth(conf: AirflowConfigParser) -> Auth0SSOClientConfig:
"""
Load config for Google Auth
"""
@ -89,7 +91,7 @@ def load_auth0_auth() -> Auth0SSOClientConfig:
@provider_config_registry.add(AuthProvider.azure.value)
def load_azure_auth() -> AzureSSOClientConfig:
def load_azure_auth(conf: AirflowConfigParser) -> AzureSSOClientConfig:
"""
Load config for Azure Auth
"""
@ -97,12 +99,12 @@ def load_azure_auth() -> AzureSSOClientConfig:
clientSecret=conf.get(LINEAGE, "client_secret"),
authority=conf.get(LINEAGE, "authority"),
clientId=conf.get(LINEAGE, "client_id"),
scopes=conf.getjson(LINEAGE, "scopes", fallback=[]),
scopes=json.loads(conf.get(LINEAGE, "scopes", fallback="[]")),
)
@provider_config_registry.add(AuthProvider.openmetadata.value)
def load_om_auth() -> OpenMetadataJWTClientConfig:
def load_om_auth(conf: AirflowConfigParser) -> OpenMetadataJWTClientConfig:
"""
Load config for Azure Auth
"""
@ -110,7 +112,7 @@ def load_om_auth() -> OpenMetadataJWTClientConfig:
@provider_config_registry.add(AuthProvider.custom_oidc.value)
def load_custom_oidc_auth() -> CustomOIDCSSOClientConfig:
def load_custom_oidc_auth(conf: AirflowConfigParser) -> CustomOIDCSSOClientConfig:
"""
Load config for Custom OIDC Auth
"""

View File

@ -18,7 +18,6 @@ from collections import defaultdict
from logging.config import DictConfigurator
from typing import Dict, List, Optional, Tuple
from sqllineage.exceptions import SQLLineageException
from sqlparse.sql import Comparison, Identifier, Statement
from metadata.config.common import ConfigModel
@ -37,6 +36,7 @@ from metadata.utils.logger import ingestion_logger
configure = DictConfigurator.configure
DictConfigurator.configure = lambda _: None
from sqllineage.core import models
from sqllineage.exceptions import SQLLineageException
from sqllineage.runner import LineageRunner
# Reverting changes after import is done

View File

@ -11,10 +11,9 @@
"""Mode source module"""
import traceback
from logging.config import DictConfigurator
from typing import Iterable, List, Optional
from sqllineage.runner import LineageRunner
from metadata.generated.schema.api.data.createChart import CreateChartRequest
from metadata.generated.schema.api.data.createDashboard import CreateDashboardRequest
from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest
@ -37,10 +36,18 @@ from metadata.ingestion.api.source import InvalidSourceException
from metadata.ingestion.source.dashboard.dashboard_service import DashboardServiceSource
from metadata.utils import fqn, mode_client
from metadata.utils.filters import filter_by_chart
from metadata.utils.helpers import get_chart_entities_from_id
from metadata.utils.logger import ingestion_logger
from metadata.utils.sql_lineage import search_table_entities
# Prevent sqllineage from modifying the logger config
# Disable the DictConfigurator.configure method while importing LineageRunner
configure = DictConfigurator.configure
DictConfigurator.configure = lambda _: None
from sqllineage.runner import LineageRunner
# Reverting changes after import is done
DictConfigurator.configure = configure
logger = ingestion_logger()

View File

View File

@ -0,0 +1,324 @@
# Copyright 2021 Collate
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Mock providers and check custom load
"""
from unittest import TestCase
from airflow.configuration import AirflowConfigParser
from airflow_provider_openmetadata.lineage.config.loader import (
AirflowLineageConfig,
parse_airflow_config,
)
from metadata.generated.schema.entity.services.connections.metadata.openMetadataConnection import (
AuthProvider,
OpenMetadataConnection,
)
from metadata.generated.schema.security.client.auth0SSOClientConfig import (
Auth0SSOClientConfig,
)
from metadata.generated.schema.security.client.azureSSOClientConfig import (
AzureSSOClientConfig,
)
from metadata.generated.schema.security.client.customOidcSSOClientConfig import (
CustomOIDCSSOClientConfig,
)
from metadata.generated.schema.security.client.googleSSOClientConfig import (
GoogleSSOClientConfig,
)
from metadata.generated.schema.security.client.oktaSSOClientConfig import (
OktaSSOClientConfig,
)
from metadata.generated.schema.security.client.openMetadataJWTClientConfig import (
OpenMetadataJWTClientConfig,
)
AIRFLOW_SERVICE_NAME = "test-service"
class TestAirflowAuthProviders(TestCase):
"""
Make sure we are properly loading all required classes
"""
def test_google_sso(self):
sso_config = """
[lineage]
backend = airflow_provider_openmetadata.lineage.openmetadata.OpenMetadataLineageBackend
airflow_service_name = local_airflow
openmetadata_api_endpoint = http://localhost:8585/api
auth_provider_type = google
secret_key = path/to/key
"""
# mock the conf object
conf = AirflowConfigParser(default_config=sso_config)
lineage_config = parse_airflow_config(AIRFLOW_SERVICE_NAME, conf)
self.assertEqual(
lineage_config,
AirflowLineageConfig(
airflow_service_name=AIRFLOW_SERVICE_NAME,
metadata_config=OpenMetadataConnection(
hostPort="http://localhost:8585/api",
authProvider=AuthProvider.google.value,
securityConfig=GoogleSSOClientConfig(secretKey="path/to/key"),
),
),
)
def test_okta_sso(self):
sso_config = """
[lineage]
backend = airflow_provider_openmetadata.lineage.openmetadata.OpenMetadataLineageBackend
airflow_service_name = local_airflow
openmetadata_api_endpoint = http://localhost:8585/api
auth_provider_type = okta
client_id = client_id
org_url = org_url
private_key = private_key
email = email
scopes = ["scope1", "scope2"]
"""
# mock the conf object
conf = AirflowConfigParser(default_config=sso_config)
lineage_config = parse_airflow_config(AIRFLOW_SERVICE_NAME, conf)
self.assertEqual(
lineage_config,
AirflowLineageConfig(
airflow_service_name=AIRFLOW_SERVICE_NAME,
metadata_config=OpenMetadataConnection(
hostPort="http://localhost:8585/api",
authProvider=AuthProvider.okta.value,
securityConfig=OktaSSOClientConfig(
clientId="client_id",
orgURL="org_url",
privateKey="private_key",
email="email",
scopes=["scope1", "scope2"],
),
),
),
)
# Validate default scopes
sso_config = """
[lineage]
backend = airflow_provider_openmetadata.lineage.openmetadata.OpenMetadataLineageBackend
airflow_service_name = local_airflow
openmetadata_api_endpoint = http://localhost:8585/api
auth_provider_type = okta
client_id = client_id
org_url = org_url
private_key = private_key
email = email
"""
# mock the conf object
conf = AirflowConfigParser(default_config=sso_config)
lineage_config = parse_airflow_config(AIRFLOW_SERVICE_NAME, conf)
self.assertEqual(
lineage_config,
AirflowLineageConfig(
airflow_service_name=AIRFLOW_SERVICE_NAME,
metadata_config=OpenMetadataConnection(
hostPort="http://localhost:8585/api",
authProvider=AuthProvider.okta.value,
securityConfig=OktaSSOClientConfig(
clientId="client_id",
orgURL="org_url",
privateKey="private_key",
email="email",
scopes=[],
),
),
),
)
def test_auth0_sso(self):
sso_config = """
[lineage]
backend = airflow_provider_openmetadata.lineage.openmetadata.OpenMetadataLineageBackend
airflow_service_name = local_airflow
openmetadata_api_endpoint = http://localhost:8585/api
auth_provider_type = auth0
client_id = client_id
secret_key = secret_key
domain = domain
"""
# mock the conf object
conf = AirflowConfigParser(default_config=sso_config)
lineage_config = parse_airflow_config(AIRFLOW_SERVICE_NAME, conf)
self.assertEqual(
lineage_config,
AirflowLineageConfig(
airflow_service_name=AIRFLOW_SERVICE_NAME,
metadata_config=OpenMetadataConnection(
hostPort="http://localhost:8585/api",
authProvider=AuthProvider.auth0.value,
securityConfig=Auth0SSOClientConfig(
clientId="client_id",
secretKey="secret_key",
domain="domain",
),
),
),
)
def test_azure_sso(self):
sso_config = """
[lineage]
backend = airflow_provider_openmetadata.lineage.openmetadata.OpenMetadataLineageBackend
airflow_service_name = local_airflow
openmetadata_api_endpoint = http://localhost:8585/api
auth_provider_type = azure
client_id = client_id
client_secret = client_secret
authority = authority
scopes = ["scope1", "scope2"]
"""
# mock the conf object
conf = AirflowConfigParser(default_config=sso_config)
lineage_config = parse_airflow_config(AIRFLOW_SERVICE_NAME, conf)
self.assertEqual(
lineage_config,
AirflowLineageConfig(
airflow_service_name=AIRFLOW_SERVICE_NAME,
metadata_config=OpenMetadataConnection(
hostPort="http://localhost:8585/api",
authProvider=AuthProvider.azure.value,
securityConfig=AzureSSOClientConfig(
clientId="client_id",
clientSecret="client_secret",
authority="authority",
scopes=["scope1", "scope2"],
),
),
),
)
# Validate default scopes
sso_config = """
[lineage]
backend = airflow_provider_openmetadata.lineage.openmetadata.OpenMetadataLineageBackend
airflow_service_name = local_airflow
openmetadata_api_endpoint = http://localhost:8585/api
auth_provider_type = azure
client_id = client_id
client_secret = client_secret
authority = authority
"""
# mock the conf object
conf = AirflowConfigParser(default_config=sso_config)
lineage_config = parse_airflow_config(AIRFLOW_SERVICE_NAME, conf)
self.assertEqual(
lineage_config,
AirflowLineageConfig(
airflow_service_name=AIRFLOW_SERVICE_NAME,
metadata_config=OpenMetadataConnection(
hostPort="http://localhost:8585/api",
authProvider=AuthProvider.azure.value,
securityConfig=AzureSSOClientConfig(
clientId="client_id",
clientSecret="client_secret",
authority="authority",
scopes=[],
),
),
),
)
def test_om_sso(self):
sso_config = """
[lineage]
backend = airflow_provider_openmetadata.lineage.openmetadata.OpenMetadataLineageBackend
airflow_service_name = local_airflow
openmetadata_api_endpoint = http://localhost:8585/api
auth_provider_type = openmetadata
jwt_token = jwt_token
"""
# mock the conf object
conf = AirflowConfigParser(default_config=sso_config)
lineage_config = parse_airflow_config(AIRFLOW_SERVICE_NAME, conf)
self.assertEqual(
lineage_config,
AirflowLineageConfig(
airflow_service_name=AIRFLOW_SERVICE_NAME,
metadata_config=OpenMetadataConnection(
hostPort="http://localhost:8585/api",
authProvider=AuthProvider.openmetadata.value,
securityConfig=OpenMetadataJWTClientConfig(
jwtToken="jwt_token",
),
),
),
)
def test_custom_oidc_sso(self):
sso_config = """
[lineage]
backend = airflow_provider_openmetadata.lineage.openmetadata.OpenMetadataLineageBackend
airflow_service_name = local_airflow
openmetadata_api_endpoint = http://localhost:8585/api
auth_provider_type = custom-oidc
client_id = client_id
secret_key = secret_key
token_endpoint = token_endpoint
"""
# mock the conf object
conf = AirflowConfigParser(default_config=sso_config)
lineage_config = parse_airflow_config(AIRFLOW_SERVICE_NAME, conf)
self.assertEqual(
lineage_config,
AirflowLineageConfig(
airflow_service_name=AIRFLOW_SERVICE_NAME,
metadata_config=OpenMetadataConnection(
hostPort="http://localhost:8585/api",
authProvider=AuthProvider.custom_oidc.value,
securityConfig=CustomOIDCSSOClientConfig(
clientId="client_id",
secretKey="secret_key",
tokenEndpoint="token_endpoint",
),
),
),
)