mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-07-24 01:40:00 +00:00
Minor: Optimise Databricks Client (#14776)
This commit is contained in:
parent
fd7a3f19ef
commit
492bac32c0
@ -14,7 +14,7 @@ Client to interact with databricks apis
|
||||
import json
|
||||
import traceback
|
||||
from datetime import timedelta
|
||||
from typing import List
|
||||
from typing import Iterable, List
|
||||
|
||||
import requests
|
||||
|
||||
@ -31,6 +31,12 @@ API_TIMEOUT = 10
|
||||
QUERIES_PATH = "/sql/history/queries"
|
||||
|
||||
|
||||
class DatabricksClientException(Exception):
|
||||
"""
|
||||
Class to throw auth and other databricks api exceptions.
|
||||
"""
|
||||
|
||||
|
||||
class DatabricksClient:
|
||||
"""
|
||||
DatabricksClient creates a Databricks connection based on DatabricksCredentials.
|
||||
@ -60,14 +66,33 @@ class DatabricksClient:
|
||||
if res.status_code != 200:
|
||||
raise APIError(res.json)
|
||||
|
||||
def _run_query_paginator(self, data, result, end_time, response):
|
||||
while True:
|
||||
if response:
|
||||
next_page_token = response.get("next_page_token", None)
|
||||
has_next_page = response.get("has_next_page", None)
|
||||
if next_page_token:
|
||||
data["page_token"] = next_page_token
|
||||
if not has_next_page:
|
||||
data = {}
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
if result[-1]["execution_end_time_ms"] <= end_time:
|
||||
response = self.client.get(
|
||||
self.base_query_url,
|
||||
data=json.dumps(data),
|
||||
headers=self.headers,
|
||||
timeout=API_TIMEOUT,
|
||||
).json()
|
||||
yield from response.get("res") or []
|
||||
|
||||
def list_query_history(self, start_date=None, end_date=None) -> List[dict]:
|
||||
"""
|
||||
Method returns List the history of queries through SQL warehouses
|
||||
"""
|
||||
query_details = []
|
||||
try:
|
||||
next_page_token = None
|
||||
has_next_page = None
|
||||
|
||||
data = {}
|
||||
daydiff = end_date - start_date
|
||||
@ -98,36 +123,15 @@ class DatabricksClient:
|
||||
result = response.get("res") or []
|
||||
data = {}
|
||||
|
||||
while True:
|
||||
if result:
|
||||
query_details.extend(result)
|
||||
|
||||
next_page_token = response.get("next_page_token", None)
|
||||
has_next_page = response.get("has_next_page", None)
|
||||
if next_page_token:
|
||||
data["page_token"] = next_page_token
|
||||
|
||||
if not has_next_page:
|
||||
data = {}
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
if result[-1]["execution_end_time_ms"] <= end_time:
|
||||
response = self.client.get(
|
||||
self.base_query_url,
|
||||
data=json.dumps(data),
|
||||
headers=self.headers,
|
||||
timeout=API_TIMEOUT,
|
||||
).json()
|
||||
result = response.get("res")
|
||||
yield from result
|
||||
yield from self._run_query_paginator(
|
||||
data=data, result=result, end_time=end_time, response=response
|
||||
) or []
|
||||
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.error(exc)
|
||||
|
||||
return query_details
|
||||
|
||||
def is_query_valid(self, row) -> bool:
|
||||
query_text = row.get("query_text")
|
||||
return not (
|
||||
@ -137,18 +141,19 @@ class DatabricksClient:
|
||||
|
||||
def list_jobs_test_connection(self) -> None:
|
||||
data = {"limit": 1, "expand_tasks": True, "offset": 0}
|
||||
self.client.get(
|
||||
response = self.client.get(
|
||||
self.jobs_list_url,
|
||||
data=json.dumps(data),
|
||||
headers=self.headers,
|
||||
timeout=API_TIMEOUT,
|
||||
).json()
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise DatabricksClientException(response.text)
|
||||
|
||||
def list_jobs(self) -> List[dict]:
|
||||
def list_jobs(self) -> Iterable[dict]:
|
||||
"""
|
||||
Method returns List all the created jobs in a Databricks Workspace
|
||||
"""
|
||||
job_list = []
|
||||
try:
|
||||
data = {"limit": 25, "expand_tasks": True, "offset": 0}
|
||||
|
||||
@ -159,9 +164,9 @@ class DatabricksClient:
|
||||
timeout=API_TIMEOUT,
|
||||
).json()
|
||||
|
||||
job_list.extend(response.get("jobs") or [])
|
||||
yield from response.get("jobs") or []
|
||||
|
||||
while response["has_more"]:
|
||||
while response and response.get("has_more"):
|
||||
data["offset"] = len(response.get("jobs") or [])
|
||||
|
||||
response = self.client.get(
|
||||
@ -171,19 +176,16 @@ class DatabricksClient:
|
||||
timeout=API_TIMEOUT,
|
||||
).json()
|
||||
|
||||
job_list.extend(response.get("jobs") or [])
|
||||
yield from response.get("jobs") or []
|
||||
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.error(exc)
|
||||
|
||||
return job_list
|
||||
|
||||
def get_job_runs(self, job_id) -> List[dict]:
|
||||
"""
|
||||
Method returns List of all runs for a job by the specified job_id
|
||||
"""
|
||||
job_runs = []
|
||||
try:
|
||||
params = {
|
||||
"job_id": job_id,
|
||||
@ -200,7 +202,7 @@ class DatabricksClient:
|
||||
timeout=API_TIMEOUT,
|
||||
).json()
|
||||
|
||||
job_runs.extend(response.get("runs") or [])
|
||||
yield from response.get("runs") or []
|
||||
|
||||
while response["has_more"]:
|
||||
params.update({"start_time_to": response["runs"][-1]["start_time"]})
|
||||
@ -212,10 +214,8 @@ class DatabricksClient:
|
||||
timeout=API_TIMEOUT,
|
||||
).json()
|
||||
|
||||
job_runs.extend(response.get("runs" or []))
|
||||
yield from response.get("runs") or []
|
||||
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
logger.error(exc)
|
||||
|
||||
return job_runs
|
||||
|
@ -35,7 +35,7 @@ class DatabricksLineageSource(DatabricksQueryParserSource, LineageSource):
|
||||
start_date=self.start,
|
||||
end_date=self.end,
|
||||
)
|
||||
for row in data:
|
||||
for row in data or []:
|
||||
try:
|
||||
if self.client.is_query_valid(row):
|
||||
yield TableQuery(
|
||||
|
@ -39,7 +39,7 @@ class DatabricksUsageSource(DatabricksQueryParserSource, UsageSource):
|
||||
start_date=self.start,
|
||||
end_date=self.end,
|
||||
)
|
||||
for row in data:
|
||||
for row in data or []:
|
||||
try:
|
||||
if self.client.is_query_valid(row):
|
||||
queries.append(
|
||||
|
@ -81,7 +81,7 @@ class DatabrickspipelineSource(PipelineServiceSource):
|
||||
return cls(config, metadata)
|
||||
|
||||
def get_pipelines_list(self) -> Iterable[dict]:
|
||||
for workflow in self.client.list_jobs():
|
||||
for workflow in self.client.list_jobs() or []:
|
||||
yield workflow
|
||||
|
||||
def get_pipeline_name(self, pipeline_details: dict) -> str:
|
||||
@ -195,7 +195,7 @@ class DatabrickspipelineSource(PipelineServiceSource):
|
||||
for job_id in self.context.job_id_list:
|
||||
try:
|
||||
runs = self.client.get_job_runs(job_id=job_id)
|
||||
for attempt in runs:
|
||||
for attempt in runs or []:
|
||||
for task_run in attempt["tasks"]:
|
||||
task_status = []
|
||||
task_status.append(
|
||||
|
Loading…
x
Reference in New Issue
Block a user