fix(azure_ad): fix infinite loop on request error (#10679)

This commit is contained in:
Davi Arnaut 2024-06-11 13:41:36 -07:00 committed by GitHub
parent b9e71a61b1
commit 52ac3143a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -9,6 +9,7 @@ from typing import Any, Dict, Generator, Iterable, List, Optional
import click import click
import requests import requests
from pydantic.fields import Field from pydantic.fields import Field
from requests.adapters import HTTPAdapter, Retry
from datahub.configuration.common import AllowDenyPattern from datahub.configuration.common import AllowDenyPattern
from datahub.configuration.source_common import DatasetSourceConfigMixin from datahub.configuration.source_common import DatasetSourceConfigMixin
@ -268,6 +269,14 @@ class AzureADSource(StatefulIngestionSourceBase):
self.report = AzureADSourceReport( self.report = AzureADSourceReport(
filtered_tracking=self.config.filtered_tracking filtered_tracking=self.config.filtered_tracking
) )
session = requests.Session()
retries = Retry(
total=5, backoff_factor=1, status_forcelist=[429, 500, 502, 503, 504]
)
adapter = HTTPAdapter(max_retries=retries)
session.mount("http://", adapter)
session.mount("https://", adapter)
self.session = session
self.token_data = { self.token_data = {
"grant_type": "client_credentials", "grant_type": "client_credentials",
"client_id": self.config.client_id, "client_id": self.config.client_id,
@ -494,7 +503,7 @@ class AzureADSource(StatefulIngestionSourceBase):
while True: while True:
if not url: if not url:
break break
response = requests.get(url, headers=headers) response = self.session.get(url, headers=headers)
if response.status_code == 200: if response.status_code == 200:
json_data = json.loads(response.text) json_data = json.loads(response.text)
try: try:
@ -512,7 +521,7 @@ class AzureADSource(StatefulIngestionSourceBase):
logger.debug(f"URL = {url}") logger.debug(f"URL = {url}")
logger.error(error_str) logger.error(error_str)
self.report.report_failure("_get_azure_ad_data_", error_str) self.report.report_failure("_get_azure_ad_data_", error_str)
continue raise Exception(f"Unable to get {url}, error {response.status_code}")
def _map_identity_to_urn(self, func, id_to_extract, mapping_identifier, id_type): def _map_identity_to_urn(self, func, id_to_extract, mapping_identifier, id_type):
result, error_str = None, None result, error_str = None, None