This commit is contained in:
Qu 2023-12-22 19:49:39 +08:00
parent e0009c1ecb
commit 4e11000e2f
28 changed files with 450 additions and 234 deletions

View File

@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from knext.chain.builder_chain import BuilderChain
__all__ = [
"BuilderChain",
]

View File

@ -9,15 +9,18 @@
# Unless required by applicable law or agreed to in writing, software distributed under the License # Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. # or implied.
from knext.operator.op import BaseOp, LinkOp, ExtractOp, FuseOp, PromptOp
from knext.operator.op import LinkOp, ExtractOp, FuseOp, PromptOp, PredictOp
from knext.operator.spg_record import SPGRecord
from knext.operator.builtin.auto_prompt import SPOPrompt from knext.operator.builtin.auto_prompt import SPOPrompt
__all__ = [ __all__ = [
"BaseOp",
"ExtractOp", "ExtractOp",
"LinkOp", "LinkOp",
"FuseOp", "FuseOp",
"PromptOp", "PromptOp",
"PredictOp",
"SPOPrompt", "SPOPrompt",
"SPGRecord",
] ]

View File

@ -36,7 +36,7 @@ class BuilderClient(Client):
def submit(self, job_name: str): def submit(self, job_name: str):
"""Submit an asynchronous builder job to the server by name.""" """Submit an asynchronous builder job to the server by name."""
job = BuilderJob.by_name(job_name)() job = BuilderJob.by_name(job_name)()
builder_chain = job.build() builder_chain = BuilderChain.from_chain(job.build())
dag_config = builder_chain.to_rest() dag_config = builder_chain.to_rest()
params = { params = {

View File

@ -70,3 +70,7 @@ class BuilderJob:
return subclass return subclass
else: else:
raise ValueError(f"{name} is not a registered name for {cls.__name__}. ") raise ValueError(f"{name} is not a registered name for {cls.__name__}. ")
@classmethod
def has_registered(cls):
return cls._has_registered

View File

@ -1,24 +1,35 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from abc import ABC from abc import ABC
from typing import Type from typing import Type, List
from knext import rest from knext import rest
class RESTable(ABC): class RESTable(ABC):
@property @property
def upstream_types(self) -> Type["RESTable"]: def upstream_types(self) -> List[Type["RESTable"]]:
raise NotImplementedError("To be implemented in subclass") return []
@property @property
def downstream_types(self) -> Type["RESTable"]: def downstream_types(self) -> List[Type["RESTable"]]:
raise NotImplementedError("To be implemented in subclass") return []
def to_rest(self) -> rest.Node: def to_rest(self) -> rest.Node:
raise NotImplementedError("To be implemented in subclass") raise NotImplementedError(f"`to_rest` is not currently supported for {self.__class__.__name__}.")
@classmethod @classmethod
def from_rest(cls, node: rest.Node): def from_rest(cls, node: rest.Node):
raise NotImplementedError("To be implemented in subclass") raise NotImplementedError(f"`from_rest` is not currently supported for {cls.__name__}.")
def submit(self): def submit(self):
raise NotImplementedError("To be implemented in subclass") raise NotImplementedError(f"`submit` is not currently supported for {self.__class__.__name__}.")

View File

@ -1,4 +1,16 @@
from typing import TypeVar, Sequence, Generic, Type # -*- coding: utf-8 -*-
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from typing import TypeVar, Sequence, Type
from pydantic import BaseConfig, BaseModel from pydantic import BaseConfig, BaseModel
@ -21,7 +33,7 @@ class Runnable(BaseModel):
return return
def invoke(self, input: Input) -> Sequence[Output]: def invoke(self, input: Input) -> Sequence[Output]:
raise NotImplementedError("To be implemented in subclass") raise NotImplementedError(f"`invoke` is not currently supported for {self.__class__.__name__}.")
def __rshift__(self, other: Other): def __rshift__(self, other: Other):
raise NotImplementedError("To be implemented in subclass") raise NotImplementedError("To be implemented in subclass")

View File

@ -9,7 +9,7 @@ from knext import rest
from knext.operator.op import PromptOp, ExtractOp from knext.operator.op import PromptOp, ExtractOp
# try: # try:
from nn4k.invoker.base import LLMInvoker, NNInvoker # noqa: F403 from nn4k.invoker.base import NNInvoker # noqa: F403
# except ImportError: # except ImportError:
# pass # pass
@ -54,7 +54,7 @@ class LLMBasedExtractor(SPGExtractor):
def to_rest(self): def to_rest(self):
"""Transforms `LLMBasedExtractor` to REST model `ExtractNodeConfig`.""" """Transforms `LLMBasedExtractor` to REST model `ExtractNodeConfig`."""
params = {} params = dict()
params["model_config"] = json.dumps(self.llm._nn_config) params["model_config"] = json.dumps(self.llm._nn_config)
api_client = OperatorClient()._rest_client.api_client api_client = OperatorClient()._rest_client.api_client
params["prompt_config"] = json.dumps([api_client.sanitize_for_serialization(op.to_rest()) for op in self.prompt_ops], ensure_ascii=False) params["prompt_config"] = json.dumps([api_client.sanitize_for_serialization(op.to_rest()) for op in self.prompt_ops], ensure_ascii=False)
@ -117,7 +117,7 @@ class UserDefinedExtractor(SPGExtractor):
"""Transforms `UserDefinedExtractor` to REST model `ExtractNodeConfig`.""" """Transforms `UserDefinedExtractor` to REST model `ExtractNodeConfig`."""
operator_config = self.extract_op.to_rest() operator_config = self.extract_op.to_rest()
config = rest.UserDefinedExtractNodeConfig( config = rest.UserDefinedExtractNodeConfig(
output_fields=self.output_fields, operator_config=operator_config operator_config=operator_config
) )
return rest.Node(**super().to_dict(), node_config=config) return rest.Node(**super().to_dict(), node_config=config)

View File

@ -21,46 +21,44 @@ from knext.operator.op import LinkOp, FuseOp, PredictOp
from knext.operator.spg_record import SPGRecord from knext.operator.spg_record import SPGRecord
class MappingTypeEnum(str, Enum): class LinkingStrategyEnum(str, Enum):
SPGType = "SPG_TYPE"
Relation = "RELATION"
class LinkStrategyEnum(str, Enum):
IDEquals = "ID_EQUALS" IDEquals = "ID_EQUALS"
SPG_TYPE_BASE_FIELDS = ["id"] class FusingStrategyEnum(str, Enum):
pass
RELATION_BASE_FIELDS = ["src_id", "dst_id"]
class PredictingStrategyEnum(str, Enum):
pass
class SPGTypeMapping(Mapping): class SPGTypeMapping(Mapping):
"""A Process Component that mapping data to entity/event/concept type. """A Process Component that mapping data to entity/event/concept/standard type.
Args: Args:
spg_type_name: The SPG type name import from SPGTypeHelper. spg_type_name: The SPG type name of subject import from SPGTypeHelper.
Examples: Examples:
mapping = SPGTypeMapping( mapping = SPGTypeMapping(
spg_type_name=DEFAULT.App spg_type_name=DEFAULT.App
).add_field("id", DEFAULT.App.id) \ ).add_mapping_field("id", DEFAULT.App.id) \
.add_field("id", DEFAULT.App.name) \ .add_mapping_field("name", DEFAULT.App.name) \
.add_field("riskMark", DEFAULT.App.riskMark) \ .add_mapping_field("riskMark", DEFAULT.App.riskMark) \
.add_field("useCert", DEFAULT.App.useCert) .add_predicting_field(DEFAULT.App.useCert)
""" """
"""The SPG type name of subject import from SPGTypeHelper."""
spg_type_name: Union[str, SPGTypeHelper] spg_type_name: Union[str, SPGTypeHelper]
mapping: Dict[str, str] = dict() mapping: Dict[str, str] = dict()
filters: List[Tuple[str, str]] = list() filters: List[Tuple[str, str]] = list()
subject_fuse_strategy: Optional[FuseOp] = None subject_fusing_strategy: Optional[Union[FusingStrategyEnum, FuseOp]] = None
object_link_strategies: Dict[str, Union[LinkStrategyEnum, LinkOp]] = dict() object_linking_strategies: Dict[str, Union[LinkingStrategyEnum, LinkOp]] = dict()
predicate_predict_strategies: Dict[str, PredictOp] = dict() predicate_predicting_strategies: Dict[str, Union[PredictingStrategyEnum, PredictOp]] = dict()
@property @property
def input_types(self) -> Input: def input_types(self) -> Input:
@ -78,32 +76,34 @@ class SPGTypeMapping(Mapping):
def output_keys(self): def output_keys(self):
return self.output_fields return self.output_fields
def set_fuse_strategy(self, fuse_strategy: FuseOp): def set_fusing_strategy(self, fusing_strategy: FuseOp):
self.subject_fuse_strategy = fuse_strategy """"""
self.subject_fusing_strategy = fusing_strategy
return self return self
def add_mapping_field( def add_mapping_field(
self, self,
source_field: str, source_field: str,
target_field: Union[str, PropertyHelper], target_field: Union[str, PropertyHelper],
link_strategy: Union[LinkStrategyEnum, LinkOp] = None, linking_strategy: Union[LinkingStrategyEnum, LinkOp] = None,
): ):
"""Adds a field mapping from source data to property of spg_type. """Adds a field mapping from source data to property of spg_type.
:param source_field: The source field to be mapped. :param source_field: The source field to be mapped.
:param target_field: The target field to map the source field to. :param target_field: The target field (SPG property name) to map the source field to.
:param linking_strategy: The target field to map the source field to.
:return: self :return: self
""" """
self.mapping[target_field] = source_field self.mapping[target_field] = source_field
self.object_link_strategies[target_field] = link_strategy self.object_linking_strategies[target_field] = linking_strategy
return self return self
def add_predicting_field( def add_predicting_field(
self, self,
field: Union[str, PropertyHelper], field: Union[str, PropertyHelper],
predict_strategy: PredictOp = None, predicting_strategy: PredictOp = None,
): ):
self.predicate_predict_strategies[field] = predict_strategy self.predicate_predicting_strategies[field] = predicting_strategy
return self return self
def add_filter(self, column_name: str, column_value: str): def add_filter(self, column_name: str, column_value: str):
@ -121,6 +121,9 @@ class SPGTypeMapping(Mapping):
""" """
Transforms `SPGTypeMapping` to REST model `SpgTypeMappingNodeConfig`. Transforms `SPGTypeMapping` to REST model `SpgTypeMappingNodeConfig`.
""" """
from knext.client.schema import SchemaClient
client = SchemaClient()
spg_type = client.query_spg_type(self.spg_type_name)
mapping_filters = [ mapping_filters = [
rest.MappingFilter(column_name=name, column_value=value) rest.MappingFilter(column_name=name, column_value=value)
@ -128,17 +131,25 @@ class SPGTypeMapping(Mapping):
] ]
mapping_configs = [] mapping_configs = []
for tgt_name, src_name in self.mapping.items(): for tgt_name, src_name in self.mapping.items():
link_strategy = self.object_link_strategies.get(tgt_name, None) linking_strategy = self.object_linking_strategies.get(tgt_name, None)
if isinstance(link_strategy, LinkOp): if isinstance(linking_strategy, LinkOp):
strategy_config = rest.OperatorLinkingConfig( strategy_config = rest.OperatorLinkingConfig(
operator_config=link_strategy.to_rest() operator_config=linking_strategy.to_rest()
) )
elif link_strategy == LinkStrategyEnum.IDEquals: elif linking_strategy == LinkingStrategyEnum.IDEquals:
strategy_config = rest.IdEqualsLinkingConfig() strategy_config = rest.IdEqualsLinkingConfig()
elif not link_strategy: elif not linking_strategy:
strategy_config = None object_type_name = spg_type.properties[tgt_name].object_type_name
if object_type_name in LinkOp.bind_schemas:
op_name = LinkOp.bind_schemas[object_type_name]
op = LinkOp.by_name(op_name)()
strategy_config = rest.OperatorLinkingConfig(
operator_config=op.to_rest()
)
else:
strategy_config = None
else: else:
raise ValueError(f"Invalid link_strategy [{link_strategy}].") raise ValueError(f"Invalid linking_strategy [{linking_strategy}].")
mapping_configs.append( mapping_configs.append(
rest.MappingConfig( rest.MappingConfig(
source=src_name, source=src_name,
@ -148,37 +159,42 @@ class SPGTypeMapping(Mapping):
) )
predicting_configs = [] predicting_configs = []
for predict_strategy in self.predicate_predict_strategies: for predicate_name, predicting_strategy in self.predicate_predicting_strategies.items():
if isinstance(predict_strategy, PredictOp): if isinstance(predicting_strategy, PredictOp):
strategy_config = rest.OperatorPredictingConfig( strategy_config = rest.OperatorPredictingConfig(
operator_config=predict_strategy.to_rest() operator_config=predicting_strategy.to_rest()
) )
elif not predict_strategy: elif not predicting_strategy:
# if self.spg_type_name in PredictOp._bind_schemas: if (self.spg_type_name, predicate_name) in PredictOp.bind_schemas:
# op_name = PredictOp._bind_schemas[self.spg_type_name] op_name = PredictOp.bind_schemas[(self.spg_type_name, predicate_name)]
# op = PredictOp.by_name(op_name)() op = PredictOp.by_name(op_name)()
# strategy_config = op.to_rest() strategy_config = rest.OperatorPredictingConfig(
# else: operator_config=op.to_rest()
strategy_config = None )
else:
strategy_config = None
else: else:
raise ValueError(f"Invalid predict_strategy [{predict_strategy}].") raise ValueError(f"Invalid predicting_strategy [{predicting_strategy}].")
predicting_configs.append( if strategy_config:
strategy_config predicting_configs.append(
) strategy_config
)
if isinstance(self.subject_fuse_strategy, FuseOp): if isinstance(self.subject_fusing_strategy, FuseOp):
fusing_config = rest.OperatorFusingConfig( fusing_config = rest.OperatorFusingConfig(
operator_config=self.fuse_strategy.to_rest() operator_config=self.fusing_strategy.to_rest()
) )
elif not self.subject_fuse_strategy: elif not self.subject_fusing_strategy:
if self.spg_type_name in FuseOp._bind_schemas: if self.spg_type_name in FuseOp.bind_schemas:
op_name = FuseOp._bind_schemas[self.spg_type_name] op_name = FuseOp.bind_schemas[self.spg_type_name]
op = FuseOp.by_name(op_name)() op = FuseOp.by_name(op_name)()
fusing_config = op.to_rest() fusing_config = rest.OperatorFusingConfig(
operator_config=op.to_rest()
)
else: else:
fusing_config = None fusing_config = None
else: else:
raise ValueError(f"Invalid fuse_strategy [{self.subject_fuse_strategy}].") raise ValueError(f"Invalid fusing_strategy [{self.subject_fusing_strategy}].")
config = rest.SpgTypeMappingNodeConfig( config = rest.SpgTypeMappingNodeConfig(
spg_type=self.spg_type_name, spg_type=self.spg_type_name,
@ -190,11 +206,11 @@ class SPGTypeMapping(Mapping):
return rest.Node(**super().to_dict(), node_config=config) return rest.Node(**super().to_dict(), node_config=config)
def invoke(self, input: Input) -> Sequence[Output]: def invoke(self, input: Input) -> Sequence[Output]:
pass raise NotImplementedError(f"`invoke` method is not currently supported for {self.__class__.__name__}.")
@classmethod @classmethod
def from_rest(cls, node: rest.Node): def from_rest(cls, node: rest.Node):
pass raise NotImplementedError(f"`invoke` method is not currently supported for {cls.__name__}.")
def submit(self): def submit(self):
pass pass
@ -208,15 +224,16 @@ class RelationMapping(Mapping):
predicate_name: The predicate name. predicate_name: The predicate name.
object_name: The object name import from SPGTypeHelper. object_name: The object name import from SPGTypeHelper.
Examples: Examples:
mapping = RelationMappingComponent( mapping = RelationMapping(
subject_name=DEFAULT.App, subject_name=DEFAULT.App,
predicate_name=DEFAULT.App.useCert, predicate_name=DEFAULT.App.useCert,
object_name=DEFAULT.Cert, object_name=DEFAULT.Cert,
).add_field("src_id", "srcId") \ ).add_mapping_field("src_id", "srcId") \
.add_field("dst_id", "dstId") .add_mapping_field("dst_id", "dstId")
""" """
"""The SPG type names of (subject, predicate, object) triplet imported from SPGTypeHelper and PropertyHelper."""
subject_name: Union[str, SPGTypeHelper] subject_name: Union[str, SPGTypeHelper]
predicate_name: Union[str, PropertyHelper] predicate_name: Union[str, PropertyHelper]
object_name: Union[str, SPGTypeHelper] object_name: Union[str, SPGTypeHelper]
@ -277,19 +294,33 @@ class RelationMapping(Mapping):
class SubGraphMapping(Mapping): class SubGraphMapping(Mapping):
"""A Process Component that mapping data to relation type.
Args:
spg_type_name: The SPG type name import from SPGTypeHelper.
Examples:
mapping = SubGraphMapping(
spg_type_name=DEFAULT.App,
).add_mapping_field("id", DEFAULT.App.id) \
.add_mapping_field("name", DEFAULT.App.name) \
.add_mapping_field("useCert", DEFAULT.App.useCert)
.add_predicting_field(
"""
""""""
spg_type_name: Union[str, SPGTypeHelper] spg_type_name: Union[str, SPGTypeHelper]
mapping: Dict[str, str] = dict() mapping: Dict[str, str] = dict()
filters: List[Tuple[str, str]] = list() filters: List[Tuple[str, str]] = list()
subject_fuse_strategy: Optional[FuseOp] = None subject_fusing_strategy: Optional[FuseOp] = None
predicate_predict_strategies: Dict[str, PredictOp] = dict() predicate_predicting_strategies: Dict[str, PredictOp] = dict()
object_fuse_strategies: Dict[str, FuseOp] = dict() object_fuse_strategies: Dict[str, FuseOp] = dict()
@property @property
def input_types(self) -> Input: def input_types(self) -> Input:
return Union[Dict[str, str], SPGRecord] return Union[Dict[str, str], SPGRecord]
@ -306,15 +337,15 @@ class SubGraphMapping(Mapping):
def output_keys(self): def output_keys(self):
return self.output_fields return self.output_fields
def set_fuse_strategy(self, fuse_strategy: FuseOp): def set_fusing_strategy(self, fusing_strategy: FuseOp):
self.subject_fuse_strategy = fuse_strategy self.subject_fusing_strategy = fusing_strategy
return self return self
def add_mapping_field( def add_mapping_field(
self, self,
source_field: str, source_field: str,
target_field: Union[str, PropertyHelper], target_field: Union[str, PropertyHelper],
fuse_strategy: FuseOp = None, fusing_strategy: Union[FusingStrategyEnum, FuseOp] = None,
): ):
"""Adds a field mapping from source data to property of spg_type. """Adds a field mapping from source data to property of spg_type.
@ -323,15 +354,15 @@ class SubGraphMapping(Mapping):
:return: self :return: self
""" """
self.mapping[target_field] = source_field self.mapping[target_field] = source_field
self.object_fuse_strategies[target_field] = fuse_strategy self.object_fuse_strategies[target_field] = fusing_strategy
return self return self
def add_predicting_field( def add_predicting_field(
self, self,
target_field: Union[str, PropertyHelper], target_field: Union[str, PropertyHelper],
predict_strategy: PredictOp = None, predicting_strategy: PredictOp = None,
): ):
self.predict_strategies[target_field] = predict_strategy self.predicate_predicting_strategies[target_field] = predicting_strategy
return self return self
def add_filter(self, column_name: str, column_value: str): def add_filter(self, column_name: str, column_value: str):
@ -349,6 +380,9 @@ class SubGraphMapping(Mapping):
""" """
Transforms `SubGraphMapping` to REST model `SpgTypeMappingNodeConfig`. Transforms `SubGraphMapping` to REST model `SpgTypeMappingNodeConfig`.
""" """
from knext.client.schema import SchemaClient
client = SchemaClient()
spg_type = client.query_spg_type(self.spg_type_name)
mapping_filters = [ mapping_filters = [
rest.MappingFilter(column_name=name, column_value=value) rest.MappingFilter(column_name=name, column_value=value)
@ -356,16 +390,23 @@ class SubGraphMapping(Mapping):
] ]
mapping_configs = [] mapping_configs = []
for tgt_name, src_name in self.mapping.items(): for tgt_name, src_name in self.mapping.items():
fuse_strategy = self.object_fuse_strategies.get(tgt_name, None) fusing_strategy = self.object_fuse_strategies.get(tgt_name, None)
if isinstance(fuse_strategy, FuseOp): if isinstance(fusing_strategy, FuseOp):
strategy_config = rest.OperatorFusingConfig( strategy_config = rest.OperatorFusingConfig(
operator_config=fuse_strategy.to_rest() operator_config=fusing_strategy.to_rest()
)
elif not self.subject_fuse_strategy:
strategy_config = rest.NewInstanceFusingConfig(
) )
elif not self.subject_fusing_strategy:
object_type_name = spg_type.properties[tgt_name].object_type_name
if object_type_name in FuseOp.bind_schemas:
op_name = FuseOp.bind_schemas[object_type_name]
op = FuseOp.by_name(op_name)()
strategy_config = rest.OperatorFusingConfig(
operator_config=op.to_rest()
)
else:
strategy_config = rest.NewInstanceFusingConfig()
else: else:
raise ValueError(f"Invalid fuse_strategy [{fuse_strategy}].") raise ValueError(f"Invalid fusing_strategy [{fusing_strategy}].")
mapping_configs.append( mapping_configs.append(
rest.MappingConfig( rest.MappingConfig(
source=src_name, source=src_name,
@ -375,28 +416,42 @@ class SubGraphMapping(Mapping):
) )
predicting_configs = [] predicting_configs = []
for predict_strategy in self.predicate_predict_strategies: for predicate_name, predicting_strategy in self.predicate_predicting_strategies.items():
if isinstance(predict_strategy, PredictOp): if isinstance(predicting_strategy, PredictOp):
strategy_config = rest.OperatorPredictingConfig( strategy_config = rest.OperatorPredictingConfig(
operator_config=predict_strategy.to_rest() operator_config=predicting_strategy.to_rest()
) )
elif not predict_strategy: elif not predicting_strategy:
strategy_config = None if (self.spg_type_name, predicate_name) in PredictOp.bind_schemas:
op_name = PredictOp.bind_schemas[(self.spg_type_name, predicate_name)]
op = PredictOp.by_name(op_name)()
strategy_config = rest.OperatorPredictingConfig(
operator_config=op.to_rest()
)
else:
strategy_config = None
else: else:
raise ValueError(f"Invalid predict_strategy [{predict_strategy}].") raise ValueError(f"Invalid predicting_strategy [{predicting_strategy}].")
predicting_configs.append( if strategy_config:
strategy_config predicting_configs.append(
) strategy_config
if isinstance(self.subject_fuse_strategy, FuseOp):
fusing_config = rest.OperatorFusingConfig(
operator_config=self.fuse_strategy.to_rest()
)
elif not self.subject_fuse_strategy:
fusing_config = rest.NewInstanceFusingConfig(
) )
if isinstance(self.subject_fusing_strategy, FuseOp):
fusing_config = rest.OperatorFusingConfig(
operator_config=self.fusing_strategy.to_rest()
)
elif not self.subject_fusing_strategy:
if self.spg_type_name in FuseOp.bind_schemas:
op_name = FuseOp.bind_schemas[self.spg_type_name]
op = FuseOp.by_name(op_name)()
fusing_config = rest.OperatorFusingConfig(
operator_config=op.to_rest()
)
else:
fusing_config = rest.NewInstanceFusingConfig()
else: else:
raise ValueError(f"Invalid fuse_strategy [{self.subject_fuse_strategy}].") raise ValueError(f"Invalid fusing_strategy [{self.subject_fusing_strategy}].")
config = rest.SubGraphMappingNodeConfig( config = rest.SubGraphMappingNodeConfig(
spg_type=self.spg_type_name, spg_type=self.spg_type_name,

View File

@ -0,0 +1,20 @@
# This config file is auto generated by:
# knext project create --name 全风 --namespace Financial --desc 全风财政指标抽取
# Do not edit the config file manually.
[local]
project_name = 全风
description = 全风财政指标抽取
namespace = Financial
project_id = 4
project_dir = financial
schema_dir = schema
schema_file = financial.schema
builder_dir = builder
builder_operator_dir = builder/operator
builder_record_dir = builder/error_record
builder_job_dir = builder/job
builder_model_dir = builder/model
reasoner_dir = reasoner
reasoner_result_dir = reasoner/result

View File

@ -0,0 +1,20 @@
```bash
knext project create --name 全风 --namespace Financial --desc 全风财政指标抽取
```
```bash
knext schema commit
```
```bash
knext operator publish DemoExtractOp
```
```bash
knext builder submit Demo
```
```bash
knext reasoner query --file ./reasoner/demo.dsl
```

View File

@ -0,0 +1,2 @@
input
济南市财政收入质量及自给能力均较好但土地出让收入大幅下降致综合财力明显下滑。济南市财政收入质量及自给能力均较好但土地出让收入大幅下降致综合财力明显下滑。2022年济南市一般公共预算收入1,000.21亿元,扣除留 抵退税因素后同比增长8%规模在山东省下辖地市中排名第2位其中税收收入690.31亿元税收占比69.02%;一般公共 预算支出1,260.23亿元财政自给率79.37%。政府性基金收入547.29亿元同比大幅下降48.38%,主要系土地出让收入 同比由966.74亿元降至453.74亿元转移性收入285.78亿元上年同期为233.11亿元综合财力约1,833.28亿元(上年 同期为2,301.02亿元)。
1 input
2 济南市财政收入质量及自给能力均较好,但土地出让收入大幅下降致综合财力明显下滑。济南市财政收入质量及自给能力均较好,但土地出让收入大幅下降致综合财力明显下滑。2022年济南市一般公共预算收入1,000.21亿元,扣除留 抵退税因素后同比增长8%,规模在山东省下辖地市中排名第2位;其中税收收入690.31亿元,税收占比69.02%;一般公共 预算支出1,260.23亿元,财政自给率79.37%。政府性基金收入547.29亿元,同比大幅下降48.38%,主要系土地出让收入 同比由966.74亿元降至453.74亿元;转移性收入285.78亿元(上年同期为233.11亿元);综合财力约1,833.28亿元(上年 同期为2,301.02亿元)。

View File

@ -0,0 +1,43 @@
# -*- coding: utf-8 -*-
from knext.client.model.builder_job import BuilderJob
from knext.api.component import CSVReader, SPGTypeMapping, KGWriter
from knext.component.builder import LLMBasedExtractor, SubGraphMapping
from nn4k.invoker import LLMInvoker
try:
from schema.financial_schema_helper import Financial
except:
pass
class StateAndIndicator(BuilderJob):
def build(self):
source = CSVReader(
local_path="/Users/jier/openspg/python/knext/examples/financial/builder/job/data/document.csv",
columns=["input"],
start_row=2
)
from knext.examples.financial.builder.operator.IndicatorNER import IndicatorNER
from knext.examples.financial.builder.operator.IndicatorREL import IndicatorREL
from knext.examples.financial.builder.operator.IndicatorLOGIC import IndicatorLOGIC
extract = LLMBasedExtractor(llm=LLMInvoker.from_config("/Users/jier/openspg/python/knext/examples/financial/builder/model/openai_infer.json"),
prompt_ops=[IndicatorNER(), IndicatorREL(), IndicatorLOGIC()]
)
state_mapping = SubGraphMapping(spg_type_name="Financial.State")\
.add_mapping_field("id", "id") \
.add_mapping_field("name", "name") \
.add_mapping_field("causeOf", "causeOf") \
.add_predicting_field("derivedFrom")
indicator_mapping = SubGraphMapping(spg_type_name="Financial.Indicator")\
.add_mapping_field("id", "id") \
.add_mapping_field("name", "name")
# .add_predicting_field("isA")
sink = KGWriter()
return source >> extract >> [state_mapping, indicator_mapping] >> sink

View File

@ -14,6 +14,8 @@ class IndicatorFuse(FuseOp):
self.search_client = SearchClient("Financial.Indicator") self.search_client = SearchClient("Financial.Indicator")
def invoke(self, subject_records: List[SPGRecord]) -> List[SPGRecord]: def invoke(self, subject_records: List[SPGRecord]) -> List[SPGRecord]:
print("##########IndicatorFuse###########")
print(subject_records)
fused_records = [] fused_records = []
for record in subject_records: for record in subject_records:
query = {"match": {"name": record.get_property("name", "")}} query = {"match": {"name": record.get_property("name", "")}}

View File

@ -43,10 +43,11 @@ ${rel}
""" """
response: "[{\"subject\": \"土地出让收入大幅下降\", \"predicate\": \"顺承\", \"object\": [\"综合财力明显下滑\"]}]" response: "[{\"subject\": \"土地出让收入大幅下降\", \"predicate\": \"顺承\", \"object\": [\"综合财力明显下滑\"]}]"
""" """
print("##########IndicatorLOGIC###########")
response = "[{\"subject\": \"土地出让收入大幅下降\", \"predicate\": \"顺承\", \"object\": [\"综合财力明显下滑\"]}]"
output_list = json.loads(response) output_list = json.loads(response)
logic_result = [] logic_result = []
# IF hasA
for output in output_list: for output in output_list:
properties = {} properties = {}
for k, v in output.items(): for k, v in output.items():
@ -55,6 +56,6 @@ ${rel}
properties["name"] = k properties["name"] = k
elif k == "object": elif k == "object":
properties["causeOf"] = ','.join(v) properties["causeOf"] = ','.join(v)
logic_result.append(SPGRecord("FEL.State", properties=properties)) logic_result.append(SPGRecord("Financial.State", properties=properties))
return logic_result return logic_result

View File

@ -22,29 +22,18 @@ class IndicatorNER(PromptOp):
def parse_response( def parse_response(
self, response: str self, response: str
) -> List[SPGRecord]: ) -> List[SPGRecord]:
# output_list = json.loads(response)
#
# ner_result = []
# # IF hasA
# for output in output_list:
# # {'财政': ['财政收入....}
# for k, v in output.items():
# # '财政', ['财政收入....]
# ner_result.append(SPGRecord("FEL.Indicator", properties={"id": k, "name": k, "hasA": ','.join(v)}))
#
# # ELSE isA
# # TODO 通过属性isA支持
# for output in output_list:
# # {'财政': ['财政收入....}
# for k, v in output.items():
# # '财政', ['财政收入....]
# for _v in v:
# # '财政收入....'
# ner_result.append(SPGRecord("FEL.Indicator", properties={"id": f'{k}-{_v}', "name": _v}))
print("##########IndicatorNER###########")
ner_result = [SPGRecord(spg_type_name="Financial.Indicator", properties={"id": "土地出让收入", "name": "土地出让收入"})]
print(ner_result)
print("##########IndicatorNER###########") print("##########IndicatorNER###########")
response = "[{'财政': ['财政收入质量', '财政自给能力', '土地出让收入', '一般公共预算收入', '留抵退税', '税收收入', '税收收入/一般公共预算收入', '一般公共预算支出', '财政自给率', '政府性基金收入', '转移性收入', '综合财力']}]"
output_list = json.loads(response.replace("'", "\""))
ner_result = []
# IF hasA
for output in output_list:
# {'财政': ['财政收入....}
for category, indicator_list in output.items():
# '财政', ['财政收入....]
for indicator in indicator_list:
ner_result.append(SPGRecord("Financial.Indicator", properties={"id": indicator, "name": indicator}))
return ner_result return ner_result
def build_next_variables( def build_next_variables(
@ -53,5 +42,6 @@ class IndicatorNER(PromptOp):
""" """
response: "[{'subject': '一般公共预算收入', 'predicate': '包含', 'object': ['税收收入']}, {'subject': '税收收入', 'predicate': '包含', 'object': ['留抵退税']}, {'subject': '政府性基金收入', 'predicate': '包含', 'object': ['土地出让收入', '转移性收入']}, {'subject': '综合财力', 'predicate': '包含', 'object': ['一般公共预算收入', '政府性基金收入']}]" response: "[{'subject': '一般公共预算收入', 'predicate': '包含', 'object': ['税收收入']}, {'subject': '税收收入', 'predicate': '包含', 'object': ['留抵退税']}, {'subject': '政府性基金收入', 'predicate': '包含', 'object': ['土地出让收入', '转移性收入']}, {'subject': '综合财力', 'predicate': '包含', 'object': ['一般公共预算收入', '政府性基金收入']}]"
""" """
response = "" response = "[{'财政': ['财政收入质量', '财政自给能力', '土地出让收入', '一般公共预算收入', '留抵退税', '税收收入', '税收收入/一般公共预算收入', '一般公共预算支出', '财政自给率', '政府性基金收入', '转移性收入', '综合财力']}]"
return [{"input": variables["input"], "ner": response}] return [{"input": variables["input"], "ner": response}]

View File

@ -11,15 +11,13 @@ class IndicatorPredict(PredictOp):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
# self.search_client = SearchClient("Financial.Indicator") self.search_client = SearchClient("Financial.Indicator")
def invoke(self, subject_record: SPGRecord) -> List[SPGRecord]: def invoke(self, subject_record: SPGRecord) -> List[SPGRecord]:
# query = {"match": {"name": subject_record.get_property("name", "")}}
# recall_records = self.search_client.search(query, start=0, size=10)
# if recall_records is not None and len(recall_records) > 0:
# rerank_record = SPGRecord("Financial.Indicator", {"id": recall_records[0].doc_id, "name": recall_records[0].properties.get("name", "")})
# return [rerank_record]
# return []
print("##########IndicatorPredict###########") print("##########IndicatorPredict###########")
query = {"match": {"name": subject_record.get_property("name", "")}}
return [subject_record] recall_records = self.search_client.search(query, start=0, size=10)
if recall_records is not None and len(recall_records) > 0:
rerank_record = SPGRecord("Financial.Indicator", {"id": recall_records[0].doc_id, "name": recall_records[0].properties.get("name", "")})
return [rerank_record]
return []

View File

@ -38,4 +38,5 @@ ${ner}
""" """
response: "[{'subject': '一般公共预算收入', 'predicate': '包含', 'object': ['税收收入']}, {'subject': '税收收入', 'predicate': '包含', 'object': ['留抵退税']}, {'subject': '政府性基金收入', 'predicate': '包含', 'object': ['土地出让收入', '转移性收入']}, {'subject': '综合财力', 'predicate': '包含', 'object': ['一般公共预算收入', '政府性基金收入']}]" response: "[{'subject': '一般公共预算收入', 'predicate': '包含', 'object': ['税收收入']}, {'subject': '税收收入', 'predicate': '包含', 'object': ['留抵退税']}, {'subject': '政府性基金收入', 'predicate': '包含', 'object': ['土地出让收入', '转移性收入']}, {'subject': '综合财力', 'predicate': '包含', 'object': ['一般公共预算收入', '政府性基金收入']}]"
""" """
response = "[{'subject': '一般公共预算收入', 'predicate': '包含', 'object': ['税收收入']}, {'subject': '税收收入', 'predicate': '包含', 'object': ['留抵退税']}, {'subject': '政府性基金收入', 'predicate': '包含', 'object': ['土地出让收入', '转移性收入']}, {'subject': '综合财力', 'predicate': '包含', 'object': ['一般公共预算收入', '政府性基金收入']}]"
return [{"input": variables["input"], "ner": variables["ner"], "rel": response}] return [{"input": variables["input"], "ner": variables["ner"], "rel": response}]

View File

@ -14,6 +14,7 @@ class StateFuse(FuseOp):
self.search_client = SearchClient("Financial.State") self.search_client = SearchClient("Financial.State")
def invoke(self, subject_records: List[SPGRecord]) -> List[SPGRecord]: def invoke(self, subject_records: List[SPGRecord]) -> List[SPGRecord]:
print("##########StateFuse###########")
fused_records = [] fused_records = []
for record in subject_records: for record in subject_records:
query = {"match": {"name": record.get_property("name", "")}} query = {"match": {"name": record.get_property("name", "")}}

View File

@ -0,0 +1,12 @@
namespace Financial
Indicator(指标概念): ConceptType
hypernymPredicate: isA
State(状态): ConceptType
desc: 指标状态
properties:
causeOf(导致): State
desc: 状态顺承关系
derivedFrom(指标): Indicator
desc: 状态的指标

View File

@ -0,0 +1,34 @@
# ATTENTION!
# This file is generated by Schema automatically, it will be refreshed after schema has been committed
# PLEASE DO NOT MODIFY THIS FILE!!!
#
class Financial:
def __init__(self):
self.Indicator = self.Indicator()
self.State = self.State()
class Indicator:
__typename__ = "Financial.Indicator"
description = "description"
id = "id"
name = "name"
alias = "alias"
stdId = "stdId"
def __init__(self):
pass
class State:
__typename__ = "Financial.State"
description = "description"
id = "id"
name = "name"
stdId = "stdId"
derivedFrom = "derivedFrom"
causeOf = "causeOf"
alias = "alias"
def __init__(self):
pass

View File

@ -1,2 +1 @@
content
甲状腺结节是指在甲状腺内的肿块可随吞咽动作随甲状腺而上下移动是临床常见的病症可由多种病因引起。临床上有多种甲状腺疾病如甲状腺退行性变、炎症、自身免疫以及新生物等都可以表现为结节。甲状腺结节可以单发也可以多发多发结节比单发结节的发病率高但单发结节甲状腺癌的发生率较高。患者通常可以选择在普外科甲状腺外科内分泌科头颈外科挂号就诊。有些患者可以触摸到自己颈部前方的结节。在大多情况下甲状腺结节没有任何症状甲状腺功能也是正常的。甲状腺结节进展为其它甲状腺疾病的概率只有1%。有些人会感觉到颈部疼痛、咽喉部异物感,或者存在压迫感。当甲状腺结节发生囊内自发性出血时,疼痛感会更加强烈。治疗方面,一般情况下可以用放射性碘治疗,复方碘口服液(Lugol液)等,或者服用抗甲状腺药物来抑制甲状腺激素的分泌。目前常用的抗甲状腺药物是硫脲类化合物,包括硫氧嘧啶类的丙基硫氧嘧啶(PTU)和甲基硫氧嘧啶(MTU)及咪唑类的甲硫咪唑和卡比马唑。 甲状腺结节是指在甲状腺内的肿块可随吞咽动作随甲状腺而上下移动是临床常见的病症可由多种病因引起。临床上有多种甲状腺疾病如甲状腺退行性变、炎症、自身免疫以及新生物等都可以表现为结节。甲状腺结节可以单发也可以多发多发结节比单发结节的发病率高但单发结节甲状腺癌的发生率较高。患者通常可以选择在普外科甲状腺外科内分泌科头颈外科挂号就诊。有些患者可以触摸到自己颈部前方的结节。在大多情况下甲状腺结节没有任何症状甲状腺功能也是正常的。甲状腺结节进展为其它甲状腺疾病的概率只有1%。有些人会感觉到颈部疼痛、咽喉部异物感,或者存在压迫感。当甲状腺结节发生囊内自发性出血时,疼痛感会更加强烈。治疗方面,一般情况下可以用放射性碘治疗,复方碘口服液(Lugol液)等,或者服用抗甲状腺药物来抑制甲状腺激素的分泌。目前常用的抗甲状腺药物是硫脲类化合物,包括硫氧嘧啶类的丙基硫氧嘧啶(PTU)和甲基硫氧嘧啶(MTU)及咪唑类的甲硫咪唑和卡比马唑。
1 content 甲状腺结节是指在甲状腺内的肿块,可随吞咽动作随甲状腺而上下移动,是临床常见的病症,可由多种病因引起。临床上有多种甲状腺疾病,如甲状腺退行性变、炎症、自身免疫以及新生物等都可以表现为结节。甲状腺结节可以单发,也可以多发,多发结节比单发结节的发病率高,但单发结节甲状腺癌的发生率较高。患者通常可以选择在普外科,甲状腺外科,内分泌科,头颈外科挂号就诊。有些患者可以触摸到自己颈部前方的结节。在大多情况下,甲状腺结节没有任何症状,甲状腺功能也是正常的。甲状腺结节进展为其它甲状腺疾病的概率只有1%。有些人会感觉到颈部疼痛、咽喉部异物感,或者存在压迫感。当甲状腺结节发生囊内自发性出血时,疼痛感会更加强烈。治疗方面,一般情况下可以用放射性碘治疗,复方碘口服液(Lugol液)等,或者服用抗甲状腺药物来抑制甲状腺激素的分泌。目前常用的抗甲状腺药物是硫脲类化合物,包括硫氧嘧啶类的丙基硫氧嘧啶(PTU)和甲基硫氧嘧啶(MTU)及咪唑类的甲硫咪唑和卡比马唑。
content
1 甲状腺结节是指在甲状腺内的肿块,可随吞咽动作随甲状腺而上下移动,是临床常见的病症,可由多种病因引起。临床上有多种甲状腺疾病,如甲状腺退行性变、炎症、自身免疫以及新生物等都可以表现为结节。甲状腺结节可以单发,也可以多发,多发结节比单发结节的发病率高,但单发结节甲状腺癌的发生率较高。患者通常可以选择在普外科,甲状腺外科,内分泌科,头颈外科挂号就诊。有些患者可以触摸到自己颈部前方的结节。在大多情况下,甲状腺结节没有任何症状,甲状腺功能也是正常的。甲状腺结节进展为其它甲状腺疾病的概率只有1%。有些人会感觉到颈部疼痛、咽喉部异物感,或者存在压迫感。当甲状腺结节发生囊内自发性出血时,疼痛感会更加强烈。治疗方面,一般情况下可以用放射性碘治疗,复方碘口服液(Lugol液)等,或者服用抗甲状腺药物来抑制甲状腺激素的分泌。目前常用的抗甲状腺药物是硫脲类化合物,包括硫氧嘧啶类的丙基硫氧嘧啶(PTU)和甲基硫氧嘧啶(MTU)及咪唑类的甲硫咪唑和卡比马唑。 甲状腺结节是指在甲状腺内的肿块,可随吞咽动作随甲状腺而上下移动,是临床常见的病症,可由多种病因引起。临床上有多种甲状腺疾病,如甲状腺退行性变、炎症、自身免疫以及新生物等都可以表现为结节。甲状腺结节可以单发,也可以多发,多发结节比单发结节的发病率高,但单发结节甲状腺癌的发生率较高。患者通常可以选择在普外科,甲状腺外科,内分泌科,头颈外科挂号就诊。有些患者可以触摸到自己颈部前方的结节。在大多情况下,甲状腺结节没有任何症状,甲状腺功能也是正常的。甲状腺结节进展为其它甲状腺疾病的概率只有1%。有些人会感觉到颈部疼痛、咽喉部异物感,或者存在压迫感。当甲状腺结节发生囊内自发性出血时,疼痛感会更加强烈。治疗方面,一般情况下可以用放射性碘治疗,复方碘口服液(Lugol液)等,或者服用抗甲状腺药物来抑制甲状腺激素的分泌。目前常用的抗甲状腺药物是硫脲类化合物,包括硫氧嘧啶类的丙基硫氧嘧啶(PTU)和甲基硫氧嘧啶(MTU)及咪唑类的甲硫咪唑和卡比马唑。

View File

@ -13,10 +13,11 @@
from knext.client.model.builder_job import BuilderJob from knext.client.model.builder_job import BuilderJob
from knext.api.component import ( from knext.api.component import (
CSVReader, CSVReader,
LLMBasedExtractor,
SubGraphMapping,
KGWriter KGWriter
) )
from knext.component.builder import LLMBasedExtractor, SubGraphMapping from knext.api.operator import SPOPrompt
from knext.operator.builtin.auto_prompt import SPOPrompt
from nn4k.invoker import LLMInvoker from nn4k.invoker import LLMInvoker
@ -26,17 +27,26 @@ class Disease(BuilderJob):
1. 定义输入源CSV文件 1. 定义输入源CSV文件
""" """
source = CSVReader( source = CSVReader(
local_path="job/data/Disease.csv", local_path="builder/job/data/Disease.csv",
columns=["id", "input"], columns=["input"],
start_row=2, start_row=1,
) )
""" """
2. 定义大模型抽取组件从长文本中抽取Medical.Disease类型实体 2. 定义大模型抽取组件从长文本中抽取Medical.Disease类型实体
""" """
extract = LLMBasedExtractor(
extract = LLMBasedExtractor(llm=LLMInvoker.from_config("openai_infer.json"), llm=LLMInvoker.from_config("builder/model/openai_infer.json"),
prompt_ops=[SPOPrompt("Medical.Disease", ["commonSymptom", "applicableDrug"])]) prompt_ops=[SPOPrompt(
spg_type_name="Medical.Disease",
property_names=[
"complication",
"commonSymptom",
"applicableDrug",
"department",
"diseaseSite",
])]
)
""" """
2. 定义子图映射组件 2. 定义子图映射组件
@ -44,8 +54,11 @@ class Disease(BuilderJob):
mapping = SubGraphMapping(spg_type_name="Medical.Disease") \ mapping = SubGraphMapping(spg_type_name="Medical.Disease") \
.add_mapping_field("id", "id") \ .add_mapping_field("id", "id") \
.add_mapping_field("name", "name") \ .add_mapping_field("name", "name") \
.add_mapping_field("complication", "complication") \
.add_mapping_field("commonSymptom", "commonSymptom") \ .add_mapping_field("commonSymptom", "commonSymptom") \
.add_mapping_field("applicableDrug", "applicableDrug") .add_mapping_field("applicableDrug", "applicableDrug") \
.add_mapping_field("department", "department") \
.add_mapping_field("diseaseSite", "diseaseSite")
""" """
4. 定义输出到图谱 4. 定义输出到图谱
@ -56,4 +69,3 @@ class Disease(BuilderJob):
5. 定义builder_chain 5. 定义builder_chain
""" """
return source >> extract >> mapping >> sink return source >> extract >> mapping >> sink

View File

@ -78,7 +78,7 @@ class BaseOp(ABC):
) )
cls._registry[name] = subclass cls._registry[name] = subclass
if hasattr(subclass, "bind_to"): if hasattr(subclass, "bind_to"):
subclass.__bases__[0]._bind_schemas[subclass.bind_to] = name subclass.__bases__[0].bind_schemas[subclass.bind_to] = name
return subclass return subclass
return add_subclass_to_registry return add_subclass_to_registry
@ -104,3 +104,7 @@ class BaseOp(ABC):
method="_handle", method="_handle",
params=self.params, params=self.params,
) )
@property
def has_registered(self):
return self._has_registered

View File

@ -48,23 +48,24 @@ input:${input}
return self.template.replace("${input}", variables.get("input")) return self.template.replace("${input}", variables.get("input"))
def parse_response(self, response: str) -> List[SPGRecord]: def parse_response(self, response: str) -> List[SPGRecord]:
print(response)
result = [] result = []
subject = {} subject = {}
# re_obj = json.loads(response) re_obj = json.loads(response)
re_obj = { # re_obj = {
"spo": [ # "spo": [
{ # {
"subject": "甲状腺结节", # "subject": "甲状腺结节",
"predicate": "常见症状", # "predicate": "常见症状",
"object": "甲状腺结节" # "object": "甲状腺结节"
}, # },
{ # {
"subject": "甲状腺结节", # "subject": "甲状腺结节",
"predicate": "适用药品", # "predicate": "适用药品",
"object": "放射性碘治疗,复方碘口服液(Lugol液),抗甲状腺药物,硫脲类化合物,丙基硫氧嘧啶(PTU),甲基硫氧嘧啶(MTU),咪唑类的甲硫咪唑和卡比马唑" # "object": "放射性碘治疗,复方碘口服液(Lugol液),抗甲状腺药物,硫脲类化合物,丙基硫氧嘧啶(PTU),甲基硫氧嘧啶(MTU),咪唑类的甲硫咪唑和卡比马唑"
} # }
] # ]
} # }
if "spo" not in re_obj.keys(): if "spo" not in re_obj.keys():
raise ValueError("SPO format error.") raise ValueError("SPO format error.")
subject_properties = {} subject_properties = {}
@ -95,21 +96,21 @@ input:${input}
return result return result
def build_variables(self, variables: Dict[str, str], response: str) -> List[Dict[str, str]]: def build_variables(self, variables: Dict[str, str], response: str) -> List[Dict[str, str]]:
# re_obj = json.loads(response) re_obj = json.loads(response)
re_obj = { # re_obj = {
"spo": [ # "spo": [
{ # {
"subject": "甲状腺结节", # "subject": "甲状腺结节",
"predicate": "常见症状", # "predicate": "常见症状",
"object": "甲状腺结节" # "object": "甲状腺结节"
}, # },
{ # {
"subject": "甲状腺结节", # "subject": "甲状腺结节",
"predicate": "适用药品", # "predicate": "适用药品",
"object": "放射性碘治疗,复方碘口服液(Lugol液),抗甲状腺药物,硫脲类化合物,丙基硫氧嘧啶(PTU),甲基硫氧嘧啶(MTU),咪唑类的甲硫咪唑和卡比马唑" # "object": "放射性碘治疗,复方碘口服液(Lugol液),抗甲状腺药物,硫脲类化合物,丙基硫氧嘧啶(PTU),甲基硫氧嘧啶(MTU),咪唑类的甲硫咪唑和卡比马唑"
} # }
] # ]
} # }
if "spo" not in re_obj.keys(): if "spo" not in re_obj.keys():
raise ValueError("SPO format error.") raise ValueError("SPO format error.")
re = re_obj.get("spo", []) re = re_obj.get("spo", [])

View File

@ -17,6 +17,7 @@ class _BuiltInOnlineExtractor(ExtractOp):
super().__init__(params) super().__init__(params)
self.model = self.load_model() self.model = self.load_model()
self.prompt_ops = self.load_operator() self.prompt_ops = self.load_operator()
self.max_retry_times = int(self.params.get("max_retry_times", "3"))
def load_model(self): def load_model(self):
model_config = json.loads(self.params["model_config"]) model_config = json.loads(self.params["model_config"])
@ -27,63 +28,38 @@ class _BuiltInOnlineExtractor(ExtractOp):
prompt_config = json.loads(self.params["prompt_config"]) prompt_config = json.loads(self.params["prompt_config"])
prompt_ops = [] prompt_ops = []
for op_config in prompt_config: for op_config in prompt_config:
# 创建模块规范和模块对象
spec = importlib.util.spec_from_file_location(op_config["modulePath"], op_config["filePath"]) spec = importlib.util.spec_from_file_location(op_config["modulePath"], op_config["filePath"])
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
# 加载模块
spec.loader.exec_module(module) spec.loader.exec_module(module)
op_clazz = getattr(module, op_config["className"]) op_clazz = getattr(module, op_config["className"])
op_obj = op_clazz(**op_config["params"]) params = op_config.get("params", {})
op_obj = op_clazz(**params)
prompt_ops.append(op_obj) prompt_ops.append(op_obj)
return prompt_ops return prompt_ops
def invoke(self, record: Dict[str, str]) -> List[SPGRecord]: def invoke(self, record: Dict[str, str]) -> List[SPGRecord]:
# 对于单条数据【record】执行多层抽取
# 每次抽取都需要执行op.build_prompt()->model.predict()->op.parse_response()流程
# 且每次抽取后可能得到多条结果,下次抽取需要对多条结果分别进行抽取。
collector = [] collector = []
input_params = [record] input_params = [record]
# 循环所有prompt算子算子数量决定对单条数据执行几层抽取
for op in self.prompt_ops: for op in self.prompt_ops:
next_params = [] next_params = []
# record_list可能有多条数据对多条数据都要进行抽取
for input_param in input_params: for input_param in input_params:
# 生成完整query retry_times = 0
query = op.build_prompt(input_param) while retry_times < self.max_retry_times:
# 模型预测,生成模型输出结果 try:
# response = self.model.remote_inference(query) query = op.build_prompt(input_param)
response = "test" # response = self.model.remote_inference(query)
# response = '{"spo": [{"subject": "甲状腺结节", "predicate": "常见症状", "object": "头疼"}]}' response = "test"
# 模型结果的后置处理,可能会拆分成多条数据 List[dict[str, str]] if hasattr(op, "parse_response"):
if hasattr(op, "parse_response"): collector.extend(op.parse_response(response))
collector.extend(op.parse_response(response)) if hasattr(op, "build_variables"):
if hasattr(op, "build_variables"): next_params.extend(op.build_variables(input_param, response))
next_params.extend(op.build_variables(input_param, response)) break
except Exception as e:
retry_times += 1
raise e
input_params = next_params input_params = next_params
print(collector)
return collector return collector
if __name__ == '__main__':
config = {
"invoker_type": "OpenAI",
"openai_api_key": "EMPTY",
"openai_api_base": "http://localhost:38000/v1",
"openai_model_name": "vicuna-7b-v1.5",
"openai_max_tokens": 1000
}
model = LLMInvoker.from_config(config)
query = """
已知SPO关系包括:[录音室专辑(录音室专辑)-发行年份-文本]从下列句子中提取定义的这些关系最终抽取结果以json格式输出
input:范特西是周杰伦的第二张音乐专辑由周杰伦担任制作人于2001年9月14日发行共收录爱在西元前威廉古堡双截棍等10首歌曲 [1]
输出格式为:{"spo":[{"subject":,"predicate":,"object":},]}
"output":
"""
response = model.remote_inference(query)
print(response)

View File

@ -52,7 +52,7 @@ class LinkOp(BaseOp, ABC):
bind_to: SPGTypeName bind_to: SPGTypeName
_bind_schemas: Dict[SPGTypeName, str] = {} bind_schemas: Dict[SPGTypeName, str] = {}
def __init__(self, params: Dict[str, str] = None): def __init__(self, params: Dict[str, str] = None):
super().__init__(params) super().__init__(params)
@ -81,7 +81,7 @@ class FuseOp(BaseOp, ABC):
bind_to: SPGTypeName bind_to: SPGTypeName
_bind_schemas: Dict[SPGTypeName, str] = {} bind_schemas: Dict[SPGTypeName, str] = {}
def __init__(self, params: Dict[str, str] = None): def __init__(self, params: Dict[str, str] = None):
super().__init__(params) super().__init__(params)
@ -95,7 +95,7 @@ class FuseOp(BaseOp, ABC):
def _pre_process(*inputs): def _pre_process(*inputs):
return [ return [
SPGRecord.from_dict(input) for input in inputs[0] SPGRecord.from_dict(input) for input in inputs[0]
] ],
@staticmethod @staticmethod
def _post_process(output) -> Dict[str, Any]: def _post_process(output) -> Dict[str, Any]:
@ -134,7 +134,7 @@ class PredictOp(BaseOp, ABC):
bind_to: Tuple[SPGTypeName, PropertyName, SPGTypeName] bind_to: Tuple[SPGTypeName, PropertyName, SPGTypeName]
_bind_schemas: Dict[Tuple[SPGTypeName, PropertyName], str] = {} bind_schemas: Dict[Tuple[SPGTypeName, PropertyName], str] = {}
def invoke(self, subject_record: SPGRecord) -> List[SPGRecord]: def invoke(self, subject_record: SPGRecord) -> List[SPGRecord]:
raise NotImplementedError( raise NotImplementedError(
@ -145,7 +145,7 @@ class PredictOp(BaseOp, ABC):
def _pre_process(*inputs): def _pre_process(*inputs):
return [ return [
SPGRecord.from_dict(input) for input in inputs[0] SPGRecord.from_dict(input) for input in inputs[0]
] ],
@staticmethod @staticmethod
def _post_process(output) -> Dict[str, Any]: def _post_process(output) -> Dict[str, Any]:

View File

@ -9,10 +9,7 @@ elasticsearch==8.10.0
six==1.16.0 six==1.16.0
click==8.1.7 click==8.1.7
dateutils==0.6.12 dateutils==0.6.12
pemja==0.4.0
numpy==1.24.4
scipy==1.10.1
scikit-learn==1.3.1
certifi==2023.11.17 certifi==2023.11.17
urllib3==2.1.0 urllib3==2.1.0
python-dateutil==2.8.2 python-dateutil==2.8.2