Minor: Optimise Databricks Client (#14776)

This commit is contained in:
Mayur Singal 2024-01-23 11:28:02 +05:30 committed by GitHub
parent fd7a3f19ef
commit 492bac32c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 47 additions and 47 deletions

View File

@ -14,7 +14,7 @@ Client to interact with databricks apis
import json
import traceback
from datetime import timedelta
from typing import List
from typing import Iterable, List
import requests
@ -31,6 +31,12 @@ API_TIMEOUT = 10
QUERIES_PATH = "/sql/history/queries"
class DatabricksClientException(Exception):
"""
Class to throw auth and other databricks api exceptions.
"""
class DatabricksClient:
"""
DatabricksClient creates a Databricks connection based on DatabricksCredentials.
@ -60,14 +66,33 @@ class DatabricksClient:
if res.status_code != 200:
raise APIError(res.json)
def _run_query_paginator(self, data, result, end_time, response):
while True:
if response:
next_page_token = response.get("next_page_token", None)
has_next_page = response.get("has_next_page", None)
if next_page_token:
data["page_token"] = next_page_token
if not has_next_page:
data = {}
break
else:
break
if result[-1]["execution_end_time_ms"] <= end_time:
response = self.client.get(
self.base_query_url,
data=json.dumps(data),
headers=self.headers,
timeout=API_TIMEOUT,
).json()
yield from response.get("res") or []
def list_query_history(self, start_date=None, end_date=None) -> List[dict]:
"""
Method returns List the history of queries through SQL warehouses
"""
query_details = []
try:
next_page_token = None
has_next_page = None
data = {}
daydiff = end_date - start_date
@ -98,36 +123,15 @@ class DatabricksClient:
result = response.get("res") or []
data = {}
while True:
if result:
query_details.extend(result)
next_page_token = response.get("next_page_token", None)
has_next_page = response.get("has_next_page", None)
if next_page_token:
data["page_token"] = next_page_token
if not has_next_page:
data = {}
break
else:
break
if result[-1]["execution_end_time_ms"] <= end_time:
response = self.client.get(
self.base_query_url,
data=json.dumps(data),
headers=self.headers,
timeout=API_TIMEOUT,
).json()
result = response.get("res")
yield from result
yield from self._run_query_paginator(
data=data, result=result, end_time=end_time, response=response
) or []
except Exception as exc:
logger.debug(traceback.format_exc())
logger.error(exc)
return query_details
def is_query_valid(self, row) -> bool:
query_text = row.get("query_text")
return not (
@ -137,18 +141,19 @@ class DatabricksClient:
def list_jobs_test_connection(self) -> None:
data = {"limit": 1, "expand_tasks": True, "offset": 0}
self.client.get(
response = self.client.get(
self.jobs_list_url,
data=json.dumps(data),
headers=self.headers,
timeout=API_TIMEOUT,
).json()
)
if response.status_code != 200:
raise DatabricksClientException(response.text)
def list_jobs(self) -> List[dict]:
def list_jobs(self) -> Iterable[dict]:
"""
Method returns List all the created jobs in a Databricks Workspace
"""
job_list = []
try:
data = {"limit": 25, "expand_tasks": True, "offset": 0}
@ -159,9 +164,9 @@ class DatabricksClient:
timeout=API_TIMEOUT,
).json()
job_list.extend(response.get("jobs") or [])
yield from response.get("jobs") or []
while response["has_more"]:
while response and response.get("has_more"):
data["offset"] = len(response.get("jobs") or [])
response = self.client.get(
@ -171,19 +176,16 @@ class DatabricksClient:
timeout=API_TIMEOUT,
).json()
job_list.extend(response.get("jobs") or [])
yield from response.get("jobs") or []
except Exception as exc:
logger.debug(traceback.format_exc())
logger.error(exc)
return job_list
def get_job_runs(self, job_id) -> List[dict]:
"""
Method returns List of all runs for a job by the specified job_id
"""
job_runs = []
try:
params = {
"job_id": job_id,
@ -200,7 +202,7 @@ class DatabricksClient:
timeout=API_TIMEOUT,
).json()
job_runs.extend(response.get("runs") or [])
yield from response.get("runs") or []
while response["has_more"]:
params.update({"start_time_to": response["runs"][-1]["start_time"]})
@ -212,10 +214,8 @@ class DatabricksClient:
timeout=API_TIMEOUT,
).json()
job_runs.extend(response.get("runs" or []))
yield from response.get("runs") or []
except Exception as exc:
logger.debug(traceback.format_exc())
logger.error(exc)
return job_runs

View File

@ -35,7 +35,7 @@ class DatabricksLineageSource(DatabricksQueryParserSource, LineageSource):
start_date=self.start,
end_date=self.end,
)
for row in data:
for row in data or []:
try:
if self.client.is_query_valid(row):
yield TableQuery(

View File

@ -39,7 +39,7 @@ class DatabricksUsageSource(DatabricksQueryParserSource, UsageSource):
start_date=self.start,
end_date=self.end,
)
for row in data:
for row in data or []:
try:
if self.client.is_query_valid(row):
queries.append(

View File

@ -81,7 +81,7 @@ class DatabrickspipelineSource(PipelineServiceSource):
return cls(config, metadata)
def get_pipelines_list(self) -> Iterable[dict]:
for workflow in self.client.list_jobs():
for workflow in self.client.list_jobs() or []:
yield workflow
def get_pipeline_name(self, pipeline_details: dict) -> str:
@ -195,7 +195,7 @@ class DatabrickspipelineSource(PipelineServiceSource):
for job_id in self.context.job_id_list:
try:
runs = self.client.get_job_runs(job_id=job_id)
for attempt in runs:
for attempt in runs or []:
for task_run in attempt["tasks"]:
task_status = []
task_status.append(