# -*- coding: utf-8 -*- # Copyright 2023 Ant Group CO., Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. from abc import ABC from enum import Enum from typing import List, Dict, Any, Type, Sequence, Union from knext import rest from knext.common.restable import RESTable from knext.common.runnable import Runnable, Other, Input, Output from knext.common.schema_helper import SPGTypeHelper from knext.operator.eval_result import EvalResult from knext.operator.spg_record import SPGRecord class OperatorTypeEnum(str, Enum): EntityLinkOp = "ENTITY_LINK" EntityFuseOp = "ENTITY_FUSE" PropertyNormalizeOp = "PROPERTY_NORMALIZE" KnowledgeExtractOp = "KNOWLEDGE_EXTRACT" class BaseOp(ABC): """Base class for all user-defined operator functions. The execution logic of the operator needs to be implemented in the `eval` method. """ """Operator name.""" name: str """Operator description.""" desc: str = "" """SPG type the operator bind to.""" bind_to: Union[str, SPGTypeHelper] = None params: Dict[str, str] = None _registry = {} _local_path: str _type: str _version: int def __init__(self, params: Dict[str, str] = None): self.params = params def eval(self, *args): """Used to implement operator execution logic.""" raise NotImplementedError( f"{self.__class__.__name__} need to implement `eval` method." ) def _handle(self, *inputs) -> Dict[str, Any]: """Only available for Builder in OpenSPG to call through the pemja tool.""" pre_input = self._pre_process(*inputs) output = self.eval(*pre_input) post_output = self._post_process(output) return post_output @staticmethod def _pre_process(*inputs): """Convert data structures in building job into structures in operator before `eval` method.""" pass @staticmethod def _post_process(output: EvalResult) -> Dict[str, Any]: """Convert result structures in operator into structures in building job after `eval` method.""" pass @classmethod def register(cls, name: str, local_path: str): """ Register a class as subclass of BaseOp with name and local_path. After registration, the subclass object can be inspected by `BaseOp.by_name(op_name)`. """ def add_subclass_to_registry(subclass: Type["BaseOp"]): subclass.name = name subclass._local_path = local_path subclass._type = OperatorTypeEnum[subclass.__base__.__name__] if name in cls._registry: raise ValueError( f"Operator [{name}] conflict in {subclass._local_path} and {cls.by_name(name)._local_path}." ) cls._registry[name] = subclass return subclass return add_subclass_to_registry @classmethod def by_name(cls, name: str): """Reflection from op name to subclass object of BaseOp.""" if name in cls._registry: subclass = cls._registry[name] return subclass else: raise ValueError(f"{name} is not a registered name for {cls.__name__}. ") def to_rest(self): return rest.OperatorConfig(file_path=self._local_path, module_path="", class_name=self.name, method="_handle", params=self.params, ) def from_rest(cls, node): pass