This commit is contained in:
Qu 2023-12-24 17:54:27 +08:00
parent 950b8327a5
commit 84a2266e52
3 changed files with 14 additions and 17 deletions

View File

@ -422,8 +422,9 @@ class SubGraphMapping(Mapping):
operator_config=predicting_strategy.to_rest() operator_config=predicting_strategy.to_rest()
) )
elif not predicting_strategy: elif not predicting_strategy:
if (self.spg_type_name, predicate_name) in PredictOp.bind_schemas: object_type_name = spg_type.properties[predicate_name].object_type_name
op_name = PredictOp.bind_schemas[(self.spg_type_name, predicate_name)] if (self.spg_type_name, predicate_name, object_type_name) in PredictOp.bind_schemas:
op_name = PredictOp.bind_schemas[(self.spg_type_name, predicate_name, object_type_name)]
op = PredictOp.by_name(op_name)() op = PredictOp.by_name(op_name)()
strategy_config = rest.OperatorPredictingConfig( strategy_config = rest.OperatorPredictingConfig(
operator_config=op.to_rest() operator_config=op.to_rest()
@ -434,7 +435,7 @@ class SubGraphMapping(Mapping):
raise ValueError(f"Invalid predicting_strategy [{predicting_strategy}].") raise ValueError(f"Invalid predicting_strategy [{predicting_strategy}].")
if strategy_config: if strategy_config:
predicting_configs.append( predicting_configs.append(
strategy_config rest.PredictingConfig(target=predicate_name,predicting_config=strategy_config)
) )
if isinstance(self.subject_fusing_strategy, FuseOp): if isinstance(self.subject_fusing_strategy, FuseOp):

View File

@ -1,21 +1,20 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from knext.client.model.builder_job import BuilderJob from knext.client.model.builder_job import BuilderJob
from knext.api.component import CSVReader, SPGTypeMapping, KGWriter from knext.api.component import (
from knext.component.builder import LLMBasedExtractor, SubGraphMapping CSVReader,
LLMBasedExtractor,
KGWriter,
SubGraphMapping
)
from nn4k.invoker import LLMInvoker from nn4k.invoker import LLMInvoker
try:
from schema.financial_schema_helper import Financial
except:
pass
class StateAndIndicator(BuilderJob): class StateAndIndicator(BuilderJob):
def build(self): def build(self):
source = CSVReader( source = CSVReader(
local_path="/Users/jier/openspg/python/knext/knext/examples/financial/builder/job/data/document.csv", local_path="builder/job/data/document.csv",
columns=["input"], columns=["input"],
start_row=2 start_row=2
) )
@ -23,7 +22,7 @@ class StateAndIndicator(BuilderJob):
from knext.examples.financial.builder.operator.IndicatorNER import IndicatorNER from knext.examples.financial.builder.operator.IndicatorNER import IndicatorNER
from knext.examples.financial.builder.operator.IndicatorREL import IndicatorREL from knext.examples.financial.builder.operator.IndicatorREL import IndicatorREL
from knext.examples.financial.builder.operator.IndicatorLOGIC import IndicatorLOGIC from knext.examples.financial.builder.operator.IndicatorLOGIC import IndicatorLOGIC
extract = LLMBasedExtractor(llm=LLMInvoker.from_config("/Users/jier/openspg/python/knext/knext/examples/financial/builder/model/openai_infer.json"), extract = LLMBasedExtractor(llm=LLMInvoker.from_config("builder/model/openai_infer.json"),
prompt_ops=[IndicatorNER(), IndicatorREL(), IndicatorLOGIC()] prompt_ops=[IndicatorNER(), IndicatorREL(), IndicatorLOGIC()]
) )
@ -36,7 +35,6 @@ class StateAndIndicator(BuilderJob):
indicator_mapping = SubGraphMapping(spg_type_name="Financial.Indicator")\ indicator_mapping = SubGraphMapping(spg_type_name="Financial.Indicator")\
.add_mapping_field("id", "id") \ .add_mapping_field("id", "id") \
.add_mapping_field("name", "name") .add_mapping_field("name", "name")
# .add_predicting_field("isA")
sink = KGWriter() sink = KGWriter()

View File

@ -134,7 +134,7 @@ class PredictOp(BaseOp, ABC):
bind_to: Tuple[SPGTypeName, PropertyName, SPGTypeName] bind_to: Tuple[SPGTypeName, PropertyName, SPGTypeName]
bind_schemas: Dict[Tuple[SPGTypeName, PropertyName], str] = {} bind_schemas: Dict[Tuple[SPGTypeName, PropertyName, SPGTypeName], str] = {}
def invoke(self, subject_record: SPGRecord) -> List[SPGRecord]: def invoke(self, subject_record: SPGRecord) -> List[SPGRecord]:
raise NotImplementedError( raise NotImplementedError(
@ -143,9 +143,7 @@ class PredictOp(BaseOp, ABC):
@staticmethod @staticmethod
def _pre_process(*inputs): def _pre_process(*inputs):
return [ return SPGRecord.from_dict(inputs[0]),
SPGRecord.from_dict(input) for input in inputs[0]
],
@staticmethod @staticmethod
def _post_process(output) -> Dict[str, Any]: def _post_process(output) -> Dict[str, Any]: