openspg/python/tests/medical_case.py

149 lines
4.9 KiB
Python
Raw Normal View History

2023-12-11 10:44:37 +08:00
from typing import Dict, List
from knext.client.model.builder_job import BuilderJob
2023-12-22 22:09:46 +08:00
from knext.component.builder import CSVReader, LLMBasedExtractor, KGWriter
from knext.component.builder.mapping import SubGraphMapping, SPGTypeMapping
2023-12-11 10:44:37 +08:00
from knext.examples.medical.schema.medical_schema_helper import Medical
2023-12-22 22:09:46 +08:00
from knext.operator.builtin.auto_prompt import SPOPrompt
2023-12-11 10:44:37 +08:00
from knext.operator.op import PromptOp
from knext.operator.spg_record import SPGRecord
from nn4k.invoker import NNInvoker, LLMInvoker
from nn4k.invoker.openai_invoker import OpenAIInvoker
class DiseaseREPromptOp(PromptOp):
template = """
假设你是一个专业的医学专家请从文本中抽取关系我们会首先提供文本然后会提供知识图谱schema再提供回答的具体要求最后是一个举例
----文本----
{re_input}
----知识图谱schema----
${schema}
----回答要求----
1. 答案格式为json格式[{"subject":,"predicate":,"object":},]
2. object要求简洁必须是中文如果object包含多个值请用英文逗号分隔
3. 每一条关系必须属于知识图谱schema
----举例----
文本为急性扁桃体炎通常伴有咽痛声嘶发热等症状回答为{"subject":"急性扁桃体炎","predicate":"症状","object":"咽痛,声嘶,发热"}
"""
def build_prompt(self, record: Dict[str, str]) -> str:
"""
record: {"input": "甲状腺结节是指在甲状腺内的肿块,可随吞咽动作随甲状腺而上下移动,是临床常见的病症......."}
"""
return self.template.format(input=record.get("input"))
def parse_response(self, response: str) -> List[SPGRecord]:
"""
默认解析逻辑
response: [{"subject":"甲状腺结节","predicate":"发病位置","object":"甲状腺"},
{"subject":"急性扁桃体炎","predicate":"症状","object":"咽痛,声嘶,发热"}
]
->
[{"id": "甲状腺结节", "name": "甲状腺结节", "bodyPart": "甲状腺"},
{"id": "急性扁桃体炎", "name": "急性扁桃体炎", "commonSymptom": "咽痛,声嘶,发热"}
]
"""
pass
class DiseaseNERPromptOp:
template = """
已知实体类型(entity_type)包括:${schema}
假设你是一个专业的医学专家请从下列文本中抽取所有实体(entity)
----文本----
{input}
----回答要求----
1. 答案格式为[{"entity": ,"entity_type": },]
"""
def build_prompt(self, record: Dict[str, str]) -> str:
"""
record: {"id": "急性扁桃体炎", "name": "急性扁桃体炎", "commonSymptom": "咽痛,声嘶,发热" "ner_input": "咽痛,声嘶,发热", "input": "..."}
"""
return self.template.format(input=record.get("ner_input"))
def parse_response(self, response: str) -> List[Dict[str, str]]:
"""
response: [{"entity": "咽痛", "entity_type": "症状"},
{"entity": "声嘶", "entity_type": "症状"},
{"entity": "发热", "entity_type": "症状"}
]
->
[{"id": "咽痛", "name": "咽痛", "bodyPart": "甲状腺", "ner_input": "甲状腺"}),
SPGRecord("spg_type_name": "Medical.Disease", "properties": {"id": "急性扁桃体炎", "name": "急性扁桃体炎", "commonSymptom": "咽痛,声嘶,发热" "ner_input": "咽痛,声嘶,发热"})
]
"""
pass
class BodyPartLinkOp:
pass
class Disease(BuilderJob):
def build(self):
"""
1. 定义输入源CSV文件其中CSV文件每一行为一段文本
"""
2023-12-22 22:09:46 +08:00
source = CSVReader(
local_path="Disease.csv",
2023-12-11 10:44:37 +08:00
columns=["content"],
start_row=2,
)
2023-12-22 22:09:46 +08:00
spo_prompt = SPOPrompt(
spg_type_name=Medical.Disease,
property_names=[Medical.Disease.bodyPart, Medical.Disease.commonSymptom])
extract = LLMBasedExtractor(llm=OpenAIInvoker.from_config("./config.json"), prompt_ops=[])
2023-12-11 10:44:37 +08:00
"""
2. 指定SPG知识映射组件设置抽取算子从长文本中抽取多种实体类型
"""
# mapping_schema = [
# {
# "identifier": "Medical.Disease",
# "property_name": "bodyPart",
# "link_strategy": "id_equal",
# },
# {
# "identifier": "Medical.Disease",
# "property_name": "description",
# }
# ]
#
# mapping_config = [
# {
# "identifier": "Medical.Disease",
# "source": "bodyPart",
# "target": "bodyPart"
# }
# ]
"""
3. 定义输出到图谱
"""
2023-12-22 22:09:46 +08:00
sink = KGWriter()
2023-12-11 10:44:37 +08:00
"""
4. 完整Pipeline定义
"""
2023-12-22 22:09:46 +08:00
return source >> mapping >> sink
d = Disease()
chain = d.build()
print(chain)
chain.invoke()