mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-08-18 14:06:59 +00:00
* Add inner connection validations and throw better error * Use custom classes when retrieving metadata from Airflow to allow more versions
This commit is contained in:
parent
2b8f721094
commit
ad95bdb9c6
@ -83,6 +83,9 @@ logger = ingestion_logger()
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
# Sources which contain inner connections to validate
|
||||
HAS_INNER_CONNECTION = {"Airflow"}
|
||||
|
||||
|
||||
def get_service_type(
|
||||
source_type: str,
|
||||
@ -91,6 +94,8 @@ def get_service_type(
|
||||
Type[DatabaseConnection],
|
||||
Type[MessagingConnection],
|
||||
Type[MetadataConnection],
|
||||
Type[PipelineConnection],
|
||||
Type[MlModelConnection],
|
||||
]:
|
||||
"""
|
||||
Return the service type for a source string
|
||||
@ -116,13 +121,13 @@ def get_service_type(
|
||||
def get_source_config_class(
|
||||
source_config_type: str,
|
||||
) -> Union[
|
||||
Type[DatabaseMetadataConfigType],
|
||||
Type[ProfilerConfigType],
|
||||
Type[DatabaseUsageConfigType],
|
||||
Type[DashboardMetadataConfigType],
|
||||
Type[MessagingMetadataConfigType],
|
||||
Type[MlModelMetadataConfigType],
|
||||
Type[PipelineMetadataConfigType],
|
||||
Type[DashboardServiceMetadataPipeline],
|
||||
Type[DatabaseServiceProfilerPipeline],
|
||||
Type[DatabaseServiceQueryUsagePipeline],
|
||||
Type[MessagingServiceMetadataPipeline],
|
||||
Type[PipelineServiceMetadataPipeline],
|
||||
Type[MlModelServiceMetadataPipeline],
|
||||
Type[DatabaseServiceMetadataPipeline],
|
||||
]:
|
||||
"""
|
||||
Return the source config type for a source string
|
||||
@ -180,6 +185,81 @@ def get_connection_class(
|
||||
return connection_class
|
||||
|
||||
|
||||
def _unsafe_parse_config(config: dict, cls: T, message: str) -> None:
|
||||
"""
|
||||
Given a config dictionary and the class it should match,
|
||||
try to parse it or log the given message
|
||||
"""
|
||||
|
||||
# Parse the service connection dictionary with the scoped class
|
||||
try:
|
||||
cls.parse_obj(config)
|
||||
except ValidationError as err:
|
||||
logger.error(message)
|
||||
logger.error(
|
||||
f"The supported properties for {cls.__name__} are {list(cls.__fields__.keys())}"
|
||||
)
|
||||
raise err
|
||||
|
||||
|
||||
def parse_service_connection(config_dict: dict) -> None:
|
||||
"""
|
||||
Parse the service connection and raise any scoped
|
||||
errors during the validation process
|
||||
|
||||
:param config_dict: JSON configuration
|
||||
"""
|
||||
# Unsafe access to the keys. Allow a KeyError if the config is not well formatted
|
||||
source_type = config_dict["source"]["serviceConnection"]["config"]["type"]
|
||||
|
||||
logger.error(
|
||||
f"Error parsing the Workflow Configuration for {source_type} ingestion"
|
||||
)
|
||||
|
||||
service_type = get_service_type(source_type)
|
||||
connection_class = get_connection_class(source_type, service_type)
|
||||
|
||||
if source_type in HAS_INNER_CONNECTION:
|
||||
# We will first parse the inner `connection` configuration
|
||||
inner_source_type = config_dict["source"]["serviceConnection"]["config"][
|
||||
"connection"
|
||||
]["type"]
|
||||
inner_service_type = get_service_type(inner_source_type)
|
||||
inner_connection_class = get_connection_class(
|
||||
inner_source_type, inner_service_type
|
||||
)
|
||||
_unsafe_parse_config(
|
||||
config=config_dict["source"]["serviceConnection"]["config"]["connection"],
|
||||
cls=inner_connection_class,
|
||||
message=f"Error parsing the inner service connection for {source_type}",
|
||||
)
|
||||
|
||||
# Parse the service connection dictionary with the scoped class
|
||||
_unsafe_parse_config(
|
||||
config=config_dict["source"]["serviceConnection"]["config"],
|
||||
cls=connection_class,
|
||||
message="Error parsing the service connection",
|
||||
)
|
||||
|
||||
|
||||
def parse_source_config(config_dict: dict) -> None:
|
||||
"""
|
||||
Parse the sourceConfig to help catch any config
|
||||
misconfigurations
|
||||
|
||||
:param config_dict: JSON configuration
|
||||
"""
|
||||
# Parse the source config
|
||||
source_config_type = config_dict["source"]["sourceConfig"]["config"]["type"]
|
||||
source_config_class = get_source_config_class(source_config_type)
|
||||
|
||||
_unsafe_parse_config(
|
||||
config=config_dict["source"]["sourceConfig"]["config"],
|
||||
cls=source_config_class,
|
||||
message="Error parsing the source config",
|
||||
)
|
||||
|
||||
|
||||
def parse_workflow_source(config_dict: dict) -> None:
|
||||
"""
|
||||
Validate the parsing of the source in the config dict.
|
||||
@ -188,22 +268,8 @@ def parse_workflow_source(config_dict: dict) -> None:
|
||||
|
||||
:param config_dict: JSON configuration
|
||||
"""
|
||||
# Unsafe access to the keys. Allow a KeyError if the config is not well formatted
|
||||
source_type = config_dict["source"]["serviceConnection"]["config"]["type"]
|
||||
logger.error(
|
||||
f"Error parsing the Workflow Configuration for {source_type} ingestion"
|
||||
)
|
||||
|
||||
service_type = get_service_type(source_type)
|
||||
connection_class = get_connection_class(source_type, service_type)
|
||||
|
||||
# Parse the dictionary with the scoped class
|
||||
connection_class.parse_obj(config_dict["source"]["serviceConnection"]["config"])
|
||||
|
||||
# Parse the source config
|
||||
source_config_type = config_dict["source"]["sourceConfig"]["config"]["type"]
|
||||
source_config_class = get_source_config_class(source_config_type)
|
||||
source_config_class.parse_obj(config_dict["source"]["sourceConfig"]["config"])
|
||||
parse_service_connection(config_dict)
|
||||
parse_source_config(config_dict)
|
||||
|
||||
|
||||
def parse_server_config(config_dict: dict) -> None:
|
||||
@ -224,13 +290,16 @@ def parse_server_config(config_dict: dict) -> None:
|
||||
|
||||
# If the error comes from the security config:
|
||||
auth_class = PROVIDER_CLASS_MAP.get(auth_provider)
|
||||
security_config = (
|
||||
config_dict.get("workflowConfig")
|
||||
.get("openMetadataServerConfig")
|
||||
.get("securityConfig")
|
||||
# throw an error if the keys are not present
|
||||
security_config = config_dict["workflowConfig"]["openMetadataServerConfig"][
|
||||
"securityConfig"
|
||||
]
|
||||
|
||||
_unsafe_parse_config(
|
||||
config=security_config,
|
||||
cls=auth_class,
|
||||
message="Error parsing the workflow security config",
|
||||
)
|
||||
if auth_class and security_config:
|
||||
auth_class.parse_obj(security_config)
|
||||
|
||||
# If the security config is properly configured, let's raise the ValidationError of the whole WorkflowConfig
|
||||
WorkflowConfig.parse_obj(config_dict["workflowConfig"])
|
||||
@ -284,4 +353,8 @@ def parse_test_connection_request_gracefully(
|
||||
connection_class = get_connection_class(source_type, service_type)
|
||||
|
||||
# Parse the dictionary with the scoped class
|
||||
connection_class.parse_obj(config_dict["connection"]["config"])
|
||||
_unsafe_parse_config(
|
||||
config=config_dict["connection"]["config"],
|
||||
cls=connection_class,
|
||||
message="Error parsing the connection config",
|
||||
)
|
||||
|
@ -12,11 +12,13 @@
|
||||
Airflow source to extract metadata from OM UI
|
||||
"""
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Any, Iterable, List, Optional, cast
|
||||
|
||||
from airflow.models import BaseOperator, DagRun
|
||||
from airflow.models import BaseOperator, DagRun, TaskInstance
|
||||
from airflow.models.serialized_dag import SerializedDagModel
|
||||
from airflow.serialization.serialized_objects import SerializedDAG
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from metadata.generated.schema.api.data.createPipeline import CreatePipelineRequest
|
||||
@ -54,6 +56,35 @@ STATUS_MAP = {
|
||||
"queued": StatusType.Pending.value,
|
||||
}
|
||||
|
||||
IGNORE_DAG_RUN_COL = {"log_template_id"}
|
||||
|
||||
|
||||
class OMSerializedDagDetails(BaseModel):
|
||||
"""
|
||||
Custom model we get from the Airflow db
|
||||
as a scoped SELECT from SerializedDagModel
|
||||
"""
|
||||
|
||||
dag_id: str
|
||||
data: Any
|
||||
fileloc: str
|
||||
|
||||
# We don't have a validator for SerializedDag
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class OMTaskInstance(BaseModel):
|
||||
"""
|
||||
Custom model we get from the Airflow db
|
||||
as a scoped SELECT from TaskInstance
|
||||
"""
|
||||
|
||||
task_id: str
|
||||
state: str
|
||||
start_date: Optional[datetime]
|
||||
end_date: Optional[datetime]
|
||||
|
||||
|
||||
class AirflowSource(PipelineServiceSource):
|
||||
"""
|
||||
@ -87,29 +118,81 @@ class AirflowSource(PipelineServiceSource):
|
||||
Return the SQLAlchemy session from the engine
|
||||
"""
|
||||
if not self._session:
|
||||
self._session = create_and_bind_session(self.engine)
|
||||
self._session = create_and_bind_session(self.connection)
|
||||
|
||||
return self._session
|
||||
|
||||
def get_pipeline_status(self, dag_id: str) -> DagRun:
|
||||
dag_run_list: DagRun = (
|
||||
self.session.query(DagRun)
|
||||
def get_pipeline_status(self, dag_id: str) -> List[DagRun]:
|
||||
|
||||
dag_run_list = (
|
||||
self.session.query(
|
||||
*[c for c in DagRun.__table__.c if c.name not in IGNORE_DAG_RUN_COL]
|
||||
)
|
||||
.filter(DagRun.dag_id == dag_id)
|
||||
.order_by(DagRun.execution_date.desc())
|
||||
.limit(self.config.serviceConnection.__root__.config.numberOfStatus)
|
||||
.all()
|
||||
)
|
||||
return dag_run_list
|
||||
|
||||
dag_run_dict = [dict(elem) for elem in dag_run_list]
|
||||
|
||||
# Build DagRun manually to not fall into new/old columns from
|
||||
# different Airflow versions
|
||||
return [
|
||||
DagRun(
|
||||
dag_id=elem.get("dag_id"),
|
||||
run_id=elem.get("run_id"),
|
||||
queued_at=elem.get("queued_at"),
|
||||
execution_date=elem.get("execution_date"),
|
||||
start_date=elem.get("start_date"),
|
||||
state=elem.get("state"),
|
||||
)
|
||||
for elem in dag_run_dict
|
||||
]
|
||||
|
||||
def get_task_instances(self, dag_id: str, run_id: str) -> List[OMTaskInstance]:
|
||||
"""
|
||||
We are building our own scoped TaskInstance
|
||||
class to only focus on core properties required
|
||||
by the metadata ingestion.
|
||||
|
||||
This makes the versioning more flexible on which Airflow
|
||||
sources we support.
|
||||
"""
|
||||
task_instance_list = (
|
||||
self.session.query(
|
||||
TaskInstance.task_id,
|
||||
TaskInstance.state,
|
||||
TaskInstance.start_date,
|
||||
TaskInstance.end_date,
|
||||
TaskInstance.run_id,
|
||||
)
|
||||
.filter(TaskInstance.dag_id == dag_id, TaskInstance.run_id == run_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
task_instance_dict = [dict(elem) for elem in task_instance_list]
|
||||
|
||||
return [
|
||||
OMTaskInstance(
|
||||
task_id=elem.get("task_id"),
|
||||
state=elem.get("state"),
|
||||
start_date=elem.get("start_date"),
|
||||
end_date=elem.get("end_date"),
|
||||
)
|
||||
for elem in task_instance_dict
|
||||
]
|
||||
|
||||
def yield_pipeline_status(
|
||||
self, pipeline_details: SerializedDAG
|
||||
) -> OMetaPipelineStatus:
|
||||
dag_run_list = self.get_pipeline_status(pipeline_details.dag_id)
|
||||
|
||||
for dag in dag_run_list:
|
||||
if isinstance(dag.task_instances, Iterable):
|
||||
tasks = dag.task_instances
|
||||
else:
|
||||
tasks = [dag.task_instances]
|
||||
for dag_run in dag_run_list:
|
||||
tasks = self.get_task_instances(
|
||||
dag_id=dag_run.dag_id,
|
||||
run_id=dag_run.run_id,
|
||||
)
|
||||
|
||||
task_statuses = [
|
||||
TaskStatus(
|
||||
@ -121,30 +204,38 @@ class AirflowSource(PipelineServiceSource):
|
||||
endTime=datetime_to_ts(
|
||||
task.end_date
|
||||
), # Might be None for running tasks
|
||||
logLink=task.log_url,
|
||||
)
|
||||
) # Log link might not be present in all Airflow versions
|
||||
for task in tasks
|
||||
]
|
||||
|
||||
pipeline_status = PipelineStatus(
|
||||
taskStatus=task_statuses,
|
||||
executionStatus=STATUS_MAP.get(dag._state, StatusType.Pending.value),
|
||||
timestamp=dag.execution_date.timestamp(),
|
||||
executionStatus=STATUS_MAP.get(dag_run.state, StatusType.Pending.value),
|
||||
timestamp=dag_run.execution_date.timestamp(),
|
||||
)
|
||||
yield OMetaPipelineStatus(
|
||||
pipeline_fqn=self.context.pipeline.fullyQualifiedName.__root__,
|
||||
pipeline_status=pipeline_status,
|
||||
)
|
||||
|
||||
def get_pipelines_list(self) -> Iterable[SerializedDagModel]:
|
||||
def get_pipelines_list(self) -> Iterable[OMSerializedDagDetails]:
|
||||
"""
|
||||
List all DAGs from the metadata db.
|
||||
|
||||
We are using the SerializedDagModel as it helps
|
||||
us retrieve all the task and inlets/outlets information
|
||||
"""
|
||||
for serialized_dag in self.session.query(SerializedDagModel).all():
|
||||
yield serialized_dag
|
||||
for serialized_dag in self.session.query(
|
||||
SerializedDagModel.dag_id,
|
||||
SerializedDagModel._data,
|
||||
SerializedDagModel.fileloc,
|
||||
).all():
|
||||
|
||||
yield OMSerializedDagDetails(
|
||||
dag_id=serialized_dag[0],
|
||||
data=serialized_dag[1],
|
||||
fileloc=serialized_dag[2],
|
||||
)
|
||||
|
||||
def get_pipeline_name(self, pipeline_details: SerializedDAG) -> str:
|
||||
"""
|
||||
@ -173,15 +264,29 @@ class AirflowSource(PipelineServiceSource):
|
||||
for task in cast(Iterable[BaseOperator], dag.tasks)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _build_dag(data: Any) -> SerializedDAG:
|
||||
"""
|
||||
Use the queried data to fetch the DAG
|
||||
:param data: from SQA query
|
||||
:return: SerializedDAG
|
||||
"""
|
||||
|
||||
if isinstance(data, dict):
|
||||
return SerializedDAG.from_dict(data)
|
||||
|
||||
return SerializedDAG.from_json(data)
|
||||
|
||||
def yield_pipeline(
|
||||
self, pipeline_details: SerializedDagModel
|
||||
self, pipeline_details: OMSerializedDagDetails
|
||||
) -> Iterable[CreatePipelineRequest]:
|
||||
"""
|
||||
Convert a DAG into a Pipeline Entity
|
||||
:param serialized_dag: SerializedDAG from airflow metadata DB
|
||||
:param pipeline_details: SerializedDAG from airflow metadata DB
|
||||
:return: Create Pipeline request with tasks
|
||||
"""
|
||||
dag: SerializedDAG = pipeline_details.dag
|
||||
|
||||
dag: SerializedDAG = self._build_dag(pipeline_details.data)
|
||||
yield CreatePipelineRequest(
|
||||
name=pipeline_details.dag_id,
|
||||
description=dag.description,
|
||||
@ -242,14 +347,14 @@ class AirflowSource(PipelineServiceSource):
|
||||
return None
|
||||
|
||||
def yield_pipeline_lineage_details(
|
||||
self, pipeline_details: SerializedDagModel
|
||||
self, pipeline_details: OMSerializedDagDetails
|
||||
) -> Optional[Iterable[AddLineageRequest]]:
|
||||
"""
|
||||
Parse xlets and add lineage between Pipelines and Tables
|
||||
:param pipeline_details: SerializedDAG from airflow metadata DB
|
||||
:return: Lineage from inlets and outlets
|
||||
"""
|
||||
dag: SerializedDAG = pipeline_details.dag
|
||||
dag: SerializedDAG = self._build_dag(pipeline_details.data)
|
||||
|
||||
for task in dag.tasks:
|
||||
for table_fqn in self.get_inlets(task) or []:
|
||||
@ -294,6 +399,3 @@ class AirflowSource(PipelineServiceSource):
|
||||
|
||||
def close(self):
|
||||
self.session.close()
|
||||
|
||||
def test_connection(self) -> None:
|
||||
test_connection(self.engine)
|
||||
|
@ -194,7 +194,7 @@ class PipelineServiceSource(TopologyRunnerMixin, Source, ABC):
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Method to implement any required logic after the ingesion process is completed
|
||||
Method to implement any required logic after the ingestion process is completed
|
||||
"""
|
||||
|
||||
def get_services(self) -> Iterable[WorkflowSource]:
|
||||
@ -223,5 +223,5 @@ class PipelineServiceSource(TopologyRunnerMixin, Source, ABC):
|
||||
|
||||
def prepare(self):
|
||||
"""
|
||||
Method to implement any required logic before starting the ingesion process
|
||||
Method to implement any required logic before starting the ingestion process
|
||||
"""
|
||||
|
@ -67,7 +67,7 @@ class TestWorkflowParse(TestCase):
|
||||
|
||||
def test_get_service_type(self):
|
||||
"""
|
||||
Test that we can get the service type of a source
|
||||
Test that we can get the service type of source
|
||||
"""
|
||||
|
||||
database_service = get_service_type("Mysql")
|
||||
|
Loading…
x
Reference in New Issue
Block a user