Minor: Optimise Databricks Client (#14776)

This commit is contained in:
Mayur Singal 2024-01-23 11:28:02 +05:30 committed by GitHub
parent fd7a3f19ef
commit 492bac32c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 47 additions and 47 deletions

View File

@ -14,7 +14,7 @@ Client to interact with databricks apis
import json import json
import traceback import traceback
from datetime import timedelta from datetime import timedelta
from typing import List from typing import Iterable, List
import requests import requests
@ -31,6 +31,12 @@ API_TIMEOUT = 10
QUERIES_PATH = "/sql/history/queries" QUERIES_PATH = "/sql/history/queries"
class DatabricksClientException(Exception):
"""
Class to throw auth and other databricks api exceptions.
"""
class DatabricksClient: class DatabricksClient:
""" """
DatabricksClient creates a Databricks connection based on DatabricksCredentials. DatabricksClient creates a Databricks connection based on DatabricksCredentials.
@ -60,14 +66,33 @@ class DatabricksClient:
if res.status_code != 200: if res.status_code != 200:
raise APIError(res.json) raise APIError(res.json)
def _run_query_paginator(self, data, result, end_time, response):
while True:
if response:
next_page_token = response.get("next_page_token", None)
has_next_page = response.get("has_next_page", None)
if next_page_token:
data["page_token"] = next_page_token
if not has_next_page:
data = {}
break
else:
break
if result[-1]["execution_end_time_ms"] <= end_time:
response = self.client.get(
self.base_query_url,
data=json.dumps(data),
headers=self.headers,
timeout=API_TIMEOUT,
).json()
yield from response.get("res") or []
def list_query_history(self, start_date=None, end_date=None) -> List[dict]: def list_query_history(self, start_date=None, end_date=None) -> List[dict]:
""" """
Method returns List the history of queries through SQL warehouses Method returns List the history of queries through SQL warehouses
""" """
query_details = []
try: try:
next_page_token = None
has_next_page = None
data = {} data = {}
daydiff = end_date - start_date daydiff = end_date - start_date
@ -98,36 +123,15 @@ class DatabricksClient:
result = response.get("res") or [] result = response.get("res") or []
data = {} data = {}
while True: yield from result
if result: yield from self._run_query_paginator(
query_details.extend(result) data=data, result=result, end_time=end_time, response=response
) or []
next_page_token = response.get("next_page_token", None)
has_next_page = response.get("has_next_page", None)
if next_page_token:
data["page_token"] = next_page_token
if not has_next_page:
data = {}
break
else:
break
if result[-1]["execution_end_time_ms"] <= end_time:
response = self.client.get(
self.base_query_url,
data=json.dumps(data),
headers=self.headers,
timeout=API_TIMEOUT,
).json()
result = response.get("res")
except Exception as exc: except Exception as exc:
logger.debug(traceback.format_exc()) logger.debug(traceback.format_exc())
logger.error(exc) logger.error(exc)
return query_details
def is_query_valid(self, row) -> bool: def is_query_valid(self, row) -> bool:
query_text = row.get("query_text") query_text = row.get("query_text")
return not ( return not (
@ -137,18 +141,19 @@ class DatabricksClient:
def list_jobs_test_connection(self) -> None: def list_jobs_test_connection(self) -> None:
data = {"limit": 1, "expand_tasks": True, "offset": 0} data = {"limit": 1, "expand_tasks": True, "offset": 0}
self.client.get( response = self.client.get(
self.jobs_list_url, self.jobs_list_url,
data=json.dumps(data), data=json.dumps(data),
headers=self.headers, headers=self.headers,
timeout=API_TIMEOUT, timeout=API_TIMEOUT,
).json() )
if response.status_code != 200:
raise DatabricksClientException(response.text)
def list_jobs(self) -> List[dict]: def list_jobs(self) -> Iterable[dict]:
""" """
Method returns List all the created jobs in a Databricks Workspace Method returns List all the created jobs in a Databricks Workspace
""" """
job_list = []
try: try:
data = {"limit": 25, "expand_tasks": True, "offset": 0} data = {"limit": 25, "expand_tasks": True, "offset": 0}
@ -159,9 +164,9 @@ class DatabricksClient:
timeout=API_TIMEOUT, timeout=API_TIMEOUT,
).json() ).json()
job_list.extend(response.get("jobs") or []) yield from response.get("jobs") or []
while response["has_more"]: while response and response.get("has_more"):
data["offset"] = len(response.get("jobs") or []) data["offset"] = len(response.get("jobs") or [])
response = self.client.get( response = self.client.get(
@ -171,19 +176,16 @@ class DatabricksClient:
timeout=API_TIMEOUT, timeout=API_TIMEOUT,
).json() ).json()
job_list.extend(response.get("jobs") or []) yield from response.get("jobs") or []
except Exception as exc: except Exception as exc:
logger.debug(traceback.format_exc()) logger.debug(traceback.format_exc())
logger.error(exc) logger.error(exc)
return job_list
def get_job_runs(self, job_id) -> List[dict]: def get_job_runs(self, job_id) -> List[dict]:
""" """
Method returns List of all runs for a job by the specified job_id Method returns List of all runs for a job by the specified job_id
""" """
job_runs = []
try: try:
params = { params = {
"job_id": job_id, "job_id": job_id,
@ -200,7 +202,7 @@ class DatabricksClient:
timeout=API_TIMEOUT, timeout=API_TIMEOUT,
).json() ).json()
job_runs.extend(response.get("runs") or []) yield from response.get("runs") or []
while response["has_more"]: while response["has_more"]:
params.update({"start_time_to": response["runs"][-1]["start_time"]}) params.update({"start_time_to": response["runs"][-1]["start_time"]})
@ -212,10 +214,8 @@ class DatabricksClient:
timeout=API_TIMEOUT, timeout=API_TIMEOUT,
).json() ).json()
job_runs.extend(response.get("runs" or [])) yield from response.get("runs") or []
except Exception as exc: except Exception as exc:
logger.debug(traceback.format_exc()) logger.debug(traceback.format_exc())
logger.error(exc) logger.error(exc)
return job_runs

View File

@ -35,7 +35,7 @@ class DatabricksLineageSource(DatabricksQueryParserSource, LineageSource):
start_date=self.start, start_date=self.start,
end_date=self.end, end_date=self.end,
) )
for row in data: for row in data or []:
try: try:
if self.client.is_query_valid(row): if self.client.is_query_valid(row):
yield TableQuery( yield TableQuery(

View File

@ -39,7 +39,7 @@ class DatabricksUsageSource(DatabricksQueryParserSource, UsageSource):
start_date=self.start, start_date=self.start,
end_date=self.end, end_date=self.end,
) )
for row in data: for row in data or []:
try: try:
if self.client.is_query_valid(row): if self.client.is_query_valid(row):
queries.append( queries.append(

View File

@ -81,7 +81,7 @@ class DatabrickspipelineSource(PipelineServiceSource):
return cls(config, metadata) return cls(config, metadata)
def get_pipelines_list(self) -> Iterable[dict]: def get_pipelines_list(self) -> Iterable[dict]:
for workflow in self.client.list_jobs(): for workflow in self.client.list_jobs() or []:
yield workflow yield workflow
def get_pipeline_name(self, pipeline_details: dict) -> str: def get_pipeline_name(self, pipeline_details: dict) -> str:
@ -195,7 +195,7 @@ class DatabrickspipelineSource(PipelineServiceSource):
for job_id in self.context.job_id_list: for job_id in self.context.job_id_list:
try: try:
runs = self.client.get_job_runs(job_id=job_id) runs = self.client.get_job_runs(job_id=job_id)
for attempt in runs: for attempt in runs or []:
for task_run in attempt["tasks"]: for task_run in attempt["tasks"]:
task_status = [] task_status = []
task_status.append( task_status.append(