mirror of
https://github.com/OpenSPG/openspg.git
synced 2025-09-21 14:38:26 +00:00
fix
This commit is contained in:
parent
e0009c1ecb
commit
4e11000e2f
@ -0,0 +1,18 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
from knext.chain.builder_chain import BuilderChain
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BuilderChain",
|
||||
]
|
@ -9,15 +9,18 @@
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
from knext.operator.op import BaseOp, LinkOp, ExtractOp, FuseOp, PromptOp
|
||||
|
||||
from knext.operator.op import LinkOp, ExtractOp, FuseOp, PromptOp, PredictOp
|
||||
from knext.operator.spg_record import SPGRecord
|
||||
from knext.operator.builtin.auto_prompt import SPOPrompt
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseOp",
|
||||
"ExtractOp",
|
||||
"LinkOp",
|
||||
"FuseOp",
|
||||
"PromptOp",
|
||||
"PredictOp",
|
||||
"SPOPrompt",
|
||||
"SPGRecord",
|
||||
]
|
||||
|
@ -36,7 +36,7 @@ class BuilderClient(Client):
|
||||
def submit(self, job_name: str):
|
||||
"""Submit an asynchronous builder job to the server by name."""
|
||||
job = BuilderJob.by_name(job_name)()
|
||||
builder_chain = job.build()
|
||||
builder_chain = BuilderChain.from_chain(job.build())
|
||||
dag_config = builder_chain.to_rest()
|
||||
|
||||
params = {
|
||||
|
@ -70,3 +70,7 @@ class BuilderJob:
|
||||
return subclass
|
||||
else:
|
||||
raise ValueError(f"{name} is not a registered name for {cls.__name__}. ")
|
||||
|
||||
@classmethod
|
||||
def has_registered(cls):
|
||||
return cls._has_registered
|
||||
|
@ -1,24 +1,35 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
from abc import ABC
|
||||
from typing import Type
|
||||
from typing import Type, List
|
||||
|
||||
from knext import rest
|
||||
|
||||
|
||||
class RESTable(ABC):
|
||||
@property
|
||||
def upstream_types(self) -> Type["RESTable"]:
|
||||
raise NotImplementedError("To be implemented in subclass")
|
||||
def upstream_types(self) -> List[Type["RESTable"]]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def downstream_types(self) -> Type["RESTable"]:
|
||||
raise NotImplementedError("To be implemented in subclass")
|
||||
def downstream_types(self) -> List[Type["RESTable"]]:
|
||||
return []
|
||||
|
||||
def to_rest(self) -> rest.Node:
|
||||
raise NotImplementedError("To be implemented in subclass")
|
||||
raise NotImplementedError(f"`to_rest` is not currently supported for {self.__class__.__name__}.")
|
||||
|
||||
@classmethod
|
||||
def from_rest(cls, node: rest.Node):
|
||||
raise NotImplementedError("To be implemented in subclass")
|
||||
raise NotImplementedError(f"`from_rest` is not currently supported for {cls.__name__}.")
|
||||
|
||||
def submit(self):
|
||||
raise NotImplementedError("To be implemented in subclass")
|
||||
raise NotImplementedError(f"`submit` is not currently supported for {self.__class__.__name__}.")
|
||||
|
@ -1,4 +1,16 @@
|
||||
from typing import TypeVar, Sequence, Generic, Type
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
from typing import TypeVar, Sequence, Type
|
||||
|
||||
from pydantic import BaseConfig, BaseModel
|
||||
|
||||
@ -21,7 +33,7 @@ class Runnable(BaseModel):
|
||||
return
|
||||
|
||||
def invoke(self, input: Input) -> Sequence[Output]:
|
||||
raise NotImplementedError("To be implemented in subclass")
|
||||
raise NotImplementedError(f"`invoke` is not currently supported for {self.__class__.__name__}.")
|
||||
|
||||
def __rshift__(self, other: Other):
|
||||
raise NotImplementedError("To be implemented in subclass")
|
||||
|
@ -9,7 +9,7 @@ from knext import rest
|
||||
from knext.operator.op import PromptOp, ExtractOp
|
||||
|
||||
# try:
|
||||
from nn4k.invoker.base import LLMInvoker, NNInvoker # noqa: F403
|
||||
from nn4k.invoker.base import NNInvoker # noqa: F403
|
||||
|
||||
# except ImportError:
|
||||
# pass
|
||||
@ -54,7 +54,7 @@ class LLMBasedExtractor(SPGExtractor):
|
||||
|
||||
def to_rest(self):
|
||||
"""Transforms `LLMBasedExtractor` to REST model `ExtractNodeConfig`."""
|
||||
params = {}
|
||||
params = dict()
|
||||
params["model_config"] = json.dumps(self.llm._nn_config)
|
||||
api_client = OperatorClient()._rest_client.api_client
|
||||
params["prompt_config"] = json.dumps([api_client.sanitize_for_serialization(op.to_rest()) for op in self.prompt_ops], ensure_ascii=False)
|
||||
@ -117,7 +117,7 @@ class UserDefinedExtractor(SPGExtractor):
|
||||
"""Transforms `UserDefinedExtractor` to REST model `ExtractNodeConfig`."""
|
||||
operator_config = self.extract_op.to_rest()
|
||||
config = rest.UserDefinedExtractNodeConfig(
|
||||
output_fields=self.output_fields, operator_config=operator_config
|
||||
operator_config=operator_config
|
||||
)
|
||||
|
||||
return rest.Node(**super().to_dict(), node_config=config)
|
||||
|
@ -21,46 +21,44 @@ from knext.operator.op import LinkOp, FuseOp, PredictOp
|
||||
from knext.operator.spg_record import SPGRecord
|
||||
|
||||
|
||||
class MappingTypeEnum(str, Enum):
|
||||
SPGType = "SPG_TYPE"
|
||||
Relation = "RELATION"
|
||||
|
||||
|
||||
class LinkStrategyEnum(str, Enum):
|
||||
class LinkingStrategyEnum(str, Enum):
|
||||
IDEquals = "ID_EQUALS"
|
||||
|
||||
|
||||
SPG_TYPE_BASE_FIELDS = ["id"]
|
||||
class FusingStrategyEnum(str, Enum):
|
||||
pass
|
||||
|
||||
RELATION_BASE_FIELDS = ["src_id", "dst_id"]
|
||||
|
||||
class PredictingStrategyEnum(str, Enum):
|
||||
pass
|
||||
|
||||
|
||||
class SPGTypeMapping(Mapping):
|
||||
"""A Process Component that mapping data to entity/event/concept type.
|
||||
"""A Process Component that mapping data to entity/event/concept/standard type.
|
||||
|
||||
Args:
|
||||
spg_type_name: The SPG type name import from SPGTypeHelper.
|
||||
spg_type_name: The SPG type name of subject import from SPGTypeHelper.
|
||||
Examples:
|
||||
mapping = SPGTypeMapping(
|
||||
spg_type_name=DEFAULT.App
|
||||
).add_field("id", DEFAULT.App.id) \
|
||||
.add_field("id", DEFAULT.App.name) \
|
||||
.add_field("riskMark", DEFAULT.App.riskMark) \
|
||||
.add_field("useCert", DEFAULT.App.useCert)
|
||||
|
||||
).add_mapping_field("id", DEFAULT.App.id) \
|
||||
.add_mapping_field("name", DEFAULT.App.name) \
|
||||
.add_mapping_field("riskMark", DEFAULT.App.riskMark) \
|
||||
.add_predicting_field(DEFAULT.App.useCert)
|
||||
"""
|
||||
|
||||
"""The SPG type name of subject import from SPGTypeHelper."""
|
||||
spg_type_name: Union[str, SPGTypeHelper]
|
||||
|
||||
mapping: Dict[str, str] = dict()
|
||||
|
||||
filters: List[Tuple[str, str]] = list()
|
||||
|
||||
subject_fuse_strategy: Optional[FuseOp] = None
|
||||
subject_fusing_strategy: Optional[Union[FusingStrategyEnum, FuseOp]] = None
|
||||
|
||||
object_link_strategies: Dict[str, Union[LinkStrategyEnum, LinkOp]] = dict()
|
||||
object_linking_strategies: Dict[str, Union[LinkingStrategyEnum, LinkOp]] = dict()
|
||||
|
||||
predicate_predict_strategies: Dict[str, PredictOp] = dict()
|
||||
predicate_predicting_strategies: Dict[str, Union[PredictingStrategyEnum, PredictOp]] = dict()
|
||||
|
||||
@property
|
||||
def input_types(self) -> Input:
|
||||
@ -78,32 +76,34 @@ class SPGTypeMapping(Mapping):
|
||||
def output_keys(self):
|
||||
return self.output_fields
|
||||
|
||||
def set_fuse_strategy(self, fuse_strategy: FuseOp):
|
||||
self.subject_fuse_strategy = fuse_strategy
|
||||
def set_fusing_strategy(self, fusing_strategy: FuseOp):
|
||||
""""""
|
||||
self.subject_fusing_strategy = fusing_strategy
|
||||
return self
|
||||
|
||||
def add_mapping_field(
|
||||
self,
|
||||
source_field: str,
|
||||
target_field: Union[str, PropertyHelper],
|
||||
link_strategy: Union[LinkStrategyEnum, LinkOp] = None,
|
||||
linking_strategy: Union[LinkingStrategyEnum, LinkOp] = None,
|
||||
):
|
||||
"""Adds a field mapping from source data to property of spg_type.
|
||||
|
||||
:param source_field: The source field to be mapped.
|
||||
:param target_field: The target field to map the source field to.
|
||||
:param target_field: The target field (SPG property name) to map the source field to.
|
||||
:param linking_strategy: The target field to map the source field to.
|
||||
:return: self
|
||||
"""
|
||||
self.mapping[target_field] = source_field
|
||||
self.object_link_strategies[target_field] = link_strategy
|
||||
self.object_linking_strategies[target_field] = linking_strategy
|
||||
return self
|
||||
|
||||
def add_predicting_field(
|
||||
self,
|
||||
field: Union[str, PropertyHelper],
|
||||
predict_strategy: PredictOp = None,
|
||||
predicting_strategy: PredictOp = None,
|
||||
):
|
||||
self.predicate_predict_strategies[field] = predict_strategy
|
||||
self.predicate_predicting_strategies[field] = predicting_strategy
|
||||
return self
|
||||
|
||||
def add_filter(self, column_name: str, column_value: str):
|
||||
@ -121,6 +121,9 @@ class SPGTypeMapping(Mapping):
|
||||
"""
|
||||
Transforms `SPGTypeMapping` to REST model `SpgTypeMappingNodeConfig`.
|
||||
"""
|
||||
from knext.client.schema import SchemaClient
|
||||
client = SchemaClient()
|
||||
spg_type = client.query_spg_type(self.spg_type_name)
|
||||
|
||||
mapping_filters = [
|
||||
rest.MappingFilter(column_name=name, column_value=value)
|
||||
@ -128,17 +131,25 @@ class SPGTypeMapping(Mapping):
|
||||
]
|
||||
mapping_configs = []
|
||||
for tgt_name, src_name in self.mapping.items():
|
||||
link_strategy = self.object_link_strategies.get(tgt_name, None)
|
||||
if isinstance(link_strategy, LinkOp):
|
||||
linking_strategy = self.object_linking_strategies.get(tgt_name, None)
|
||||
if isinstance(linking_strategy, LinkOp):
|
||||
strategy_config = rest.OperatorLinkingConfig(
|
||||
operator_config=link_strategy.to_rest()
|
||||
operator_config=linking_strategy.to_rest()
|
||||
)
|
||||
elif link_strategy == LinkStrategyEnum.IDEquals:
|
||||
elif linking_strategy == LinkingStrategyEnum.IDEquals:
|
||||
strategy_config = rest.IdEqualsLinkingConfig()
|
||||
elif not link_strategy:
|
||||
strategy_config = None
|
||||
elif not linking_strategy:
|
||||
object_type_name = spg_type.properties[tgt_name].object_type_name
|
||||
if object_type_name in LinkOp.bind_schemas:
|
||||
op_name = LinkOp.bind_schemas[object_type_name]
|
||||
op = LinkOp.by_name(op_name)()
|
||||
strategy_config = rest.OperatorLinkingConfig(
|
||||
operator_config=op.to_rest()
|
||||
)
|
||||
else:
|
||||
strategy_config = None
|
||||
else:
|
||||
raise ValueError(f"Invalid link_strategy [{link_strategy}].")
|
||||
raise ValueError(f"Invalid linking_strategy [{linking_strategy}].")
|
||||
mapping_configs.append(
|
||||
rest.MappingConfig(
|
||||
source=src_name,
|
||||
@ -148,37 +159,42 @@ class SPGTypeMapping(Mapping):
|
||||
)
|
||||
|
||||
predicting_configs = []
|
||||
for predict_strategy in self.predicate_predict_strategies:
|
||||
if isinstance(predict_strategy, PredictOp):
|
||||
for predicate_name, predicting_strategy in self.predicate_predicting_strategies.items():
|
||||
if isinstance(predicting_strategy, PredictOp):
|
||||
strategy_config = rest.OperatorPredictingConfig(
|
||||
operator_config=predict_strategy.to_rest()
|
||||
operator_config=predicting_strategy.to_rest()
|
||||
)
|
||||
elif not predict_strategy:
|
||||
# if self.spg_type_name in PredictOp._bind_schemas:
|
||||
# op_name = PredictOp._bind_schemas[self.spg_type_name]
|
||||
# op = PredictOp.by_name(op_name)()
|
||||
# strategy_config = op.to_rest()
|
||||
# else:
|
||||
strategy_config = None
|
||||
elif not predicting_strategy:
|
||||
if (self.spg_type_name, predicate_name) in PredictOp.bind_schemas:
|
||||
op_name = PredictOp.bind_schemas[(self.spg_type_name, predicate_name)]
|
||||
op = PredictOp.by_name(op_name)()
|
||||
strategy_config = rest.OperatorPredictingConfig(
|
||||
operator_config=op.to_rest()
|
||||
)
|
||||
else:
|
||||
strategy_config = None
|
||||
else:
|
||||
raise ValueError(f"Invalid predict_strategy [{predict_strategy}].")
|
||||
predicting_configs.append(
|
||||
strategy_config
|
||||
)
|
||||
raise ValueError(f"Invalid predicting_strategy [{predicting_strategy}].")
|
||||
if strategy_config:
|
||||
predicting_configs.append(
|
||||
strategy_config
|
||||
)
|
||||
|
||||
if isinstance(self.subject_fuse_strategy, FuseOp):
|
||||
if isinstance(self.subject_fusing_strategy, FuseOp):
|
||||
fusing_config = rest.OperatorFusingConfig(
|
||||
operator_config=self.fuse_strategy.to_rest()
|
||||
operator_config=self.fusing_strategy.to_rest()
|
||||
)
|
||||
elif not self.subject_fuse_strategy:
|
||||
if self.spg_type_name in FuseOp._bind_schemas:
|
||||
op_name = FuseOp._bind_schemas[self.spg_type_name]
|
||||
elif not self.subject_fusing_strategy:
|
||||
if self.spg_type_name in FuseOp.bind_schemas:
|
||||
op_name = FuseOp.bind_schemas[self.spg_type_name]
|
||||
op = FuseOp.by_name(op_name)()
|
||||
fusing_config = op.to_rest()
|
||||
fusing_config = rest.OperatorFusingConfig(
|
||||
operator_config=op.to_rest()
|
||||
)
|
||||
else:
|
||||
fusing_config = None
|
||||
else:
|
||||
raise ValueError(f"Invalid fuse_strategy [{self.subject_fuse_strategy}].")
|
||||
raise ValueError(f"Invalid fusing_strategy [{self.subject_fusing_strategy}].")
|
||||
|
||||
config = rest.SpgTypeMappingNodeConfig(
|
||||
spg_type=self.spg_type_name,
|
||||
@ -190,11 +206,11 @@ class SPGTypeMapping(Mapping):
|
||||
return rest.Node(**super().to_dict(), node_config=config)
|
||||
|
||||
def invoke(self, input: Input) -> Sequence[Output]:
|
||||
pass
|
||||
raise NotImplementedError(f"`invoke` method is not currently supported for {self.__class__.__name__}.")
|
||||
|
||||
@classmethod
|
||||
def from_rest(cls, node: rest.Node):
|
||||
pass
|
||||
raise NotImplementedError(f"`invoke` method is not currently supported for {cls.__name__}.")
|
||||
|
||||
def submit(self):
|
||||
pass
|
||||
@ -208,15 +224,16 @@ class RelationMapping(Mapping):
|
||||
predicate_name: The predicate name.
|
||||
object_name: The object name import from SPGTypeHelper.
|
||||
Examples:
|
||||
mapping = RelationMappingComponent(
|
||||
mapping = RelationMapping(
|
||||
subject_name=DEFAULT.App,
|
||||
predicate_name=DEFAULT.App.useCert,
|
||||
object_name=DEFAULT.Cert,
|
||||
).add_field("src_id", "srcId") \
|
||||
.add_field("dst_id", "dstId")
|
||||
).add_mapping_field("src_id", "srcId") \
|
||||
.add_mapping_field("dst_id", "dstId")
|
||||
|
||||
"""
|
||||
|
||||
"""The SPG type names of (subject, predicate, object) triplet imported from SPGTypeHelper and PropertyHelper."""
|
||||
subject_name: Union[str, SPGTypeHelper]
|
||||
predicate_name: Union[str, PropertyHelper]
|
||||
object_name: Union[str, SPGTypeHelper]
|
||||
@ -277,19 +294,33 @@ class RelationMapping(Mapping):
|
||||
|
||||
|
||||
class SubGraphMapping(Mapping):
|
||||
"""A Process Component that mapping data to relation type.
|
||||
|
||||
Args:
|
||||
spg_type_name: The SPG type name import from SPGTypeHelper.
|
||||
Examples:
|
||||
mapping = SubGraphMapping(
|
||||
spg_type_name=DEFAULT.App,
|
||||
).add_mapping_field("id", DEFAULT.App.id) \
|
||||
.add_mapping_field("name", DEFAULT.App.name) \
|
||||
.add_mapping_field("useCert", DEFAULT.App.useCert)
|
||||
.add_predicting_field(
|
||||
|
||||
"""
|
||||
|
||||
""""""
|
||||
spg_type_name: Union[str, SPGTypeHelper]
|
||||
|
||||
mapping: Dict[str, str] = dict()
|
||||
|
||||
filters: List[Tuple[str, str]] = list()
|
||||
|
||||
subject_fuse_strategy: Optional[FuseOp] = None
|
||||
subject_fusing_strategy: Optional[FuseOp] = None
|
||||
|
||||
predicate_predict_strategies: Dict[str, PredictOp] = dict()
|
||||
predicate_predicting_strategies: Dict[str, PredictOp] = dict()
|
||||
|
||||
object_fuse_strategies: Dict[str, FuseOp] = dict()
|
||||
|
||||
|
||||
@property
|
||||
def input_types(self) -> Input:
|
||||
return Union[Dict[str, str], SPGRecord]
|
||||
@ -306,15 +337,15 @@ class SubGraphMapping(Mapping):
|
||||
def output_keys(self):
|
||||
return self.output_fields
|
||||
|
||||
def set_fuse_strategy(self, fuse_strategy: FuseOp):
|
||||
self.subject_fuse_strategy = fuse_strategy
|
||||
def set_fusing_strategy(self, fusing_strategy: FuseOp):
|
||||
self.subject_fusing_strategy = fusing_strategy
|
||||
return self
|
||||
|
||||
def add_mapping_field(
|
||||
self,
|
||||
source_field: str,
|
||||
target_field: Union[str, PropertyHelper],
|
||||
fuse_strategy: FuseOp = None,
|
||||
fusing_strategy: Union[FusingStrategyEnum, FuseOp] = None,
|
||||
):
|
||||
"""Adds a field mapping from source data to property of spg_type.
|
||||
|
||||
@ -323,15 +354,15 @@ class SubGraphMapping(Mapping):
|
||||
:return: self
|
||||
"""
|
||||
self.mapping[target_field] = source_field
|
||||
self.object_fuse_strategies[target_field] = fuse_strategy
|
||||
self.object_fuse_strategies[target_field] = fusing_strategy
|
||||
return self
|
||||
|
||||
def add_predicting_field(
|
||||
self,
|
||||
target_field: Union[str, PropertyHelper],
|
||||
predict_strategy: PredictOp = None,
|
||||
predicting_strategy: PredictOp = None,
|
||||
):
|
||||
self.predict_strategies[target_field] = predict_strategy
|
||||
self.predicate_predicting_strategies[target_field] = predicting_strategy
|
||||
return self
|
||||
|
||||
def add_filter(self, column_name: str, column_value: str):
|
||||
@ -349,6 +380,9 @@ class SubGraphMapping(Mapping):
|
||||
"""
|
||||
Transforms `SubGraphMapping` to REST model `SpgTypeMappingNodeConfig`.
|
||||
"""
|
||||
from knext.client.schema import SchemaClient
|
||||
client = SchemaClient()
|
||||
spg_type = client.query_spg_type(self.spg_type_name)
|
||||
|
||||
mapping_filters = [
|
||||
rest.MappingFilter(column_name=name, column_value=value)
|
||||
@ -356,16 +390,23 @@ class SubGraphMapping(Mapping):
|
||||
]
|
||||
mapping_configs = []
|
||||
for tgt_name, src_name in self.mapping.items():
|
||||
fuse_strategy = self.object_fuse_strategies.get(tgt_name, None)
|
||||
if isinstance(fuse_strategy, FuseOp):
|
||||
fusing_strategy = self.object_fuse_strategies.get(tgt_name, None)
|
||||
if isinstance(fusing_strategy, FuseOp):
|
||||
strategy_config = rest.OperatorFusingConfig(
|
||||
operator_config=fuse_strategy.to_rest()
|
||||
)
|
||||
elif not self.subject_fuse_strategy:
|
||||
strategy_config = rest.NewInstanceFusingConfig(
|
||||
operator_config=fusing_strategy.to_rest()
|
||||
)
|
||||
elif not self.subject_fusing_strategy:
|
||||
object_type_name = spg_type.properties[tgt_name].object_type_name
|
||||
if object_type_name in FuseOp.bind_schemas:
|
||||
op_name = FuseOp.bind_schemas[object_type_name]
|
||||
op = FuseOp.by_name(op_name)()
|
||||
strategy_config = rest.OperatorFusingConfig(
|
||||
operator_config=op.to_rest()
|
||||
)
|
||||
else:
|
||||
strategy_config = rest.NewInstanceFusingConfig()
|
||||
else:
|
||||
raise ValueError(f"Invalid fuse_strategy [{fuse_strategy}].")
|
||||
raise ValueError(f"Invalid fusing_strategy [{fusing_strategy}].")
|
||||
mapping_configs.append(
|
||||
rest.MappingConfig(
|
||||
source=src_name,
|
||||
@ -375,28 +416,42 @@ class SubGraphMapping(Mapping):
|
||||
)
|
||||
|
||||
predicting_configs = []
|
||||
for predict_strategy in self.predicate_predict_strategies:
|
||||
if isinstance(predict_strategy, PredictOp):
|
||||
for predicate_name, predicting_strategy in self.predicate_predicting_strategies.items():
|
||||
if isinstance(predicting_strategy, PredictOp):
|
||||
strategy_config = rest.OperatorPredictingConfig(
|
||||
operator_config=predict_strategy.to_rest()
|
||||
operator_config=predicting_strategy.to_rest()
|
||||
)
|
||||
elif not predict_strategy:
|
||||
strategy_config = None
|
||||
elif not predicting_strategy:
|
||||
if (self.spg_type_name, predicate_name) in PredictOp.bind_schemas:
|
||||
op_name = PredictOp.bind_schemas[(self.spg_type_name, predicate_name)]
|
||||
op = PredictOp.by_name(op_name)()
|
||||
strategy_config = rest.OperatorPredictingConfig(
|
||||
operator_config=op.to_rest()
|
||||
)
|
||||
else:
|
||||
strategy_config = None
|
||||
else:
|
||||
raise ValueError(f"Invalid predict_strategy [{predict_strategy}].")
|
||||
predicting_configs.append(
|
||||
strategy_config
|
||||
)
|
||||
|
||||
if isinstance(self.subject_fuse_strategy, FuseOp):
|
||||
fusing_config = rest.OperatorFusingConfig(
|
||||
operator_config=self.fuse_strategy.to_rest()
|
||||
)
|
||||
elif not self.subject_fuse_strategy:
|
||||
fusing_config = rest.NewInstanceFusingConfig(
|
||||
raise ValueError(f"Invalid predicting_strategy [{predicting_strategy}].")
|
||||
if strategy_config:
|
||||
predicting_configs.append(
|
||||
strategy_config
|
||||
)
|
||||
|
||||
if isinstance(self.subject_fusing_strategy, FuseOp):
|
||||
fusing_config = rest.OperatorFusingConfig(
|
||||
operator_config=self.fusing_strategy.to_rest()
|
||||
)
|
||||
elif not self.subject_fusing_strategy:
|
||||
if self.spg_type_name in FuseOp.bind_schemas:
|
||||
op_name = FuseOp.bind_schemas[self.spg_type_name]
|
||||
op = FuseOp.by_name(op_name)()
|
||||
fusing_config = rest.OperatorFusingConfig(
|
||||
operator_config=op.to_rest()
|
||||
)
|
||||
else:
|
||||
fusing_config = rest.NewInstanceFusingConfig()
|
||||
else:
|
||||
raise ValueError(f"Invalid fuse_strategy [{self.subject_fuse_strategy}].")
|
||||
raise ValueError(f"Invalid fusing_strategy [{self.subject_fusing_strategy}].")
|
||||
|
||||
config = rest.SubGraphMappingNodeConfig(
|
||||
spg_type=self.spg_type_name,
|
||||
|
0
python/knext/component/reasoner/__init__.py
Normal file
0
python/knext/component/reasoner/__init__.py
Normal file
20
python/knext/examples/financial/.knext.cfg
Normal file
20
python/knext/examples/financial/.knext.cfg
Normal file
@ -0,0 +1,20 @@
|
||||
# This config file is auto generated by:
|
||||
# knext project create --name 全风 --namespace Financial --desc 全风财政指标抽取
|
||||
# Do not edit the config file manually.
|
||||
|
||||
[local]
|
||||
project_name = 全风
|
||||
description = 全风财政指标抽取
|
||||
namespace = Financial
|
||||
project_id = 4
|
||||
|
||||
project_dir = financial
|
||||
schema_dir = schema
|
||||
schema_file = financial.schema
|
||||
builder_dir = builder
|
||||
builder_operator_dir = builder/operator
|
||||
builder_record_dir = builder/error_record
|
||||
builder_job_dir = builder/job
|
||||
builder_model_dir = builder/model
|
||||
reasoner_dir = reasoner
|
||||
reasoner_result_dir = reasoner/result
|
20
python/knext/examples/financial/README.md
Normal file
20
python/knext/examples/financial/README.md
Normal file
@ -0,0 +1,20 @@
|
||||
```bash
|
||||
knext project create --name 全风 --namespace Financial --desc 全风财政指标抽取
|
||||
```
|
||||
|
||||
```bash
|
||||
knext schema commit
|
||||
```
|
||||
|
||||
```bash
|
||||
knext operator publish DemoExtractOp
|
||||
```
|
||||
|
||||
```bash
|
||||
knext builder submit Demo
|
||||
```
|
||||
|
||||
```bash
|
||||
knext reasoner query --file ./reasoner/demo.dsl
|
||||
```
|
||||
|
@ -0,0 +1,2 @@
|
||||
input
|
||||
济南市财政收入质量及自给能力均较好,但土地出让收入大幅下降致综合财力明显下滑。济南市财政收入质量及自给能力均较好,但土地出让收入大幅下降致综合财力明显下滑。2022年济南市一般公共预算收入1,000.21亿元,扣除留 抵退税因素后同比增长8%,规模在山东省下辖地市中排名第2位;其中税收收入690.31亿元,税收占比69.02%;一般公共 预算支出1,260.23亿元,财政自给率79.37%。政府性基金收入547.29亿元,同比大幅下降48.38%,主要系土地出让收入 同比由966.74亿元降至453.74亿元;转移性收入285.78亿元(上年同期为233.11亿元);综合财力约1,833.28亿元(上年 同期为2,301.02亿元)。
|
|
@ -0,0 +1,43 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from knext.client.model.builder_job import BuilderJob
|
||||
from knext.api.component import CSVReader, SPGTypeMapping, KGWriter
|
||||
from knext.component.builder import LLMBasedExtractor, SubGraphMapping
|
||||
from nn4k.invoker import LLMInvoker
|
||||
|
||||
try:
|
||||
from schema.financial_schema_helper import Financial
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
class StateAndIndicator(BuilderJob):
|
||||
|
||||
def build(self):
|
||||
source = CSVReader(
|
||||
local_path="/Users/jier/openspg/python/knext/examples/financial/builder/job/data/document.csv",
|
||||
columns=["input"],
|
||||
start_row=2
|
||||
)
|
||||
|
||||
from knext.examples.financial.builder.operator.IndicatorNER import IndicatorNER
|
||||
from knext.examples.financial.builder.operator.IndicatorREL import IndicatorREL
|
||||
from knext.examples.financial.builder.operator.IndicatorLOGIC import IndicatorLOGIC
|
||||
extract = LLMBasedExtractor(llm=LLMInvoker.from_config("/Users/jier/openspg/python/knext/examples/financial/builder/model/openai_infer.json"),
|
||||
prompt_ops=[IndicatorNER(), IndicatorREL(), IndicatorLOGIC()]
|
||||
)
|
||||
|
||||
state_mapping = SubGraphMapping(spg_type_name="Financial.State")\
|
||||
.add_mapping_field("id", "id") \
|
||||
.add_mapping_field("name", "name") \
|
||||
.add_mapping_field("causeOf", "causeOf") \
|
||||
.add_predicting_field("derivedFrom")
|
||||
|
||||
indicator_mapping = SubGraphMapping(spg_type_name="Financial.Indicator")\
|
||||
.add_mapping_field("id", "id") \
|
||||
.add_mapping_field("name", "name")
|
||||
# .add_predicting_field("isA")
|
||||
|
||||
sink = KGWriter()
|
||||
|
||||
return source >> extract >> [state_mapping, indicator_mapping] >> sink
|
@ -14,6 +14,8 @@ class IndicatorFuse(FuseOp):
|
||||
self.search_client = SearchClient("Financial.Indicator")
|
||||
|
||||
def invoke(self, subject_records: List[SPGRecord]) -> List[SPGRecord]:
|
||||
print("##########IndicatorFuse###########")
|
||||
print(subject_records)
|
||||
fused_records = []
|
||||
for record in subject_records:
|
||||
query = {"match": {"name": record.get_property("name", "")}}
|
||||
|
@ -43,10 +43,11 @@ ${rel}
|
||||
"""
|
||||
response: "[{\"subject\": \"土地出让收入大幅下降\", \"predicate\": \"顺承\", \"object\": [\"综合财力明显下滑\"]}]"
|
||||
"""
|
||||
print("##########IndicatorLOGIC###########")
|
||||
response = "[{\"subject\": \"土地出让收入大幅下降\", \"predicate\": \"顺承\", \"object\": [\"综合财力明显下滑\"]}]"
|
||||
output_list = json.loads(response)
|
||||
|
||||
logic_result = []
|
||||
# IF hasA
|
||||
for output in output_list:
|
||||
properties = {}
|
||||
for k, v in output.items():
|
||||
@ -55,6 +56,6 @@ ${rel}
|
||||
properties["name"] = k
|
||||
elif k == "object":
|
||||
properties["causeOf"] = ','.join(v)
|
||||
logic_result.append(SPGRecord("FEL.State", properties=properties))
|
||||
logic_result.append(SPGRecord("Financial.State", properties=properties))
|
||||
|
||||
return logic_result
|
||||
|
@ -22,29 +22,18 @@ class IndicatorNER(PromptOp):
|
||||
def parse_response(
|
||||
self, response: str
|
||||
) -> List[SPGRecord]:
|
||||
# output_list = json.loads(response)
|
||||
#
|
||||
# ner_result = []
|
||||
# # IF hasA
|
||||
# for output in output_list:
|
||||
# # {'财政': ['财政收入....}
|
||||
# for k, v in output.items():
|
||||
# # '财政', ['财政收入....]
|
||||
# ner_result.append(SPGRecord("FEL.Indicator", properties={"id": k, "name": k, "hasA": ','.join(v)}))
|
||||
#
|
||||
# # ELSE isA
|
||||
# # TODO 通过属性isA支持
|
||||
# for output in output_list:
|
||||
# # {'财政': ['财政收入....}
|
||||
# for k, v in output.items():
|
||||
# # '财政', ['财政收入....]
|
||||
# for _v in v:
|
||||
# # '财政收入....'
|
||||
# ner_result.append(SPGRecord("FEL.Indicator", properties={"id": f'{k}-{_v}', "name": _v}))
|
||||
print("##########IndicatorNER###########")
|
||||
ner_result = [SPGRecord(spg_type_name="Financial.Indicator", properties={"id": "土地出让收入", "name": "土地出让收入"})]
|
||||
print(ner_result)
|
||||
print("##########IndicatorNER###########")
|
||||
response = "[{'财政': ['财政收入质量', '财政自给能力', '土地出让收入', '一般公共预算收入', '留抵退税', '税收收入', '税收收入/一般公共预算收入', '一般公共预算支出', '财政自给率', '政府性基金收入', '转移性收入', '综合财力']}]"
|
||||
|
||||
output_list = json.loads(response.replace("'", "\""))
|
||||
ner_result = []
|
||||
# IF hasA
|
||||
for output in output_list:
|
||||
# {'财政': ['财政收入....}
|
||||
for category, indicator_list in output.items():
|
||||
# '财政', ['财政收入....]
|
||||
for indicator in indicator_list:
|
||||
ner_result.append(SPGRecord("Financial.Indicator", properties={"id": indicator, "name": indicator}))
|
||||
return ner_result
|
||||
|
||||
def build_next_variables(
|
||||
@ -53,5 +42,6 @@ class IndicatorNER(PromptOp):
|
||||
"""
|
||||
response: "[{'subject': '一般公共预算收入', 'predicate': '包含', 'object': ['税收收入']}, {'subject': '税收收入', 'predicate': '包含', 'object': ['留抵退税']}, {'subject': '政府性基金收入', 'predicate': '包含', 'object': ['土地出让收入', '转移性收入']}, {'subject': '综合财力', 'predicate': '包含', 'object': ['一般公共预算收入', '政府性基金收入']}]"
|
||||
"""
|
||||
response = ""
|
||||
response = "[{'财政': ['财政收入质量', '财政自给能力', '土地出让收入', '一般公共预算收入', '留抵退税', '税收收入', '税收收入/一般公共预算收入', '一般公共预算支出', '财政自给率', '政府性基金收入', '转移性收入', '综合财力']}]"
|
||||
|
||||
return [{"input": variables["input"], "ner": response}]
|
||||
|
@ -11,15 +11,13 @@ class IndicatorPredict(PredictOp):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# self.search_client = SearchClient("Financial.Indicator")
|
||||
self.search_client = SearchClient("Financial.Indicator")
|
||||
|
||||
def invoke(self, subject_record: SPGRecord) -> List[SPGRecord]:
|
||||
# query = {"match": {"name": subject_record.get_property("name", "")}}
|
||||
# recall_records = self.search_client.search(query, start=0, size=10)
|
||||
# if recall_records is not None and len(recall_records) > 0:
|
||||
# rerank_record = SPGRecord("Financial.Indicator", {"id": recall_records[0].doc_id, "name": recall_records[0].properties.get("name", "")})
|
||||
# return [rerank_record]
|
||||
# return []
|
||||
print("##########IndicatorPredict###########")
|
||||
|
||||
return [subject_record]
|
||||
query = {"match": {"name": subject_record.get_property("name", "")}}
|
||||
recall_records = self.search_client.search(query, start=0, size=10)
|
||||
if recall_records is not None and len(recall_records) > 0:
|
||||
rerank_record = SPGRecord("Financial.Indicator", {"id": recall_records[0].doc_id, "name": recall_records[0].properties.get("name", "")})
|
||||
return [rerank_record]
|
||||
return []
|
||||
|
@ -38,4 +38,5 @@ ${ner}
|
||||
"""
|
||||
response: "[{'subject': '一般公共预算收入', 'predicate': '包含', 'object': ['税收收入']}, {'subject': '税收收入', 'predicate': '包含', 'object': ['留抵退税']}, {'subject': '政府性基金收入', 'predicate': '包含', 'object': ['土地出让收入', '转移性收入']}, {'subject': '综合财力', 'predicate': '包含', 'object': ['一般公共预算收入', '政府性基金收入']}]"
|
||||
"""
|
||||
response = "[{'subject': '一般公共预算收入', 'predicate': '包含', 'object': ['税收收入']}, {'subject': '税收收入', 'predicate': '包含', 'object': ['留抵退税']}, {'subject': '政府性基金收入', 'predicate': '包含', 'object': ['土地出让收入', '转移性收入']}, {'subject': '综合财力', 'predicate': '包含', 'object': ['一般公共预算收入', '政府性基金收入']}]"
|
||||
return [{"input": variables["input"], "ner": variables["ner"], "rel": response}]
|
||||
|
@ -14,6 +14,7 @@ class StateFuse(FuseOp):
|
||||
self.search_client = SearchClient("Financial.State")
|
||||
|
||||
def invoke(self, subject_records: List[SPGRecord]) -> List[SPGRecord]:
|
||||
print("##########StateFuse###########")
|
||||
fused_records = []
|
||||
for record in subject_records:
|
||||
query = {"match": {"name": record.get_property("name", "")}}
|
||||
|
12
python/knext/examples/financial/schema/financial.schema
Normal file
12
python/knext/examples/financial/schema/financial.schema
Normal file
@ -0,0 +1,12 @@
|
||||
namespace Financial
|
||||
|
||||
Indicator(指标概念): ConceptType
|
||||
hypernymPredicate: isA
|
||||
|
||||
State(状态): ConceptType
|
||||
desc: 指标状态
|
||||
properties:
|
||||
causeOf(导致): State
|
||||
desc: 状态顺承关系
|
||||
derivedFrom(指标): Indicator
|
||||
desc: 状态的指标
|
@ -0,0 +1,34 @@
|
||||
# ATTENTION!
|
||||
# This file is generated by Schema automatically, it will be refreshed after schema has been committed
|
||||
# PLEASE DO NOT MODIFY THIS FILE!!!
|
||||
#
|
||||
|
||||
class Financial:
|
||||
def __init__(self):
|
||||
self.Indicator = self.Indicator()
|
||||
self.State = self.State()
|
||||
|
||||
class Indicator:
|
||||
__typename__ = "Financial.Indicator"
|
||||
description = "description"
|
||||
id = "id"
|
||||
name = "name"
|
||||
alias = "alias"
|
||||
stdId = "stdId"
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
class State:
|
||||
__typename__ = "Financial.State"
|
||||
description = "description"
|
||||
id = "id"
|
||||
name = "name"
|
||||
stdId = "stdId"
|
||||
derivedFrom = "derivedFrom"
|
||||
causeOf = "causeOf"
|
||||
alias = "alias"
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
@ -1,2 +1 @@
|
||||
content
|
||||
甲状腺结节是指在甲状腺内的肿块,可随吞咽动作随甲状腺而上下移动,是临床常见的病症,可由多种病因引起。临床上有多种甲状腺疾病,如甲状腺退行性变、炎症、自身免疫以及新生物等都可以表现为结节。甲状腺结节可以单发,也可以多发,多发结节比单发结节的发病率高,但单发结节甲状腺癌的发生率较高。患者通常可以选择在普外科,甲状腺外科,内分泌科,头颈外科挂号就诊。有些患者可以触摸到自己颈部前方的结节。在大多情况下,甲状腺结节没有任何症状,甲状腺功能也是正常的。甲状腺结节进展为其它甲状腺疾病的概率只有1%。有些人会感觉到颈部疼痛、咽喉部异物感,或者存在压迫感。当甲状腺结节发生囊内自发性出血时,疼痛感会更加强烈。治疗方面,一般情况下可以用放射性碘治疗,复方碘口服液(Lugol液)等,或者服用抗甲状腺药物来抑制甲状腺激素的分泌。目前常用的抗甲状腺药物是硫脲类化合物,包括硫氧嘧啶类的丙基硫氧嘧啶(PTU)和甲基硫氧嘧啶(MTU)及咪唑类的甲硫咪唑和卡比马唑。
|
|
@ -13,10 +13,11 @@
|
||||
from knext.client.model.builder_job import BuilderJob
|
||||
from knext.api.component import (
|
||||
CSVReader,
|
||||
LLMBasedExtractor,
|
||||
SubGraphMapping,
|
||||
KGWriter
|
||||
)
|
||||
from knext.component.builder import LLMBasedExtractor, SubGraphMapping
|
||||
from knext.operator.builtin.auto_prompt import SPOPrompt
|
||||
from knext.api.operator import SPOPrompt
|
||||
from nn4k.invoker import LLMInvoker
|
||||
|
||||
|
||||
@ -26,17 +27,26 @@ class Disease(BuilderJob):
|
||||
1. 定义输入源,CSV文件
|
||||
"""
|
||||
source = CSVReader(
|
||||
local_path="job/data/Disease.csv",
|
||||
columns=["id", "input"],
|
||||
start_row=2,
|
||||
local_path="builder/job/data/Disease.csv",
|
||||
columns=["input"],
|
||||
start_row=1,
|
||||
)
|
||||
|
||||
"""
|
||||
2. 定义大模型抽取组件,从长文本中抽取Medical.Disease类型实体
|
||||
"""
|
||||
|
||||
extract = LLMBasedExtractor(llm=LLMInvoker.from_config("openai_infer.json"),
|
||||
prompt_ops=[SPOPrompt("Medical.Disease", ["commonSymptom", "applicableDrug"])])
|
||||
extract = LLMBasedExtractor(
|
||||
llm=LLMInvoker.from_config("builder/model/openai_infer.json"),
|
||||
prompt_ops=[SPOPrompt(
|
||||
spg_type_name="Medical.Disease",
|
||||
property_names=[
|
||||
"complication",
|
||||
"commonSymptom",
|
||||
"applicableDrug",
|
||||
"department",
|
||||
"diseaseSite",
|
||||
])]
|
||||
)
|
||||
|
||||
"""
|
||||
2. 定义子图映射组件
|
||||
@ -44,8 +54,11 @@ class Disease(BuilderJob):
|
||||
mapping = SubGraphMapping(spg_type_name="Medical.Disease") \
|
||||
.add_mapping_field("id", "id") \
|
||||
.add_mapping_field("name", "name") \
|
||||
.add_mapping_field("complication", "complication") \
|
||||
.add_mapping_field("commonSymptom", "commonSymptom") \
|
||||
.add_mapping_field("applicableDrug", "applicableDrug")
|
||||
.add_mapping_field("applicableDrug", "applicableDrug") \
|
||||
.add_mapping_field("department", "department") \
|
||||
.add_mapping_field("diseaseSite", "diseaseSite")
|
||||
|
||||
"""
|
||||
4. 定义输出到图谱
|
||||
@ -56,4 +69,3 @@ class Disease(BuilderJob):
|
||||
5. 定义builder_chain
|
||||
"""
|
||||
return source >> extract >> mapping >> sink
|
||||
|
||||
|
@ -78,7 +78,7 @@ class BaseOp(ABC):
|
||||
)
|
||||
cls._registry[name] = subclass
|
||||
if hasattr(subclass, "bind_to"):
|
||||
subclass.__bases__[0]._bind_schemas[subclass.bind_to] = name
|
||||
subclass.__bases__[0].bind_schemas[subclass.bind_to] = name
|
||||
return subclass
|
||||
|
||||
return add_subclass_to_registry
|
||||
@ -104,3 +104,7 @@ class BaseOp(ABC):
|
||||
method="_handle",
|
||||
params=self.params,
|
||||
)
|
||||
|
||||
@property
|
||||
def has_registered(self):
|
||||
return self._has_registered
|
||||
|
@ -48,23 +48,24 @@ input:${input}
|
||||
return self.template.replace("${input}", variables.get("input"))
|
||||
|
||||
def parse_response(self, response: str) -> List[SPGRecord]:
|
||||
print(response)
|
||||
result = []
|
||||
subject = {}
|
||||
# re_obj = json.loads(response)
|
||||
re_obj = {
|
||||
"spo": [
|
||||
{
|
||||
"subject": "甲状腺结节",
|
||||
"predicate": "常见症状",
|
||||
"object": "甲状腺结节"
|
||||
},
|
||||
{
|
||||
"subject": "甲状腺结节",
|
||||
"predicate": "适用药品",
|
||||
"object": "放射性碘治疗,复方碘口服液(Lugol液),抗甲状腺药物,硫脲类化合物,丙基硫氧嘧啶(PTU),甲基硫氧嘧啶(MTU),咪唑类的甲硫咪唑和卡比马唑"
|
||||
}
|
||||
]
|
||||
}
|
||||
re_obj = json.loads(response)
|
||||
# re_obj = {
|
||||
# "spo": [
|
||||
# {
|
||||
# "subject": "甲状腺结节",
|
||||
# "predicate": "常见症状",
|
||||
# "object": "甲状腺结节"
|
||||
# },
|
||||
# {
|
||||
# "subject": "甲状腺结节",
|
||||
# "predicate": "适用药品",
|
||||
# "object": "放射性碘治疗,复方碘口服液(Lugol液),抗甲状腺药物,硫脲类化合物,丙基硫氧嘧啶(PTU),甲基硫氧嘧啶(MTU),咪唑类的甲硫咪唑和卡比马唑"
|
||||
# }
|
||||
# ]
|
||||
# }
|
||||
if "spo" not in re_obj.keys():
|
||||
raise ValueError("SPO format error.")
|
||||
subject_properties = {}
|
||||
@ -95,21 +96,21 @@ input:${input}
|
||||
return result
|
||||
|
||||
def build_variables(self, variables: Dict[str, str], response: str) -> List[Dict[str, str]]:
|
||||
# re_obj = json.loads(response)
|
||||
re_obj = {
|
||||
"spo": [
|
||||
{
|
||||
"subject": "甲状腺结节",
|
||||
"predicate": "常见症状",
|
||||
"object": "甲状腺结节"
|
||||
},
|
||||
{
|
||||
"subject": "甲状腺结节",
|
||||
"predicate": "适用药品",
|
||||
"object": "放射性碘治疗,复方碘口服液(Lugol液),抗甲状腺药物,硫脲类化合物,丙基硫氧嘧啶(PTU),甲基硫氧嘧啶(MTU),咪唑类的甲硫咪唑和卡比马唑"
|
||||
}
|
||||
]
|
||||
}
|
||||
re_obj = json.loads(response)
|
||||
# re_obj = {
|
||||
# "spo": [
|
||||
# {
|
||||
# "subject": "甲状腺结节",
|
||||
# "predicate": "常见症状",
|
||||
# "object": "甲状腺结节"
|
||||
# },
|
||||
# {
|
||||
# "subject": "甲状腺结节",
|
||||
# "predicate": "适用药品",
|
||||
# "object": "放射性碘治疗,复方碘口服液(Lugol液),抗甲状腺药物,硫脲类化合物,丙基硫氧嘧啶(PTU),甲基硫氧嘧啶(MTU),咪唑类的甲硫咪唑和卡比马唑"
|
||||
# }
|
||||
# ]
|
||||
# }
|
||||
if "spo" not in re_obj.keys():
|
||||
raise ValueError("SPO format error.")
|
||||
re = re_obj.get("spo", [])
|
||||
|
@ -17,6 +17,7 @@ class _BuiltInOnlineExtractor(ExtractOp):
|
||||
super().__init__(params)
|
||||
self.model = self.load_model()
|
||||
self.prompt_ops = self.load_operator()
|
||||
self.max_retry_times = int(self.params.get("max_retry_times", "3"))
|
||||
|
||||
def load_model(self):
|
||||
model_config = json.loads(self.params["model_config"])
|
||||
@ -27,63 +28,38 @@ class _BuiltInOnlineExtractor(ExtractOp):
|
||||
prompt_config = json.loads(self.params["prompt_config"])
|
||||
prompt_ops = []
|
||||
for op_config in prompt_config:
|
||||
# 创建模块规范和模块对象
|
||||
spec = importlib.util.spec_from_file_location(op_config["modulePath"], op_config["filePath"])
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
|
||||
# 加载模块
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
op_clazz = getattr(module, op_config["className"])
|
||||
op_obj = op_clazz(**op_config["params"])
|
||||
params = op_config.get("params", {})
|
||||
op_obj = op_clazz(**params)
|
||||
prompt_ops.append(op_obj)
|
||||
|
||||
return prompt_ops
|
||||
|
||||
def invoke(self, record: Dict[str, str]) -> List[SPGRecord]:
|
||||
|
||||
# 对于单条数据【record】执行多层抽取
|
||||
# 每次抽取都需要执行op.build_prompt()->model.predict()->op.parse_response()流程
|
||||
# 且每次抽取后可能得到多条结果,下次抽取需要对多条结果分别进行抽取。
|
||||
collector = []
|
||||
input_params = [record]
|
||||
# 循环所有prompt算子,算子数量决定对单条数据执行几层抽取
|
||||
for op in self.prompt_ops:
|
||||
next_params = []
|
||||
# record_list可能有多条数据,对多条数据都要进行抽取
|
||||
for input_param in input_params:
|
||||
# 生成完整query
|
||||
query = op.build_prompt(input_param)
|
||||
# 模型预测,生成模型输出结果
|
||||
# response = self.model.remote_inference(query)
|
||||
response = "test"
|
||||
# response = '{"spo": [{"subject": "甲状腺结节", "predicate": "常见症状", "object": "头疼"}]}'
|
||||
# 模型结果的后置处理,可能会拆分成多条数据 List[dict[str, str]]
|
||||
if hasattr(op, "parse_response"):
|
||||
collector.extend(op.parse_response(response))
|
||||
if hasattr(op, "build_variables"):
|
||||
next_params.extend(op.build_variables(input_param, response))
|
||||
|
||||
retry_times = 0
|
||||
while retry_times < self.max_retry_times:
|
||||
try:
|
||||
query = op.build_prompt(input_param)
|
||||
# response = self.model.remote_inference(query)
|
||||
response = "test"
|
||||
if hasattr(op, "parse_response"):
|
||||
collector.extend(op.parse_response(response))
|
||||
if hasattr(op, "build_variables"):
|
||||
next_params.extend(op.build_variables(input_param, response))
|
||||
break
|
||||
except Exception as e:
|
||||
retry_times += 1
|
||||
raise e
|
||||
input_params = next_params
|
||||
print(collector)
|
||||
|
||||
return collector
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config = {
|
||||
"invoker_type": "OpenAI",
|
||||
"openai_api_key": "EMPTY",
|
||||
"openai_api_base": "http://localhost:38000/v1",
|
||||
"openai_model_name": "vicuna-7b-v1.5",
|
||||
"openai_max_tokens": 1000
|
||||
}
|
||||
model = LLMInvoker.from_config(config)
|
||||
query = """
|
||||
已知SPO关系包括:[录音室专辑(录音室专辑)-发行年份-文本]。从下列句子中提取定义的这些关系。最终抽取结果以json格式输出。
|
||||
input:《范特西》是周杰伦的第二张音乐专辑,由周杰伦担任制作人,于2001年9月14日发行,共收录《爱在西元前》《威廉古堡》《双截棍》等10首歌曲 [1]。
|
||||
输出格式为:{"spo":[{"subject":,"predicate":,"object":},]}
|
||||
"output":
|
||||
"""
|
||||
|
||||
response = model.remote_inference(query)
|
||||
print(response)
|
||||
|
@ -52,7 +52,7 @@ class LinkOp(BaseOp, ABC):
|
||||
|
||||
bind_to: SPGTypeName
|
||||
|
||||
_bind_schemas: Dict[SPGTypeName, str] = {}
|
||||
bind_schemas: Dict[SPGTypeName, str] = {}
|
||||
|
||||
def __init__(self, params: Dict[str, str] = None):
|
||||
super().__init__(params)
|
||||
@ -81,7 +81,7 @@ class FuseOp(BaseOp, ABC):
|
||||
|
||||
bind_to: SPGTypeName
|
||||
|
||||
_bind_schemas: Dict[SPGTypeName, str] = {}
|
||||
bind_schemas: Dict[SPGTypeName, str] = {}
|
||||
|
||||
def __init__(self, params: Dict[str, str] = None):
|
||||
super().__init__(params)
|
||||
@ -95,7 +95,7 @@ class FuseOp(BaseOp, ABC):
|
||||
def _pre_process(*inputs):
|
||||
return [
|
||||
SPGRecord.from_dict(input) for input in inputs[0]
|
||||
]
|
||||
],
|
||||
|
||||
@staticmethod
|
||||
def _post_process(output) -> Dict[str, Any]:
|
||||
@ -134,7 +134,7 @@ class PredictOp(BaseOp, ABC):
|
||||
|
||||
bind_to: Tuple[SPGTypeName, PropertyName, SPGTypeName]
|
||||
|
||||
_bind_schemas: Dict[Tuple[SPGTypeName, PropertyName], str] = {}
|
||||
bind_schemas: Dict[Tuple[SPGTypeName, PropertyName], str] = {}
|
||||
|
||||
def invoke(self, subject_record: SPGRecord) -> List[SPGRecord]:
|
||||
raise NotImplementedError(
|
||||
@ -145,7 +145,7 @@ class PredictOp(BaseOp, ABC):
|
||||
def _pre_process(*inputs):
|
||||
return [
|
||||
SPGRecord.from_dict(input) for input in inputs[0]
|
||||
]
|
||||
],
|
||||
|
||||
@staticmethod
|
||||
def _post_process(output) -> Dict[str, Any]:
|
||||
|
@ -9,10 +9,7 @@ elasticsearch==8.10.0
|
||||
six==1.16.0
|
||||
click==8.1.7
|
||||
dateutils==0.6.12
|
||||
|
||||
numpy==1.24.4
|
||||
scipy==1.10.1
|
||||
scikit-learn==1.3.1
|
||||
pemja==0.4.0
|
||||
certifi==2023.11.17
|
||||
urllib3==2.1.0
|
||||
python-dateutil==2.8.2
|
||||
|
Loading…
x
Reference in New Issue
Block a user