diff --git a/ingestion/src/metadata/ingestion/source/database/databricks/client.py b/ingestion/src/metadata/ingestion/source/database/databricks/client.py index 86a44271705..addbc926f5e 100644 --- a/ingestion/src/metadata/ingestion/source/database/databricks/client.py +++ b/ingestion/src/metadata/ingestion/source/database/databricks/client.py @@ -14,7 +14,7 @@ Client to interact with databricks apis import json import traceback from datetime import timedelta -from typing import List +from typing import Iterable, List import requests @@ -31,6 +31,12 @@ API_TIMEOUT = 10 QUERIES_PATH = "/sql/history/queries" +class DatabricksClientException(Exception): + """ + Class to throw auth and other databricks api exceptions. + """ + + class DatabricksClient: """ DatabricksClient creates a Databricks connection based on DatabricksCredentials. @@ -60,14 +66,33 @@ class DatabricksClient: if res.status_code != 200: raise APIError(res.json) + def _run_query_paginator(self, data, result, end_time, response): + while True: + if response: + next_page_token = response.get("next_page_token", None) + has_next_page = response.get("has_next_page", None) + if next_page_token: + data["page_token"] = next_page_token + if not has_next_page: + data = {} + break + else: + break + + if result[-1]["execution_end_time_ms"] <= end_time: + response = self.client.get( + self.base_query_url, + data=json.dumps(data), + headers=self.headers, + timeout=API_TIMEOUT, + ).json() + yield from response.get("res") or [] + def list_query_history(self, start_date=None, end_date=None) -> List[dict]: """ Method returns List the history of queries through SQL warehouses """ - query_details = [] try: - next_page_token = None - has_next_page = None data = {} daydiff = end_date - start_date @@ -98,36 +123,15 @@ class DatabricksClient: result = response.get("res") or [] data = {} - while True: - if result: - query_details.extend(result) - - next_page_token = response.get("next_page_token", None) - has_next_page = response.get("has_next_page", None) - if next_page_token: - data["page_token"] = next_page_token - - if not has_next_page: - data = {} - break - else: - break - - if result[-1]["execution_end_time_ms"] <= end_time: - response = self.client.get( - self.base_query_url, - data=json.dumps(data), - headers=self.headers, - timeout=API_TIMEOUT, - ).json() - result = response.get("res") + yield from result + yield from self._run_query_paginator( + data=data, result=result, end_time=end_time, response=response + ) or [] except Exception as exc: logger.debug(traceback.format_exc()) logger.error(exc) - return query_details - def is_query_valid(self, row) -> bool: query_text = row.get("query_text") return not ( @@ -137,18 +141,19 @@ class DatabricksClient: def list_jobs_test_connection(self) -> None: data = {"limit": 1, "expand_tasks": True, "offset": 0} - self.client.get( + response = self.client.get( self.jobs_list_url, data=json.dumps(data), headers=self.headers, timeout=API_TIMEOUT, - ).json() + ) + if response.status_code != 200: + raise DatabricksClientException(response.text) - def list_jobs(self) -> List[dict]: + def list_jobs(self) -> Iterable[dict]: """ Method returns List all the created jobs in a Databricks Workspace """ - job_list = [] try: data = {"limit": 25, "expand_tasks": True, "offset": 0} @@ -159,9 +164,9 @@ class DatabricksClient: timeout=API_TIMEOUT, ).json() - job_list.extend(response.get("jobs") or []) + yield from response.get("jobs") or [] - while response["has_more"]: + while response and response.get("has_more"): data["offset"] = len(response.get("jobs") or []) response = self.client.get( @@ -171,19 +176,16 @@ class DatabricksClient: timeout=API_TIMEOUT, ).json() - job_list.extend(response.get("jobs") or []) + yield from response.get("jobs") or [] except Exception as exc: logger.debug(traceback.format_exc()) logger.error(exc) - return job_list - def get_job_runs(self, job_id) -> List[dict]: """ Method returns List of all runs for a job by the specified job_id """ - job_runs = [] try: params = { "job_id": job_id, @@ -200,7 +202,7 @@ class DatabricksClient: timeout=API_TIMEOUT, ).json() - job_runs.extend(response.get("runs") or []) + yield from response.get("runs") or [] while response["has_more"]: params.update({"start_time_to": response["runs"][-1]["start_time"]}) @@ -212,10 +214,8 @@ class DatabricksClient: timeout=API_TIMEOUT, ).json() - job_runs.extend(response.get("runs" or [])) + yield from response.get("runs") or [] except Exception as exc: logger.debug(traceback.format_exc()) logger.error(exc) - - return job_runs diff --git a/ingestion/src/metadata/ingestion/source/database/databricks/lineage.py b/ingestion/src/metadata/ingestion/source/database/databricks/lineage.py index 9c948213e42..7795f6a262c 100644 --- a/ingestion/src/metadata/ingestion/source/database/databricks/lineage.py +++ b/ingestion/src/metadata/ingestion/source/database/databricks/lineage.py @@ -35,7 +35,7 @@ class DatabricksLineageSource(DatabricksQueryParserSource, LineageSource): start_date=self.start, end_date=self.end, ) - for row in data: + for row in data or []: try: if self.client.is_query_valid(row): yield TableQuery( diff --git a/ingestion/src/metadata/ingestion/source/database/databricks/usage.py b/ingestion/src/metadata/ingestion/source/database/databricks/usage.py index 9199c79274b..77f2cecd351 100644 --- a/ingestion/src/metadata/ingestion/source/database/databricks/usage.py +++ b/ingestion/src/metadata/ingestion/source/database/databricks/usage.py @@ -39,7 +39,7 @@ class DatabricksUsageSource(DatabricksQueryParserSource, UsageSource): start_date=self.start, end_date=self.end, ) - for row in data: + for row in data or []: try: if self.client.is_query_valid(row): queries.append( diff --git a/ingestion/src/metadata/ingestion/source/pipeline/databrickspipeline/metadata.py b/ingestion/src/metadata/ingestion/source/pipeline/databrickspipeline/metadata.py index a8910e877f8..bda3964b32a 100644 --- a/ingestion/src/metadata/ingestion/source/pipeline/databrickspipeline/metadata.py +++ b/ingestion/src/metadata/ingestion/source/pipeline/databrickspipeline/metadata.py @@ -81,7 +81,7 @@ class DatabrickspipelineSource(PipelineServiceSource): return cls(config, metadata) def get_pipelines_list(self) -> Iterable[dict]: - for workflow in self.client.list_jobs(): + for workflow in self.client.list_jobs() or []: yield workflow def get_pipeline_name(self, pipeline_details: dict) -> str: @@ -195,7 +195,7 @@ class DatabrickspipelineSource(PipelineServiceSource): for job_id in self.context.job_id_list: try: runs = self.client.get_job_runs(job_id=job_id) - for attempt in runs: + for attempt in runs or []: for task_run in attempt["tasks"]: task_status = [] task_status.append(