mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-09-21 14:59:57 +00:00
parent
3651efd7f5
commit
8a7fcf0e54
@ -11,11 +11,28 @@
|
|||||||
"""
|
"""
|
||||||
Interface definition for an Auth provider
|
Interface definition for an Auth provider
|
||||||
"""
|
"""
|
||||||
|
import http.client
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
from metadata.config.common import ConfigModel
|
from metadata.config.common import ConfigModel
|
||||||
|
from metadata.generated.schema.metadataIngestion.workflow import (
|
||||||
|
Auth0SSOConfig,
|
||||||
|
CustomOidcSSOConfig,
|
||||||
|
GoogleSSOConfig,
|
||||||
|
OktaSSOConfig,
|
||||||
|
OpenMetadataServerConfig,
|
||||||
|
)
|
||||||
|
from metadata.ingestion.ometa.client import APIError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(init=False) # type: ignore[misc]
|
@dataclass(init=False) # type: ignore[misc]
|
||||||
@ -53,3 +70,291 @@ class AuthenticationProvider(metaclass=ABCMeta):
|
|||||||
Returns:
|
Returns:
|
||||||
str
|
str
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class NoOpAuthenticationProvider(AuthenticationProvider):
|
||||||
|
"""
|
||||||
|
Extends AuthenticationProvider class
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (MetadataServerConfig):
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
config (MetadataServerConfig)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: OpenMetadataServerConfig):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, config: OpenMetadataServerConfig):
|
||||||
|
return cls(config)
|
||||||
|
|
||||||
|
def auth_token(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_access_token(self):
|
||||||
|
return "no_token", None
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleAuthenticationProvider(AuthenticationProvider):
|
||||||
|
"""
|
||||||
|
Google authentication implementation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (MetadataServerConfig):
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
config (MetadataServerConfig)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: OpenMetadataServerConfig):
|
||||||
|
self.config = config
|
||||||
|
self.security_config: GoogleSSOConfig = self.config.securityConfig
|
||||||
|
|
||||||
|
self.generated_auth_token = None
|
||||||
|
self.expiry = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, config: OpenMetadataServerConfig):
|
||||||
|
return cls(config)
|
||||||
|
|
||||||
|
def auth_token(self) -> None:
|
||||||
|
import google.auth
|
||||||
|
import google.auth.transport.requests
|
||||||
|
from google.oauth2 import service_account
|
||||||
|
|
||||||
|
credentials = service_account.IDTokenCredentials.from_service_account_file(
|
||||||
|
self.security_config.secretKey,
|
||||||
|
target_audience=self.security_config.audience,
|
||||||
|
)
|
||||||
|
request = google.auth.transport.requests.Request()
|
||||||
|
credentials.refresh(request)
|
||||||
|
self.generated_auth_token = credentials.token
|
||||||
|
self.expiry = credentials.expiry
|
||||||
|
|
||||||
|
def get_access_token(self):
|
||||||
|
self.auth_token()
|
||||||
|
return self.generated_auth_token, self.expiry
|
||||||
|
|
||||||
|
|
||||||
|
class OktaAuthenticationProvider(AuthenticationProvider):
|
||||||
|
"""
|
||||||
|
Prepare the Json Web Token for Okta auth
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: OpenMetadataServerConfig):
|
||||||
|
self.config = config
|
||||||
|
self.security_config: OktaSSOConfig = self.config.securityConfig
|
||||||
|
|
||||||
|
self.generated_auth_token = None
|
||||||
|
self.expiry = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, config: OpenMetadataServerConfig):
|
||||||
|
return cls(config)
|
||||||
|
|
||||||
|
async def auth_token(self) -> None:
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from urllib.parse import quote, urlencode
|
||||||
|
|
||||||
|
from okta.cache.okta_cache import OktaCache
|
||||||
|
from okta.jwt import JWT, jwt
|
||||||
|
from okta.request_executor import RequestExecutor
|
||||||
|
|
||||||
|
try:
|
||||||
|
my_pem, my_jwk = JWT.get_PEM_JWK(self.security_config.privateKey)
|
||||||
|
issued_time = int(time.time())
|
||||||
|
expiry_time = issued_time + JWT.ONE_HOUR
|
||||||
|
generated_jwt_id = str(uuid.uuid4())
|
||||||
|
claims = {
|
||||||
|
"sub": self.security_config.clientId,
|
||||||
|
"iat": issued_time,
|
||||||
|
"exp": expiry_time,
|
||||||
|
"iss": self.security_config.clientId,
|
||||||
|
"aud": self.security_config.orgURL,
|
||||||
|
"jti": generated_jwt_id,
|
||||||
|
}
|
||||||
|
token = jwt.encode(claims, my_jwk.to_dict(), JWT.HASH_ALGORITHM)
|
||||||
|
config = {
|
||||||
|
"client": {
|
||||||
|
"orgUrl": self.security_config.orgURL,
|
||||||
|
"authorizationMode": "BEARER",
|
||||||
|
"rateLimit": {},
|
||||||
|
"privateKey": self.security_config.privateKey,
|
||||||
|
"clientId": self.security_config.clientId,
|
||||||
|
"token": token,
|
||||||
|
"scopes": self.security_config.scopes,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
request_exec = RequestExecutor(
|
||||||
|
config=config, cache=OktaCache(ttl=expiry_time, tti=issued_time)
|
||||||
|
)
|
||||||
|
parameters = {
|
||||||
|
"grant_type": "client_credentials",
|
||||||
|
"scope": " ".join(config["client"]["scopes"]),
|
||||||
|
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||||
|
"client_assertion": token,
|
||||||
|
}
|
||||||
|
encoded_parameters = urlencode(parameters, quote_via=quote)
|
||||||
|
url = f"{self.security_config.orgURL}?" + encoded_parameters
|
||||||
|
token_request_object = await request_exec.create_request(
|
||||||
|
"POST",
|
||||||
|
url,
|
||||||
|
None,
|
||||||
|
{
|
||||||
|
"Accept": "application/json",
|
||||||
|
"Content-Type": "application/x-www-form-urlencoded",
|
||||||
|
},
|
||||||
|
oauth=True,
|
||||||
|
)
|
||||||
|
_, res_details, res_json, err = await request_exec.fire_request(
|
||||||
|
token_request_object[0]
|
||||||
|
)
|
||||||
|
if err:
|
||||||
|
raise APIError(f"{err}")
|
||||||
|
response_dict = json.loads(res_json)
|
||||||
|
self.generated_auth_token = response_dict.get("access_token")
|
||||||
|
self.expiry = response_dict.get("expires_in")
|
||||||
|
except Exception as err:
|
||||||
|
logger.debug(traceback.print_exc())
|
||||||
|
logger.error(err)
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
def get_access_token(self):
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(self.auth_token())
|
||||||
|
return self.generated_auth_token, self.expiry
|
||||||
|
|
||||||
|
|
||||||
|
class Auth0AuthenticationProvider(AuthenticationProvider):
|
||||||
|
"""
|
||||||
|
OAuth authentication implementation
|
||||||
|
Args:
|
||||||
|
config (MetadataServerConfig):
|
||||||
|
Attributes:
|
||||||
|
config (MetadataServerConfig)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: OpenMetadataServerConfig):
|
||||||
|
self.config = config
|
||||||
|
self.security_config: Auth0SSOConfig = self.config.securityConfig
|
||||||
|
|
||||||
|
self.generated_auth_token = None
|
||||||
|
self.expiry = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, config: OpenMetadataServerConfig):
|
||||||
|
return cls(config)
|
||||||
|
|
||||||
|
def auth_token(self) -> None:
|
||||||
|
conn = http.client.HTTPSConnection(self.security_config.domain)
|
||||||
|
payload = (
|
||||||
|
f"grant_type=client_credentials&client_id={self.security_config.clientId}"
|
||||||
|
f"&client_secret={self.security_config.secretKey}&audience=https://{self.security_config.domain}/api/v2/"
|
||||||
|
)
|
||||||
|
headers = {"content-type": "application/x-www-form-urlencoded"}
|
||||||
|
conn.request(
|
||||||
|
"POST", f"/{self.security_config.domain}/oauth/token", payload, headers
|
||||||
|
)
|
||||||
|
res = conn.getresponse()
|
||||||
|
data = res.read()
|
||||||
|
token = json.loads(data.decode("utf-8"))
|
||||||
|
self.generated_auth_token = token["access_token"]
|
||||||
|
self.expiry = token["expires_in"]
|
||||||
|
|
||||||
|
def get_access_token(self):
|
||||||
|
self.auth_token()
|
||||||
|
return self.generated_auth_token, self.expiry
|
||||||
|
|
||||||
|
|
||||||
|
class AzureAuthenticationProvider(AuthenticationProvider):
|
||||||
|
"""
|
||||||
|
Prepare the Json Web Token for Azure auth
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO: Prepare JSON for Azure Auth
|
||||||
|
def __init__(self, config: OpenMetadataServerConfig):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.generated_auth_token = None
|
||||||
|
self.expiry = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, config: OpenMetadataServerConfig):
|
||||||
|
return cls(config)
|
||||||
|
|
||||||
|
def auth_token(self) -> None:
|
||||||
|
from msal import (
|
||||||
|
ConfidentialClientApplication, # pylint: disable=import-outside-toplevel
|
||||||
|
)
|
||||||
|
|
||||||
|
app = ConfidentialClientApplication(
|
||||||
|
client_id=self.config.client_id,
|
||||||
|
client_credential=self.config.secret_key,
|
||||||
|
authority=self.config.authority,
|
||||||
|
)
|
||||||
|
token = app.acquire_token_for_client(scopes=self.config.scopes)
|
||||||
|
try:
|
||||||
|
self.generated_auth_token = token["access_token"]
|
||||||
|
self.expiry = token["expires_in"]
|
||||||
|
|
||||||
|
except KeyError as err:
|
||||||
|
logger.error(f"Invalid Credentials - {err}")
|
||||||
|
logger.debug(traceback.format_exc())
|
||||||
|
logger.debug(traceback.print_exc())
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def get_access_token(self):
|
||||||
|
self.auth_token()
|
||||||
|
return self.generated_auth_token, self.expiry
|
||||||
|
|
||||||
|
|
||||||
|
class CustomOIDCAuthenticationProvider(AuthenticationProvider):
|
||||||
|
"""
|
||||||
|
Custom OIDC authentication implementation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (MetadataServerConfig):
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
config (MetadataServerConfig)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: OpenMetadataServerConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.security_config: CustomOidcSSOConfig = self.config.securityConfig
|
||||||
|
|
||||||
|
self.generated_auth_token = None
|
||||||
|
self.expiry = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls, config: OpenMetadataServerConfig
|
||||||
|
) -> "CustomOIDCAuthenticationProvider":
|
||||||
|
return cls(config)
|
||||||
|
|
||||||
|
def auth_token(self) -> None:
|
||||||
|
data = {
|
||||||
|
"grant_type": "client_credentials",
|
||||||
|
"client_id": self.security_config.clientId,
|
||||||
|
"client_secret": self.security_config.secretKey,
|
||||||
|
}
|
||||||
|
response = requests.post(
|
||||||
|
url=self.security_config.tokenEndpoint,
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
|
if response.ok:
|
||||||
|
response_json = response.json()
|
||||||
|
self.generated_auth_token = response_json["access_token"]
|
||||||
|
self.expiry = response_json["expires_in"]
|
||||||
|
else:
|
||||||
|
raise APIError(
|
||||||
|
error={"message": response.text}, http_error=response.status_code
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_access_token(self) -> Tuple[str, int]:
|
||||||
|
self.auth_token()
|
||||||
|
return self.generated_auth_token, self.expiry
|
||||||
|
@ -60,13 +60,9 @@ from metadata.ingestion.ometa.mixins.pipeline_mixin import OMetaPipelineMixin
|
|||||||
from metadata.ingestion.ometa.mixins.table_mixin import OMetaTableMixin
|
from metadata.ingestion.ometa.mixins.table_mixin import OMetaTableMixin
|
||||||
from metadata.ingestion.ometa.mixins.tag_mixin import OMetaTagMixin
|
from metadata.ingestion.ometa.mixins.tag_mixin import OMetaTagMixin
|
||||||
from metadata.ingestion.ometa.mixins.version_mixin import OMetaVersionMixin
|
from metadata.ingestion.ometa.mixins.version_mixin import OMetaVersionMixin
|
||||||
from metadata.ingestion.ometa.openmetadata_rest import (
|
from metadata.ingestion.ometa.provider_registry import (
|
||||||
Auth0AuthenticationProvider,
|
InvalidAuthProviderException,
|
||||||
AzureAuthenticationProvider,
|
auth_provider_registry,
|
||||||
CustomOIDCAuthenticationProvider,
|
|
||||||
GoogleAuthenticationProvider,
|
|
||||||
NoOpAuthenticationProvider,
|
|
||||||
OktaAuthenticationProvider,
|
|
||||||
)
|
)
|
||||||
from metadata.ingestion.ometa.utils import get_entity_type, model_str
|
from metadata.ingestion.ometa.utils import get_entity_type, model_str
|
||||||
|
|
||||||
@ -142,30 +138,18 @@ class OpenMetadata(
|
|||||||
|
|
||||||
def __init__(self, config: OpenMetadataServerConfig, raw_data: bool = False):
|
def __init__(self, config: OpenMetadataServerConfig, raw_data: bool = False):
|
||||||
self.config = config
|
self.config = config
|
||||||
if self.config.authProvider.value == "google":
|
|
||||||
self._auth_provider: AuthenticationProvider = (
|
# Load the auth provider init from the registry
|
||||||
GoogleAuthenticationProvider.create(self.config)
|
auth_provider_fn = auth_provider_registry.registry.get(
|
||||||
)
|
self.config.authProvider.value
|
||||||
elif self.config.authProvider.value == "okta":
|
)
|
||||||
self._auth_provider: AuthenticationProvider = (
|
if not auth_provider_fn:
|
||||||
OktaAuthenticationProvider.create(self.config)
|
raise InvalidAuthProviderException(
|
||||||
)
|
f"Cannot find {self.config.authProvider.value} in {auth_provider_registry.registry}"
|
||||||
elif self.config.authProvider.value == "auth0":
|
|
||||||
self._auth_provider: AuthenticationProvider = (
|
|
||||||
Auth0AuthenticationProvider.create(self.config)
|
|
||||||
)
|
|
||||||
elif self.config.authProvider.value == "azure":
|
|
||||||
self._auth_provider: AuthenticationProvider = (
|
|
||||||
AzureAuthenticationProvider.create(self.config)
|
|
||||||
)
|
|
||||||
elif self.config.authProvider.value == "custom-oidc":
|
|
||||||
self._auth_provider: AuthenticationProvider = (
|
|
||||||
CustomOIDCAuthenticationProvider.create(self.config)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._auth_provider: AuthenticationProvider = (
|
|
||||||
NoOpAuthenticationProvider.create(self.config)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._auth_provider = auth_provider_fn(self.config)
|
||||||
|
|
||||||
client_config: ClientConfig = ClientConfig(
|
client_config: ClientConfig = ClientConfig(
|
||||||
base_url=self.config.hostPort,
|
base_url=self.config.hostPort,
|
||||||
api_version=self.config.apiVersion,
|
api_version=self.config.apiVersion,
|
||||||
|
@ -1,375 +0,0 @@
|
|||||||
# Copyright 2021 Collate
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""
|
|
||||||
Helper classes to model OpenMetadata Entities,
|
|
||||||
server configuration and auth.
|
|
||||||
"""
|
|
||||||
import http.client
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
import requests
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from metadata.generated.schema.entity.data.dashboard import Dashboard
|
|
||||||
from metadata.generated.schema.entity.data.database import Database
|
|
||||||
from metadata.generated.schema.entity.data.pipeline import Pipeline
|
|
||||||
from metadata.generated.schema.entity.data.table import Table, TableProfile
|
|
||||||
from metadata.generated.schema.entity.data.topic import Topic
|
|
||||||
from metadata.generated.schema.entity.services.databaseService import DatabaseService
|
|
||||||
from metadata.generated.schema.entity.tags.tagCategory import Tag
|
|
||||||
from metadata.generated.schema.metadataIngestion.workflow import (
|
|
||||||
Auth0SSOConfig,
|
|
||||||
CustomOidcSSOConfig,
|
|
||||||
GoogleSSOConfig,
|
|
||||||
OktaSSOConfig,
|
|
||||||
OpenMetadataServerConfig,
|
|
||||||
)
|
|
||||||
from metadata.ingestion.ometa.auth_provider import AuthenticationProvider
|
|
||||||
from metadata.ingestion.ometa.client import APIError
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
DatabaseServiceEntities = List[DatabaseService]
|
|
||||||
DatabaseEntities = List[Database]
|
|
||||||
Tags = List[Tag]
|
|
||||||
TableProfiles = List[TableProfile]
|
|
||||||
|
|
||||||
|
|
||||||
class TableEntities(BaseModel):
|
|
||||||
"""
|
|
||||||
Table entity pydantic model
|
|
||||||
"""
|
|
||||||
|
|
||||||
tables: List[Table]
|
|
||||||
total: int
|
|
||||||
after: str = None
|
|
||||||
|
|
||||||
|
|
||||||
class TopicEntities(BaseModel):
|
|
||||||
"""
|
|
||||||
Topic entity pydantic model
|
|
||||||
"""
|
|
||||||
|
|
||||||
topics: List[Topic]
|
|
||||||
total: int
|
|
||||||
after: str = None
|
|
||||||
|
|
||||||
|
|
||||||
class DashboardEntities(BaseModel):
|
|
||||||
"""
|
|
||||||
Dashboard entity pydantic model
|
|
||||||
"""
|
|
||||||
|
|
||||||
dashboards: List[Dashboard]
|
|
||||||
total: int
|
|
||||||
after: str = None
|
|
||||||
|
|
||||||
|
|
||||||
class PipelineEntities(BaseModel):
|
|
||||||
"""
|
|
||||||
Pipeline entity pydantic model
|
|
||||||
"""
|
|
||||||
|
|
||||||
pipelines: List[Pipeline]
|
|
||||||
total: int
|
|
||||||
after: str = None
|
|
||||||
|
|
||||||
|
|
||||||
class NoOpAuthenticationProvider(AuthenticationProvider):
|
|
||||||
"""
|
|
||||||
Extends AuthenticationProvider class
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (MetadataServerConfig):
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
config (MetadataServerConfig)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: OpenMetadataServerConfig):
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create(cls, config: OpenMetadataServerConfig):
|
|
||||||
return cls(config)
|
|
||||||
|
|
||||||
def auth_token(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_access_token(self):
|
|
||||||
return ("no_token", None)
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleAuthenticationProvider(AuthenticationProvider):
|
|
||||||
"""
|
|
||||||
Google authentication implementation
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (MetadataServerConfig):
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
config (MetadataServerConfig)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: OpenMetadataServerConfig):
|
|
||||||
self.config = config
|
|
||||||
self.security_config: GoogleSSOConfig = self.config.securityConfig
|
|
||||||
|
|
||||||
self.generated_auth_token = None
|
|
||||||
self.expiry = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create(cls, config: OpenMetadataServerConfig):
|
|
||||||
return cls(config)
|
|
||||||
|
|
||||||
def auth_token(self) -> None:
|
|
||||||
import google.auth
|
|
||||||
import google.auth.transport.requests
|
|
||||||
from google.oauth2 import service_account
|
|
||||||
|
|
||||||
credentials = service_account.IDTokenCredentials.from_service_account_file(
|
|
||||||
self.security_config.secretKey,
|
|
||||||
target_audience=self.security_config.audience,
|
|
||||||
)
|
|
||||||
request = google.auth.transport.requests.Request()
|
|
||||||
credentials.refresh(request)
|
|
||||||
self.generated_auth_token = credentials.token
|
|
||||||
self.expiry = credentials.expiry
|
|
||||||
|
|
||||||
def get_access_token(self):
|
|
||||||
self.auth_token()
|
|
||||||
return self.generated_auth_token, self.expiry
|
|
||||||
|
|
||||||
|
|
||||||
class OktaAuthenticationProvider(AuthenticationProvider):
|
|
||||||
"""
|
|
||||||
Prepare the Json Web Token for Okta auth
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: OpenMetadataServerConfig):
|
|
||||||
self.config = config
|
|
||||||
self.security_config: OktaSSOConfig = self.config.securityConfig
|
|
||||||
|
|
||||||
self.generated_auth_token = None
|
|
||||||
self.expiry = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create(cls, config: OpenMetadataServerConfig):
|
|
||||||
return cls(config)
|
|
||||||
|
|
||||||
async def auth_token(self) -> None:
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from urllib.parse import quote, urlencode
|
|
||||||
|
|
||||||
from okta.cache.okta_cache import OktaCache
|
|
||||||
from okta.jwt import JWT, jwt
|
|
||||||
from okta.request_executor import RequestExecutor
|
|
||||||
|
|
||||||
try:
|
|
||||||
my_pem, my_jwk = JWT.get_PEM_JWK(self.security_config.privateKey)
|
|
||||||
issued_time = int(time.time())
|
|
||||||
expiry_time = issued_time + JWT.ONE_HOUR
|
|
||||||
generated_jwt_id = str(uuid.uuid4())
|
|
||||||
claims = {
|
|
||||||
"sub": self.security_config.clientId,
|
|
||||||
"iat": issued_time,
|
|
||||||
"exp": expiry_time,
|
|
||||||
"iss": self.security_config.clientId,
|
|
||||||
"aud": self.security_config.orgURL,
|
|
||||||
"jti": generated_jwt_id,
|
|
||||||
}
|
|
||||||
token = jwt.encode(claims, my_jwk.to_dict(), JWT.HASH_ALGORITHM)
|
|
||||||
config = {
|
|
||||||
"client": {
|
|
||||||
"orgUrl": self.security_config.orgURL,
|
|
||||||
"authorizationMode": "BEARER",
|
|
||||||
"rateLimit": {},
|
|
||||||
"privateKey": self.security_config.privateKey,
|
|
||||||
"clientId": self.security_config.clientId,
|
|
||||||
"token": token,
|
|
||||||
"scopes": self.security_config.scopes,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
request_exec = RequestExecutor(
|
|
||||||
config=config, cache=OktaCache(ttl=expiry_time, tti=issued_time)
|
|
||||||
)
|
|
||||||
parameters = {
|
|
||||||
"grant_type": "client_credentials",
|
|
||||||
"scope": " ".join(config["client"]["scopes"]),
|
|
||||||
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
|
||||||
"client_assertion": token,
|
|
||||||
}
|
|
||||||
encoded_parameters = urlencode(parameters, quote_via=quote)
|
|
||||||
url = f"{self.security_config.orgURL}?" + encoded_parameters
|
|
||||||
token_request_object = await request_exec.create_request(
|
|
||||||
"POST",
|
|
||||||
url,
|
|
||||||
None,
|
|
||||||
{
|
|
||||||
"Accept": "application/json",
|
|
||||||
"Content-Type": "application/x-www-form-urlencoded",
|
|
||||||
},
|
|
||||||
oauth=True,
|
|
||||||
)
|
|
||||||
_, res_details, res_json, err = await request_exec.fire_request(
|
|
||||||
token_request_object[0]
|
|
||||||
)
|
|
||||||
if err:
|
|
||||||
raise APIError(f"{err}")
|
|
||||||
response_dict = json.loads(res_json)
|
|
||||||
self.generated_auth_token = response_dict.get("access_token")
|
|
||||||
self.expiry = response_dict.get("expires_in")
|
|
||||||
except Exception as err:
|
|
||||||
logger.debug(traceback.print_exc())
|
|
||||||
logger.error(err)
|
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
def get_access_token(self):
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
asyncio.run(self.auth_token())
|
|
||||||
return self.generated_auth_token, self.expiry
|
|
||||||
|
|
||||||
|
|
||||||
class Auth0AuthenticationProvider(AuthenticationProvider):
|
|
||||||
"""
|
|
||||||
OAuth authentication implementation
|
|
||||||
Args:
|
|
||||||
config (MetadataServerConfig):
|
|
||||||
Attributes:
|
|
||||||
config (MetadataServerConfig)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: OpenMetadataServerConfig):
|
|
||||||
self.config = config
|
|
||||||
self.security_config: Auth0SSOConfig = self.config.securityConfig
|
|
||||||
|
|
||||||
self.generated_auth_token = None
|
|
||||||
self.expiry = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create(cls, config: OpenMetadataServerConfig):
|
|
||||||
return cls(config)
|
|
||||||
|
|
||||||
def auth_token(self) -> None:
|
|
||||||
conn = http.client.HTTPSConnection(self.security_config.domain)
|
|
||||||
payload = (
|
|
||||||
f"grant_type=client_credentials&client_id={self.security_config.clientId}"
|
|
||||||
f"&client_secret={self.security_config.secretKey}&audience=https://{self.security_config.domain}/api/v2/"
|
|
||||||
)
|
|
||||||
headers = {"content-type": "application/x-www-form-urlencoded"}
|
|
||||||
conn.request(
|
|
||||||
"POST", f"/{self.security_config.domain}/oauth/token", payload, headers
|
|
||||||
)
|
|
||||||
res = conn.getresponse()
|
|
||||||
data = res.read()
|
|
||||||
token = json.loads(data.decode("utf-8"))
|
|
||||||
self.generated_auth_token = token["access_token"]
|
|
||||||
self.expiry = token["expires_in"]
|
|
||||||
|
|
||||||
def get_access_token(self):
|
|
||||||
self.auth_token()
|
|
||||||
return self.generated_auth_token, self.expiry
|
|
||||||
|
|
||||||
|
|
||||||
class AzureAuthenticationProvider(AuthenticationProvider):
|
|
||||||
"""
|
|
||||||
Prepare the Json Web Token for Azure auth
|
|
||||||
"""
|
|
||||||
|
|
||||||
# TODO: Prepare JSON for Azure Auth
|
|
||||||
def __init__(self, config: OpenMetadataServerConfig):
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
self.generated_auth_token = None
|
|
||||||
self.expiry = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create(cls, config: OpenMetadataServerConfig):
|
|
||||||
return cls(config)
|
|
||||||
|
|
||||||
def auth_token(self) -> None:
|
|
||||||
from msal import (
|
|
||||||
ConfidentialClientApplication, # pylint: disable=import-outside-toplevel
|
|
||||||
)
|
|
||||||
|
|
||||||
app = ConfidentialClientApplication(
|
|
||||||
client_id=self.config.client_id,
|
|
||||||
client_credential=self.config.secret_key,
|
|
||||||
authority=self.config.authority,
|
|
||||||
)
|
|
||||||
token = app.acquire_token_for_client(scopes=self.config.scopes)
|
|
||||||
try:
|
|
||||||
self.generated_auth_token = token["access_token"]
|
|
||||||
self.expiry = token["expires_in"]
|
|
||||||
|
|
||||||
except KeyError as err:
|
|
||||||
logger.error(f"Invalid Credentials - {err}")
|
|
||||||
logger.debug(traceback.format_exc())
|
|
||||||
logger.debug(traceback.print_exc())
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
def get_access_token(self):
|
|
||||||
self.auth_token()
|
|
||||||
return self.generated_auth_token, self.expiry
|
|
||||||
|
|
||||||
|
|
||||||
class CustomOIDCAuthenticationProvider(AuthenticationProvider):
|
|
||||||
"""
|
|
||||||
Custom OIDC authentication implementation
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (MetadataServerConfig):
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
config (MetadataServerConfig)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: OpenMetadataServerConfig) -> None:
|
|
||||||
self.config = config
|
|
||||||
self.security_config: CustomOidcSSOConfig = self.config.securityConfig
|
|
||||||
|
|
||||||
self.generated_auth_token = None
|
|
||||||
self.expiry = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create(
|
|
||||||
cls, config: OpenMetadataServerConfig
|
|
||||||
) -> "CustomOIDCAuthenticationProvider":
|
|
||||||
return cls(config)
|
|
||||||
|
|
||||||
def auth_token(self) -> None:
|
|
||||||
data = {
|
|
||||||
"grant_type": "client_credentials",
|
|
||||||
"client_id": self.security_config.clientId,
|
|
||||||
"client_secret": self.security_config.secretKey,
|
|
||||||
}
|
|
||||||
response = requests.post(
|
|
||||||
url=self.security_config.tokenEndpoint,
|
|
||||||
data=data,
|
|
||||||
)
|
|
||||||
if response.ok:
|
|
||||||
response_json = response.json()
|
|
||||||
self.generated_auth_token = response_json["access_token"]
|
|
||||||
self.expiry = response_json["expires_in"]
|
|
||||||
else:
|
|
||||||
raise APIError(
|
|
||||||
error={"message": response.text}, http_error=response.status_code
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_access_token(self) -> Tuple[str, int]:
|
|
||||||
self.auth_token()
|
|
||||||
return self.generated_auth_token, self.expiry
|
|
67
ingestion/src/metadata/ingestion/ometa/provider_registry.py
Normal file
67
ingestion/src/metadata/ingestion/ometa/provider_registry.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
# Copyright 2021 Collate
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Register auth provider init functions here
|
||||||
|
"""
|
||||||
|
from metadata.generated.schema.metadataIngestion.workflow import (
|
||||||
|
AuthProvider,
|
||||||
|
OpenMetadataServerConfig,
|
||||||
|
)
|
||||||
|
from metadata.ingestion.ometa.auth_provider import (
|
||||||
|
Auth0AuthenticationProvider,
|
||||||
|
AuthenticationProvider,
|
||||||
|
AzureAuthenticationProvider,
|
||||||
|
CustomOIDCAuthenticationProvider,
|
||||||
|
GoogleAuthenticationProvider,
|
||||||
|
NoOpAuthenticationProvider,
|
||||||
|
OktaAuthenticationProvider,
|
||||||
|
)
|
||||||
|
from metadata.utils.dispatch import enum_register
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidAuthProviderException(Exception):
|
||||||
|
"""
|
||||||
|
Raised when we cannot find a valid auth provider
|
||||||
|
in the registry
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
auth_provider_registry = enum_register()
|
||||||
|
|
||||||
|
|
||||||
|
@auth_provider_registry.add(AuthProvider.no_auth.value)
|
||||||
|
def no_auth_init(config: OpenMetadataServerConfig) -> AuthenticationProvider:
|
||||||
|
return NoOpAuthenticationProvider.create(config)
|
||||||
|
|
||||||
|
|
||||||
|
@auth_provider_registry.add(AuthProvider.google.value)
|
||||||
|
def google_auth_init(config: OpenMetadataServerConfig) -> AuthenticationProvider:
|
||||||
|
return GoogleAuthenticationProvider.create(config)
|
||||||
|
|
||||||
|
|
||||||
|
@auth_provider_registry.add(AuthProvider.okta.value)
|
||||||
|
def okta_auth_init(config: OpenMetadataServerConfig) -> AuthenticationProvider:
|
||||||
|
return OktaAuthenticationProvider.create(config)
|
||||||
|
|
||||||
|
|
||||||
|
@auth_provider_registry.add(AuthProvider.auth0.value)
|
||||||
|
def auth0_auth_init(config: OpenMetadataServerConfig) -> AuthenticationProvider:
|
||||||
|
return Auth0AuthenticationProvider.create(config)
|
||||||
|
|
||||||
|
|
||||||
|
@auth_provider_registry.add("azure") # TODO: update JSON
|
||||||
|
def azure_auth_init(config: OpenMetadataServerConfig) -> AuthenticationProvider:
|
||||||
|
return AzureAuthenticationProvider.create(config)
|
||||||
|
|
||||||
|
|
||||||
|
@auth_provider_registry.add(AuthProvider.custom_oidc.value)
|
||||||
|
def custom_oidc_auth_init(config: OpenMetadataServerConfig) -> AuthenticationProvider:
|
||||||
|
return CustomOIDCAuthenticationProvider.create(config)
|
33
ingestion/src/metadata/utils/dispatch.py
Normal file
33
ingestion/src/metadata/utils/dispatch.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
# Copyright 2021 Collate
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Helper that implements custom dispatcher logic
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
|
|
||||||
|
def enum_register():
|
||||||
|
"""
|
||||||
|
Helps us register custom function for enum values
|
||||||
|
"""
|
||||||
|
registry = dict()
|
||||||
|
|
||||||
|
def add(name: str):
|
||||||
|
def inner(fn):
|
||||||
|
registry[name] = fn
|
||||||
|
return fn
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
Register = namedtuple("Register", ["add", "registry"])
|
||||||
|
return Register(add, registry)
|
Loading…
x
Reference in New Issue
Block a user