2023-12-22 14:11:31 +08:00

419 lines
14 KiB
Python

# -*- coding: utf-8 -*-
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from enum import Enum
from typing import Union, Dict, List, Tuple, Sequence, Optional
from knext import rest
from knext.common.runnable import Input, Output
from knext.common.schema_helper import SPGTypeHelper, PropertyHelper
from knext.component.builder.base import Mapping
from knext.operator.op import LinkOp, FuseOp, PredictOp
from knext.operator.spg_record import SPGRecord
class MappingTypeEnum(str, Enum):
SPGType = "SPG_TYPE"
Relation = "RELATION"
class LinkStrategyEnum(str, Enum):
IDEquals = "ID_EQUALS"
SPG_TYPE_BASE_FIELDS = ["id"]
RELATION_BASE_FIELDS = ["src_id", "dst_id"]
class SPGTypeMapping(Mapping):
"""A Process Component that mapping data to entity/event/concept type.
Args:
spg_type_name: The SPG type name import from SPGTypeHelper.
Examples:
mapping = SPGTypeMapping(
spg_type_name=DEFAULT.App
).add_field("id", DEFAULT.App.id) \
.add_field("id", DEFAULT.App.name) \
.add_field("riskMark", DEFAULT.App.riskMark) \
.add_field("useCert", DEFAULT.App.useCert)
"""
spg_type_name: Union[str, SPGTypeHelper]
mapping: Dict[str, str] = dict()
filters: List[Tuple[str, str]] = list()
subject_fuse_strategy: Optional[FuseOp] = None
object_link_strategies: Dict[str, Union[LinkStrategyEnum, LinkOp]] = dict()
predicate_predict_strategies: Dict[str, PredictOp] = dict()
@property
def input_types(self) -> Input:
return Dict[str, str]
@property
def output_types(self) -> Output:
return SPGRecord
@property
def input_keys(self):
return None
@property
def output_keys(self):
return self.output_fields
def set_fuse_strategy(self, fuse_strategy: FuseOp):
self.subject_fuse_strategy = fuse_strategy
return self
def add_mapping_field(
self,
source_field: str,
target_field: Union[str, PropertyHelper],
link_strategy: Union[LinkStrategyEnum, LinkOp] = None,
):
"""Adds a field mapping from source data to property of spg_type.
:param source_field: The source field to be mapped.
:param target_field: The target field to map the source field to.
:return: self
"""
self.mapping[target_field] = source_field
self.object_link_strategies[target_field] = link_strategy
return self
def add_predicting_field(
self,
field: Union[str, PropertyHelper],
predict_strategy: PredictOp = None,
):
self.predicate_predict_strategies[field] = predict_strategy
return self
def add_filter(self, column_name: str, column_value: str):
"""Adds data filtering rule.
Only the column that meets `column_name=column_value` will execute the mapping.
:param column_name: The column name to be filtered.
:param column_value: The column value to be filtered.
:return: self
"""
self.filters.append((column_name, column_value))
return self
def to_rest(self):
"""
Transforms `SPGTypeMapping` to REST model `SpgTypeMappingNodeConfig`.
"""
mapping_filters = [
rest.MappingFilter(column_name=name, column_value=value)
for name, value in self.filters
]
mapping_configs = []
for tgt_name, src_name in self.mapping.items():
link_strategy = self.object_link_strategies.get(tgt_name, None)
if isinstance(link_strategy, LinkOp):
strategy_config = rest.OperatorLinkingConfig(
operator_config=link_strategy.to_rest()
)
elif link_strategy == LinkStrategyEnum.IDEquals:
strategy_config = rest.IdEqualsLinkingConfig()
elif not link_strategy:
strategy_config = None
else:
raise ValueError(f"Invalid link_strategy [{link_strategy}].")
mapping_configs.append(
rest.MappingConfig(
source=src_name,
target=tgt_name,
strategy_config=strategy_config,
)
)
predicting_configs = []
for predict_strategy in self.predicate_predict_strategies:
if isinstance(predict_strategy, PredictOp):
strategy_config = rest.OperatorPredictingConfig(
operator_config=predict_strategy.to_rest()
)
elif not predict_strategy:
# if self.spg_type_name in PredictOp._bind_schemas:
# op_name = PredictOp._bind_schemas[self.spg_type_name]
# op = PredictOp.by_name(op_name)()
# strategy_config = op.to_rest()
# else:
strategy_config = None
else:
raise ValueError(f"Invalid predict_strategy [{predict_strategy}].")
predicting_configs.append(
strategy_config
)
if isinstance(self.subject_fuse_strategy, FuseOp):
fusing_config = rest.OperatorFusingConfig(
operator_config=self.fuse_strategy.to_rest()
)
elif not self.subject_fuse_strategy:
if self.spg_type_name in FuseOp._bind_schemas:
op_name = FuseOp._bind_schemas[self.spg_type_name]
op = FuseOp.by_name(op_name)()
fusing_config = op.to_rest()
else:
fusing_config = None
else:
raise ValueError(f"Invalid fuse_strategy [{self.subject_fuse_strategy}].")
config = rest.SpgTypeMappingNodeConfig(
spg_type=self.spg_type_name,
mapping_filters=mapping_filters,
mapping_configs=mapping_configs,
subject_fusing_config=fusing_config,
predicting_configs=predicting_configs
)
return rest.Node(**super().to_dict(), node_config=config)
def invoke(self, input: Input) -> Sequence[Output]:
pass
@classmethod
def from_rest(cls, node: rest.Node):
pass
def submit(self):
pass
class RelationMapping(Mapping):
"""A Process Component that mapping data to relation type.
Args:
subject_name: The subject name import from SPGTypeHelper.
predicate_name: The predicate name.
object_name: The object name import from SPGTypeHelper.
Examples:
mapping = RelationMappingComponent(
subject_name=DEFAULT.App,
predicate_name=DEFAULT.App.useCert,
object_name=DEFAULT.Cert,
).add_field("src_id", "srcId") \
.add_field("dst_id", "dstId")
"""
subject_name: Union[str, SPGTypeHelper]
predicate_name: Union[str, PropertyHelper]
object_name: Union[str, SPGTypeHelper]
mapping: Dict[str, str] = dict()
filters: List[Tuple[str, str]] = list()
def add_mapping_field(self, source_field: str, target_field: str):
"""Adds a field mapping from source data to property of spg_type.
:param source_field: The source field to be mapped.
:param target_field: The target field to map the source field to.
:return: self
"""
self.mapping[target_field] = source_field
return self
def add_filter(self, column_name: str, column_value: str):
"""Adds data filtering rule.
Only the column that meets `column_ame=column_value` will execute the mapping.
:param column_name: The column name to be filtered.
:param column_value: The column value to be filtered.
:return: self
"""
self.filters.append((column_name, column_value))
return self
def to_rest(self):
"""Transforms `RelationMappingComponent` to REST model `MappingNodeConfig`."""
mapping_filters = [
rest.MappingFilter(column_name=name, column_value=value)
for name, value in self.filters
]
mapping_configs = [
rest.MappingConfig(source=src_name, target=tgt_name)
for tgt_name, src_name in self.mapping.items()
]
config = rest.RelationMappingNodeConfig(
relation=f"{self.subject_name}_{self.predicate_name}_{self.object_name}",
mapping_filters=mapping_filters,
mapping_configs=mapping_configs,
)
return rest.Node(**super().to_dict(), node_config=config)
@classmethod
def from_rest(cls, node: rest.Node):
pass
def invoke(self, input: Input) -> Sequence[Output]:
pass
def submit(self):
pass
class SubGraphMapping(Mapping):
spg_type_name: Union[str, SPGTypeHelper]
mapping: Dict[str, str] = dict()
filters: List[Tuple[str, str]] = list()
subject_fuse_strategy: Optional[FuseOp] = None
predicate_predict_strategies: Dict[str, PredictOp] = dict()
object_fuse_strategies: Dict[str, FuseOp] = dict()
@property
def input_types(self) -> Input:
return Union[Dict[str, str], SPGRecord]
@property
def output_types(self) -> Output:
return SPGRecord
@property
def input_keys(self):
return None
@property
def output_keys(self):
return self.output_fields
def set_fuse_strategy(self, fuse_strategy: FuseOp):
self.subject_fuse_strategy = fuse_strategy
return self
def add_mapping_field(
self,
source_field: str,
target_field: Union[str, PropertyHelper],
fuse_strategy: FuseOp = None,
):
"""Adds a field mapping from source data to property of spg_type.
:param source_field: The source field to be mapped.
:param target_field: The target field to map the source field to.
:return: self
"""
self.mapping[target_field] = source_field
self.object_fuse_strategies[target_field] = fuse_strategy
return self
def add_predicting_field(
self,
target_field: Union[str, PropertyHelper],
predict_strategy: PredictOp = None,
):
self.predict_strategies[target_field] = predict_strategy
return self
def add_filter(self, column_name: str, column_value: str):
"""Adds data filtering rule.
Only the column that meets `column_name=column_value` will execute the mapping.
:param column_name: The column name to be filtered.
:param column_value: The column value to be filtered.
:return: self
"""
self.filters.append((column_name, column_value))
return self
def to_rest(self):
"""
Transforms `SubGraphMapping` to REST model `SpgTypeMappingNodeConfig`.
"""
mapping_filters = [
rest.MappingFilter(column_name=name, column_value=value)
for name, value in self.filters
]
mapping_configs = []
for tgt_name, src_name in self.mapping.items():
fuse_strategy = self.object_fuse_strategies.get(tgt_name, None)
if isinstance(fuse_strategy, FuseOp):
strategy_config = rest.OperatorFusingConfig(
operator_config=fuse_strategy.to_rest()
)
elif not self.subject_fuse_strategy:
strategy_config = rest.NewInstanceFusingConfig(
)
else:
raise ValueError(f"Invalid fuse_strategy [{fuse_strategy}].")
mapping_configs.append(
rest.MappingConfig(
source=src_name,
target=tgt_name,
strategy_config=strategy_config,
)
)
predicting_configs = []
for predict_strategy in self.predicate_predict_strategies:
if isinstance(predict_strategy, PredictOp):
strategy_config = rest.OperatorPredictingConfig(
operator_config=predict_strategy.to_rest()
)
elif not predict_strategy:
strategy_config = None
else:
raise ValueError(f"Invalid predict_strategy [{predict_strategy}].")
predicting_configs.append(
strategy_config
)
if isinstance(self.subject_fuse_strategy, FuseOp):
fusing_config = rest.OperatorFusingConfig(
operator_config=self.fuse_strategy.to_rest()
)
elif not self.subject_fuse_strategy:
fusing_config = rest.NewInstanceFusingConfig(
)
else:
raise ValueError(f"Invalid fuse_strategy [{self.subject_fuse_strategy}].")
config = rest.SubGraphMappingNodeConfig(
spg_type=self.spg_type_name,
mapping_filters=mapping_filters,
mapping_configs=mapping_configs,
subject_fusing_config=fusing_config,
predicting_configs=predicting_configs
)
return rest.Node(**super().to_dict(), node_config=config)
@classmethod
def from_rest(cls, node: rest.Node):
pass
def invoke(self, input: Input) -> Sequence[Output]:
pass
def submit(self):
pass