2023-12-06 17:26:39 +08:00
|
|
|
from abc import ABC
|
2023-12-11 10:44:37 +08:00
|
|
|
from typing import List, Dict, Any, Union
|
2023-12-06 17:26:39 +08:00
|
|
|
|
|
|
|
from knext.operator.base import BaseOp
|
|
|
|
from knext.operator.eval_result import EvalResult
|
|
|
|
from knext.operator.spg_record import SPGRecord
|
|
|
|
|
|
|
|
|
|
|
|
class ExtractOp(BaseOp, ABC):
|
|
|
|
"""Base class for all knowledge extract operators."""
|
|
|
|
|
|
|
|
def __init__(self, params: Dict[str, str] = None):
|
|
|
|
super().__init__(params)
|
|
|
|
|
2023-12-15 17:33:54 +08:00
|
|
|
def invoke(self, record: Dict[str, str]) -> List[SPGRecord]:
|
2023-12-06 17:26:39 +08:00
|
|
|
raise NotImplementedError(
|
|
|
|
f"{self.__class__.__name__} need to implement `eval` method."
|
|
|
|
)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _pre_process(*inputs):
|
2023-12-15 17:33:54 +08:00
|
|
|
return SPGRecord.from_dict(inputs[0]).properties,
|
2023-12-06 17:26:39 +08:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _post_process(output) -> Dict[str, Any]:
|
|
|
|
if isinstance(output, EvalResult):
|
|
|
|
return output.to_dict()
|
|
|
|
if isinstance(output, tuple):
|
|
|
|
return EvalResult[List[SPGRecord]](*output[:3]).to_dict()
|
|
|
|
else:
|
|
|
|
return EvalResult[List[SPGRecord]](output).to_dict()
|
|
|
|
|
|
|
|
|
|
|
|
class LinkOp(BaseOp, ABC):
|
|
|
|
"""Base class for all entity link operators."""
|
|
|
|
|
|
|
|
def __init__(self, params: Dict[str, str] = None):
|
|
|
|
super().__init__(params)
|
|
|
|
|
|
|
|
def eval(self, property: str, record: SPGRecord) -> List[SPGRecord]:
|
|
|
|
raise NotImplementedError(
|
|
|
|
f"{self.__class__.__name__} need to implement `eval` method."
|
|
|
|
)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _pre_process(*inputs):
|
|
|
|
return inputs[0], SPGRecord.from_dict(inputs[1])
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _post_process(output) -> Dict[str, Any]:
|
|
|
|
if isinstance(output, EvalResult):
|
|
|
|
return output.to_dict()
|
|
|
|
if isinstance(output, tuple):
|
|
|
|
return EvalResult[List[SPGRecord]](*output[:3]).to_dict()
|
|
|
|
else:
|
|
|
|
return EvalResult[List[SPGRecord]](output).to_dict()
|
|
|
|
|
|
|
|
|
|
|
|
class FuseOp(BaseOp, ABC):
|
|
|
|
"""Base class for all entity fuse operators."""
|
|
|
|
|
|
|
|
def __init__(self, params: Dict[str, str] = None):
|
|
|
|
super().__init__(params)
|
|
|
|
|
|
|
|
def eval(
|
2023-12-15 17:33:54 +08:00
|
|
|
self, source_record: SPGRecord, target_records: List[SPGRecord]
|
2023-12-06 17:26:39 +08:00
|
|
|
) -> List[SPGRecord]:
|
|
|
|
raise NotImplementedError(
|
|
|
|
f"{self.__class__.__name__} need to implement `eval` method."
|
|
|
|
)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _pre_process(*inputs):
|
|
|
|
return SPGRecord.from_dict(inputs[0]), [
|
|
|
|
SPGRecord.from_dict(input) for input in inputs[1]
|
|
|
|
]
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _post_process(output) -> Dict[str, Any]:
|
|
|
|
if isinstance(output, EvalResult):
|
|
|
|
return output.to_dict()
|
|
|
|
if isinstance(output, tuple):
|
|
|
|
return EvalResult[List[SPGRecord]](*output[:3]).to_dict()
|
|
|
|
else:
|
|
|
|
return EvalResult[List[SPGRecord]](output).to_dict()
|
|
|
|
|
|
|
|
|
2023-12-15 17:33:54 +08:00
|
|
|
class PromptOp(BaseOp, ABC):
|
2023-12-06 17:26:39 +08:00
|
|
|
"""Base class for all prompt operators."""
|
|
|
|
|
|
|
|
template: str
|
|
|
|
|
2023-12-15 17:33:54 +08:00
|
|
|
def build_prompt(self, params: Dict[str, str]) -> str:
|
2023-12-06 17:26:39 +08:00
|
|
|
raise NotImplementedError(
|
|
|
|
f"{self.__class__.__name__} need to implement `build_prompt` method."
|
|
|
|
)
|
|
|
|
|
2023-12-11 15:34:02 +08:00
|
|
|
def parse_response(
|
|
|
|
self, response: str
|
2023-12-15 17:33:54 +08:00
|
|
|
) -> List[SPGRecord]:
|
2023-12-06 17:26:39 +08:00
|
|
|
raise NotImplementedError(
|
|
|
|
f"{self.__class__.__name__} need to implement `parse_response` method."
|
|
|
|
)
|
|
|
|
|
2023-12-15 17:33:54 +08:00
|
|
|
def build_params(self, record: Dict[str, str], response: str) -> List[Dict[str, str]]:
|
|
|
|
raise NotImplementedError(
|
|
|
|
f"{self.__class__.__name__} need to implement `build_placeholder` method."
|
|
|
|
)
|