Fix Okta Authentication and Validation - Ingestion (#2955)

This commit is contained in:
Ayush Shah 2022-02-26 21:49:36 +05:30 committed by GitHub
parent e79be68bea
commit 412d61a875
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 154 additions and 36 deletions

View File

@ -0,0 +1,26 @@
{
"source": {
"type": "sample-data",
"config": {
"sample_data_folder": "./examples/sample_data"
}
},
"sink": {
"type": "metadata-rest",
"config": {}
},
"metadata_server": {
"type": "metadata-server",
"config": {
"api_endpoint": "http://localhost:8585/api",
"auth_provider_type": "okta",
"client_id": "{client_id}",
"org_url": "https://{Issuer URI}/v1/token",
"private_key": "{'p':'<value>','kty': 'RSA','q': '<value>','d': '<value>','e': '<value>','use': 'sig','kid': '<value / key id>','qi': '<value>','dp': '<value>','alg': 'RS256','dq': '<value>','n': '<value>'}",
"email": "email",
"scopes": [
"Authorization Server Scopes"
]
}
}
}

View File

@ -72,7 +72,6 @@ plugins: Dict[str, Set[str]] = {
"google-cloud-datacatalog==3.6.2",
},
"bigquery-usage": {"google-cloud-logging", "cachetools"},
# "docker": {"docker==5.0.3"},
"docker": {"python_on_whales==0.34.0"},
"backup": {"boto3~=1.19.12"},
"dbt": {},

View File

@ -13,6 +13,7 @@ import logging
import os
import pathlib
import sys
import traceback
from typing import List, Optional, Tuple
import click
@ -82,6 +83,7 @@ def ingest(config: str) -> None:
workflow = Workflow.create(workflow_config)
except ValidationError as e:
click.echo(e, err=True)
logger.debug(traceback.print_exc())
sys.exit(1)
workflow.execute()

View File

@ -162,6 +162,9 @@ class Workflow:
if hasattr(self, "sink"):
self.sink.write_record(processed_record)
self.report["sink"] = self.sink.get_status().as_obj()
if hasattr(self, "bulk_sink"):
self.bulk_sink.write_records()
self.report["Bulk_Sink"] = self.bulk_sink.get_status().as_obj()
def stop(self):
if hasattr(self, "processor"):
@ -169,8 +172,6 @@ class Workflow:
if hasattr(self, "stage"):
self.stage.close()
if hasattr(self, "bulk_sink"):
self.bulk_sink.write_records()
self.report["Bulk_Sink"] = self.bulk_sink.get_status().as_obj()
self.bulk_sink.close()
if hasattr(self, "sink"):
self.sink.close()

View File

@ -44,3 +44,12 @@ class AuthenticationProvider(metaclass=ABCMeta):
Returns:
str
"""
@abstractmethod
def get_access_token(self):
"""
Authentication token
Returns:
str
"""

View File

@ -13,7 +13,7 @@ Python API REST wrapper and helpers
"""
import logging
import time
from typing import List, Optional
from typing import Callable, List, Optional
import requests
from requests.exceptions import HTTPError
@ -95,7 +95,9 @@ class ClientConfig(ConfigModel):
retry: Optional[int] = 3
retry_wait: Optional[int] = 30
retry_codes: List[int] = [429, 504]
auth_token: Optional[str] = None
auth_token: Optional[Callable] = None
access_token: Optional[str] = None
expires_in: Optional[int] = None
auth_header: Optional[str] = None
raw_data: Optional[bool] = False
allow_redirects: Optional[bool] = False
@ -127,8 +129,15 @@ class REST:
version = api_version if api_version else self._api_version
url: URL = URL(base_url + "/" + version + path)
headers = {"Content-type": "application/json"}
if self._auth_token is not None and self._auth_token != "no_token":
headers[self.config.auth_header] = f"Bearer {self._auth_token}"
if (
self.config.expires_in
and time.time() >= self.config.expires_in
or not self.config.access_token
):
self.config.access_token, expiry = self._auth_token()
if not self.config.access_token == "no_token":
self.config.expires_in = time.time() + expiry - 120
headers[self.config.auth_header] = f"Bearer {self.config.access_token}"
opts = {
"headers": headers,
# Since we allow users to set endpoint URL via env var,
@ -181,6 +190,8 @@ class REST:
raise APIError(error, http_error) from http_error
else:
raise
except Exception as err:
print(err)
if resp.text != "":
return resp.json()
return None

View File

@ -156,7 +156,7 @@ class OpenMetadata(
base_url=self.config.api_endpoint,
api_version=self.config.api_version,
auth_header="Authorization",
auth_token=self._auth_provider.auth_token(),
auth_token=lambda: self._auth_provider.get_access_token(),
)
self.client = REST(client_config)
self._use_raw_data = raw_data

View File

@ -15,14 +15,10 @@ server configuration and auth.
import http.client
import json
import logging
import time
import uuid
import sys
import traceback
from typing import List
import google.auth
import google.auth.transport.requests
from google.oauth2 import service_account
from jose import jwt
from pydantic import BaseModel
from metadata.config.common import ConfigModel
@ -34,6 +30,7 @@ from metadata.generated.schema.entity.data.topic import Topic
from metadata.generated.schema.entity.services.databaseService import DatabaseService
from metadata.generated.schema.entity.tags.tagCategory import Tag
from metadata.ingestion.ometa.auth_provider import AuthenticationProvider
from metadata.ingestion.ometa.client import APIError
logger = logging.getLogger(__name__)
@ -101,6 +98,7 @@ class MetadataServerConfig(ConfigModel):
email: str = None
audience: str = "https://www.googleapis.com/oauth2/v4/token"
auth_header: str = "Authorization"
scopes: List = []
class NoOpAuthenticationProvider(AuthenticationProvider):
@ -121,8 +119,11 @@ class NoOpAuthenticationProvider(AuthenticationProvider):
def create(cls, config: MetadataServerConfig):
return cls(config)
def auth_token(self) -> str:
return "no_token"
def auth_token(self):
pass
def get_access_token(self):
return ("no_token", None)
class GoogleAuthenticationProvider(AuthenticationProvider):
@ -144,12 +145,21 @@ class GoogleAuthenticationProvider(AuthenticationProvider):
return cls(config)
def auth_token(self) -> str:
import google.auth
import google.auth.transport.requests
from google.oauth2 import service_account
credentials = service_account.IDTokenCredentials.from_service_account_file(
self.config.secret_key, target_audience=self.config.audience
)
request = google.auth.transport.requests.Request()
credentials.refresh(request)
return credentials.token
self.generated_auth_token = credentials.token
self.expiry = credentials.expiry
def get_access_token(self):
self.auth_token()
return (self.generated_auth_token, self.expiry)
class OktaAuthenticationProvider(AuthenticationProvider):
@ -164,30 +174,86 @@ class OktaAuthenticationProvider(AuthenticationProvider):
def create(cls, config: MetadataServerConfig):
return cls(config)
def auth_token(self) -> str:
from okta.jwt import JWT # pylint: disable=import-outside-toplevel
async def auth_token(self) -> str:
import time
import uuid
from urllib.parse import quote, urlencode
_, my_jwk = JWT.get_PEM_JWK(self.config.private_key)
claims = {
"sub": self.config.client_id,
"iat": time.time(),
"exp": time.time() + JWT.ONE_HOUR,
"iss": self.config.client_id,
"aud": self.config.org_url + JWT.OAUTH_ENDPOINT,
"jti": uuid.uuid4(),
"email": self.config.email,
}
token = jwt.encode(claims, my_jwk.to_dict(), JWT.HASH_ALGORITHM)
return token
from okta.cache.okta_cache import OktaCache
from okta.jwt import JWT, jwt
from okta.request_executor import RequestExecutor
try:
my_pem, my_jwk = JWT.get_PEM_JWK(self.config.private_key)
issued_time = int(time.time())
expiry_time = issued_time + JWT.ONE_HOUR
generated_jwt_id = str(uuid.uuid4())
claims = {
"sub": self.config.client_id,
"iat": issued_time,
"exp": expiry_time,
"iss": self.config.client_id,
"aud": self.config.org_url,
"jti": generated_jwt_id,
}
token = jwt.encode(claims, my_jwk.to_dict(), JWT.HASH_ALGORITHM)
config = {
"client": {
"orgUrl": self.config.org_url,
"authorizationMode": "BEARER",
"rateLimit": {},
"privateKey": self.config.private_key,
"clientId": self.config.client_id,
"token": token,
"scopes": self.config.scopes,
}
}
request_exec = RequestExecutor(
config=config, cache=OktaCache(ttl=expiry_time, tti=issued_time)
)
parameters = {
"grant_type": "client_credentials",
"scope": " ".join(config["client"]["scopes"]),
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": token,
}
encoded_parameters = urlencode(parameters, quote_via=quote)
url = f"{self.config.org_url}?" + encoded_parameters
token_request_object = await request_exec.create_request(
"POST",
url,
None,
{
"Accept": "application/json",
"Content-Type": "application/x-www-form-urlencoded",
},
oauth=True,
)
_, res_details, res_json, err = await request_exec.fire_request(
token_request_object[0]
)
if err:
raise APIError(f"{err}")
response_dict = json.loads(res_json)
self.generated_auth_token = response_dict.get("access_token")
self.expiry = response_dict.get("expires_in")
except Exception as err:
logger.debug(traceback.print_exc())
logger.error(err)
sys.exit()
def get_access_token(self):
import asyncio
asyncio.run(self.auth_token())
return (self.generated_auth_token, self.expiry)
class Auth0AuthenticationProvider(AuthenticationProvider):
"""
OAuth authentication implementation
Args:
config (MetadataServerConfig):
Attributes:
config (MetadataServerConfig)
"""
@ -203,12 +269,16 @@ class Auth0AuthenticationProvider(AuthenticationProvider):
conn = http.client.HTTPSConnection(self.config.domain)
payload = (
f"grant_type=client_credentials&client_id={self.config.client_id}"
f"&client_secret={self.config.secret_key}"
f"&audience=https://{self.config.domain}/api/v2/"
f"&client_secret={self.config.secret_key}&audience=https://{self.config.domain}/api/v2/"
)
headers = {"content-type": "application/x-www-form-urlencoded"}
conn.request("POST", f"/{self.config.domain}/oauth/token", payload, headers)
res = conn.getresponse()
data = res.read()
token = json.loads(data.decode("utf-8"))
return token["access_token"]
self.generated_auth_token = token["access_token"]
self.expiry = token["expires_in"]
def get_access_token(self):
self.auth_token()
return (self.generated_auth_token, self.expiry)