openspg/python/tests/medical_case.py
2023-12-11 11:28:28 +08:00

151 lines
5.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Dict, List
from knext.client.model.builder_job import BuilderJob
from knext.component.builder import CsvSourceReader, LLMBasedExtractor, KGSinkWriter
from knext.component.builder.mapping import SubGraphMapping
from knext.examples.medical.schema.medical_schema_helper import Medical
from knext.operator.op import PromptOp
from knext.operator.spg_record import SPGRecord
from nn4k.invoker import NNInvoker, LLMInvoker
from nn4k.invoker.openai_invoker import OpenAIInvoker
class DiseaseREPromptOp(PromptOp):
template = """
假设你是一个专业的医学专家请从文本中抽取关系。我们会首先提供文本然后会提供知识图谱schema再提供回答的具体要求最后是一个举例。
----文本----
{re_input}
----知识图谱schema----
${schema}
----回答要求----
1. 答案格式为json格式[{"subject":,"predicate":,"object":},]
2. object要求简洁必须是中文如果object包含多个值请用英文逗号分隔
3. 每一条关系必须属于知识图谱schema。
----举例----
文本为:急性扁桃体炎通常伴有咽痛,声嘶,发热等症状。回答为:{"subject":"急性扁桃体炎","predicate":"症状","object":"咽痛,声嘶,发热"}
"""
def build_prompt(self, record: Dict[str, str]) -> str:
"""
record: {"input": "甲状腺结节是指在甲状腺内的肿块,可随吞咽动作随甲状腺而上下移动,是临床常见的病症......."}
"""
return self.template.format(input=record.get("input"))
def parse_response(self, response: str) -> List[SPGRecord]:
"""
默认解析逻辑:
response: [{"subject":"甲状腺结节","predicate":"发病位置","object":"甲状腺"},
{"subject":"急性扁桃体炎","predicate":"症状","object":"咽痛,声嘶,发热"}
]
->
[{"id": "甲状腺结节", "name": "甲状腺结节", "bodyPart": "甲状腺"},
{"id": "急性扁桃体炎", "name": "急性扁桃体炎", "commonSymptom": "咽痛,声嘶,发热"}
]
"""
pass
class DiseaseNERPromptOp:
template = """
已知实体类型(entity_type)包括:${schema}
假设你是一个专业的医学专家,请从下列文本中抽取所有实体(entity)。
----文本----
{input}
----回答要求----
1. 答案格式为:[{"entity": ,"entity_type": },]
"""
def build_prompt(self, record: Dict[str, str]) -> str:
"""
record: {"id": "急性扁桃体炎", "name": "急性扁桃体炎", "commonSymptom": "咽痛,声嘶,发热" "ner_input": "咽痛,声嘶,发热", "input": "..."}
"""
return self.template.format(input=record.get("ner_input"))
def parse_response(self, response: str) -> List[Dict[str, str]]:
"""
response: [{"entity": "咽痛", "entity_type": "症状"},
{"entity": "声嘶", "entity_type": "症状"},
{"entity": "发热", "entity_type": "症状"}
]
->
[{"id": "咽痛", "name": "咽痛", "bodyPart": "甲状腺", "ner_input": "甲状腺"}),
SPGRecord("spg_type_name": "Medical.Disease", "properties": {"id": "急性扁桃体炎", "name": "急性扁桃体炎", "commonSymptom": "咽痛,声嘶,发热" "ner_input": "咽痛,声嘶,发热"})
]
"""
pass
class BodyPartLinkOp:
pass
class Disease(BuilderJob):
def build(self):
"""
1. 定义输入源CSV文件其中CSV文件每一行为一段文本
"""
source = CsvSourceReader(
local_path="./builder/job/data/Disease.csv",
columns=["content"],
start_row=2,
)
"""
[
SPGRecord(spg_type_name="Medical.Disease", properties={"id": "甲状腺结节", "name": "甲状腺结节", "description": "这是病", "body_part": "甲状腺,123"}),
SPGRecord(spg_type_name="Medical.BodyPart", properties={"id": "甲状腺", "name": "甲状腺"})
]
"""
extract = LLMBasedExtractor(llm=OpenAIInvoker.from_config("./config.json"), prompt_ops=[DiseaseREPromptOp])
"""
2. 指定SPG知识映射组件设置抽取算子从长文本中抽取多种实体类型
"""
mapping = SubGraphMapping()\
.add_pattern()
.subject_type('Medical.Disease')\
.add_field('body_part', Medical.Disease.bodyPart, link_strategy="ID_EQUAL")\
.object_type('Medical.BodyPart')\
.sub
# mapping_schema = [
# {
# "identifier": "Medical.Disease",
# "property_name": "bodyPart",
# "link_strategy": "id_equal",
# },
# {
# "identifier": "Medical.Disease",
# "property_name": "description",
# }
# ]
#
# mapping_config = [
# {
# "identifier": "Medical.Disease",
# "source": "bodyPart",
# "target": "bodyPart"
# }
# ]
"""
3. 定义输出到图谱
"""
sink = KGSinkWriter()
"""
4. 完整Pipeline定义
"""
return source >> extract >> mapping >> sink