111 lines
3.3 KiB
Python
Raw Normal View History

2023-12-06 17:26:39 +08:00
from abc import ABC
2023-12-18 13:46:44 +08:00
from typing import List, Dict, Any
2023-12-06 17:26:39 +08:00
from knext.operator.base import BaseOp
from knext.operator.eval_result import EvalResult
from knext.operator.spg_record import SPGRecord
class ExtractOp(BaseOp, ABC):
"""Base class for all knowledge extract operators."""
def __init__(self, params: Dict[str, str] = None):
super().__init__(params)
2023-12-18 13:46:44 +08:00
def eval(self, record: Dict[str, str]) -> List[SPGRecord]:
2023-12-06 17:26:39 +08:00
raise NotImplementedError(
f"{self.__class__.__name__} need to implement `eval` method."
)
@staticmethod
def _pre_process(*inputs):
2023-12-18 13:46:44 +08:00
return inputs[0],
2023-12-06 17:26:39 +08:00
@staticmethod
def _post_process(output) -> Dict[str, Any]:
if isinstance(output, EvalResult):
return output.to_dict()
if isinstance(output, tuple):
return EvalResult[List[SPGRecord]](*output[:3]).to_dict()
else:
return EvalResult[List[SPGRecord]](output).to_dict()
class LinkOp(BaseOp, ABC):
"""Base class for all entity link operators."""
def __init__(self, params: Dict[str, str] = None):
super().__init__(params)
def eval(self, property: str, record: SPGRecord) -> List[SPGRecord]:
raise NotImplementedError(
f"{self.__class__.__name__} need to implement `eval` method."
)
@staticmethod
def _pre_process(*inputs):
return inputs[0], SPGRecord.from_dict(inputs[1])
@staticmethod
def _post_process(output) -> Dict[str, Any]:
if isinstance(output, EvalResult):
return output.to_dict()
if isinstance(output, tuple):
return EvalResult[List[SPGRecord]](*output[:3]).to_dict()
else:
return EvalResult[List[SPGRecord]](output).to_dict()
class FuseOp(BaseOp, ABC):
"""Base class for all entity fuse operators."""
def __init__(self, params: Dict[str, str] = None):
super().__init__(params)
2023-12-18 13:46:44 +08:00
def eval(self, records: List[SPGRecord]) -> List[SPGRecord]:
2023-12-06 17:26:39 +08:00
raise NotImplementedError(
f"{self.__class__.__name__} need to implement `eval` method."
)
@staticmethod
def _pre_process(*inputs):
2023-12-18 13:46:44 +08:00
return [
SPGRecord.from_dict(input) for input in inputs[0]
2023-12-06 17:26:39 +08:00
]
@staticmethod
def _post_process(output) -> Dict[str, Any]:
if isinstance(output, EvalResult):
return output.to_dict()
if isinstance(output, tuple):
return EvalResult[List[SPGRecord]](*output[:3]).to_dict()
else:
return EvalResult[List[SPGRecord]](output).to_dict()
2023-12-15 17:33:54 +08:00
class PromptOp(BaseOp, ABC):
2023-12-06 17:26:39 +08:00
"""Base class for all prompt operators."""
template: str
2023-12-18 13:46:44 +08:00
def __init__(self, **kwargs):
super().__init__()
def build_prompt(self, variables: Dict[str, str]) -> str:
2023-12-06 17:26:39 +08:00
raise NotImplementedError(
f"{self.__class__.__name__} need to implement `build_prompt` method."
)
2023-12-18 13:46:44 +08:00
def parse_response(self, response: str) -> List[SPGRecord]:
2023-12-06 17:26:39 +08:00
raise NotImplementedError(
f"{self.__class__.__name__} need to implement `parse_response` method."
)
2023-12-18 13:46:44 +08:00
def build_variables(self, variables: Dict[str, str], response: str) -> List[Dict[str, str]]:
2023-12-15 17:33:54 +08:00
raise NotImplementedError(
2023-12-18 13:46:44 +08:00
f"{self.__class__.__name__} need to implement `build_variables` method."
2023-12-15 17:33:54 +08:00
)
2023-12-18 13:46:44 +08:00
def eval(self, *args):
pass