Fix #6735 - Improve source validation for inner connections (#6738)

* Add inner connection validations and throw better error

* Use custom classes when retrieving metadata from Airflow to allow more versions
This commit is contained in:
Pere Miquel Brull 2022-08-16 18:47:50 +02:00 committed by GitHub
parent 2b8f721094
commit ad95bdb9c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 234 additions and 59 deletions

View File

@ -83,6 +83,9 @@ logger = ingestion_logger()
T = TypeVar("T", bound=BaseModel)
# Sources which contain inner connections to validate
HAS_INNER_CONNECTION = {"Airflow"}
def get_service_type(
source_type: str,
@ -91,6 +94,8 @@ def get_service_type(
Type[DatabaseConnection],
Type[MessagingConnection],
Type[MetadataConnection],
Type[PipelineConnection],
Type[MlModelConnection],
]:
"""
Return the service type for a source string
@ -116,13 +121,13 @@ def get_service_type(
def get_source_config_class(
source_config_type: str,
) -> Union[
Type[DatabaseMetadataConfigType],
Type[ProfilerConfigType],
Type[DatabaseUsageConfigType],
Type[DashboardMetadataConfigType],
Type[MessagingMetadataConfigType],
Type[MlModelMetadataConfigType],
Type[PipelineMetadataConfigType],
Type[DashboardServiceMetadataPipeline],
Type[DatabaseServiceProfilerPipeline],
Type[DatabaseServiceQueryUsagePipeline],
Type[MessagingServiceMetadataPipeline],
Type[PipelineServiceMetadataPipeline],
Type[MlModelServiceMetadataPipeline],
Type[DatabaseServiceMetadataPipeline],
]:
"""
Return the source config type for a source string
@ -180,6 +185,81 @@ def get_connection_class(
return connection_class
def _unsafe_parse_config(config: dict, cls: T, message: str) -> None:
"""
Given a config dictionary and the class it should match,
try to parse it or log the given message
"""
# Parse the service connection dictionary with the scoped class
try:
cls.parse_obj(config)
except ValidationError as err:
logger.error(message)
logger.error(
f"The supported properties for {cls.__name__} are {list(cls.__fields__.keys())}"
)
raise err
def parse_service_connection(config_dict: dict) -> None:
"""
Parse the service connection and raise any scoped
errors during the validation process
:param config_dict: JSON configuration
"""
# Unsafe access to the keys. Allow a KeyError if the config is not well formatted
source_type = config_dict["source"]["serviceConnection"]["config"]["type"]
logger.error(
f"Error parsing the Workflow Configuration for {source_type} ingestion"
)
service_type = get_service_type(source_type)
connection_class = get_connection_class(source_type, service_type)
if source_type in HAS_INNER_CONNECTION:
# We will first parse the inner `connection` configuration
inner_source_type = config_dict["source"]["serviceConnection"]["config"][
"connection"
]["type"]
inner_service_type = get_service_type(inner_source_type)
inner_connection_class = get_connection_class(
inner_source_type, inner_service_type
)
_unsafe_parse_config(
config=config_dict["source"]["serviceConnection"]["config"]["connection"],
cls=inner_connection_class,
message=f"Error parsing the inner service connection for {source_type}",
)
# Parse the service connection dictionary with the scoped class
_unsafe_parse_config(
config=config_dict["source"]["serviceConnection"]["config"],
cls=connection_class,
message="Error parsing the service connection",
)
def parse_source_config(config_dict: dict) -> None:
"""
Parse the sourceConfig to help catch any config
misconfigurations
:param config_dict: JSON configuration
"""
# Parse the source config
source_config_type = config_dict["source"]["sourceConfig"]["config"]["type"]
source_config_class = get_source_config_class(source_config_type)
_unsafe_parse_config(
config=config_dict["source"]["sourceConfig"]["config"],
cls=source_config_class,
message="Error parsing the source config",
)
def parse_workflow_source(config_dict: dict) -> None:
"""
Validate the parsing of the source in the config dict.
@ -188,22 +268,8 @@ def parse_workflow_source(config_dict: dict) -> None:
:param config_dict: JSON configuration
"""
# Unsafe access to the keys. Allow a KeyError if the config is not well formatted
source_type = config_dict["source"]["serviceConnection"]["config"]["type"]
logger.error(
f"Error parsing the Workflow Configuration for {source_type} ingestion"
)
service_type = get_service_type(source_type)
connection_class = get_connection_class(source_type, service_type)
# Parse the dictionary with the scoped class
connection_class.parse_obj(config_dict["source"]["serviceConnection"]["config"])
# Parse the source config
source_config_type = config_dict["source"]["sourceConfig"]["config"]["type"]
source_config_class = get_source_config_class(source_config_type)
source_config_class.parse_obj(config_dict["source"]["sourceConfig"]["config"])
parse_service_connection(config_dict)
parse_source_config(config_dict)
def parse_server_config(config_dict: dict) -> None:
@ -224,13 +290,16 @@ def parse_server_config(config_dict: dict) -> None:
# If the error comes from the security config:
auth_class = PROVIDER_CLASS_MAP.get(auth_provider)
security_config = (
config_dict.get("workflowConfig")
.get("openMetadataServerConfig")
.get("securityConfig")
# throw an error if the keys are not present
security_config = config_dict["workflowConfig"]["openMetadataServerConfig"][
"securityConfig"
]
_unsafe_parse_config(
config=security_config,
cls=auth_class,
message="Error parsing the workflow security config",
)
if auth_class and security_config:
auth_class.parse_obj(security_config)
# If the security config is properly configured, let's raise the ValidationError of the whole WorkflowConfig
WorkflowConfig.parse_obj(config_dict["workflowConfig"])
@ -284,4 +353,8 @@ def parse_test_connection_request_gracefully(
connection_class = get_connection_class(source_type, service_type)
# Parse the dictionary with the scoped class
connection_class.parse_obj(config_dict["connection"]["config"])
_unsafe_parse_config(
config=config_dict["connection"]["config"],
cls=connection_class,
message="Error parsing the connection config",
)

View File

@ -12,11 +12,13 @@
Airflow source to extract metadata from OM UI
"""
import traceback
from datetime import datetime
from typing import Any, Iterable, List, Optional, cast
from airflow.models import BaseOperator, DagRun
from airflow.models import BaseOperator, DagRun, TaskInstance
from airflow.models.serialized_dag import SerializedDagModel
from airflow.serialization.serialized_objects import SerializedDAG
from pydantic import BaseModel
from sqlalchemy.orm import Session
from metadata.generated.schema.api.data.createPipeline import CreatePipelineRequest
@ -54,6 +56,35 @@ STATUS_MAP = {
"queued": StatusType.Pending.value,
}
IGNORE_DAG_RUN_COL = {"log_template_id"}
class OMSerializedDagDetails(BaseModel):
"""
Custom model we get from the Airflow db
as a scoped SELECT from SerializedDagModel
"""
dag_id: str
data: Any
fileloc: str
# We don't have a validator for SerializedDag
class Config:
arbitrary_types_allowed = True
class OMTaskInstance(BaseModel):
"""
Custom model we get from the Airflow db
as a scoped SELECT from TaskInstance
"""
task_id: str
state: str
start_date: Optional[datetime]
end_date: Optional[datetime]
class AirflowSource(PipelineServiceSource):
"""
@ -87,29 +118,81 @@ class AirflowSource(PipelineServiceSource):
Return the SQLAlchemy session from the engine
"""
if not self._session:
self._session = create_and_bind_session(self.engine)
self._session = create_and_bind_session(self.connection)
return self._session
def get_pipeline_status(self, dag_id: str) -> DagRun:
dag_run_list: DagRun = (
self.session.query(DagRun)
def get_pipeline_status(self, dag_id: str) -> List[DagRun]:
dag_run_list = (
self.session.query(
*[c for c in DagRun.__table__.c if c.name not in IGNORE_DAG_RUN_COL]
)
.filter(DagRun.dag_id == dag_id)
.order_by(DagRun.execution_date.desc())
.limit(self.config.serviceConnection.__root__.config.numberOfStatus)
.all()
)
return dag_run_list
dag_run_dict = [dict(elem) for elem in dag_run_list]
# Build DagRun manually to not fall into new/old columns from
# different Airflow versions
return [
DagRun(
dag_id=elem.get("dag_id"),
run_id=elem.get("run_id"),
queued_at=elem.get("queued_at"),
execution_date=elem.get("execution_date"),
start_date=elem.get("start_date"),
state=elem.get("state"),
)
for elem in dag_run_dict
]
def get_task_instances(self, dag_id: str, run_id: str) -> List[OMTaskInstance]:
"""
We are building our own scoped TaskInstance
class to only focus on core properties required
by the metadata ingestion.
This makes the versioning more flexible on which Airflow
sources we support.
"""
task_instance_list = (
self.session.query(
TaskInstance.task_id,
TaskInstance.state,
TaskInstance.start_date,
TaskInstance.end_date,
TaskInstance.run_id,
)
.filter(TaskInstance.dag_id == dag_id, TaskInstance.run_id == run_id)
.all()
)
task_instance_dict = [dict(elem) for elem in task_instance_list]
return [
OMTaskInstance(
task_id=elem.get("task_id"),
state=elem.get("state"),
start_date=elem.get("start_date"),
end_date=elem.get("end_date"),
)
for elem in task_instance_dict
]
def yield_pipeline_status(
self, pipeline_details: SerializedDAG
) -> OMetaPipelineStatus:
dag_run_list = self.get_pipeline_status(pipeline_details.dag_id)
for dag in dag_run_list:
if isinstance(dag.task_instances, Iterable):
tasks = dag.task_instances
else:
tasks = [dag.task_instances]
for dag_run in dag_run_list:
tasks = self.get_task_instances(
dag_id=dag_run.dag_id,
run_id=dag_run.run_id,
)
task_statuses = [
TaskStatus(
@ -121,30 +204,38 @@ class AirflowSource(PipelineServiceSource):
endTime=datetime_to_ts(
task.end_date
), # Might be None for running tasks
logLink=task.log_url,
)
) # Log link might not be present in all Airflow versions
for task in tasks
]
pipeline_status = PipelineStatus(
taskStatus=task_statuses,
executionStatus=STATUS_MAP.get(dag._state, StatusType.Pending.value),
timestamp=dag.execution_date.timestamp(),
executionStatus=STATUS_MAP.get(dag_run.state, StatusType.Pending.value),
timestamp=dag_run.execution_date.timestamp(),
)
yield OMetaPipelineStatus(
pipeline_fqn=self.context.pipeline.fullyQualifiedName.__root__,
pipeline_status=pipeline_status,
)
def get_pipelines_list(self) -> Iterable[SerializedDagModel]:
def get_pipelines_list(self) -> Iterable[OMSerializedDagDetails]:
"""
List all DAGs from the metadata db.
We are using the SerializedDagModel as it helps
us retrieve all the task and inlets/outlets information
"""
for serialized_dag in self.session.query(SerializedDagModel).all():
yield serialized_dag
for serialized_dag in self.session.query(
SerializedDagModel.dag_id,
SerializedDagModel._data,
SerializedDagModel.fileloc,
).all():
yield OMSerializedDagDetails(
dag_id=serialized_dag[0],
data=serialized_dag[1],
fileloc=serialized_dag[2],
)
def get_pipeline_name(self, pipeline_details: SerializedDAG) -> str:
"""
@ -173,15 +264,29 @@ class AirflowSource(PipelineServiceSource):
for task in cast(Iterable[BaseOperator], dag.tasks)
]
@staticmethod
def _build_dag(data: Any) -> SerializedDAG:
"""
Use the queried data to fetch the DAG
:param data: from SQA query
:return: SerializedDAG
"""
if isinstance(data, dict):
return SerializedDAG.from_dict(data)
return SerializedDAG.from_json(data)
def yield_pipeline(
self, pipeline_details: SerializedDagModel
self, pipeline_details: OMSerializedDagDetails
) -> Iterable[CreatePipelineRequest]:
"""
Convert a DAG into a Pipeline Entity
:param serialized_dag: SerializedDAG from airflow metadata DB
:param pipeline_details: SerializedDAG from airflow metadata DB
:return: Create Pipeline request with tasks
"""
dag: SerializedDAG = pipeline_details.dag
dag: SerializedDAG = self._build_dag(pipeline_details.data)
yield CreatePipelineRequest(
name=pipeline_details.dag_id,
description=dag.description,
@ -242,14 +347,14 @@ class AirflowSource(PipelineServiceSource):
return None
def yield_pipeline_lineage_details(
self, pipeline_details: SerializedDagModel
self, pipeline_details: OMSerializedDagDetails
) -> Optional[Iterable[AddLineageRequest]]:
"""
Parse xlets and add lineage between Pipelines and Tables
:param pipeline_details: SerializedDAG from airflow metadata DB
:return: Lineage from inlets and outlets
"""
dag: SerializedDAG = pipeline_details.dag
dag: SerializedDAG = self._build_dag(pipeline_details.data)
for task in dag.tasks:
for table_fqn in self.get_inlets(task) or []:
@ -294,6 +399,3 @@ class AirflowSource(PipelineServiceSource):
def close(self):
self.session.close()
def test_connection(self) -> None:
test_connection(self.engine)

View File

@ -194,7 +194,7 @@ class PipelineServiceSource(TopologyRunnerMixin, Source, ABC):
def close(self):
"""
Method to implement any required logic after the ingesion process is completed
Method to implement any required logic after the ingestion process is completed
"""
def get_services(self) -> Iterable[WorkflowSource]:
@ -223,5 +223,5 @@ class PipelineServiceSource(TopologyRunnerMixin, Source, ABC):
def prepare(self):
"""
Method to implement any required logic before starting the ingesion process
Method to implement any required logic before starting the ingestion process
"""

View File

@ -67,7 +67,7 @@ class TestWorkflowParse(TestCase):
def test_get_service_type(self):
"""
Test that we can get the service type of a source
Test that we can get the service type of source
"""
database_service = get_service_type("Mysql")