mirror of
https://github.com/OpenSPG/openspg.git
synced 2025-09-27 09:27:40 +00:00
fix
This commit is contained in:
parent
950b8327a5
commit
84a2266e52
@ -422,8 +422,9 @@ class SubGraphMapping(Mapping):
|
|||||||
operator_config=predicting_strategy.to_rest()
|
operator_config=predicting_strategy.to_rest()
|
||||||
)
|
)
|
||||||
elif not predicting_strategy:
|
elif not predicting_strategy:
|
||||||
if (self.spg_type_name, predicate_name) in PredictOp.bind_schemas:
|
object_type_name = spg_type.properties[predicate_name].object_type_name
|
||||||
op_name = PredictOp.bind_schemas[(self.spg_type_name, predicate_name)]
|
if (self.spg_type_name, predicate_name, object_type_name) in PredictOp.bind_schemas:
|
||||||
|
op_name = PredictOp.bind_schemas[(self.spg_type_name, predicate_name, object_type_name)]
|
||||||
op = PredictOp.by_name(op_name)()
|
op = PredictOp.by_name(op_name)()
|
||||||
strategy_config = rest.OperatorPredictingConfig(
|
strategy_config = rest.OperatorPredictingConfig(
|
||||||
operator_config=op.to_rest()
|
operator_config=op.to_rest()
|
||||||
@ -434,7 +435,7 @@ class SubGraphMapping(Mapping):
|
|||||||
raise ValueError(f"Invalid predicting_strategy [{predicting_strategy}].")
|
raise ValueError(f"Invalid predicting_strategy [{predicting_strategy}].")
|
||||||
if strategy_config:
|
if strategy_config:
|
||||||
predicting_configs.append(
|
predicting_configs.append(
|
||||||
strategy_config
|
rest.PredictingConfig(target=predicate_name,predicting_config=strategy_config)
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(self.subject_fusing_strategy, FuseOp):
|
if isinstance(self.subject_fusing_strategy, FuseOp):
|
||||||
|
@ -1,21 +1,20 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
from knext.client.model.builder_job import BuilderJob
|
from knext.client.model.builder_job import BuilderJob
|
||||||
from knext.api.component import CSVReader, SPGTypeMapping, KGWriter
|
from knext.api.component import (
|
||||||
from knext.component.builder import LLMBasedExtractor, SubGraphMapping
|
CSVReader,
|
||||||
|
LLMBasedExtractor,
|
||||||
|
KGWriter,
|
||||||
|
SubGraphMapping
|
||||||
|
)
|
||||||
from nn4k.invoker import LLMInvoker
|
from nn4k.invoker import LLMInvoker
|
||||||
|
|
||||||
try:
|
|
||||||
from schema.financial_schema_helper import Financial
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class StateAndIndicator(BuilderJob):
|
class StateAndIndicator(BuilderJob):
|
||||||
|
|
||||||
def build(self):
|
def build(self):
|
||||||
source = CSVReader(
|
source = CSVReader(
|
||||||
local_path="/Users/jier/openspg/python/knext/knext/examples/financial/builder/job/data/document.csv",
|
local_path="builder/job/data/document.csv",
|
||||||
columns=["input"],
|
columns=["input"],
|
||||||
start_row=2
|
start_row=2
|
||||||
)
|
)
|
||||||
@ -23,7 +22,7 @@ class StateAndIndicator(BuilderJob):
|
|||||||
from knext.examples.financial.builder.operator.IndicatorNER import IndicatorNER
|
from knext.examples.financial.builder.operator.IndicatorNER import IndicatorNER
|
||||||
from knext.examples.financial.builder.operator.IndicatorREL import IndicatorREL
|
from knext.examples.financial.builder.operator.IndicatorREL import IndicatorREL
|
||||||
from knext.examples.financial.builder.operator.IndicatorLOGIC import IndicatorLOGIC
|
from knext.examples.financial.builder.operator.IndicatorLOGIC import IndicatorLOGIC
|
||||||
extract = LLMBasedExtractor(llm=LLMInvoker.from_config("/Users/jier/openspg/python/knext/knext/examples/financial/builder/model/openai_infer.json"),
|
extract = LLMBasedExtractor(llm=LLMInvoker.from_config("builder/model/openai_infer.json"),
|
||||||
prompt_ops=[IndicatorNER(), IndicatorREL(), IndicatorLOGIC()]
|
prompt_ops=[IndicatorNER(), IndicatorREL(), IndicatorLOGIC()]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -36,7 +35,6 @@ class StateAndIndicator(BuilderJob):
|
|||||||
indicator_mapping = SubGraphMapping(spg_type_name="Financial.Indicator")\
|
indicator_mapping = SubGraphMapping(spg_type_name="Financial.Indicator")\
|
||||||
.add_mapping_field("id", "id") \
|
.add_mapping_field("id", "id") \
|
||||||
.add_mapping_field("name", "name")
|
.add_mapping_field("name", "name")
|
||||||
# .add_predicting_field("isA")
|
|
||||||
|
|
||||||
sink = KGWriter()
|
sink = KGWriter()
|
||||||
|
|
||||||
|
@ -134,7 +134,7 @@ class PredictOp(BaseOp, ABC):
|
|||||||
|
|
||||||
bind_to: Tuple[SPGTypeName, PropertyName, SPGTypeName]
|
bind_to: Tuple[SPGTypeName, PropertyName, SPGTypeName]
|
||||||
|
|
||||||
bind_schemas: Dict[Tuple[SPGTypeName, PropertyName], str] = {}
|
bind_schemas: Dict[Tuple[SPGTypeName, PropertyName, SPGTypeName], str] = {}
|
||||||
|
|
||||||
def invoke(self, subject_record: SPGRecord) -> List[SPGRecord]:
|
def invoke(self, subject_record: SPGRecord) -> List[SPGRecord]:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -143,9 +143,7 @@ class PredictOp(BaseOp, ABC):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _pre_process(*inputs):
|
def _pre_process(*inputs):
|
||||||
return [
|
return SPGRecord.from_dict(inputs[0]),
|
||||||
SPGRecord.from_dict(input) for input in inputs[0]
|
|
||||||
],
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _post_process(output) -> Dict[str, Any]:
|
def _post_process(output) -> Dict[str, Any]:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user