109 lines
3.7 KiB
Python
Raw Normal View History

2023-12-06 17:26:39 +08:00
# -*- coding: utf-8 -*-
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
2023-12-18 13:46:44 +08:00
import os
2023-12-06 17:26:39 +08:00
from abc import ABC
2023-12-18 13:46:44 +08:00
from typing import Dict, Any, Type, Union
2023-12-06 17:26:39 +08:00
2023-12-15 17:33:54 +08:00
from knext import rest
from knext.common.schema_helper import SPGTypeHelper
2023-12-06 17:26:39 +08:00
from knext.operator.eval_result import EvalResult
class BaseOp(ABC):
"""Base class for all user-defined operator functions.
The execution logic of the operator needs to be implemented in the `eval` method.
"""
2023-12-15 17:33:54 +08:00
"""Operator name."""
2023-12-06 17:26:39 +08:00
name: str
2023-12-15 17:33:54 +08:00
"""Operator description."""
2023-12-06 17:26:39 +08:00
desc: str = ""
2023-12-15 17:33:54 +08:00
"""SPG type the operator bind to."""
bind_to: Union[str, SPGTypeHelper] = None
params: Dict[str, str] = None
2023-12-06 17:26:39 +08:00
_registry = {}
_local_path: str
_type: str
_version: int
2023-12-18 13:46:44 +08:00
_has_registered: bool = False
2023-12-06 17:26:39 +08:00
def __init__(self, params: Dict[str, str] = None):
self.params = params
def eval(self, *args):
"""Used to implement operator execution logic."""
raise NotImplementedError(
f"{self.__class__.__name__} need to implement `eval` method."
)
2023-12-15 17:33:54 +08:00
def _handle(self, *inputs) -> Dict[str, Any]:
"""Only available for Builder in OpenSPG to call through the pemja tool."""
2023-12-06 17:26:39 +08:00
pre_input = self._pre_process(*inputs)
output = self.eval(*pre_input)
post_output = self._post_process(output)
return post_output
@staticmethod
def _pre_process(*inputs):
"""Convert data structures in building job into structures in operator before `eval` method."""
2023-12-18 13:46:44 +08:00
return inputs
2023-12-06 17:26:39 +08:00
@staticmethod
def _post_process(output: EvalResult) -> Dict[str, Any]:
"""Convert result structures in operator into structures in building job after `eval` method."""
2023-12-18 13:46:44 +08:00
return output.to_dict()
2023-12-06 17:26:39 +08:00
@classmethod
def register(cls, name: str, local_path: str):
"""
Register a class as subclass of BaseOp with name and local_path.
After registration, the subclass object can be inspected by `BaseOp.by_name(op_name)`.
"""
def add_subclass_to_registry(subclass: Type["BaseOp"]):
subclass.name = name
subclass._local_path = local_path
if name in cls._registry:
raise ValueError(
f"Operator [{name}] conflict in {subclass._local_path} and {cls.by_name(name)._local_path}."
)
cls._registry[name] = subclass
return subclass
return add_subclass_to_registry
@classmethod
def by_name(cls, name: str):
"""Reflection from op name to subclass object of BaseOp."""
if name in cls._registry:
subclass = cls._registry[name]
return subclass
else:
raise ValueError(f"{name} is not a registered name for {cls.__name__}. ")
2023-12-15 17:33:54 +08:00
def to_rest(self):
2023-12-18 13:46:44 +08:00
if not hasattr(self, "_local_path"):
import inspect
self._local_path = inspect.getfile(self.__class__)
if not hasattr(self, "name"):
self.name = self.__class__.__name__
2023-12-15 17:33:54 +08:00
return rest.OperatorConfig(file_path=self._local_path,
2023-12-18 13:46:44 +08:00
module_path=os.path.splitext(os.path.basename(self._local_path))[0],
2023-12-15 17:33:54 +08:00
class_name=self.name,
method="_handle",
params=self.params,
)