mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2026-01-08 05:26:19 +00:00
Fix Okta Authentication and Validation - Ingestion (#2955)
This commit is contained in:
parent
e79be68bea
commit
412d61a875
26
ingestion/examples/auth_examples/okta_example.json
Normal file
26
ingestion/examples/auth_examples/okta_example.json
Normal file
@ -0,0 +1,26 @@
|
||||
{
|
||||
"source": {
|
||||
"type": "sample-data",
|
||||
"config": {
|
||||
"sample_data_folder": "./examples/sample_data"
|
||||
}
|
||||
},
|
||||
"sink": {
|
||||
"type": "metadata-rest",
|
||||
"config": {}
|
||||
},
|
||||
"metadata_server": {
|
||||
"type": "metadata-server",
|
||||
"config": {
|
||||
"api_endpoint": "http://localhost:8585/api",
|
||||
"auth_provider_type": "okta",
|
||||
"client_id": "{client_id}",
|
||||
"org_url": "https://{Issuer URI}/v1/token",
|
||||
"private_key": "{'p':'<value>','kty': 'RSA','q': '<value>','d': '<value>','e': '<value>','use': 'sig','kid': '<value / key id>','qi': '<value>','dp': '<value>','alg': 'RS256','dq': '<value>','n': '<value>'}",
|
||||
"email": "email",
|
||||
"scopes": [
|
||||
"Authorization Server Scopes"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -72,7 +72,6 @@ plugins: Dict[str, Set[str]] = {
|
||||
"google-cloud-datacatalog==3.6.2",
|
||||
},
|
||||
"bigquery-usage": {"google-cloud-logging", "cachetools"},
|
||||
# "docker": {"docker==5.0.3"},
|
||||
"docker": {"python_on_whales==0.34.0"},
|
||||
"backup": {"boto3~=1.19.12"},
|
||||
"dbt": {},
|
||||
|
||||
@ -13,6 +13,7 @@ import logging
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
import traceback
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import click
|
||||
@ -82,6 +83,7 @@ def ingest(config: str) -> None:
|
||||
workflow = Workflow.create(workflow_config)
|
||||
except ValidationError as e:
|
||||
click.echo(e, err=True)
|
||||
logger.debug(traceback.print_exc())
|
||||
sys.exit(1)
|
||||
|
||||
workflow.execute()
|
||||
|
||||
@ -162,6 +162,9 @@ class Workflow:
|
||||
if hasattr(self, "sink"):
|
||||
self.sink.write_record(processed_record)
|
||||
self.report["sink"] = self.sink.get_status().as_obj()
|
||||
if hasattr(self, "bulk_sink"):
|
||||
self.bulk_sink.write_records()
|
||||
self.report["Bulk_Sink"] = self.bulk_sink.get_status().as_obj()
|
||||
|
||||
def stop(self):
|
||||
if hasattr(self, "processor"):
|
||||
@ -169,8 +172,6 @@ class Workflow:
|
||||
if hasattr(self, "stage"):
|
||||
self.stage.close()
|
||||
if hasattr(self, "bulk_sink"):
|
||||
self.bulk_sink.write_records()
|
||||
self.report["Bulk_Sink"] = self.bulk_sink.get_status().as_obj()
|
||||
self.bulk_sink.close()
|
||||
if hasattr(self, "sink"):
|
||||
self.sink.close()
|
||||
|
||||
@ -44,3 +44,12 @@ class AuthenticationProvider(metaclass=ABCMeta):
|
||||
Returns:
|
||||
str
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_access_token(self):
|
||||
"""
|
||||
Authentication token
|
||||
|
||||
Returns:
|
||||
str
|
||||
"""
|
||||
|
||||
@ -13,7 +13,7 @@ Python API REST wrapper and helpers
|
||||
"""
|
||||
import logging
|
||||
import time
|
||||
from typing import List, Optional
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import requests
|
||||
from requests.exceptions import HTTPError
|
||||
@ -95,7 +95,9 @@ class ClientConfig(ConfigModel):
|
||||
retry: Optional[int] = 3
|
||||
retry_wait: Optional[int] = 30
|
||||
retry_codes: List[int] = [429, 504]
|
||||
auth_token: Optional[str] = None
|
||||
auth_token: Optional[Callable] = None
|
||||
access_token: Optional[str] = None
|
||||
expires_in: Optional[int] = None
|
||||
auth_header: Optional[str] = None
|
||||
raw_data: Optional[bool] = False
|
||||
allow_redirects: Optional[bool] = False
|
||||
@ -127,8 +129,15 @@ class REST:
|
||||
version = api_version if api_version else self._api_version
|
||||
url: URL = URL(base_url + "/" + version + path)
|
||||
headers = {"Content-type": "application/json"}
|
||||
if self._auth_token is not None and self._auth_token != "no_token":
|
||||
headers[self.config.auth_header] = f"Bearer {self._auth_token}"
|
||||
if (
|
||||
self.config.expires_in
|
||||
and time.time() >= self.config.expires_in
|
||||
or not self.config.access_token
|
||||
):
|
||||
self.config.access_token, expiry = self._auth_token()
|
||||
if not self.config.access_token == "no_token":
|
||||
self.config.expires_in = time.time() + expiry - 120
|
||||
headers[self.config.auth_header] = f"Bearer {self.config.access_token}"
|
||||
opts = {
|
||||
"headers": headers,
|
||||
# Since we allow users to set endpoint URL via env var,
|
||||
@ -181,6 +190,8 @@ class REST:
|
||||
raise APIError(error, http_error) from http_error
|
||||
else:
|
||||
raise
|
||||
except Exception as err:
|
||||
print(err)
|
||||
if resp.text != "":
|
||||
return resp.json()
|
||||
return None
|
||||
|
||||
@ -156,7 +156,7 @@ class OpenMetadata(
|
||||
base_url=self.config.api_endpoint,
|
||||
api_version=self.config.api_version,
|
||||
auth_header="Authorization",
|
||||
auth_token=self._auth_provider.auth_token(),
|
||||
auth_token=lambda: self._auth_provider.get_access_token(),
|
||||
)
|
||||
self.client = REST(client_config)
|
||||
self._use_raw_data = raw_data
|
||||
|
||||
@ -15,14 +15,10 @@ server configuration and auth.
|
||||
import http.client
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
import sys
|
||||
import traceback
|
||||
from typing import List
|
||||
|
||||
import google.auth
|
||||
import google.auth.transport.requests
|
||||
from google.oauth2 import service_account
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metadata.config.common import ConfigModel
|
||||
@ -34,6 +30,7 @@ from metadata.generated.schema.entity.data.topic import Topic
|
||||
from metadata.generated.schema.entity.services.databaseService import DatabaseService
|
||||
from metadata.generated.schema.entity.tags.tagCategory import Tag
|
||||
from metadata.ingestion.ometa.auth_provider import AuthenticationProvider
|
||||
from metadata.ingestion.ometa.client import APIError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -101,6 +98,7 @@ class MetadataServerConfig(ConfigModel):
|
||||
email: str = None
|
||||
audience: str = "https://www.googleapis.com/oauth2/v4/token"
|
||||
auth_header: str = "Authorization"
|
||||
scopes: List = []
|
||||
|
||||
|
||||
class NoOpAuthenticationProvider(AuthenticationProvider):
|
||||
@ -121,8 +119,11 @@ class NoOpAuthenticationProvider(AuthenticationProvider):
|
||||
def create(cls, config: MetadataServerConfig):
|
||||
return cls(config)
|
||||
|
||||
def auth_token(self) -> str:
|
||||
return "no_token"
|
||||
def auth_token(self):
|
||||
pass
|
||||
|
||||
def get_access_token(self):
|
||||
return ("no_token", None)
|
||||
|
||||
|
||||
class GoogleAuthenticationProvider(AuthenticationProvider):
|
||||
@ -144,12 +145,21 @@ class GoogleAuthenticationProvider(AuthenticationProvider):
|
||||
return cls(config)
|
||||
|
||||
def auth_token(self) -> str:
|
||||
import google.auth
|
||||
import google.auth.transport.requests
|
||||
from google.oauth2 import service_account
|
||||
|
||||
credentials = service_account.IDTokenCredentials.from_service_account_file(
|
||||
self.config.secret_key, target_audience=self.config.audience
|
||||
)
|
||||
request = google.auth.transport.requests.Request()
|
||||
credentials.refresh(request)
|
||||
return credentials.token
|
||||
self.generated_auth_token = credentials.token
|
||||
self.expiry = credentials.expiry
|
||||
|
||||
def get_access_token(self):
|
||||
self.auth_token()
|
||||
return (self.generated_auth_token, self.expiry)
|
||||
|
||||
|
||||
class OktaAuthenticationProvider(AuthenticationProvider):
|
||||
@ -164,30 +174,86 @@ class OktaAuthenticationProvider(AuthenticationProvider):
|
||||
def create(cls, config: MetadataServerConfig):
|
||||
return cls(config)
|
||||
|
||||
def auth_token(self) -> str:
|
||||
from okta.jwt import JWT # pylint: disable=import-outside-toplevel
|
||||
async def auth_token(self) -> str:
|
||||
import time
|
||||
import uuid
|
||||
from urllib.parse import quote, urlencode
|
||||
|
||||
_, my_jwk = JWT.get_PEM_JWK(self.config.private_key)
|
||||
claims = {
|
||||
"sub": self.config.client_id,
|
||||
"iat": time.time(),
|
||||
"exp": time.time() + JWT.ONE_HOUR,
|
||||
"iss": self.config.client_id,
|
||||
"aud": self.config.org_url + JWT.OAUTH_ENDPOINT,
|
||||
"jti": uuid.uuid4(),
|
||||
"email": self.config.email,
|
||||
}
|
||||
token = jwt.encode(claims, my_jwk.to_dict(), JWT.HASH_ALGORITHM)
|
||||
return token
|
||||
from okta.cache.okta_cache import OktaCache
|
||||
from okta.jwt import JWT, jwt
|
||||
from okta.request_executor import RequestExecutor
|
||||
|
||||
try:
|
||||
my_pem, my_jwk = JWT.get_PEM_JWK(self.config.private_key)
|
||||
issued_time = int(time.time())
|
||||
expiry_time = issued_time + JWT.ONE_HOUR
|
||||
generated_jwt_id = str(uuid.uuid4())
|
||||
claims = {
|
||||
"sub": self.config.client_id,
|
||||
"iat": issued_time,
|
||||
"exp": expiry_time,
|
||||
"iss": self.config.client_id,
|
||||
"aud": self.config.org_url,
|
||||
"jti": generated_jwt_id,
|
||||
}
|
||||
token = jwt.encode(claims, my_jwk.to_dict(), JWT.HASH_ALGORITHM)
|
||||
config = {
|
||||
"client": {
|
||||
"orgUrl": self.config.org_url,
|
||||
"authorizationMode": "BEARER",
|
||||
"rateLimit": {},
|
||||
"privateKey": self.config.private_key,
|
||||
"clientId": self.config.client_id,
|
||||
"token": token,
|
||||
"scopes": self.config.scopes,
|
||||
}
|
||||
}
|
||||
request_exec = RequestExecutor(
|
||||
config=config, cache=OktaCache(ttl=expiry_time, tti=issued_time)
|
||||
)
|
||||
parameters = {
|
||||
"grant_type": "client_credentials",
|
||||
"scope": " ".join(config["client"]["scopes"]),
|
||||
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
"client_assertion": token,
|
||||
}
|
||||
encoded_parameters = urlencode(parameters, quote_via=quote)
|
||||
url = f"{self.config.org_url}?" + encoded_parameters
|
||||
token_request_object = await request_exec.create_request(
|
||||
"POST",
|
||||
url,
|
||||
None,
|
||||
{
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
},
|
||||
oauth=True,
|
||||
)
|
||||
_, res_details, res_json, err = await request_exec.fire_request(
|
||||
token_request_object[0]
|
||||
)
|
||||
if err:
|
||||
raise APIError(f"{err}")
|
||||
response_dict = json.loads(res_json)
|
||||
self.generated_auth_token = response_dict.get("access_token")
|
||||
self.expiry = response_dict.get("expires_in")
|
||||
except Exception as err:
|
||||
logger.debug(traceback.print_exc())
|
||||
logger.error(err)
|
||||
sys.exit()
|
||||
|
||||
def get_access_token(self):
|
||||
import asyncio
|
||||
|
||||
asyncio.run(self.auth_token())
|
||||
return (self.generated_auth_token, self.expiry)
|
||||
|
||||
|
||||
class Auth0AuthenticationProvider(AuthenticationProvider):
|
||||
"""
|
||||
OAuth authentication implementation
|
||||
|
||||
Args:
|
||||
config (MetadataServerConfig):
|
||||
|
||||
Attributes:
|
||||
config (MetadataServerConfig)
|
||||
"""
|
||||
@ -203,12 +269,16 @@ class Auth0AuthenticationProvider(AuthenticationProvider):
|
||||
conn = http.client.HTTPSConnection(self.config.domain)
|
||||
payload = (
|
||||
f"grant_type=client_credentials&client_id={self.config.client_id}"
|
||||
f"&client_secret={self.config.secret_key}"
|
||||
f"&audience=https://{self.config.domain}/api/v2/"
|
||||
f"&client_secret={self.config.secret_key}&audience=https://{self.config.domain}/api/v2/"
|
||||
)
|
||||
headers = {"content-type": "application/x-www-form-urlencoded"}
|
||||
conn.request("POST", f"/{self.config.domain}/oauth/token", payload, headers)
|
||||
res = conn.getresponse()
|
||||
data = res.read()
|
||||
token = json.loads(data.decode("utf-8"))
|
||||
return token["access_token"]
|
||||
self.generated_auth_token = token["access_token"]
|
||||
self.expiry = token["expires_in"]
|
||||
|
||||
def get_access_token(self):
|
||||
self.auth_token()
|
||||
return (self.generated_auth_token, self.expiry)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user