2023-12-06 17:26:39 +08:00
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
# Copyright 2023 Ant Group CO., Ltd.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
|
|
|
# in compliance with the License. You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
|
|
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
|
|
|
# or implied.
|
|
|
|
|
|
|
|
from abc import ABC
|
|
|
|
from enum import Enum
|
2023-12-15 17:33:54 +08:00
|
|
|
from typing import List, Dict, Any, Type, Sequence, Union
|
2023-12-06 17:26:39 +08:00
|
|
|
|
2023-12-15 17:33:54 +08:00
|
|
|
from knext import rest
|
|
|
|
|
|
|
|
from knext.common.restable import RESTable
|
|
|
|
from knext.common.runnable import Runnable, Other, Input, Output
|
|
|
|
from knext.common.schema_helper import SPGTypeHelper
|
2023-12-06 17:26:39 +08:00
|
|
|
from knext.operator.eval_result import EvalResult
|
2023-12-08 11:25:26 +08:00
|
|
|
from knext.operator.spg_record import SPGRecord
|
2023-12-06 17:26:39 +08:00
|
|
|
|
|
|
|
|
|
|
|
class OperatorTypeEnum(str, Enum):
|
|
|
|
EntityLinkOp = "ENTITY_LINK"
|
|
|
|
EntityFuseOp = "ENTITY_FUSE"
|
|
|
|
PropertyNormalizeOp = "PROPERTY_NORMALIZE"
|
|
|
|
KnowledgeExtractOp = "KNOWLEDGE_EXTRACT"
|
|
|
|
|
|
|
|
|
|
|
|
class BaseOp(ABC):
|
|
|
|
"""Base class for all user-defined operator functions.
|
|
|
|
|
|
|
|
The execution logic of the operator needs to be implemented in the `eval` method.
|
|
|
|
"""
|
|
|
|
|
2023-12-15 17:33:54 +08:00
|
|
|
"""Operator name."""
|
2023-12-06 17:26:39 +08:00
|
|
|
name: str
|
2023-12-15 17:33:54 +08:00
|
|
|
"""Operator description."""
|
2023-12-06 17:26:39 +08:00
|
|
|
desc: str = ""
|
2023-12-15 17:33:54 +08:00
|
|
|
"""SPG type the operator bind to."""
|
|
|
|
bind_to: Union[str, SPGTypeHelper] = None
|
|
|
|
|
|
|
|
params: Dict[str, str] = None
|
2023-12-06 17:26:39 +08:00
|
|
|
|
|
|
|
_registry = {}
|
|
|
|
_local_path: str
|
|
|
|
_type: str
|
|
|
|
_version: int
|
|
|
|
|
|
|
|
def __init__(self, params: Dict[str, str] = None):
|
|
|
|
self.params = params
|
|
|
|
|
|
|
|
def eval(self, *args):
|
|
|
|
"""Used to implement operator execution logic."""
|
|
|
|
raise NotImplementedError(
|
|
|
|
f"{self.__class__.__name__} need to implement `eval` method."
|
|
|
|
)
|
|
|
|
|
2023-12-15 17:33:54 +08:00
|
|
|
def _handle(self, *inputs) -> Dict[str, Any]:
|
|
|
|
"""Only available for Builder in OpenSPG to call through the pemja tool."""
|
2023-12-06 17:26:39 +08:00
|
|
|
pre_input = self._pre_process(*inputs)
|
|
|
|
output = self.eval(*pre_input)
|
|
|
|
post_output = self._post_process(output)
|
|
|
|
return post_output
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _pre_process(*inputs):
|
|
|
|
"""Convert data structures in building job into structures in operator before `eval` method."""
|
|
|
|
pass
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _post_process(output: EvalResult) -> Dict[str, Any]:
|
|
|
|
"""Convert result structures in operator into structures in building job after `eval` method."""
|
|
|
|
pass
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def register(cls, name: str, local_path: str):
|
|
|
|
"""
|
|
|
|
Register a class as subclass of BaseOp with name and local_path.
|
|
|
|
After registration, the subclass object can be inspected by `BaseOp.by_name(op_name)`.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def add_subclass_to_registry(subclass: Type["BaseOp"]):
|
|
|
|
subclass.name = name
|
|
|
|
subclass._local_path = local_path
|
|
|
|
subclass._type = OperatorTypeEnum[subclass.__base__.__name__]
|
|
|
|
if name in cls._registry:
|
|
|
|
raise ValueError(
|
|
|
|
f"Operator [{name}] conflict in {subclass._local_path} and {cls.by_name(name)._local_path}."
|
|
|
|
)
|
|
|
|
cls._registry[name] = subclass
|
|
|
|
return subclass
|
|
|
|
|
|
|
|
return add_subclass_to_registry
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def by_name(cls, name: str):
|
|
|
|
"""Reflection from op name to subclass object of BaseOp."""
|
|
|
|
if name in cls._registry:
|
|
|
|
subclass = cls._registry[name]
|
|
|
|
return subclass
|
|
|
|
else:
|
|
|
|
raise ValueError(f"{name} is not a registered name for {cls.__name__}. ")
|
2023-12-15 17:33:54 +08:00
|
|
|
|
|
|
|
def to_rest(self):
|
|
|
|
return rest.OperatorConfig(file_path=self._local_path,
|
|
|
|
module_path="",
|
|
|
|
class_name=self.name,
|
|
|
|
method="_handle",
|
|
|
|
params=self.params,
|
|
|
|
)
|
|
|
|
|
|
|
|
def from_rest(cls, node):
|
|
|
|
pass
|