mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-09-21 14:59:57 +00:00
parent
3651efd7f5
commit
8a7fcf0e54
@ -11,11 +11,28 @@
|
||||
"""
|
||||
Interface definition for an Auth provider
|
||||
"""
|
||||
|
||||
import http.client
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import requests
|
||||
|
||||
from metadata.config.common import ConfigModel
|
||||
from metadata.generated.schema.metadataIngestion.workflow import (
|
||||
Auth0SSOConfig,
|
||||
CustomOidcSSOConfig,
|
||||
GoogleSSOConfig,
|
||||
OktaSSOConfig,
|
||||
OpenMetadataServerConfig,
|
||||
)
|
||||
from metadata.ingestion.ometa.client import APIError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(init=False) # type: ignore[misc]
|
||||
@ -53,3 +70,291 @@ class AuthenticationProvider(metaclass=ABCMeta):
|
||||
Returns:
|
||||
str
|
||||
"""
|
||||
|
||||
|
||||
class NoOpAuthenticationProvider(AuthenticationProvider):
|
||||
"""
|
||||
Extends AuthenticationProvider class
|
||||
|
||||
Args:
|
||||
config (MetadataServerConfig):
|
||||
|
||||
Attributes:
|
||||
config (MetadataServerConfig)
|
||||
"""
|
||||
|
||||
def __init__(self, config: OpenMetadataServerConfig):
|
||||
self.config = config
|
||||
|
||||
@classmethod
|
||||
def create(cls, config: OpenMetadataServerConfig):
|
||||
return cls(config)
|
||||
|
||||
def auth_token(self):
|
||||
pass
|
||||
|
||||
def get_access_token(self):
|
||||
return "no_token", None
|
||||
|
||||
|
||||
class GoogleAuthenticationProvider(AuthenticationProvider):
|
||||
"""
|
||||
Google authentication implementation
|
||||
|
||||
Args:
|
||||
config (MetadataServerConfig):
|
||||
|
||||
Attributes:
|
||||
config (MetadataServerConfig)
|
||||
"""
|
||||
|
||||
def __init__(self, config: OpenMetadataServerConfig):
|
||||
self.config = config
|
||||
self.security_config: GoogleSSOConfig = self.config.securityConfig
|
||||
|
||||
self.generated_auth_token = None
|
||||
self.expiry = None
|
||||
|
||||
@classmethod
|
||||
def create(cls, config: OpenMetadataServerConfig):
|
||||
return cls(config)
|
||||
|
||||
def auth_token(self) -> None:
|
||||
import google.auth
|
||||
import google.auth.transport.requests
|
||||
from google.oauth2 import service_account
|
||||
|
||||
credentials = service_account.IDTokenCredentials.from_service_account_file(
|
||||
self.security_config.secretKey,
|
||||
target_audience=self.security_config.audience,
|
||||
)
|
||||
request = google.auth.transport.requests.Request()
|
||||
credentials.refresh(request)
|
||||
self.generated_auth_token = credentials.token
|
||||
self.expiry = credentials.expiry
|
||||
|
||||
def get_access_token(self):
|
||||
self.auth_token()
|
||||
return self.generated_auth_token, self.expiry
|
||||
|
||||
|
||||
class OktaAuthenticationProvider(AuthenticationProvider):
|
||||
"""
|
||||
Prepare the Json Web Token for Okta auth
|
||||
"""
|
||||
|
||||
def __init__(self, config: OpenMetadataServerConfig):
|
||||
self.config = config
|
||||
self.security_config: OktaSSOConfig = self.config.securityConfig
|
||||
|
||||
self.generated_auth_token = None
|
||||
self.expiry = None
|
||||
|
||||
@classmethod
|
||||
def create(cls, config: OpenMetadataServerConfig):
|
||||
return cls(config)
|
||||
|
||||
async def auth_token(self) -> None:
|
||||
import time
|
||||
import uuid
|
||||
from urllib.parse import quote, urlencode
|
||||
|
||||
from okta.cache.okta_cache import OktaCache
|
||||
from okta.jwt import JWT, jwt
|
||||
from okta.request_executor import RequestExecutor
|
||||
|
||||
try:
|
||||
my_pem, my_jwk = JWT.get_PEM_JWK(self.security_config.privateKey)
|
||||
issued_time = int(time.time())
|
||||
expiry_time = issued_time + JWT.ONE_HOUR
|
||||
generated_jwt_id = str(uuid.uuid4())
|
||||
claims = {
|
||||
"sub": self.security_config.clientId,
|
||||
"iat": issued_time,
|
||||
"exp": expiry_time,
|
||||
"iss": self.security_config.clientId,
|
||||
"aud": self.security_config.orgURL,
|
||||
"jti": generated_jwt_id,
|
||||
}
|
||||
token = jwt.encode(claims, my_jwk.to_dict(), JWT.HASH_ALGORITHM)
|
||||
config = {
|
||||
"client": {
|
||||
"orgUrl": self.security_config.orgURL,
|
||||
"authorizationMode": "BEARER",
|
||||
"rateLimit": {},
|
||||
"privateKey": self.security_config.privateKey,
|
||||
"clientId": self.security_config.clientId,
|
||||
"token": token,
|
||||
"scopes": self.security_config.scopes,
|
||||
}
|
||||
}
|
||||
request_exec = RequestExecutor(
|
||||
config=config, cache=OktaCache(ttl=expiry_time, tti=issued_time)
|
||||
)
|
||||
parameters = {
|
||||
"grant_type": "client_credentials",
|
||||
"scope": " ".join(config["client"]["scopes"]),
|
||||
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
"client_assertion": token,
|
||||
}
|
||||
encoded_parameters = urlencode(parameters, quote_via=quote)
|
||||
url = f"{self.security_config.orgURL}?" + encoded_parameters
|
||||
token_request_object = await request_exec.create_request(
|
||||
"POST",
|
||||
url,
|
||||
None,
|
||||
{
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
},
|
||||
oauth=True,
|
||||
)
|
||||
_, res_details, res_json, err = await request_exec.fire_request(
|
||||
token_request_object[0]
|
||||
)
|
||||
if err:
|
||||
raise APIError(f"{err}")
|
||||
response_dict = json.loads(res_json)
|
||||
self.generated_auth_token = response_dict.get("access_token")
|
||||
self.expiry = response_dict.get("expires_in")
|
||||
except Exception as err:
|
||||
logger.debug(traceback.print_exc())
|
||||
logger.error(err)
|
||||
sys.exit()
|
||||
|
||||
def get_access_token(self):
|
||||
import asyncio
|
||||
|
||||
asyncio.run(self.auth_token())
|
||||
return self.generated_auth_token, self.expiry
|
||||
|
||||
|
||||
class Auth0AuthenticationProvider(AuthenticationProvider):
|
||||
"""
|
||||
OAuth authentication implementation
|
||||
Args:
|
||||
config (MetadataServerConfig):
|
||||
Attributes:
|
||||
config (MetadataServerConfig)
|
||||
"""
|
||||
|
||||
def __init__(self, config: OpenMetadataServerConfig):
|
||||
self.config = config
|
||||
self.security_config: Auth0SSOConfig = self.config.securityConfig
|
||||
|
||||
self.generated_auth_token = None
|
||||
self.expiry = None
|
||||
|
||||
@classmethod
|
||||
def create(cls, config: OpenMetadataServerConfig):
|
||||
return cls(config)
|
||||
|
||||
def auth_token(self) -> None:
|
||||
conn = http.client.HTTPSConnection(self.security_config.domain)
|
||||
payload = (
|
||||
f"grant_type=client_credentials&client_id={self.security_config.clientId}"
|
||||
f"&client_secret={self.security_config.secretKey}&audience=https://{self.security_config.domain}/api/v2/"
|
||||
)
|
||||
headers = {"content-type": "application/x-www-form-urlencoded"}
|
||||
conn.request(
|
||||
"POST", f"/{self.security_config.domain}/oauth/token", payload, headers
|
||||
)
|
||||
res = conn.getresponse()
|
||||
data = res.read()
|
||||
token = json.loads(data.decode("utf-8"))
|
||||
self.generated_auth_token = token["access_token"]
|
||||
self.expiry = token["expires_in"]
|
||||
|
||||
def get_access_token(self):
|
||||
self.auth_token()
|
||||
return self.generated_auth_token, self.expiry
|
||||
|
||||
|
||||
class AzureAuthenticationProvider(AuthenticationProvider):
|
||||
"""
|
||||
Prepare the Json Web Token for Azure auth
|
||||
"""
|
||||
|
||||
# TODO: Prepare JSON for Azure Auth
|
||||
def __init__(self, config: OpenMetadataServerConfig):
|
||||
self.config = config
|
||||
|
||||
self.generated_auth_token = None
|
||||
self.expiry = None
|
||||
|
||||
@classmethod
|
||||
def create(cls, config: OpenMetadataServerConfig):
|
||||
return cls(config)
|
||||
|
||||
def auth_token(self) -> None:
|
||||
from msal import (
|
||||
ConfidentialClientApplication, # pylint: disable=import-outside-toplevel
|
||||
)
|
||||
|
||||
app = ConfidentialClientApplication(
|
||||
client_id=self.config.client_id,
|
||||
client_credential=self.config.secret_key,
|
||||
authority=self.config.authority,
|
||||
)
|
||||
token = app.acquire_token_for_client(scopes=self.config.scopes)
|
||||
try:
|
||||
self.generated_auth_token = token["access_token"]
|
||||
self.expiry = token["expires_in"]
|
||||
|
||||
except KeyError as err:
|
||||
logger.error(f"Invalid Credentials - {err}")
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.debug(traceback.print_exc())
|
||||
sys.exit(1)
|
||||
|
||||
def get_access_token(self):
|
||||
self.auth_token()
|
||||
return self.generated_auth_token, self.expiry
|
||||
|
||||
|
||||
class CustomOIDCAuthenticationProvider(AuthenticationProvider):
|
||||
"""
|
||||
Custom OIDC authentication implementation
|
||||
|
||||
Args:
|
||||
config (MetadataServerConfig):
|
||||
|
||||
Attributes:
|
||||
config (MetadataServerConfig)
|
||||
"""
|
||||
|
||||
def __init__(self, config: OpenMetadataServerConfig) -> None:
|
||||
self.config = config
|
||||
self.security_config: CustomOidcSSOConfig = self.config.securityConfig
|
||||
|
||||
self.generated_auth_token = None
|
||||
self.expiry = None
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls, config: OpenMetadataServerConfig
|
||||
) -> "CustomOIDCAuthenticationProvider":
|
||||
return cls(config)
|
||||
|
||||
def auth_token(self) -> None:
|
||||
data = {
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": self.security_config.clientId,
|
||||
"client_secret": self.security_config.secretKey,
|
||||
}
|
||||
response = requests.post(
|
||||
url=self.security_config.tokenEndpoint,
|
||||
data=data,
|
||||
)
|
||||
if response.ok:
|
||||
response_json = response.json()
|
||||
self.generated_auth_token = response_json["access_token"]
|
||||
self.expiry = response_json["expires_in"]
|
||||
else:
|
||||
raise APIError(
|
||||
error={"message": response.text}, http_error=response.status_code
|
||||
)
|
||||
|
||||
def get_access_token(self) -> Tuple[str, int]:
|
||||
self.auth_token()
|
||||
return self.generated_auth_token, self.expiry
|
||||
|
@ -60,13 +60,9 @@ from metadata.ingestion.ometa.mixins.pipeline_mixin import OMetaPipelineMixin
|
||||
from metadata.ingestion.ometa.mixins.table_mixin import OMetaTableMixin
|
||||
from metadata.ingestion.ometa.mixins.tag_mixin import OMetaTagMixin
|
||||
from metadata.ingestion.ometa.mixins.version_mixin import OMetaVersionMixin
|
||||
from metadata.ingestion.ometa.openmetadata_rest import (
|
||||
Auth0AuthenticationProvider,
|
||||
AzureAuthenticationProvider,
|
||||
CustomOIDCAuthenticationProvider,
|
||||
GoogleAuthenticationProvider,
|
||||
NoOpAuthenticationProvider,
|
||||
OktaAuthenticationProvider,
|
||||
from metadata.ingestion.ometa.provider_registry import (
|
||||
InvalidAuthProviderException,
|
||||
auth_provider_registry,
|
||||
)
|
||||
from metadata.ingestion.ometa.utils import get_entity_type, model_str
|
||||
|
||||
@ -142,30 +138,18 @@ class OpenMetadata(
|
||||
|
||||
def __init__(self, config: OpenMetadataServerConfig, raw_data: bool = False):
|
||||
self.config = config
|
||||
if self.config.authProvider.value == "google":
|
||||
self._auth_provider: AuthenticationProvider = (
|
||||
GoogleAuthenticationProvider.create(self.config)
|
||||
)
|
||||
elif self.config.authProvider.value == "okta":
|
||||
self._auth_provider: AuthenticationProvider = (
|
||||
OktaAuthenticationProvider.create(self.config)
|
||||
)
|
||||
elif self.config.authProvider.value == "auth0":
|
||||
self._auth_provider: AuthenticationProvider = (
|
||||
Auth0AuthenticationProvider.create(self.config)
|
||||
)
|
||||
elif self.config.authProvider.value == "azure":
|
||||
self._auth_provider: AuthenticationProvider = (
|
||||
AzureAuthenticationProvider.create(self.config)
|
||||
)
|
||||
elif self.config.authProvider.value == "custom-oidc":
|
||||
self._auth_provider: AuthenticationProvider = (
|
||||
CustomOIDCAuthenticationProvider.create(self.config)
|
||||
)
|
||||
else:
|
||||
self._auth_provider: AuthenticationProvider = (
|
||||
NoOpAuthenticationProvider.create(self.config)
|
||||
|
||||
# Load the auth provider init from the registry
|
||||
auth_provider_fn = auth_provider_registry.registry.get(
|
||||
self.config.authProvider.value
|
||||
)
|
||||
if not auth_provider_fn:
|
||||
raise InvalidAuthProviderException(
|
||||
f"Cannot find {self.config.authProvider.value} in {auth_provider_registry.registry}"
|
||||
)
|
||||
|
||||
self._auth_provider = auth_provider_fn(self.config)
|
||||
|
||||
client_config: ClientConfig = ClientConfig(
|
||||
base_url=self.config.hostPort,
|
||||
api_version=self.config.apiVersion,
|
||||
|
@ -1,375 +0,0 @@
|
||||
# Copyright 2021 Collate
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Helper classes to model OpenMetadata Entities,
|
||||
server configuration and auth.
|
||||
"""
|
||||
import http.client
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
from typing import List, Tuple
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metadata.generated.schema.entity.data.dashboard import Dashboard
|
||||
from metadata.generated.schema.entity.data.database import Database
|
||||
from metadata.generated.schema.entity.data.pipeline import Pipeline
|
||||
from metadata.generated.schema.entity.data.table import Table, TableProfile
|
||||
from metadata.generated.schema.entity.data.topic import Topic
|
||||
from metadata.generated.schema.entity.services.databaseService import DatabaseService
|
||||
from metadata.generated.schema.entity.tags.tagCategory import Tag
|
||||
from metadata.generated.schema.metadataIngestion.workflow import (
|
||||
Auth0SSOConfig,
|
||||
CustomOidcSSOConfig,
|
||||
GoogleSSOConfig,
|
||||
OktaSSOConfig,
|
||||
OpenMetadataServerConfig,
|
||||
)
|
||||
from metadata.ingestion.ometa.auth_provider import AuthenticationProvider
|
||||
from metadata.ingestion.ometa.client import APIError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DatabaseServiceEntities = List[DatabaseService]
|
||||
DatabaseEntities = List[Database]
|
||||
Tags = List[Tag]
|
||||
TableProfiles = List[TableProfile]
|
||||
|
||||
|
||||
class TableEntities(BaseModel):
|
||||
"""
|
||||
Table entity pydantic model
|
||||
"""
|
||||
|
||||
tables: List[Table]
|
||||
total: int
|
||||
after: str = None
|
||||
|
||||
|
||||
class TopicEntities(BaseModel):
|
||||
"""
|
||||
Topic entity pydantic model
|
||||
"""
|
||||
|
||||
topics: List[Topic]
|
||||
total: int
|
||||
after: str = None
|
||||
|
||||
|
||||
class DashboardEntities(BaseModel):
|
||||
"""
|
||||
Dashboard entity pydantic model
|
||||
"""
|
||||
|
||||
dashboards: List[Dashboard]
|
||||
total: int
|
||||
after: str = None
|
||||
|
||||
|
||||
class PipelineEntities(BaseModel):
|
||||
"""
|
||||
Pipeline entity pydantic model
|
||||
"""
|
||||
|
||||
pipelines: List[Pipeline]
|
||||
total: int
|
||||
after: str = None
|
||||
|
||||
|
||||
class NoOpAuthenticationProvider(AuthenticationProvider):
|
||||
"""
|
||||
Extends AuthenticationProvider class
|
||||
|
||||
Args:
|
||||
config (MetadataServerConfig):
|
||||
|
||||
Attributes:
|
||||
config (MetadataServerConfig)
|
||||
"""
|
||||
|
||||
def __init__(self, config: OpenMetadataServerConfig):
|
||||
self.config = config
|
||||
|
||||
@classmethod
|
||||
def create(cls, config: OpenMetadataServerConfig):
|
||||
return cls(config)
|
||||
|
||||
def auth_token(self):
|
||||
pass
|
||||
|
||||
def get_access_token(self):
|
||||
return ("no_token", None)
|
||||
|
||||
|
||||
class GoogleAuthenticationProvider(AuthenticationProvider):
|
||||
"""
|
||||
Google authentication implementation
|
||||
|
||||
Args:
|
||||
config (MetadataServerConfig):
|
||||
|
||||
Attributes:
|
||||
config (MetadataServerConfig)
|
||||
"""
|
||||
|
||||
def __init__(self, config: OpenMetadataServerConfig):
|
||||
self.config = config
|
||||
self.security_config: GoogleSSOConfig = self.config.securityConfig
|
||||
|
||||
self.generated_auth_token = None
|
||||
self.expiry = None
|
||||
|
||||
@classmethod
|
||||
def create(cls, config: OpenMetadataServerConfig):
|
||||
return cls(config)
|
||||
|
||||
def auth_token(self) -> None:
|
||||
import google.auth
|
||||
import google.auth.transport.requests
|
||||
from google.oauth2 import service_account
|
||||
|
||||
credentials = service_account.IDTokenCredentials.from_service_account_file(
|
||||
self.security_config.secretKey,
|
||||
target_audience=self.security_config.audience,
|
||||
)
|
||||
request = google.auth.transport.requests.Request()
|
||||
credentials.refresh(request)
|
||||
self.generated_auth_token = credentials.token
|
||||
self.expiry = credentials.expiry
|
||||
|
||||
def get_access_token(self):
|
||||
self.auth_token()
|
||||
return self.generated_auth_token, self.expiry
|
||||
|
||||
|
||||
class OktaAuthenticationProvider(AuthenticationProvider):
|
||||
"""
|
||||
Prepare the Json Web Token for Okta auth
|
||||
"""
|
||||
|
||||
def __init__(self, config: OpenMetadataServerConfig):
|
||||
self.config = config
|
||||
self.security_config: OktaSSOConfig = self.config.securityConfig
|
||||
|
||||
self.generated_auth_token = None
|
||||
self.expiry = None
|
||||
|
||||
@classmethod
|
||||
def create(cls, config: OpenMetadataServerConfig):
|
||||
return cls(config)
|
||||
|
||||
async def auth_token(self) -> None:
|
||||
import time
|
||||
import uuid
|
||||
from urllib.parse import quote, urlencode
|
||||
|
||||
from okta.cache.okta_cache import OktaCache
|
||||
from okta.jwt import JWT, jwt
|
||||
from okta.request_executor import RequestExecutor
|
||||
|
||||
try:
|
||||
my_pem, my_jwk = JWT.get_PEM_JWK(self.security_config.privateKey)
|
||||
issued_time = int(time.time())
|
||||
expiry_time = issued_time + JWT.ONE_HOUR
|
||||
generated_jwt_id = str(uuid.uuid4())
|
||||
claims = {
|
||||
"sub": self.security_config.clientId,
|
||||
"iat": issued_time,
|
||||
"exp": expiry_time,
|
||||
"iss": self.security_config.clientId,
|
||||
"aud": self.security_config.orgURL,
|
||||
"jti": generated_jwt_id,
|
||||
}
|
||||
token = jwt.encode(claims, my_jwk.to_dict(), JWT.HASH_ALGORITHM)
|
||||
config = {
|
||||
"client": {
|
||||
"orgUrl": self.security_config.orgURL,
|
||||
"authorizationMode": "BEARER",
|
||||
"rateLimit": {},
|
||||
"privateKey": self.security_config.privateKey,
|
||||
"clientId": self.security_config.clientId,
|
||||
"token": token,
|
||||
"scopes": self.security_config.scopes,
|
||||
}
|
||||
}
|
||||
request_exec = RequestExecutor(
|
||||
config=config, cache=OktaCache(ttl=expiry_time, tti=issued_time)
|
||||
)
|
||||
parameters = {
|
||||
"grant_type": "client_credentials",
|
||||
"scope": " ".join(config["client"]["scopes"]),
|
||||
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
"client_assertion": token,
|
||||
}
|
||||
encoded_parameters = urlencode(parameters, quote_via=quote)
|
||||
url = f"{self.security_config.orgURL}?" + encoded_parameters
|
||||
token_request_object = await request_exec.create_request(
|
||||
"POST",
|
||||
url,
|
||||
None,
|
||||
{
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
},
|
||||
oauth=True,
|
||||
)
|
||||
_, res_details, res_json, err = await request_exec.fire_request(
|
||||
token_request_object[0]
|
||||
)
|
||||
if err:
|
||||
raise APIError(f"{err}")
|
||||
response_dict = json.loads(res_json)
|
||||
self.generated_auth_token = response_dict.get("access_token")
|
||||
self.expiry = response_dict.get("expires_in")
|
||||
except Exception as err:
|
||||
logger.debug(traceback.print_exc())
|
||||
logger.error(err)
|
||||
sys.exit()
|
||||
|
||||
def get_access_token(self):
|
||||
import asyncio
|
||||
|
||||
asyncio.run(self.auth_token())
|
||||
return self.generated_auth_token, self.expiry
|
||||
|
||||
|
||||
class Auth0AuthenticationProvider(AuthenticationProvider):
|
||||
"""
|
||||
OAuth authentication implementation
|
||||
Args:
|
||||
config (MetadataServerConfig):
|
||||
Attributes:
|
||||
config (MetadataServerConfig)
|
||||
"""
|
||||
|
||||
def __init__(self, config: OpenMetadataServerConfig):
|
||||
self.config = config
|
||||
self.security_config: Auth0SSOConfig = self.config.securityConfig
|
||||
|
||||
self.generated_auth_token = None
|
||||
self.expiry = None
|
||||
|
||||
@classmethod
|
||||
def create(cls, config: OpenMetadataServerConfig):
|
||||
return cls(config)
|
||||
|
||||
def auth_token(self) -> None:
|
||||
conn = http.client.HTTPSConnection(self.security_config.domain)
|
||||
payload = (
|
||||
f"grant_type=client_credentials&client_id={self.security_config.clientId}"
|
||||
f"&client_secret={self.security_config.secretKey}&audience=https://{self.security_config.domain}/api/v2/"
|
||||
)
|
||||
headers = {"content-type": "application/x-www-form-urlencoded"}
|
||||
conn.request(
|
||||
"POST", f"/{self.security_config.domain}/oauth/token", payload, headers
|
||||
)
|
||||
res = conn.getresponse()
|
||||
data = res.read()
|
||||
token = json.loads(data.decode("utf-8"))
|
||||
self.generated_auth_token = token["access_token"]
|
||||
self.expiry = token["expires_in"]
|
||||
|
||||
def get_access_token(self):
|
||||
self.auth_token()
|
||||
return self.generated_auth_token, self.expiry
|
||||
|
||||
|
||||
class AzureAuthenticationProvider(AuthenticationProvider):
|
||||
"""
|
||||
Prepare the Json Web Token for Azure auth
|
||||
"""
|
||||
|
||||
# TODO: Prepare JSON for Azure Auth
|
||||
def __init__(self, config: OpenMetadataServerConfig):
|
||||
self.config = config
|
||||
|
||||
self.generated_auth_token = None
|
||||
self.expiry = None
|
||||
|
||||
@classmethod
|
||||
def create(cls, config: OpenMetadataServerConfig):
|
||||
return cls(config)
|
||||
|
||||
def auth_token(self) -> None:
|
||||
from msal import (
|
||||
ConfidentialClientApplication, # pylint: disable=import-outside-toplevel
|
||||
)
|
||||
|
||||
app = ConfidentialClientApplication(
|
||||
client_id=self.config.client_id,
|
||||
client_credential=self.config.secret_key,
|
||||
authority=self.config.authority,
|
||||
)
|
||||
token = app.acquire_token_for_client(scopes=self.config.scopes)
|
||||
try:
|
||||
self.generated_auth_token = token["access_token"]
|
||||
self.expiry = token["expires_in"]
|
||||
|
||||
except KeyError as err:
|
||||
logger.error(f"Invalid Credentials - {err}")
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.debug(traceback.print_exc())
|
||||
sys.exit(1)
|
||||
|
||||
def get_access_token(self):
|
||||
self.auth_token()
|
||||
return self.generated_auth_token, self.expiry
|
||||
|
||||
|
||||
class CustomOIDCAuthenticationProvider(AuthenticationProvider):
|
||||
"""
|
||||
Custom OIDC authentication implementation
|
||||
|
||||
Args:
|
||||
config (MetadataServerConfig):
|
||||
|
||||
Attributes:
|
||||
config (MetadataServerConfig)
|
||||
"""
|
||||
|
||||
def __init__(self, config: OpenMetadataServerConfig) -> None:
|
||||
self.config = config
|
||||
self.security_config: CustomOidcSSOConfig = self.config.securityConfig
|
||||
|
||||
self.generated_auth_token = None
|
||||
self.expiry = None
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls, config: OpenMetadataServerConfig
|
||||
) -> "CustomOIDCAuthenticationProvider":
|
||||
return cls(config)
|
||||
|
||||
def auth_token(self) -> None:
|
||||
data = {
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": self.security_config.clientId,
|
||||
"client_secret": self.security_config.secretKey,
|
||||
}
|
||||
response = requests.post(
|
||||
url=self.security_config.tokenEndpoint,
|
||||
data=data,
|
||||
)
|
||||
if response.ok:
|
||||
response_json = response.json()
|
||||
self.generated_auth_token = response_json["access_token"]
|
||||
self.expiry = response_json["expires_in"]
|
||||
else:
|
||||
raise APIError(
|
||||
error={"message": response.text}, http_error=response.status_code
|
||||
)
|
||||
|
||||
def get_access_token(self) -> Tuple[str, int]:
|
||||
self.auth_token()
|
||||
return self.generated_auth_token, self.expiry
|
67
ingestion/src/metadata/ingestion/ometa/provider_registry.py
Normal file
67
ingestion/src/metadata/ingestion/ometa/provider_registry.py
Normal file
@ -0,0 +1,67 @@
|
||||
# Copyright 2021 Collate
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Register auth provider init functions here
|
||||
"""
|
||||
from metadata.generated.schema.metadataIngestion.workflow import (
|
||||
AuthProvider,
|
||||
OpenMetadataServerConfig,
|
||||
)
|
||||
from metadata.ingestion.ometa.auth_provider import (
|
||||
Auth0AuthenticationProvider,
|
||||
AuthenticationProvider,
|
||||
AzureAuthenticationProvider,
|
||||
CustomOIDCAuthenticationProvider,
|
||||
GoogleAuthenticationProvider,
|
||||
NoOpAuthenticationProvider,
|
||||
OktaAuthenticationProvider,
|
||||
)
|
||||
from metadata.utils.dispatch import enum_register
|
||||
|
||||
|
||||
class InvalidAuthProviderException(Exception):
|
||||
"""
|
||||
Raised when we cannot find a valid auth provider
|
||||
in the registry
|
||||
"""
|
||||
|
||||
|
||||
auth_provider_registry = enum_register()
|
||||
|
||||
|
||||
@auth_provider_registry.add(AuthProvider.no_auth.value)
|
||||
def no_auth_init(config: OpenMetadataServerConfig) -> AuthenticationProvider:
|
||||
return NoOpAuthenticationProvider.create(config)
|
||||
|
||||
|
||||
@auth_provider_registry.add(AuthProvider.google.value)
|
||||
def google_auth_init(config: OpenMetadataServerConfig) -> AuthenticationProvider:
|
||||
return GoogleAuthenticationProvider.create(config)
|
||||
|
||||
|
||||
@auth_provider_registry.add(AuthProvider.okta.value)
|
||||
def okta_auth_init(config: OpenMetadataServerConfig) -> AuthenticationProvider:
|
||||
return OktaAuthenticationProvider.create(config)
|
||||
|
||||
|
||||
@auth_provider_registry.add(AuthProvider.auth0.value)
|
||||
def auth0_auth_init(config: OpenMetadataServerConfig) -> AuthenticationProvider:
|
||||
return Auth0AuthenticationProvider.create(config)
|
||||
|
||||
|
||||
@auth_provider_registry.add("azure") # TODO: update JSON
|
||||
def azure_auth_init(config: OpenMetadataServerConfig) -> AuthenticationProvider:
|
||||
return AzureAuthenticationProvider.create(config)
|
||||
|
||||
|
||||
@auth_provider_registry.add(AuthProvider.custom_oidc.value)
|
||||
def custom_oidc_auth_init(config: OpenMetadataServerConfig) -> AuthenticationProvider:
|
||||
return CustomOIDCAuthenticationProvider.create(config)
|
33
ingestion/src/metadata/utils/dispatch.py
Normal file
33
ingestion/src/metadata/utils/dispatch.py
Normal file
@ -0,0 +1,33 @@
|
||||
# Copyright 2021 Collate
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Helper that implements custom dispatcher logic
|
||||
"""
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
def enum_register():
|
||||
"""
|
||||
Helps us register custom function for enum values
|
||||
"""
|
||||
registry = dict()
|
||||
|
||||
def add(name: str):
|
||||
def inner(fn):
|
||||
registry[name] = fn
|
||||
return fn
|
||||
|
||||
return inner
|
||||
|
||||
Register = namedtuple("Register", ["add", "registry"])
|
||||
return Register(add, registry)
|
Loading…
x
Reference in New Issue
Block a user