From ad95bdb9c65ca123e96b8043b0e4fe7e80462fea Mon Sep 17 00:00:00 2001 From: Pere Miquel Brull Date: Tue, 16 Aug 2022 18:47:50 +0200 Subject: [PATCH] Fix #6735 - Improve source validation for inner connections (#6738) * Add inner connection validations and throw better error * Use custom classes when retrieving metadata from Airflow to allow more versions --- .../src/metadata/ingestion/api/parser.py | 133 +++++++++++---- .../ingestion/source/pipeline/airflow.py | 154 +++++++++++++++--- .../source/pipeline/pipeline_service.py | 4 +- ingestion/tests/unit/test_workflow_parse.py | 2 +- 4 files changed, 234 insertions(+), 59 deletions(-) diff --git a/ingestion/src/metadata/ingestion/api/parser.py b/ingestion/src/metadata/ingestion/api/parser.py index ddf43e2c4d5..2164b4728f9 100644 --- a/ingestion/src/metadata/ingestion/api/parser.py +++ b/ingestion/src/metadata/ingestion/api/parser.py @@ -83,6 +83,9 @@ logger = ingestion_logger() T = TypeVar("T", bound=BaseModel) +# Sources which contain inner connections to validate +HAS_INNER_CONNECTION = {"Airflow"} + def get_service_type( source_type: str, @@ -91,6 +94,8 @@ def get_service_type( Type[DatabaseConnection], Type[MessagingConnection], Type[MetadataConnection], + Type[PipelineConnection], + Type[MlModelConnection], ]: """ Return the service type for a source string @@ -116,13 +121,13 @@ def get_service_type( def get_source_config_class( source_config_type: str, ) -> Union[ - Type[DatabaseMetadataConfigType], - Type[ProfilerConfigType], - Type[DatabaseUsageConfigType], - Type[DashboardMetadataConfigType], - Type[MessagingMetadataConfigType], - Type[MlModelMetadataConfigType], - Type[PipelineMetadataConfigType], + Type[DashboardServiceMetadataPipeline], + Type[DatabaseServiceProfilerPipeline], + Type[DatabaseServiceQueryUsagePipeline], + Type[MessagingServiceMetadataPipeline], + Type[PipelineServiceMetadataPipeline], + Type[MlModelServiceMetadataPipeline], + Type[DatabaseServiceMetadataPipeline], ]: """ Return the source config type for a source string @@ -180,6 +185,81 @@ def get_connection_class( return connection_class +def _unsafe_parse_config(config: dict, cls: T, message: str) -> None: + """ + Given a config dictionary and the class it should match, + try to parse it or log the given message + """ + + # Parse the service connection dictionary with the scoped class + try: + cls.parse_obj(config) + except ValidationError as err: + logger.error(message) + logger.error( + f"The supported properties for {cls.__name__} are {list(cls.__fields__.keys())}" + ) + raise err + + +def parse_service_connection(config_dict: dict) -> None: + """ + Parse the service connection and raise any scoped + errors during the validation process + + :param config_dict: JSON configuration + """ + # Unsafe access to the keys. Allow a KeyError if the config is not well formatted + source_type = config_dict["source"]["serviceConnection"]["config"]["type"] + + logger.error( + f"Error parsing the Workflow Configuration for {source_type} ingestion" + ) + + service_type = get_service_type(source_type) + connection_class = get_connection_class(source_type, service_type) + + if source_type in HAS_INNER_CONNECTION: + # We will first parse the inner `connection` configuration + inner_source_type = config_dict["source"]["serviceConnection"]["config"][ + "connection" + ]["type"] + inner_service_type = get_service_type(inner_source_type) + inner_connection_class = get_connection_class( + inner_source_type, inner_service_type + ) + _unsafe_parse_config( + config=config_dict["source"]["serviceConnection"]["config"]["connection"], + cls=inner_connection_class, + message=f"Error parsing the inner service connection for {source_type}", + ) + + # Parse the service connection dictionary with the scoped class + _unsafe_parse_config( + config=config_dict["source"]["serviceConnection"]["config"], + cls=connection_class, + message="Error parsing the service connection", + ) + + +def parse_source_config(config_dict: dict) -> None: + """ + Parse the sourceConfig to help catch any config + misconfigurations + + :param config_dict: JSON configuration + """ + # Parse the source config + source_config_type = config_dict["source"]["sourceConfig"]["config"]["type"] + source_config_class = get_source_config_class(source_config_type) + + _unsafe_parse_config( + config=config_dict["source"]["sourceConfig"]["config"], + cls=source_config_class, + message="Error parsing the source config", + ) + + def parse_workflow_source(config_dict: dict) -> None: """ Validate the parsing of the source in the config dict. @@ -188,22 +268,8 @@ def parse_workflow_source(config_dict: dict) -> None: :param config_dict: JSON configuration """ - # Unsafe access to the keys. Allow a KeyError if the config is not well formatted - source_type = config_dict["source"]["serviceConnection"]["config"]["type"] - logger.error( - f"Error parsing the Workflow Configuration for {source_type} ingestion" - ) - - service_type = get_service_type(source_type) - connection_class = get_connection_class(source_type, service_type) - - # Parse the dictionary with the scoped class - connection_class.parse_obj(config_dict["source"]["serviceConnection"]["config"]) - - # Parse the source config - source_config_type = config_dict["source"]["sourceConfig"]["config"]["type"] - source_config_class = get_source_config_class(source_config_type) - source_config_class.parse_obj(config_dict["source"]["sourceConfig"]["config"]) + parse_service_connection(config_dict) + parse_source_config(config_dict) def parse_server_config(config_dict: dict) -> None: @@ -224,13 +290,16 @@ def parse_server_config(config_dict: dict) -> None: # If the error comes from the security config: auth_class = PROVIDER_CLASS_MAP.get(auth_provider) - security_config = ( - config_dict.get("workflowConfig") - .get("openMetadataServerConfig") - .get("securityConfig") + # throw an error if the keys are not present + security_config = config_dict["workflowConfig"]["openMetadataServerConfig"][ + "securityConfig" + ] + + _unsafe_parse_config( + config=security_config, + cls=auth_class, + message="Error parsing the workflow security config", ) - if auth_class and security_config: - auth_class.parse_obj(security_config) # If the security config is properly configured, let's raise the ValidationError of the whole WorkflowConfig WorkflowConfig.parse_obj(config_dict["workflowConfig"]) @@ -284,4 +353,8 @@ def parse_test_connection_request_gracefully( connection_class = get_connection_class(source_type, service_type) # Parse the dictionary with the scoped class - connection_class.parse_obj(config_dict["connection"]["config"]) + _unsafe_parse_config( + config=config_dict["connection"]["config"], + cls=connection_class, + message="Error parsing the connection config", + ) diff --git a/ingestion/src/metadata/ingestion/source/pipeline/airflow.py b/ingestion/src/metadata/ingestion/source/pipeline/airflow.py index 2010a6e39f7..405f59b900c 100644 --- a/ingestion/src/metadata/ingestion/source/pipeline/airflow.py +++ b/ingestion/src/metadata/ingestion/source/pipeline/airflow.py @@ -12,11 +12,13 @@ Airflow source to extract metadata from OM UI """ import traceback +from datetime import datetime from typing import Any, Iterable, List, Optional, cast -from airflow.models import BaseOperator, DagRun +from airflow.models import BaseOperator, DagRun, TaskInstance from airflow.models.serialized_dag import SerializedDagModel from airflow.serialization.serialized_objects import SerializedDAG +from pydantic import BaseModel from sqlalchemy.orm import Session from metadata.generated.schema.api.data.createPipeline import CreatePipelineRequest @@ -54,6 +56,35 @@ STATUS_MAP = { "queued": StatusType.Pending.value, } +IGNORE_DAG_RUN_COL = {"log_template_id"} + + +class OMSerializedDagDetails(BaseModel): + """ + Custom model we get from the Airflow db + as a scoped SELECT from SerializedDagModel + """ + + dag_id: str + data: Any + fileloc: str + + # We don't have a validator for SerializedDag + class Config: + arbitrary_types_allowed = True + + +class OMTaskInstance(BaseModel): + """ + Custom model we get from the Airflow db + as a scoped SELECT from TaskInstance + """ + + task_id: str + state: str + start_date: Optional[datetime] + end_date: Optional[datetime] + class AirflowSource(PipelineServiceSource): """ @@ -87,29 +118,81 @@ class AirflowSource(PipelineServiceSource): Return the SQLAlchemy session from the engine """ if not self._session: - self._session = create_and_bind_session(self.engine) + self._session = create_and_bind_session(self.connection) return self._session - def get_pipeline_status(self, dag_id: str) -> DagRun: - dag_run_list: DagRun = ( - self.session.query(DagRun) + def get_pipeline_status(self, dag_id: str) -> List[DagRun]: + + dag_run_list = ( + self.session.query( + *[c for c in DagRun.__table__.c if c.name not in IGNORE_DAG_RUN_COL] + ) .filter(DagRun.dag_id == dag_id) .order_by(DagRun.execution_date.desc()) .limit(self.config.serviceConnection.__root__.config.numberOfStatus) + .all() ) - return dag_run_list + + dag_run_dict = [dict(elem) for elem in dag_run_list] + + # Build DagRun manually to not fall into new/old columns from + # different Airflow versions + return [ + DagRun( + dag_id=elem.get("dag_id"), + run_id=elem.get("run_id"), + queued_at=elem.get("queued_at"), + execution_date=elem.get("execution_date"), + start_date=elem.get("start_date"), + state=elem.get("state"), + ) + for elem in dag_run_dict + ] + + def get_task_instances(self, dag_id: str, run_id: str) -> List[OMTaskInstance]: + """ + We are building our own scoped TaskInstance + class to only focus on core properties required + by the metadata ingestion. + + This makes the versioning more flexible on which Airflow + sources we support. + """ + task_instance_list = ( + self.session.query( + TaskInstance.task_id, + TaskInstance.state, + TaskInstance.start_date, + TaskInstance.end_date, + TaskInstance.run_id, + ) + .filter(TaskInstance.dag_id == dag_id, TaskInstance.run_id == run_id) + .all() + ) + + task_instance_dict = [dict(elem) for elem in task_instance_list] + + return [ + OMTaskInstance( + task_id=elem.get("task_id"), + state=elem.get("state"), + start_date=elem.get("start_date"), + end_date=elem.get("end_date"), + ) + for elem in task_instance_dict + ] def yield_pipeline_status( self, pipeline_details: SerializedDAG ) -> OMetaPipelineStatus: dag_run_list = self.get_pipeline_status(pipeline_details.dag_id) - for dag in dag_run_list: - if isinstance(dag.task_instances, Iterable): - tasks = dag.task_instances - else: - tasks = [dag.task_instances] + for dag_run in dag_run_list: + tasks = self.get_task_instances( + dag_id=dag_run.dag_id, + run_id=dag_run.run_id, + ) task_statuses = [ TaskStatus( @@ -121,30 +204,38 @@ class AirflowSource(PipelineServiceSource): endTime=datetime_to_ts( task.end_date ), # Might be None for running tasks - logLink=task.log_url, - ) + ) # Log link might not be present in all Airflow versions for task in tasks ] pipeline_status = PipelineStatus( taskStatus=task_statuses, - executionStatus=STATUS_MAP.get(dag._state, StatusType.Pending.value), - timestamp=dag.execution_date.timestamp(), + executionStatus=STATUS_MAP.get(dag_run.state, StatusType.Pending.value), + timestamp=dag_run.execution_date.timestamp(), ) yield OMetaPipelineStatus( pipeline_fqn=self.context.pipeline.fullyQualifiedName.__root__, pipeline_status=pipeline_status, ) - def get_pipelines_list(self) -> Iterable[SerializedDagModel]: + def get_pipelines_list(self) -> Iterable[OMSerializedDagDetails]: """ List all DAGs from the metadata db. We are using the SerializedDagModel as it helps us retrieve all the task and inlets/outlets information """ - for serialized_dag in self.session.query(SerializedDagModel).all(): - yield serialized_dag + for serialized_dag in self.session.query( + SerializedDagModel.dag_id, + SerializedDagModel._data, + SerializedDagModel.fileloc, + ).all(): + + yield OMSerializedDagDetails( + dag_id=serialized_dag[0], + data=serialized_dag[1], + fileloc=serialized_dag[2], + ) def get_pipeline_name(self, pipeline_details: SerializedDAG) -> str: """ @@ -173,15 +264,29 @@ class AirflowSource(PipelineServiceSource): for task in cast(Iterable[BaseOperator], dag.tasks) ] + @staticmethod + def _build_dag(data: Any) -> SerializedDAG: + """ + Use the queried data to fetch the DAG + :param data: from SQA query + :return: SerializedDAG + """ + + if isinstance(data, dict): + return SerializedDAG.from_dict(data) + + return SerializedDAG.from_json(data) + def yield_pipeline( - self, pipeline_details: SerializedDagModel + self, pipeline_details: OMSerializedDagDetails ) -> Iterable[CreatePipelineRequest]: """ Convert a DAG into a Pipeline Entity - :param serialized_dag: SerializedDAG from airflow metadata DB + :param pipeline_details: SerializedDAG from airflow metadata DB :return: Create Pipeline request with tasks """ - dag: SerializedDAG = pipeline_details.dag + + dag: SerializedDAG = self._build_dag(pipeline_details.data) yield CreatePipelineRequest( name=pipeline_details.dag_id, description=dag.description, @@ -242,14 +347,14 @@ class AirflowSource(PipelineServiceSource): return None def yield_pipeline_lineage_details( - self, pipeline_details: SerializedDagModel + self, pipeline_details: OMSerializedDagDetails ) -> Optional[Iterable[AddLineageRequest]]: """ Parse xlets and add lineage between Pipelines and Tables :param pipeline_details: SerializedDAG from airflow metadata DB :return: Lineage from inlets and outlets """ - dag: SerializedDAG = pipeline_details.dag + dag: SerializedDAG = self._build_dag(pipeline_details.data) for task in dag.tasks: for table_fqn in self.get_inlets(task) or []: @@ -294,6 +399,3 @@ class AirflowSource(PipelineServiceSource): def close(self): self.session.close() - - def test_connection(self) -> None: - test_connection(self.engine) diff --git a/ingestion/src/metadata/ingestion/source/pipeline/pipeline_service.py b/ingestion/src/metadata/ingestion/source/pipeline/pipeline_service.py index 7266bebaa32..7e299708e9b 100644 --- a/ingestion/src/metadata/ingestion/source/pipeline/pipeline_service.py +++ b/ingestion/src/metadata/ingestion/source/pipeline/pipeline_service.py @@ -194,7 +194,7 @@ class PipelineServiceSource(TopologyRunnerMixin, Source, ABC): def close(self): """ - Method to implement any required logic after the ingesion process is completed + Method to implement any required logic after the ingestion process is completed """ def get_services(self) -> Iterable[WorkflowSource]: @@ -223,5 +223,5 @@ class PipelineServiceSource(TopologyRunnerMixin, Source, ABC): def prepare(self): """ - Method to implement any required logic before starting the ingesion process + Method to implement any required logic before starting the ingestion process """ diff --git a/ingestion/tests/unit/test_workflow_parse.py b/ingestion/tests/unit/test_workflow_parse.py index 81fd6fe0dbf..e313250079e 100644 --- a/ingestion/tests/unit/test_workflow_parse.py +++ b/ingestion/tests/unit/test_workflow_parse.py @@ -67,7 +67,7 @@ class TestWorkflowParse(TestCase): def test_get_service_type(self): """ - Test that we can get the service type of a source + Test that we can get the service type of source """ database_service = get_service_type("Mysql")