diff --git a/ingestion/examples/auth_examples/okta_example.json b/ingestion/examples/auth_examples/okta_example.json new file mode 100644 index 00000000000..f9369068fa4 --- /dev/null +++ b/ingestion/examples/auth_examples/okta_example.json @@ -0,0 +1,26 @@ +{ + "source": { + "type": "sample-data", + "config": { + "sample_data_folder": "./examples/sample_data" + } + }, + "sink": { + "type": "metadata-rest", + "config": {} + }, + "metadata_server": { + "type": "metadata-server", + "config": { + "api_endpoint": "http://localhost:8585/api", + "auth_provider_type": "okta", + "client_id": "{client_id}", + "org_url": "https://{Issuer URI}/v1/token", + "private_key": "{'p':'','kty': 'RSA','q': '','d': '','e': '','use': 'sig','kid': '','qi': '','dp': '','alg': 'RS256','dq': '','n': ''}", + "email": "email", + "scopes": [ + "Authorization Server Scopes" + ] + } + } +} diff --git a/ingestion/setup.py b/ingestion/setup.py index 07bdb67e96d..1286dd8fde0 100644 --- a/ingestion/setup.py +++ b/ingestion/setup.py @@ -72,7 +72,6 @@ plugins: Dict[str, Set[str]] = { "google-cloud-datacatalog==3.6.2", }, "bigquery-usage": {"google-cloud-logging", "cachetools"}, - # "docker": {"docker==5.0.3"}, "docker": {"python_on_whales==0.34.0"}, "backup": {"boto3~=1.19.12"}, "dbt": {}, diff --git a/ingestion/src/metadata/cmd.py b/ingestion/src/metadata/cmd.py index 4e0d389eed9..d8d494059cd 100644 --- a/ingestion/src/metadata/cmd.py +++ b/ingestion/src/metadata/cmd.py @@ -13,6 +13,7 @@ import logging import os import pathlib import sys +import traceback from typing import List, Optional, Tuple import click @@ -82,6 +83,7 @@ def ingest(config: str) -> None: workflow = Workflow.create(workflow_config) except ValidationError as e: click.echo(e, err=True) + logger.debug(traceback.print_exc()) sys.exit(1) workflow.execute() diff --git a/ingestion/src/metadata/ingestion/api/workflow.py b/ingestion/src/metadata/ingestion/api/workflow.py index 47fbb9e6806..8caf7fde797 100644 --- a/ingestion/src/metadata/ingestion/api/workflow.py +++ b/ingestion/src/metadata/ingestion/api/workflow.py @@ -162,6 +162,9 @@ class Workflow: if hasattr(self, "sink"): self.sink.write_record(processed_record) self.report["sink"] = self.sink.get_status().as_obj() + if hasattr(self, "bulk_sink"): + self.bulk_sink.write_records() + self.report["Bulk_Sink"] = self.bulk_sink.get_status().as_obj() def stop(self): if hasattr(self, "processor"): @@ -169,8 +172,6 @@ class Workflow: if hasattr(self, "stage"): self.stage.close() if hasattr(self, "bulk_sink"): - self.bulk_sink.write_records() - self.report["Bulk_Sink"] = self.bulk_sink.get_status().as_obj() self.bulk_sink.close() if hasattr(self, "sink"): self.sink.close() diff --git a/ingestion/src/metadata/ingestion/ometa/auth_provider.py b/ingestion/src/metadata/ingestion/ometa/auth_provider.py index e804c4a1c8c..80c442ff184 100644 --- a/ingestion/src/metadata/ingestion/ometa/auth_provider.py +++ b/ingestion/src/metadata/ingestion/ometa/auth_provider.py @@ -44,3 +44,12 @@ class AuthenticationProvider(metaclass=ABCMeta): Returns: str """ + + @abstractmethod + def get_access_token(self): + """ + Authentication token + + Returns: + str + """ diff --git a/ingestion/src/metadata/ingestion/ometa/client.py b/ingestion/src/metadata/ingestion/ometa/client.py index 7d66a8f6e5c..ea15e5b1e3c 100644 --- a/ingestion/src/metadata/ingestion/ometa/client.py +++ b/ingestion/src/metadata/ingestion/ometa/client.py @@ -13,7 +13,7 @@ Python API REST wrapper and helpers """ import logging import time -from typing import List, Optional +from typing import Callable, List, Optional import requests from requests.exceptions import HTTPError @@ -95,7 +95,9 @@ class ClientConfig(ConfigModel): retry: Optional[int] = 3 retry_wait: Optional[int] = 30 retry_codes: List[int] = [429, 504] - auth_token: Optional[str] = None + auth_token: Optional[Callable] = None + access_token: Optional[str] = None + expires_in: Optional[int] = None auth_header: Optional[str] = None raw_data: Optional[bool] = False allow_redirects: Optional[bool] = False @@ -127,8 +129,15 @@ class REST: version = api_version if api_version else self._api_version url: URL = URL(base_url + "/" + version + path) headers = {"Content-type": "application/json"} - if self._auth_token is not None and self._auth_token != "no_token": - headers[self.config.auth_header] = f"Bearer {self._auth_token}" + if ( + self.config.expires_in + and time.time() >= self.config.expires_in + or not self.config.access_token + ): + self.config.access_token, expiry = self._auth_token() + if not self.config.access_token == "no_token": + self.config.expires_in = time.time() + expiry - 120 + headers[self.config.auth_header] = f"Bearer {self.config.access_token}" opts = { "headers": headers, # Since we allow users to set endpoint URL via env var, @@ -181,6 +190,8 @@ class REST: raise APIError(error, http_error) from http_error else: raise + except Exception as err: + print(err) if resp.text != "": return resp.json() return None diff --git a/ingestion/src/metadata/ingestion/ometa/ometa_api.py b/ingestion/src/metadata/ingestion/ometa/ometa_api.py index 2d296ffb845..0b73fe27cbd 100644 --- a/ingestion/src/metadata/ingestion/ometa/ometa_api.py +++ b/ingestion/src/metadata/ingestion/ometa/ometa_api.py @@ -156,7 +156,7 @@ class OpenMetadata( base_url=self.config.api_endpoint, api_version=self.config.api_version, auth_header="Authorization", - auth_token=self._auth_provider.auth_token(), + auth_token=lambda: self._auth_provider.get_access_token(), ) self.client = REST(client_config) self._use_raw_data = raw_data diff --git a/ingestion/src/metadata/ingestion/ometa/openmetadata_rest.py b/ingestion/src/metadata/ingestion/ometa/openmetadata_rest.py index fe1321cac4f..4473228ae3c 100644 --- a/ingestion/src/metadata/ingestion/ometa/openmetadata_rest.py +++ b/ingestion/src/metadata/ingestion/ometa/openmetadata_rest.py @@ -15,14 +15,10 @@ server configuration and auth. import http.client import json import logging -import time -import uuid +import sys +import traceback from typing import List -import google.auth -import google.auth.transport.requests -from google.oauth2 import service_account -from jose import jwt from pydantic import BaseModel from metadata.config.common import ConfigModel @@ -34,6 +30,7 @@ from metadata.generated.schema.entity.data.topic import Topic from metadata.generated.schema.entity.services.databaseService import DatabaseService from metadata.generated.schema.entity.tags.tagCategory import Tag from metadata.ingestion.ometa.auth_provider import AuthenticationProvider +from metadata.ingestion.ometa.client import APIError logger = logging.getLogger(__name__) @@ -101,6 +98,7 @@ class MetadataServerConfig(ConfigModel): email: str = None audience: str = "https://www.googleapis.com/oauth2/v4/token" auth_header: str = "Authorization" + scopes: List = [] class NoOpAuthenticationProvider(AuthenticationProvider): @@ -121,8 +119,11 @@ class NoOpAuthenticationProvider(AuthenticationProvider): def create(cls, config: MetadataServerConfig): return cls(config) - def auth_token(self) -> str: - return "no_token" + def auth_token(self): + pass + + def get_access_token(self): + return ("no_token", None) class GoogleAuthenticationProvider(AuthenticationProvider): @@ -144,12 +145,21 @@ class GoogleAuthenticationProvider(AuthenticationProvider): return cls(config) def auth_token(self) -> str: + import google.auth + import google.auth.transport.requests + from google.oauth2 import service_account + credentials = service_account.IDTokenCredentials.from_service_account_file( self.config.secret_key, target_audience=self.config.audience ) request = google.auth.transport.requests.Request() credentials.refresh(request) - return credentials.token + self.generated_auth_token = credentials.token + self.expiry = credentials.expiry + + def get_access_token(self): + self.auth_token() + return (self.generated_auth_token, self.expiry) class OktaAuthenticationProvider(AuthenticationProvider): @@ -164,30 +174,86 @@ class OktaAuthenticationProvider(AuthenticationProvider): def create(cls, config: MetadataServerConfig): return cls(config) - def auth_token(self) -> str: - from okta.jwt import JWT # pylint: disable=import-outside-toplevel + async def auth_token(self) -> str: + import time + import uuid + from urllib.parse import quote, urlencode - _, my_jwk = JWT.get_PEM_JWK(self.config.private_key) - claims = { - "sub": self.config.client_id, - "iat": time.time(), - "exp": time.time() + JWT.ONE_HOUR, - "iss": self.config.client_id, - "aud": self.config.org_url + JWT.OAUTH_ENDPOINT, - "jti": uuid.uuid4(), - "email": self.config.email, - } - token = jwt.encode(claims, my_jwk.to_dict(), JWT.HASH_ALGORITHM) - return token + from okta.cache.okta_cache import OktaCache + from okta.jwt import JWT, jwt + from okta.request_executor import RequestExecutor + + try: + my_pem, my_jwk = JWT.get_PEM_JWK(self.config.private_key) + issued_time = int(time.time()) + expiry_time = issued_time + JWT.ONE_HOUR + generated_jwt_id = str(uuid.uuid4()) + claims = { + "sub": self.config.client_id, + "iat": issued_time, + "exp": expiry_time, + "iss": self.config.client_id, + "aud": self.config.org_url, + "jti": generated_jwt_id, + } + token = jwt.encode(claims, my_jwk.to_dict(), JWT.HASH_ALGORITHM) + config = { + "client": { + "orgUrl": self.config.org_url, + "authorizationMode": "BEARER", + "rateLimit": {}, + "privateKey": self.config.private_key, + "clientId": self.config.client_id, + "token": token, + "scopes": self.config.scopes, + } + } + request_exec = RequestExecutor( + config=config, cache=OktaCache(ttl=expiry_time, tti=issued_time) + ) + parameters = { + "grant_type": "client_credentials", + "scope": " ".join(config["client"]["scopes"]), + "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + "client_assertion": token, + } + encoded_parameters = urlencode(parameters, quote_via=quote) + url = f"{self.config.org_url}?" + encoded_parameters + token_request_object = await request_exec.create_request( + "POST", + url, + None, + { + "Accept": "application/json", + "Content-Type": "application/x-www-form-urlencoded", + }, + oauth=True, + ) + _, res_details, res_json, err = await request_exec.fire_request( + token_request_object[0] + ) + if err: + raise APIError(f"{err}") + response_dict = json.loads(res_json) + self.generated_auth_token = response_dict.get("access_token") + self.expiry = response_dict.get("expires_in") + except Exception as err: + logger.debug(traceback.print_exc()) + logger.error(err) + sys.exit() + + def get_access_token(self): + import asyncio + + asyncio.run(self.auth_token()) + return (self.generated_auth_token, self.expiry) class Auth0AuthenticationProvider(AuthenticationProvider): """ OAuth authentication implementation - Args: config (MetadataServerConfig): - Attributes: config (MetadataServerConfig) """ @@ -203,12 +269,16 @@ class Auth0AuthenticationProvider(AuthenticationProvider): conn = http.client.HTTPSConnection(self.config.domain) payload = ( f"grant_type=client_credentials&client_id={self.config.client_id}" - f"&client_secret={self.config.secret_key}" - f"&audience=https://{self.config.domain}/api/v2/" + f"&client_secret={self.config.secret_key}&audience=https://{self.config.domain}/api/v2/" ) headers = {"content-type": "application/x-www-form-urlencoded"} conn.request("POST", f"/{self.config.domain}/oauth/token", payload, headers) res = conn.getresponse() data = res.read() token = json.loads(data.decode("utf-8")) - return token["access_token"] + self.generated_auth_token = token["access_token"] + self.expiry = token["expires_in"] + + def get_access_token(self): + self.auth_token() + return (self.generated_auth_token, self.expiry)