diff --git a/metadata-ingestion/src/datahub/ingestion/source/openapi.py b/metadata-ingestion/src/datahub/ingestion/source/openapi.py index 1b3a6dc4be..54affafdcc 100755 --- a/metadata-ingestion/src/datahub/ingestion/source/openapi.py +++ b/metadata-ingestion/src/datahub/ingestion/source/openapi.py @@ -4,9 +4,10 @@ import warnings from abc import ABC from typing import Dict, Iterable, Optional, Tuple +from pydantic import validator from pydantic.fields import Field -from datahub.configuration.common import ConfigModel +from datahub.configuration.common import ConfigModel, ConfigurationError from datahub.emitter.mce_builder import make_tag_urn from datahub.ingestion.api.common import PipelineContext from datahub.ingestion.api.decorators import ( @@ -74,14 +75,33 @@ class OpenApiConfig(ConfigModel): token: Optional[str] = Field( default=None, description="Token for endpoint authentication." ) + bearer_token: Optional[str] = Field( + default=None, description="Bearer token for endpoint authentication." + ) get_token: dict = Field( default={}, description="Retrieving a token from the endpoint." ) + @validator("bearer_token", always=True) + def ensure_only_one_token( + cls, bearer_token: Optional[str], values: Dict + ) -> Optional[str]: + if bearer_token is not None and values.get("token") is not None: + raise ConfigurationError( + "Unable to use 'token' and 'bearer_token' together." + ) + return bearer_token + def get_swagger(self) -> Dict: - if self.get_token or self.token is not None: - if self.token is not None: - ... + if self.get_token or self.token or self.bearer_token is not None: + if self.token: + pass + elif self.bearer_token: + # TRICKY: To avoid passing a bunch of different token types around, we set the + # token's value to the properly formatted bearer token. + # TODO: We should just create a requests.Session and set all the auth + # details there once, and then use that session for all requests. + self.token = f"Bearer {self.bearer_token}" else: assert ( "url_complement" in self.get_token.keys() @@ -283,10 +303,11 @@ class APISource(Source, ABC): "{" not in endpoint_k ): # if the API does not explicitly require parameters tot_url = clean_url(config.url + self.url_basepath + endpoint_k) - if config.token: response = request_call( - tot_url, token=config.token, proxies=config.proxies + tot_url, + token=config.token, + proxies=config.proxies, ) else: response = request_call( @@ -314,7 +335,9 @@ class APISource(Source, ABC): tot_url = clean_url(config.url + self.url_basepath + url_guess) if config.token: response = request_call( - tot_url, token=config.token, proxies=config.proxies + tot_url, + token=config.token, + proxies=config.proxies, ) else: response = request_call( @@ -342,7 +365,9 @@ class APISource(Source, ABC): tot_url = clean_url(config.url + self.url_basepath + composed_url) if config.token: response = request_call( - tot_url, token=config.token, proxies=config.proxies + tot_url, + token=config.token, + proxies=config.proxies, ) else: response = request_call( diff --git a/metadata-ingestion/src/datahub/ingestion/source/openapi_parser.py b/metadata-ingestion/src/datahub/ingestion/source/openapi_parser.py index c1caca18fe..5bacafaa3f 100755 --- a/metadata-ingestion/src/datahub/ingestion/source/openapi_parser.py +++ b/metadata-ingestion/src/datahub/ingestion/source/openapi_parser.py @@ -54,12 +54,10 @@ def request_call( proxies: Optional[dict] = None, ) -> requests.Response: headers = {"accept": "application/json"} - if username is not None and password is not None: return requests.get( url, headers=headers, auth=HTTPBasicAuth(username, password) ) - elif token is not None: headers["Authorization"] = f"{token}" return requests.get(url, proxies=proxies, headers=headers) @@ -76,12 +74,9 @@ def get_swag_json( proxies: Optional[dict] = None, ) -> Dict: tot_url = url + swagger_file - if token is not None: - response = request_call(url=tot_url, token=token, proxies=proxies) - else: - response = request_call( - url=tot_url, username=username, password=password, proxies=proxies - ) + response = request_call( + url=tot_url, token=token, username=username, password=password, proxies=proxies + ) if response.status_code != 200: raise Exception(f"Unable to retrieve {tot_url}, error {response.status_code}")