mirror of
https://github.com/OpenSPG/openspg.git
synced 2025-11-02 19:13:43 +00:00
fix(deepkeprompt): add deepkeprompt (#149)
Co-authored-by: Qu <qy266141@antgroup.com>
This commit is contained in:
parent
1982db26ad
commit
5f46f3b658
@ -13,7 +13,13 @@
|
||||
from knext.builder.operator.op import LinkOp, ExtractOp, FuseOp, PromptOp, PredictOp
|
||||
from knext.builder.operator.spg_record import SPGRecord
|
||||
from knext.builder.operator.builtin.auto_prompt import REPrompt, EEPrompt
|
||||
|
||||
from knext.builder.operator.builtin.deepke_prompt import (
|
||||
OneKE_NERPrompt,
|
||||
OneKE_REPrompt,
|
||||
OneKE_SPOPrompt,
|
||||
OneKE_KGPrompt,
|
||||
OneKE_EEPrompt,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ExtractOp",
|
||||
@ -24,4 +30,9 @@ __all__ = [
|
||||
"SPGRecord",
|
||||
"REPrompt",
|
||||
"EEPrompt",
|
||||
"OneKE_NERPrompt",
|
||||
"OneKE_REPrompt",
|
||||
"OneKE_SPOPrompt",
|
||||
"OneKE_KGPrompt",
|
||||
"OneKE_EEPrompt",
|
||||
]
|
||||
|
||||
@ -15,7 +15,9 @@ import re
|
||||
from abc import ABC
|
||||
from typing import List, Dict, Tuple
|
||||
|
||||
|
||||
from knext.schema.client import SchemaClient
|
||||
from knext.schema.model.base import BaseSpgType
|
||||
from knext.schema.model.schema_helper import SPGTypeName, PropertyName, RelationName
|
||||
from knext.builder.operator.op import PromptOp
|
||||
from knext.builder.operator.spg_record import SPGRecord
|
||||
@ -23,15 +25,21 @@ import uuid
|
||||
|
||||
|
||||
class AutoPrompt(PromptOp, ABC):
|
||||
spg_type_name: SPGTypeName
|
||||
spg_types: List[BaseSpgType]
|
||||
|
||||
def __init__(self, spg_type_names: List[SPGTypeName]):
|
||||
super().__init__()
|
||||
self.spg_types = []
|
||||
schema_session = SchemaClient().create_session()
|
||||
for spg_type_name in spg_type_names:
|
||||
spg_type = schema_session.get(spg_type_name=spg_type_name)
|
||||
self.spg_types.append(spg_type)
|
||||
|
||||
def _init_render_variables(self):
|
||||
schema_session = SchemaClient().create_session()
|
||||
spg_type = schema_session.get(spg_type_name=self.spg_type_name)
|
||||
self.property_info_en = {}
|
||||
self.property_info_zh = {}
|
||||
self.relation_info_en = {}
|
||||
self.property_info_en = {}
|
||||
self.relation_info_zh = {}
|
||||
self.relation_info_en = {}
|
||||
self.spg_type_schema_info_en = {
|
||||
"Text": ("文本", None),
|
||||
"Integer": ("整型", None),
|
||||
@ -42,33 +50,36 @@ class AutoPrompt(PromptOp, ABC):
|
||||
"整型": ("Integer", None),
|
||||
"浮点型": ("Float", None),
|
||||
}
|
||||
for _rel in spg_type.relations.values():
|
||||
if _rel.is_dynamic:
|
||||
continue
|
||||
self.relation_info_zh[_rel.name_zh] = (
|
||||
_rel.name,
|
||||
_rel.desc,
|
||||
_rel.object_type_name,
|
||||
)
|
||||
self.relation_info_en[_rel.name] = (
|
||||
_rel.name_zh,
|
||||
_rel.desc,
|
||||
_rel.object_type_name,
|
||||
)
|
||||
for _prop in spg_type.properties.values():
|
||||
self.property_info_zh[_prop.name_zh] = (
|
||||
_prop.name,
|
||||
_prop.desc,
|
||||
_prop.object_type_name,
|
||||
)
|
||||
self.property_info_en[_prop.name] = (
|
||||
_prop.name_zh,
|
||||
_prop.desc,
|
||||
_prop.object_type_name,
|
||||
)
|
||||
for _type in schema_session.spg_types.values():
|
||||
if _type.name in ["Text", "Integer", "Float"]:
|
||||
continue
|
||||
for spg_type in self.spg_types:
|
||||
self.property_info_zh[spg_type.name_zh] = {}
|
||||
self.relation_info_zh[spg_type.name_zh] = {}
|
||||
self.property_info_en[spg_type.name] = {}
|
||||
self.relation_info_en[spg_type.name] = {}
|
||||
for _rel in spg_type.relations.values():
|
||||
if _rel.is_dynamic:
|
||||
continue
|
||||
self.relation_info_zh[spg_type.name_zh][_rel.name_zh] = (
|
||||
_rel.name,
|
||||
_rel.desc,
|
||||
_rel.object_type_name,
|
||||
)
|
||||
self.relation_info_en[spg_type.name][_rel.name] = (
|
||||
_rel.name_zh,
|
||||
_rel.desc,
|
||||
_rel.object_type_name,
|
||||
)
|
||||
for _prop in spg_type.properties.values():
|
||||
self.property_info_zh[spg_type.name_zh][_prop.name_zh] = (
|
||||
_prop.name,
|
||||
_prop.desc,
|
||||
_prop.object_type_name,
|
||||
)
|
||||
self.property_info_en[spg_type.name][_prop.name] = (
|
||||
_prop.name_zh,
|
||||
_prop.desc,
|
||||
_prop.object_type_name,
|
||||
)
|
||||
for _type in self.spg_types:
|
||||
self.spg_type_schema_info_zh[_type.name_zh] = (_type.name, _type.desc)
|
||||
self.spg_type_schema_info_en[_type.name] = (_type.name_zh, _type.desc)
|
||||
|
||||
@ -89,7 +100,7 @@ input:${input}
|
||||
relation_names: List[Tuple[RelationName, SPGTypeName]] = None,
|
||||
custom_prompt: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__([spg_type_name])
|
||||
|
||||
self.spg_type_name = spg_type_name
|
||||
if custom_prompt:
|
||||
@ -105,13 +116,6 @@ input:${input}
|
||||
self._init_render_variables()
|
||||
self._render()
|
||||
|
||||
self.params = {
|
||||
"spg_type_name": spg_type_name,
|
||||
"property_names": property_names,
|
||||
"relation_names": relation_names,
|
||||
"custom_prompt": custom_prompt,
|
||||
}
|
||||
|
||||
def build_prompt(self, variables: Dict[str, str]) -> str:
|
||||
return self.template.replace("${input}", variables.get("input"))
|
||||
|
||||
@ -238,7 +242,7 @@ class EEPrompt(AutoPrompt):
|
||||
relation_names: List[Tuple[RelationName, SPGTypeName]] = None,
|
||||
custom_prompt: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__([event_type_name])
|
||||
|
||||
self.spg_type_name = event_type_name
|
||||
if custom_prompt:
|
||||
@ -254,13 +258,6 @@ class EEPrompt(AutoPrompt):
|
||||
self._init_render_variables()
|
||||
self._render()
|
||||
|
||||
self.params = {
|
||||
"event_type_name": event_type_name,
|
||||
"property_names": property_names,
|
||||
"relation_names": relation_names,
|
||||
"custom_prompt": custom_prompt,
|
||||
}
|
||||
|
||||
def build_prompt(self, variables: Dict[str, str]) -> str:
|
||||
return self.template.replace("${input}", variables.get("input"))
|
||||
|
||||
|
||||
457
python/knext/knext/builder/operator/builtin/deepke_prompt.py
Normal file
457
python/knext/knext/builder/operator/builtin/deepke_prompt.py
Normal file
@ -0,0 +1,457 @@
|
||||
#
|
||||
# Copyright 2023 OpenSPG Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import List, Dict, Any
|
||||
from collections import defaultdict
|
||||
|
||||
from knext.schema.model.schema_helper import SPGTypeName
|
||||
from knext.builder.operator.spg_record import SPGRecord
|
||||
from knext.builder.operator.builtin.auto_prompt import AutoPrompt
|
||||
|
||||
|
||||
class OneKEPrompt(AutoPrompt):
|
||||
template_zh: str = ""
|
||||
template_en: str = ""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
types_list = kwargs.get("types_list", [])
|
||||
language = kwargs.get("language", "zh")
|
||||
with_description = kwargs.get("language", False)
|
||||
split_num = kwargs.get("split_num", 4)
|
||||
super().__init__(types_list)
|
||||
if language == "zh":
|
||||
self.template = self.template_zh
|
||||
else:
|
||||
self.template = self.template_en
|
||||
self.with_description = with_description
|
||||
self.split_num = split_num
|
||||
|
||||
self._init_render_variables()
|
||||
self._render()
|
||||
|
||||
self.params = kwargs
|
||||
|
||||
def build_prompt(self, variables: Dict[str, str]) -> List[str]:
|
||||
instructions = []
|
||||
for schema in self.schema_list:
|
||||
instructions.append(
|
||||
json.dumps(
|
||||
{
|
||||
"instruction": self.template,
|
||||
"schema": schema,
|
||||
"input": variables.get("input"),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
return instructions
|
||||
|
||||
def parse_response(self, response: str) -> List[SPGRecord]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _render(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def multischema_split_by_num(self, split_num, schemas: List[Any]):
|
||||
negative_length = max(len(schemas) // split_num, 1) * split_num
|
||||
total_schemas = []
|
||||
for i in range(0, negative_length, split_num):
|
||||
total_schemas.append(schemas[i : i + split_num])
|
||||
|
||||
remain_len = max(1, split_num // 2)
|
||||
tmp_schemas = schemas[negative_length:]
|
||||
if len(schemas) - negative_length >= remain_len and len(tmp_schemas) > 0:
|
||||
total_schemas.append(tmp_schemas)
|
||||
elif len(tmp_schemas) > 0:
|
||||
total_schemas[-1].extend(tmp_schemas)
|
||||
return total_schemas
|
||||
|
||||
|
||||
class OneKE_NERPrompt(OneKEPrompt):
|
||||
template_zh: str = (
|
||||
"你是专门进行实体抽取的专家。请从input中抽取出符合schema定义的实体,不存在的实体类型返回空列表。请按照JSON字符串的格式回答。"
|
||||
)
|
||||
template_en: str = "You are an expert in named entity recognition. Please extract entities that match the schema definition from the input. Return an empty list if the entity type does not exist. Please respond in the format of a JSON string."
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity_types: List[SPGTypeName],
|
||||
language: str = "zh",
|
||||
with_description: bool = False,
|
||||
split_num: int = 4,
|
||||
):
|
||||
super().__init__(
|
||||
types_list=entity_types,
|
||||
language=language,
|
||||
with_description=with_description,
|
||||
split_num=split_num,
|
||||
)
|
||||
|
||||
def parse_response(self, response: str) -> List[SPGRecord]:
|
||||
if isinstance(response, list) and len(response) > 0:
|
||||
response = response[0]
|
||||
try:
|
||||
ent_obj = json.loads(response)
|
||||
except json.decoder.JSONDecodeError:
|
||||
print("DeepKE_NERPrompt response JSONDecodeError error.")
|
||||
return []
|
||||
if type(ent_obj) != dict:
|
||||
print("DeepKE_NERPrompt response type error.")
|
||||
return []
|
||||
|
||||
spg_records = []
|
||||
for type_zh, values in ent_obj.items():
|
||||
if type_zh not in self.spg_type_schema_info_zh:
|
||||
print(f"Unrecognized entity_type: {type_zh}")
|
||||
continue
|
||||
type_en, _ = self.spg_type_schema_info_zh[type_zh]
|
||||
spg_record = SPGRecord(type_en)
|
||||
for value in values:
|
||||
spg_record.upsert_properties({"id": value, "name": value})
|
||||
spg_records.append(spg_record)
|
||||
|
||||
def _render(self):
|
||||
entity_list = []
|
||||
for spg_type in self.spg_types:
|
||||
entity_list.append(spg_type.name_zh)
|
||||
self.schema_list = self.multischema_split_by_num(self.split_num, entity_list)
|
||||
|
||||
|
||||
class OneKE_SPOPrompt(OneKEPrompt):
|
||||
template_zh: str = (
|
||||
"你是专门进行SPO三元组抽取的专家。请从input中抽取出符合schema定义的spo关系三元组,不存在的关系返回空列表。请按照JSON字符串的格式回答。"
|
||||
)
|
||||
template_en: str = "You are an expert in spo(subject, predicate, object) triples extraction. Please extract SPO relationship triples that match the schema definition from the input. Return an empty list for relationships that do not exist. Please respond in the format of a JSON string."
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spo_types: List[SPGTypeName],
|
||||
language: str = "zh",
|
||||
with_description: bool = False,
|
||||
split_num: int = 4,
|
||||
):
|
||||
super().__init__(
|
||||
types_list=spo_types,
|
||||
language=language,
|
||||
with_description=with_description,
|
||||
split_num=split_num,
|
||||
)
|
||||
self.properties_mapper = {}
|
||||
self.relations_mapper = {}
|
||||
|
||||
def parse_response(self, response: str) -> List[SPGRecord]:
|
||||
if isinstance(response, list) and len(response) > 0:
|
||||
response = response[0]
|
||||
try:
|
||||
re_obj = json.loads(response)
|
||||
except json.decoder.JSONDecodeError:
|
||||
print("DeepKE_REPrompt response JSONDecodeError error.")
|
||||
return []
|
||||
if type(re_obj) != dict:
|
||||
print("DeepKE_REPrompt response type error.")
|
||||
return []
|
||||
|
||||
relation_dcir = defaultdict(list)
|
||||
for relation_zh, values in re_obj.items():
|
||||
if relation_zh not in self.property_info_zh[relation_zh]:
|
||||
print(f"Unrecognized relation: {relation_zh}")
|
||||
continue
|
||||
if values and isinstance(values, list):
|
||||
for value in values:
|
||||
if (
|
||||
type(value) != dict
|
||||
or "subject" not in value
|
||||
or "object" not in value
|
||||
):
|
||||
print("DeepKE_REPrompt response type error.")
|
||||
continue
|
||||
s_zh, o_zh = value.get("subject", ""), value.get("object", "")
|
||||
relation_dcir[relation_zh].append((s_zh, o_zh))
|
||||
|
||||
spg_records = []
|
||||
for relation_zh, sub_obj_list in relation_dcir.items():
|
||||
sub_dict = defaultdict(list)
|
||||
for s_zh, o_zh in sub_obj_list:
|
||||
sub_dict[s_zh].append(o_zh)
|
||||
for s_zh, o_list in sub_dict.items():
|
||||
if s_zh in self.spg_type_schema_info_zh:
|
||||
print(f"Unrecognized subject_type: {s_zh}")
|
||||
continue
|
||||
object_value = ",".join(o_list)
|
||||
s_type_zh = self.properties_mapper.get(relation_zh, None)
|
||||
if s_type_zh is not None:
|
||||
s_type_en, _ = self.spg_type_schema_info_zh[s_type_zh]
|
||||
relation_en, _ = self.property_info_zh[relation_zh]
|
||||
spg_record = SPGRecord(s_type_en).upsert_properties(
|
||||
{"id": s_zh, "name": s_zh}
|
||||
)
|
||||
spg_record.upsert_property(relation_en, object_value)
|
||||
else:
|
||||
s_type_zh, o_type_zh = self.relations_mapper.get(
|
||||
relation_zh, [None, None]
|
||||
)
|
||||
if s_type_zh is None or o_type_zh is None:
|
||||
print(f"Unrecognized relation: {relation_zh}")
|
||||
continue
|
||||
s_type_en, _ = self.spg_type_schema_info_zh[s_type_zh]
|
||||
spg_record = SPGRecord(s_type_en).upsert_properties(
|
||||
{"id": s_zh, "name": s_zh}
|
||||
)
|
||||
relation_en, _, object_type = self.relation_info_zh[s_type_zh][
|
||||
relation_zh
|
||||
]
|
||||
spg_record.upsert_relation(relation_en, object_type, object_value)
|
||||
spg_records.append(spg_record)
|
||||
return spg_records
|
||||
|
||||
def _render(self):
|
||||
spo_list = []
|
||||
for spg_type in self.spg_types:
|
||||
type_en, _ = self.spg_type_schema_info_zh[spg_type]
|
||||
for v in spg_type.properties.values():
|
||||
spo_list.append(
|
||||
{
|
||||
"subject_type": spg_type.name_zh,
|
||||
"predicate": v.name_zh,
|
||||
"object_type": "文本",
|
||||
}
|
||||
)
|
||||
self.properties_mapper[v.name_zh] = spg_type
|
||||
for v in spg_type.relations.values():
|
||||
_, _, object_type = self.relation_info_en[type_en][v.name]
|
||||
spo_list.append(
|
||||
{
|
||||
"subject_type": spg_type.name_zh,
|
||||
"predicate": v.name_zh,
|
||||
"object_type": object_type,
|
||||
}
|
||||
)
|
||||
self.relations_mapper[v.name_zh] = [spg_type, object_type]
|
||||
self.schema_list = self.multischema_split_by_num(self.split_num, spo_list)
|
||||
|
||||
|
||||
class OneKE_REPrompt(OneKE_SPOPrompt):
|
||||
template_zh: str = (
|
||||
"你是专门进行关系抽取的专家。请从input中抽取出符合schema定义的关系三元组,不存在的关系返回空列表。请按照JSON字符串的格式回答。"
|
||||
)
|
||||
template_en: str = "You are an expert in relationship extraction. Please extract relationship triples that match the schema definition from the input. Return an empty list for relationships that do not exist. Please respond in the format of a JSON string."
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
relation_types: List[SPGTypeName],
|
||||
language: str = "zh",
|
||||
with_description: bool = False,
|
||||
split_num: int = 4,
|
||||
):
|
||||
super().__init__(relation_types, language, with_description, split_num)
|
||||
|
||||
def _render(self):
|
||||
re_list = []
|
||||
for spg_type in self.spg_types:
|
||||
type_en, _ = self.spg_type_schema_info_zh[spg_type]
|
||||
for v in spg_type.properties.values():
|
||||
re_list.append(v.name_zh)
|
||||
self.properties_mapper[v.name_zh] = spg_type
|
||||
for v in spg_type.relations.values():
|
||||
v_zh, _, object_type = self.relation_info_en[type_en][v.name]
|
||||
re_list.append(v.name_zh)
|
||||
self.relations_mapper[v.name_zh] = [spg_type, object_type]
|
||||
self.schema_list = self.multischema_split_by_num(self.split_num, re_list)
|
||||
|
||||
|
||||
class OneKE_KGPrompt(OneKEPrompt):
|
||||
template_zh: str = "你是一个图谱实体知识结构化专家。根据输入实体类型(entity type)的schema描述,从文本中抽取出相应的实体实例和其属性信息,不存在的属性不输出, 属性存在多值就返回列表,并输出为可解析的json格式。"
|
||||
template_en: str = "You are an expert in structured knowledge systems for graph entities. Based on the schema description of the input entity type, you extract the corresponding entity instances and their attribute information from the text. Attributes that do not exist should not be output. If an attribute has multiple values, a list should be returned. The results should be output in a parsable JSON format."
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity_types: List[SPGTypeName],
|
||||
language: str = "zh",
|
||||
with_description: bool = False,
|
||||
split_num: int = 4,
|
||||
):
|
||||
super().__init__(
|
||||
types_list=entity_types,
|
||||
language=language,
|
||||
with_description=with_description,
|
||||
split_num=split_num,
|
||||
)
|
||||
|
||||
def parse_response(self, response: str) -> List[SPGRecord]:
|
||||
if isinstance(response, list) and len(response) > 0:
|
||||
response = response[0]
|
||||
try:
|
||||
re_obj = json.loads(response)
|
||||
except json.decoder.JSONDecodeError:
|
||||
print("DeepKE_KGPrompt response JSONDecodeError error.")
|
||||
return []
|
||||
if type(re_obj) != dict:
|
||||
print("DeepKE_KGPrompt response type error.")
|
||||
return []
|
||||
|
||||
spg_records = []
|
||||
for type_zh, type_value in re_obj.items():
|
||||
if type_zh not in self.spg_type_schema_info_zh:
|
||||
print(f"Unrecognized entity_type: {type_zh}")
|
||||
continue
|
||||
type_en, _ = self.spg_type_schema_info_zh[type_zh]
|
||||
if type_value and isinstance(type_value, dict):
|
||||
for name, attrs in type_value.items():
|
||||
spg_record = SPGRecord(type_en).upsert_properties(
|
||||
{"id": name, "name": name}
|
||||
)
|
||||
for attr_zh, attr_value in attrs.items():
|
||||
if isinstance(attr_value, list):
|
||||
attr_value = ",".join(attr_value)
|
||||
if attr_zh in self.property_info_zh[type_zh]:
|
||||
attr_en, _, object_type = self.property_info_zh[type_zh][
|
||||
attr_zh
|
||||
]
|
||||
spg_record.upsert_property(attr_en, attr_value)
|
||||
elif attr_zh in self.relation_info_zh[type_zh]:
|
||||
attr_en, _, object_type = self.relation_info_zh[type_zh][
|
||||
attr_zh
|
||||
]
|
||||
spg_record.upsert_relation(attr_en, object_type, attr_value)
|
||||
else:
|
||||
print(f"Unrecognized attribute: {attr_zh}")
|
||||
continue
|
||||
if object_type == "Integer":
|
||||
matches = re.findall(r"\d+", attr_value)
|
||||
if matches:
|
||||
spg_record.upsert_property(attr_en, matches[0])
|
||||
elif object_type == "Float":
|
||||
matches = re.findall(r"\d+(?:\.\d+)?", attr_value)
|
||||
if matches:
|
||||
spg_record.upsert_property(attr_en, matches[0])
|
||||
spg_records.append(spg_record)
|
||||
return spg_records
|
||||
|
||||
def _render(self):
|
||||
spo_list = []
|
||||
for spg_type in self.spg_types:
|
||||
attributes = []
|
||||
attributes.extend(
|
||||
[
|
||||
v.name_zh
|
||||
for k, v in spg_type.properties.items()
|
||||
if k not in ["id", "description", "stdId"]
|
||||
]
|
||||
)
|
||||
attributes.extend(
|
||||
[
|
||||
v.name_zh
|
||||
for k, v in spg_type.relations.items()
|
||||
if v not in attributes and k not in ["isA"]
|
||||
]
|
||||
)
|
||||
entity_type = spg_type.name_zh
|
||||
spo_list.append({"entity_type": entity_type, "attributes": attributes})
|
||||
|
||||
self.schema_list = self.multischema_split_by_num(self.split_num, spo_list)
|
||||
|
||||
|
||||
class OneKE_EEPrompt(OneKEPrompt):
|
||||
template_zh: str = "你是专门进行事件提取的专家。请从input中抽取出符合schema定义的事件,不存在的事件返回空列表,不存在的论元返回NAN,如果论元存在多值请返回列表。请按照JSON字符串的格式回答。"
|
||||
template_en: str = "You are an expert in event extraction. Please extract events from the input that conform to the schema definition. Return an empty list for events that do not exist, and return NAN for arguments that do not exist. If an argument has multiple values, please return a list. Respond in the format of a JSON string."
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_types: List[SPGTypeName],
|
||||
language: str = "zh",
|
||||
with_description: bool = False,
|
||||
split_num: int = 4,
|
||||
):
|
||||
super().__init__(
|
||||
types_list=event_types,
|
||||
language=language,
|
||||
with_description=with_description,
|
||||
split_num=split_num,
|
||||
)
|
||||
|
||||
def parse_response(self, response: str) -> List[SPGRecord]:
|
||||
if isinstance(response, list) and len(response) > 0:
|
||||
response = response[0]
|
||||
try:
|
||||
ee_obj = json.loads(response)
|
||||
except json.decoder.JSONDecodeError:
|
||||
print("DeepKE_EEPrompt response JSONDecodeError error.")
|
||||
return []
|
||||
if type(ee_obj) != dict:
|
||||
print("DeepKE_EEPrompt response type error.")
|
||||
return []
|
||||
|
||||
spg_records = []
|
||||
for type_zh, type_values in ee_obj.items():
|
||||
if type_zh not in self.spg_type_schema_info_zh:
|
||||
print(f"Unrecognized event_type: {type_zh}")
|
||||
continue
|
||||
type_en, _ = self.spg_type_schema_info_zh[type_zh]
|
||||
if type_values and isinstance(type_values, list):
|
||||
for type_value in type_values:
|
||||
spg_record = SPGRecord(type_en).upsert_property("name", type_zh)
|
||||
arguments = type_value.get("arguments")
|
||||
if arguments and isinstance(arguments, dict):
|
||||
for attr_zh, attr_value in arguments.items():
|
||||
if isinstance(attr_value, list):
|
||||
attr_value = ",".join(attr_value)
|
||||
if attr_zh in self.property_info_zh[type_zh]:
|
||||
attr_en, _, object_type = self.property_info_zh[
|
||||
type_zh
|
||||
][attr_zh]
|
||||
spg_record.upsert_property(attr_en, attr_value)
|
||||
elif attr_zh in self.relation_info_zh[type_zh]:
|
||||
attr_en, _, object_type = self.relation_info_zh[
|
||||
type_zh
|
||||
][attr_zh]
|
||||
spg_record.upsert_relation(
|
||||
attr_en, object_type, attr_value
|
||||
)
|
||||
else:
|
||||
print(f"Unrecognized attribute: {attr_zh}")
|
||||
continue
|
||||
if object_type == "Integer":
|
||||
matches = re.findall(r"\d+", attr_value)
|
||||
if matches:
|
||||
spg_record.upsert_property(attr_en, matches[0])
|
||||
elif object_type == "Float":
|
||||
matches = re.findall(r"\d+(?:\.\d+)?", attr_value)
|
||||
if matches:
|
||||
spg_record.upsert_property(attr_en, matches[0])
|
||||
spg_records.append(spg_record)
|
||||
return spg_records
|
||||
|
||||
def _render(self):
|
||||
event_list = []
|
||||
for spg_type in self.spg_types:
|
||||
arguments = []
|
||||
arguments.extend(
|
||||
[
|
||||
v.name_zh
|
||||
for k, v in spg_type.properties.items()
|
||||
if k not in ["id", "name", "description"]
|
||||
]
|
||||
)
|
||||
arguments.extend(
|
||||
[
|
||||
v.name_zh
|
||||
for k, v in spg_type.relations.items()
|
||||
if v.name_zh not in arguments
|
||||
]
|
||||
)
|
||||
event_type = spg_type.name_zh
|
||||
event_list.append(
|
||||
{"event_type": event_type, "trigger": True, "arguments": arguments}
|
||||
)
|
||||
self.schema_list = self.multischema_split_by_num(self.split_num, event_list)
|
||||
@ -43,7 +43,7 @@ class _BuiltInOnlineExtractor(ExtractOp):
|
||||
|
||||
op_clazz = getattr(module, op_config["className"])
|
||||
params = op_config.get("params", {})
|
||||
op_obj = op_clazz(**params)
|
||||
op_obj = op_clazz(*params.values())
|
||||
if self.debug:
|
||||
print(f'{op_config["className"]}.template: {op_obj.template}')
|
||||
prompt_ops.append(op_obj)
|
||||
@ -60,12 +60,15 @@ class _BuiltInOnlineExtractor(ExtractOp):
|
||||
retry_times = 0
|
||||
while retry_times < self.max_retry_times:
|
||||
try:
|
||||
query = op.build_prompt(input_param)
|
||||
response = self.model.remote_inference(query)
|
||||
collector.extend(op.parse_response(response))
|
||||
next_params.extend(
|
||||
op._build_next_variables(input_param, response)
|
||||
)
|
||||
querys = op.build_prompt(input_param)
|
||||
if isinstance(querys, str):
|
||||
querys = [querys]
|
||||
for query in querys:
|
||||
response = self.model.remote_inference(query)
|
||||
collector.extend(op.parse_response(response))
|
||||
next_params.extend(
|
||||
op._build_next_variables(input_param, response)
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
retry_times += 1
|
||||
|
||||
16
python/knext/knext/examples/deepke/.knext.cfg
Normal file
16
python/knext/knext/examples/deepke/.knext.cfg
Normal file
@ -0,0 +1,16 @@
|
||||
[local]
|
||||
project_name = DeepKE
|
||||
description = DeepKE
|
||||
project_id = 2
|
||||
namespace = DeepKE
|
||||
project_dir = deepke
|
||||
schema_dir = schema
|
||||
schema_file = re.schema
|
||||
builder_dir = builder
|
||||
builder_operator_dir = builder/operator
|
||||
builder_record_dir = builder/error_record
|
||||
builder_job_dir = builder/job
|
||||
builder_model_dir = builder/model
|
||||
reasoner_dir = reasoner
|
||||
reasoner_result_dir = reasoner/result
|
||||
|
||||
@ -0,0 +1 @@
|
||||
甲状腺结节是指在甲状腺内的肿块,可随吞咽动作随甲状腺而上下移动,是临床常见的病症,可由多种病因引起。临床上有多种甲状腺疾病,如甲状腺退行性变、炎症、自身免疫以及新生物等都可以表现为结节。甲状腺结节可以单发,也可以多发,多发结节比单发结节的发病率高,但单发结节甲状腺癌的发生率较高。患者通常可以选择在普外科,甲状腺外科,内分泌科,头颈外科挂号就诊。有些患者可以触摸到自己颈部前方的结节。在大多情况下,甲状腺结节没有任何症状,甲状腺功能也是正常的。甲状腺结节进展为其它甲状腺疾病的概率只有1%。有些人会感觉到颈部疼痛、咽喉部异物感,或者存在压迫感。当甲状腺结节发生囊内自发性出血时,疼痛感会更加强烈。治疗方面,一般情况下可以用放射性碘治疗,复方碘口服液(Lugol液)等,或者服用抗甲状腺药物来抑制甲状腺激素的分泌。目前常用的抗甲状腺药物是硫脲类化合物,包括硫氧嘧啶类的丙基硫氧嘧啶(PTU)和甲基硫氧嘧啶(MTU)及咪唑类的甲硫咪唑和卡比马唑。
|
||||
|
61
python/knext/knext/examples/deepke/builder/job/task_entry.py
Normal file
61
python/knext/knext/examples/deepke/builder/job/task_entry.py
Normal file
@ -0,0 +1,61 @@
|
||||
# Copyright 2023 OpenSPG Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
from nn4k.invoker import NNInvoker
|
||||
from knext.builder.component import (
|
||||
CSVReader,
|
||||
LLMBasedExtractor,
|
||||
SPGTypeMapping,
|
||||
KGWriter,
|
||||
)
|
||||
from knext.builder.operator import OneKE_KGPrompt
|
||||
from knext.builder.model.builder_job import BuilderJob
|
||||
|
||||
from schema.deepke_schema_helper import DeepKE
|
||||
|
||||
|
||||
class Disease(BuilderJob):
|
||||
def build(self):
|
||||
|
||||
source = CSVReader(
|
||||
local_path="builder/job/data/Disease.csv",
|
||||
columns=["input"],
|
||||
start_row=1,
|
||||
)
|
||||
|
||||
extract = LLMBasedExtractor(
|
||||
llm=NNInvoker.from_config("builder/model/remote_infer.json"),
|
||||
prompt_ops=[
|
||||
OneKE_KGPrompt(
|
||||
entity_types=[
|
||||
DeepKE.Disease,
|
||||
DeepKE.BodyPart,
|
||||
DeepKE.Drug,
|
||||
DeepKE.HospitalDepartment,
|
||||
DeepKE.Symptom,
|
||||
DeepKE.Indicator,
|
||||
]
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
mappings = [
|
||||
SPGTypeMapping(spg_type_name=DeepKE.Disease),
|
||||
SPGTypeMapping(spg_type_name=DeepKE.BodyPart),
|
||||
SPGTypeMapping(spg_type_name=DeepKE.Drug),
|
||||
SPGTypeMapping(spg_type_name=DeepKE.HospitalDepartment),
|
||||
SPGTypeMapping(spg_type_name=DeepKE.Symptom),
|
||||
SPGTypeMapping(spg_type_name=DeepKE.Indicator),
|
||||
]
|
||||
|
||||
sink = KGWriter()
|
||||
|
||||
return source >> extract >> mappings >> sink
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
{
|
||||
// -- base model info
|
||||
"nn_model_path": "/gruntdata/event_graph/llm_models/Baichuan2-13B-Chat", // local model path
|
||||
"nn_invoker": "nn4k.invoker.base.LLMInvoker", // invoker to use
|
||||
"nn_executor": "nn4k.executor.huggingface.hf_decode_only_executor.HFDecodeOnlyExecutor", // executor to use
|
||||
// the following are optional
|
||||
"adapter_name": "baichuan2", // adapter_name must be given to enable adapter; with adapter_path along has no effect!
|
||||
"adapter_path": "/gruntdata/event_graph/zhongjin.ghh/lora_results/ie-v2/checkpoint-7830",
|
||||
"generate_config":{
|
||||
"temperature": 1.0,
|
||||
"do_sample": false,
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,8 @@
|
||||
{
|
||||
"hub_infer_url": "http://121.40.228.11:8000/v2/models/vllm_model/generate",
|
||||
"generate_config": {
|
||||
"stream": false,
|
||||
"temperature": 0.6,
|
||||
"max_tokens": 4096
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,84 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2023 OpenSPG Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
# ATTENTION!
|
||||
# This file is generated by Schema automatically, it will be refreshed after schema has been committed
|
||||
# PLEASE DO NOT MODIFY THIS FILE!!!
|
||||
#
|
||||
|
||||
from knext.schema.model.schema_helper import (
|
||||
SPGTypeHelper,
|
||||
PropertyHelper,
|
||||
RelationHelper,
|
||||
)
|
||||
|
||||
|
||||
class DeepKE:
|
||||
class BodyPart(SPGTypeHelper):
|
||||
|
||||
description = PropertyHelper("description")
|
||||
name = PropertyHelper("name")
|
||||
stdId = PropertyHelper("stdId")
|
||||
alias = PropertyHelper("alias")
|
||||
id = PropertyHelper("id")
|
||||
|
||||
class Disease(SPGTypeHelper):
|
||||
class abnormal(RelationHelper):
|
||||
shape = PropertyHelper("shape")
|
||||
range = PropertyHelper("range")
|
||||
color = PropertyHelper("color")
|
||||
|
||||
description = PropertyHelper("description")
|
||||
applicableDrug = PropertyHelper("applicableDrug")
|
||||
name = PropertyHelper("name")
|
||||
commonSymptom = PropertyHelper("commonSymptom")
|
||||
complication = PropertyHelper("complication")
|
||||
id = PropertyHelper("id")
|
||||
department = PropertyHelper("department")
|
||||
diseaseSite = PropertyHelper("diseaseSite")
|
||||
|
||||
abnormal = abnormal("abnormal")
|
||||
|
||||
class Drug(SPGTypeHelper):
|
||||
|
||||
description = PropertyHelper("description")
|
||||
id = PropertyHelper("id")
|
||||
name = PropertyHelper("name")
|
||||
|
||||
class HospitalDepartment(SPGTypeHelper):
|
||||
|
||||
description = PropertyHelper("description")
|
||||
name = PropertyHelper("name")
|
||||
stdId = PropertyHelper("stdId")
|
||||
alias = PropertyHelper("alias")
|
||||
id = PropertyHelper("id")
|
||||
|
||||
class Indicator(SPGTypeHelper):
|
||||
|
||||
description = PropertyHelper("description")
|
||||
id = PropertyHelper("id")
|
||||
name = PropertyHelper("name")
|
||||
|
||||
class Symptom(SPGTypeHelper):
|
||||
|
||||
description = PropertyHelper("description")
|
||||
id = PropertyHelper("id")
|
||||
name = PropertyHelper("name")
|
||||
|
||||
BodyPart = BodyPart("DeepKE.BodyPart")
|
||||
Disease = Disease("DeepKE.Disease")
|
||||
Drug = Drug("DeepKE.Drug")
|
||||
HospitalDepartment = HospitalDepartment("DeepKE.HospitalDepartment")
|
||||
Indicator = Indicator("DeepKE.Indicator")
|
||||
Symptom = Symptom("DeepKE.Symptom")
|
||||
|
||||
pass
|
||||
33
python/knext/knext/examples/deepke/schema/re.schema
Normal file
33
python/knext/knext/examples/deepke/schema/re.schema
Normal file
@ -0,0 +1,33 @@
|
||||
namespace DeepKE
|
||||
|
||||
Symptom(症状): EntityType
|
||||
desc: 这是一个症状
|
||||
|
||||
Drug(药品): EntityType
|
||||
|
||||
Indicator(医学指征): EntityType
|
||||
|
||||
BodyPart(人体部位): ConceptType
|
||||
hypernymPredicate: isA
|
||||
|
||||
HospitalDepartment(医院科室): ConceptType
|
||||
hypernymPredicate: isA
|
||||
|
||||
Disease(疾病): EntityType
|
||||
properties:
|
||||
complication(并发症): Disease
|
||||
constraint: MultiValue
|
||||
commonSymptom(常见症状): Symptom
|
||||
constraint: MultiValue
|
||||
applicableDrug(适用药品): Drug
|
||||
constraint: MultiValue
|
||||
department(就诊科室): HospitalDepartment
|
||||
constraint: MultiValue
|
||||
diseaseSite(发病部位): BodyPart
|
||||
constraint: MultiValue
|
||||
relations:
|
||||
abnormal(异常指征): Indicator
|
||||
properties:
|
||||
range(指标范围): Text
|
||||
color(颜色): Text
|
||||
shape(性状): Text
|
||||
@ -140,6 +140,7 @@ class SPGSchemaMarkLang:
|
||||
defined_types = {}
|
||||
|
||||
def __init__(self, filename):
|
||||
self.reset()
|
||||
self.schema_file = filename
|
||||
self.current_line_num = 0
|
||||
self.schema = SchemaClient()
|
||||
@ -158,6 +159,27 @@ class SPGSchemaMarkLang:
|
||||
self.internal_type.add(spg_type.name)
|
||||
self.load_script()
|
||||
|
||||
def reset(self):
|
||||
self.internal_type = set()
|
||||
self.entity_internal_property = set()
|
||||
self.event_internal_property = {"eventTime"}
|
||||
self.concept_internal_property = {"stdId", "alias"}
|
||||
self.keyword_type = {"EntityType", "ConceptType", "EventType", "StandardType"}
|
||||
|
||||
self.parsing_register = {
|
||||
RegisterUnit.Type: None,
|
||||
RegisterUnit.Property: None,
|
||||
RegisterUnit.Relation: None,
|
||||
RegisterUnit.SubProperty: None,
|
||||
}
|
||||
self.indent_level_pos = [None, None, None, None, None, None]
|
||||
self.rule_quote_predicate = None
|
||||
self.rule_quote_open = False
|
||||
self.current_parsing_level = 0
|
||||
self.last_indent_level = 0
|
||||
self.namespace = None
|
||||
self.types = {}
|
||||
|
||||
def save_register(self, element: RegisterUnit, value):
|
||||
"""
|
||||
maintain the session for parsing
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user