diff --git a/python/tests/chain_test.py b/python/tests/chain_test.py index b402e16c..c028c201 100644 --- a/python/tests/chain_test.py +++ b/python/tests/chain_test.py @@ -1,33 +1,8 @@ -import networkx as nx -import matplotlib.pyplot as plt +import json -from knext.api.component import SPGTypeMapping -from knext.api.component import KGSinkWriter -from knext.api.component import CsvSourceReader +response_str = "[{'财政': ['财政收入质量', '财政自给能力', '土地出让收入', '一般公共预算收入', '留抵退税', '税收收入', '税收收入/一般公共预算收入', '一般公共预算支出', '财政自给率', '政府性基金收入', '转移性收入', '综合财力']}]" -if __name__ == '__main__': - source = CsvSourceReader( - local_path="./builder/job/data/BodyPart.csv", columns=["id"], start_row=1 - ) +response_str = response_str.replace("'", "\"") - mapping1 = SPGTypeMapping(spg_type_name="Medical.BodyPart").add_field( - "id", "Medical.BodyPart.id" - ) +output_list = json.loads(response_str) - mapping2 = SPGTypeMapping(spg_type_name="Medical.BodyPart").add_field( - "id", "Medical.BodyPart.id1" - ) - - sink = KGSinkWriter() - sink2 = KGSinkWriter() - - builder_chain = source >> mapping1 >> sink2 - - print(builder_chain.dag) - - # G = builder_chain.dag - # # 绘制图形 - # # nx.draw(G, with_labels=True, arrows=True) - # - # # 显示图形 - # plt.show() diff --git a/python/tests/disease_builder_job.py b/python/tests/disease_builder_job.py index 56e42501..9144392d 100644 --- a/python/tests/disease_builder_job.py +++ b/python/tests/disease_builder_job.py @@ -1,9 +1,9 @@ from knext.client.model.builder_job import BuilderJob from knext.component.builder import ( - CsvSourceReader, + CSVReader, SPGTypeMapping, LLMBasedExtractor, - KGSinkWriter, + KGWriter, ) @@ -13,7 +13,7 @@ class Disease(BuilderJob): """ 1. 定义输入源,CSV文件,其中CSV文件每一行为一段文本 """ - source = CsvSourceReader( + source = CSVReader( local_path="Disease.csv", columns=["content"], start_row=2, @@ -29,7 +29,7 @@ class Disease(BuilderJob): """ 3. 定义输出到图谱 """ - sink = SinkToKgComponent() + sink = KGWriter() """ 4. 完整Pipeline定义 diff --git a/python/tests/medical_case.py b/python/tests/medical_case.py index f6021b1b..78666cdd 100644 --- a/python/tests/medical_case.py +++ b/python/tests/medical_case.py @@ -1,9 +1,10 @@ from typing import Dict, List from knext.client.model.builder_job import BuilderJob -from knext.component.builder import CsvSourceReader, LLMBasedExtractor, KGSinkWriter -from knext.component.builder.mapping import SubGraphMapping +from knext.component.builder import CSVReader, LLMBasedExtractor, KGWriter +from knext.component.builder.mapping import SubGraphMapping, SPGTypeMapping from knext.examples.medical.schema.medical_schema_helper import Medical +from knext.operator.builtin.auto_prompt import SPOPrompt from knext.operator.op import PromptOp from knext.operator.spg_record import SPGRecord from nn4k.invoker import NNInvoker, LLMInvoker @@ -90,32 +91,21 @@ class Disease(BuilderJob): """ 1. 定义输入源,CSV文件,其中CSV文件每一行为一段文本 """ - source = CsvSourceReader( - local_path="./builder/job/data/Disease.csv", + source = CSVReader( + local_path="Disease.csv", columns=["content"], start_row=2, ) - - """ - [ - SPGRecord(spg_type_name="Medical.Disease", properties={"id": "甲状腺结节", "name": "甲状腺结节", "description": "这是病", "body_part": "甲状腺,123"}), - SPGRecord(spg_type_name="Medical.BodyPart", properties={"id": "甲状腺", "name": "甲状腺"}) - ] - """ - - extract = LLMBasedExtractor(llm=OpenAIInvoker.from_config("./config.json"), prompt_ops=[DiseaseREPromptOp]) + spo_prompt = SPOPrompt( + spg_type_name=Medical.Disease, + property_names=[Medical.Disease.bodyPart, Medical.Disease.commonSymptom]) + extract = LLMBasedExtractor(llm=OpenAIInvoker.from_config("./config.json"), prompt_ops=[]) """ 2. 指定SPG知识映射组件,设置抽取算子,从长文本中抽取多种实体类型 """ - mapping = SubGraphMapping()\ - .add_pattern() - .subject_type('Medical.Disease')\ - .add_field('body_part', Medical.Disease.bodyPart, link_strategy="ID_EQUAL")\ - .object_type('Medical.BodyPart')\ - .sub # mapping_schema = [ # { @@ -141,10 +131,18 @@ class Disease(BuilderJob): """ 3. 定义输出到图谱 """ - sink = KGSinkWriter() + sink = KGWriter() """ 4. 完整Pipeline定义 """ - return source >> extract >> mapping >> sink + return source >> mapping >> sink + + +d = Disease() +chain = d.build() + +print(chain) + +chain.invoke()