diff --git a/python/knext/knext/api/auto_prompt.py b/python/knext/knext/api/auto_prompt.py index 246d2020..cfb644fd 100644 --- a/python/knext/knext/api/auto_prompt.py +++ b/python/knext/knext/api/auto_prompt.py @@ -10,10 +10,9 @@ # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. -from knext.operator.builtin.auto_prompt import REPrompt, EEPrompt +from knext.operator.builtin.auto_prompt import REPrompt __all__ = [ "REPrompt", - "EEPrompt", ] diff --git a/python/knext/knext/client/reasoner.py b/python/knext/knext/client/reasoner.py index f0477cd6..d5f97202 100644 --- a/python/knext/knext/client/reasoner.py +++ b/python/knext/knext/client/reasoner.py @@ -9,7 +9,7 @@ # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. - +import os from enum import Enum from knext import rest @@ -47,6 +47,38 @@ class ReasonerClient(Client): reasoner_job_submit_request=request ) + def execute(self, dsl_content: str, output_file: str = None): + """ + --projectId 2 \ --query "MATCH (s:`RiskMining.TaxOfRiskUser`/`赌博App开发者`) RETURN s.id,s.name " \ --output ./reasoner.csv \ --schemaUrl "http://localhost:8887" \ --graphStateClass "com.antgroup.openspg.reasoner.warehouse.cloudext.CloudExtGraphState" \ --graphStoreUrl "tugraph://127.0.0.1:9090?graphName=default&timeout=60000&accessId=admin&accessKey=73@TuGraph" \ + """ + + import subprocess + import datetime + from knext import lib + jar_path = os.path.join(lib.__path__[0], lib.LOCAL_REASONER_JAR) + default_output_file = f"./{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.csv" + + java_cmd = ['java', '-jar', + jar_path, + "--projectId", self._project_id, + "--query", dsl_content, + "--output", output_file or default_output_file, + "--schemaUrl", os.environ.get("KNEXT_HOST_ADDR") or lib.LOCAL_SCHEMA_URL, + "--graphStateClass", os.environ.get("KNEXT_GRAPH_STATE_CLASS") or lib.LOCAL_GRAPH_STATE_CLASS, + "--graphStoreUrl", os.environ.get("KNEXT_GRAPH_STORE_URL") or lib.LOCAL_GRAPH_STORE_URL, + ] + + print_java_cmd = [ + cmd if not cmd.startswith("{") else f"'{cmd}'" for cmd in java_cmd + ] + print_java_cmd = [ + cmd if not cmd.count(";") > 0 else f"'{cmd}'" for cmd in print_java_cmd + ] + import json + print(json.dumps(" ".join(print_java_cmd))[1:-1].replace("'", '"')) + + subprocess.call(java_cmd) + def run_dsl(self, dsl_content: str): """Submit a synchronization reasoner job by providing DSL content.""" content = rest.KgdslReasonerContent(kgdsl=dsl_content) diff --git a/python/knext/knext/command/knext_cli.py b/python/knext/knext/command/knext_cli.py index 5816d9f3..a7bdfbf5 100644 --- a/python/knext/knext/command/knext_cli.py +++ b/python/knext/knext/command/knext_cli.py @@ -26,6 +26,7 @@ from knext.command.sub_command.project import list_project from knext.command.sub_command.reasoner import query_reasoner_job from knext.command.sub_command.reasoner import run_dsl from knext.command.sub_command.reasoner import submit_reasoner_job +from knext.command.sub_command.reasoner import execute_reasoner_job from knext.command.sub_command.schema import commit_schema from knext.command.sub_command.schema import diff_schema from knext.command.sub_command.schema import list_schema @@ -110,6 +111,7 @@ def reasoner() -> None: reasoner.command("submit")(submit_reasoner_job) reasoner.command("query")(run_dsl) reasoner.command("get")(query_reasoner_job) +reasoner.command("execute")(execute_reasoner_job) if __name__ == "__main__": _main() diff --git a/python/knext/knext/command/sub_command/reasoner.py b/python/knext/knext/command/sub_command/reasoner.py index 5264e8ab..ba208d01 100644 --- a/python/knext/knext/command/sub_command/reasoner.py +++ b/python/knext/knext/command/sub_command/reasoner.py @@ -121,3 +121,22 @@ def query_reasoner_job(id): ) else: sys.exit() + + +@click.option("--file", help="Path of DSL file.") +@click.option("--dsl", help="DSL string enclosed in double quotes.") +@click.option("--output", help="Output file.") +def execute_reasoner_job(file, dsl, output=None): + """ + Submit asynchronous reasoner jobs to server by providing DSL file or string. + """ + client = ReasonerClient() + if file and not dsl: + with open(file, "r") as f: + dsl_content = f.read() + elif not file and dsl: + dsl_content = dsl + else: + click.secho("ERROR: Please choose either --file or --dsl.", fg="bright_red") + sys.exit() + client.execute(dsl_content, output) diff --git a/python/knext/knext/examples/financial/builder/job/company.py b/python/knext/knext/examples/financial/builder/job/company.py new file mode 100644 index 00000000..afa71fd7 --- /dev/null +++ b/python/knext/knext/examples/financial/builder/job/company.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- + +from knext.examples.financial.schema.financial_schema_helper import Financial + +from knext.api.component import CSVReader, LLMBasedExtractor, KGWriter, SubGraphMapping +from knext.client.model.builder_job import BuilderJob +from nn4k.invoker import LLMInvoker + + +class Company(BuilderJob): + def build(self): + source = CSVReader( + local_path="builder/job/data/company.csv", columns=["input"], start_row=2 + ) + + from knext.api.auto_prompt import REPrompt + prompt = REPrompt( + spg_type_name=Financial.Company, + property_names=[ + Financial.Company.name, + Financial.Company.orgCertNo, + Financial.Company.regArea, + Financial.Company.businessScope, + Financial.Company.establishDate, + Financial.Company.legalPerson, + Financial.Company.regCapital + ], + ) + + extract = LLMBasedExtractor( + llm=LLMInvoker.from_config("builder/model/openai_infer.json"), + prompt_ops=[prompt], + ) + + mapping = ( + SubGraphMapping(spg_type_name=Financial.Company) + .add_mapping_field("name", Financial.Company.id) + .add_mapping_field("name", Financial.Company.name) + .add_mapping_field("regArea", Financial.Company.regArea) + .add_mapping_field("businessScope", Financial.Company.businessScope) + .add_mapping_field("establishDate", Financial.Company.establishDate) + .add_mapping_field("legalPerson", Financial.Company.legalPerson) + .add_mapping_field("regCapital", Financial.Company.regCapital) + ) + + sink = KGWriter() + + return source >> extract >> mapping >> sink + + +if __name__ == '__main__': + from knext.api.auto_prompt import REPrompt + prompt = REPrompt( + spg_type_name=Financial.Company, + property_names=[ + Financial.Company.orgCertNo, + Financial.Company.regArea, + Financial.Company.businessScope, + Financial.Company.establishDate, + Financial.Company.legalPerson, + Financial.Company.regCapital + ], + ) + print(prompt.template) diff --git a/python/knext/knext/examples/financial/builder/job/data/company.csv b/python/knext/knext/examples/financial/builder/job/data/company.csv new file mode 100644 index 00000000..a5c58a39 --- /dev/null +++ b/python/knext/knext/examples/financial/builder/job/data/company.csv @@ -0,0 +1,2 @@ +input +阿里巴巴(中国)有限公司是一家从事企业管理,计算机系统服务,电脑动画设计等业务的公司,成立于2007年03月26日,公司坐落在浙江省;经营有阿里邮箱、浙烟邮箱,师生家校、点淘-淘宝直播官方平台、云上会展等产品,经国家企业信用信息公示系统查询得知,阿里巴巴(中国)有限公司的信用代码/税号为91330100799655058B,法人是蒋芳,注册资本为15412.764910万美元,企业的经营范围为:服务:企业管理,计算机系统服务,电脑动画设计,经济信息咨询服务(除商品中介),成年人的非证书劳动职业技能培训和成年人的非文化教育培训(涉及前置审批的项目除外);生产:计算机软件;销售自产产品。(国家禁止和限制的除外,凡涉及许可证制度的凭证经营) \ No newline at end of file diff --git a/python/knext/knext/examples/financial/builder/model/openai_infer.json b/python/knext/knext/examples/financial/builder/model/openai_infer.json index 19dbf2cf..4eee4404 100644 --- a/python/knext/knext/examples/financial/builder/model/openai_infer.json +++ b/python/knext/knext/examples/financial/builder/model/openai_infer.json @@ -1,7 +1,7 @@ { "invoker_type": "OpenAI", "openai_api_key": "EMPTY", - "openai_api_base": "http://localhost:38000/v1", - "openai_model_name": "vicuna-7b-v1.5", - "openai_max_tokens": 1000 + "openai_api_base": "http://127.0.0.1:38080/v1", + "openai_model_name": "gpt-3.5-turbo", + "openai_max_tokens": 2000 } \ No newline at end of file diff --git a/python/knext/knext/examples/financial/builder/operator/IndicatorFuse.py b/python/knext/knext/examples/financial/builder/operator/IndicatorFuse.py index c7ff27be..f6e7fe2a 100644 --- a/python/knext/knext/examples/financial/builder/operator/IndicatorFuse.py +++ b/python/knext/knext/examples/financial/builder/operator/IndicatorFuse.py @@ -13,31 +13,28 @@ class IndicatorFuse(FuseOp): super().__init__() self.search_client = SearchClient("Financial.Indicator") - def link(self, subject_records: List[SPGRecord]) -> List[SPGRecord]: - print("####################IndicatorFuse#####################") + def link(self, subject_record: SPGRecord) -> List[SPGRecord]: + print("####################IndicatorFuse(指标融合)#####################") print("IndicatorFuse(Input): ") print("----------------------") - [print(r) for r in subject_records] + print(subject_record) linked_records = [] - for record in subject_records: - query = {"match": {"name": record.get_property("name", "")}} - recall_records = self.search_client.search(query, start=0, size=10) - if recall_records is not None and len(recall_records) > 0: - linked_records.append(SPGRecord( - "Financial.Indicator", - { - "id": recall_records[0].doc_id, - "name": recall_records[0].properties.get("name", ""), - }, - )) + query = {"match": {"name": subject_record.get_property("name", "")}} + recall_records = self.search_client.search(query, start=0, size=10) + if recall_records is not None and len(recall_records) > 0: + linked_records.append(SPGRecord( + "Financial.Indicator", + { + "id": recall_records[0].doc_id, + "name": recall_records[0].properties.get("name", ""), + }, + )) return linked_records - def merge(self, subject_records: List[SPGRecord], target_records: List[SPGRecord]) -> List[SPGRecord]: + def merge(self, subject_record: SPGRecord, linked_records: List[SPGRecord]) -> List[SPGRecord]: merged_records = [] - for s in subject_records: - if s in target_records: - continue - merged_records.append(s) + if not linked_records: + merged_records.append(subject_record) print("IndicatorFuse(Output): ") print("----------------------") [print(r) for r in merged_records] diff --git a/python/knext/knext/examples/financial/builder/operator/StateFuse.py b/python/knext/knext/examples/financial/builder/operator/StateFuse.py index 6752f69b..22e83c5c 100644 --- a/python/knext/knext/examples/financial/builder/operator/StateFuse.py +++ b/python/knext/knext/examples/financial/builder/operator/StateFuse.py @@ -13,37 +13,29 @@ class StateFuse(FuseOp): super().__init__() self.search_client = SearchClient("Financial.State") - def link(self, subject_records: List[SPGRecord]) -> List[SPGRecord]: + def link(self, subject_record: SPGRecord) -> List[SPGRecord]: print("####################StateFuse(状态融合)#####################") print("StateFuse(Input): ") print("----------------------") - [print(r) for r in subject_records] + print(subject_record) linked_records = [] - for record in subject_records: - query = {"match": {"name": record.get_property("name", "")}} - recall_records = self.search_client.search(query, start=0, size=10) - if recall_records is not None and len(recall_records) > 0: - linked_records.append(SPGRecord( - "Financial.State", - { - "id": recall_records[0].doc_id, - "name": recall_records[0].properties.get("name", ""), - }, - ) - ) - return linked_records - - def merge(self, subject_records: List[SPGRecord], target_records: List[SPGRecord]) -> List[SPGRecord]: - merged_records = [] - for s in subject_records: - # for t in target_records: - merged_records.append(SPGRecord( + query = {"match": {"name": subject_record.get_property("name", "")}} + recall_records = self.search_client.search(query, start=0, size=10) + if recall_records is not None and len(recall_records) > 0: + linked_records.append(SPGRecord( "Financial.State", { - "id": s.get_property("id"), - "name": s.get_property("name", ""), - }) + "id": recall_records[0].doc_id, + "name": recall_records[0].properties.get("name", ""), + }, ) + ) + return linked_records + + def merge(self, subject_record: SPGRecord, linked_records: List[SPGRecord]) -> List[SPGRecord]: + merged_records = [] + if not linked_records: + merged_records.append(subject_record) print("StateFuse(Output): ") print("----------------------") [print(r) for r in merged_records] diff --git a/python/knext/knext/examples/financial/schema/financial_schema_helper.py b/python/knext/knext/examples/financial/schema/financial_schema_helper.py index 0c48b903..3617ce49 100644 --- a/python/knext/knext/examples/financial/schema/financial_schema_helper.py +++ b/python/knext/knext/examples/financial/schema/financial_schema_helper.py @@ -20,6 +20,42 @@ from knext.common.schema_helper import SPGTypeHelper, PropertyHelper class Financial: + class AdministrativeArea(SPGTypeHelper): + description = PropertyHelper("description") + id = PropertyHelper("id") + name = PropertyHelper("name") + stdId = PropertyHelper("stdId") + alias = PropertyHelper("alias") + + class AreaRiskEvent(SPGTypeHelper): + description = PropertyHelper("description") + id = PropertyHelper("id") + name = PropertyHelper("name") + eventTime = PropertyHelper("eventTime") + object = PropertyHelper("object") + subject = PropertyHelper("subject") + + class Company(SPGTypeHelper): + description = PropertyHelper("description") + id = PropertyHelper("id") + name = PropertyHelper("name") + orgCertNo = PropertyHelper("orgCertNo") + establishDate = PropertyHelper("establishDate") + regArea = PropertyHelper("regArea") + regCapital = PropertyHelper("regCapital") + businessScope = PropertyHelper("businessScope") + legalPerson = PropertyHelper("legalPerson") + + class CompanyEvent(SPGTypeHelper): + description = PropertyHelper("description") + id = PropertyHelper("id") + name = PropertyHelper("name") + location = PropertyHelper("location") + eventTime = PropertyHelper("eventTime") + happenedTime = PropertyHelper("happenedTime") + subject = PropertyHelper("subject") + object = PropertyHelper("object") + class Indicator(SPGTypeHelper): description = PropertyHelper("description") id = PropertyHelper("id") @@ -36,6 +72,10 @@ class Financial: derivedFrom = PropertyHelper("derivedFrom") stdId = PropertyHelper("stdId") + AdministrativeArea = AdministrativeArea("Financial.AdministrativeArea") + AreaRiskEvent = AreaRiskEvent("Financial.AreaRiskEvent") + Company = Company("Financial.Company") + CompanyEvent = CompanyEvent("Financial.CompanyEvent") Indicator = Indicator("Financial.Indicator") State = State("Financial.State") \ No newline at end of file diff --git a/python/knext/knext/examples/medical/builder/job/disease.py b/python/knext/knext/examples/medical/builder/job/disease.py index f8f9d1d9..a6a47652 100644 --- a/python/knext/knext/examples/medical/builder/job/disease.py +++ b/python/knext/knext/examples/medical/builder/job/disease.py @@ -37,10 +37,10 @@ class Disease(BuilderJob): spg_type_name="Medical.Disease", property_names=[ "complication", - # "commonSymptom", - # "applicableDrug", - # "department", - # "diseaseSite", + "commonSymptom", + "applicableDrug", + "department", + "diseaseSite", ], ) ], @@ -54,10 +54,10 @@ class Disease(BuilderJob): .add_mapping_field("id", "id") .add_mapping_field("name", "name") .add_mapping_field("complication", "complication") - # .add_mapping_field("commonSymptom", "commonSymptom") - # .add_mapping_field("applicableDrug", "applicableDrug") - # .add_mapping_field("department", "department") - # .add_mapping_field("diseaseSite", "diseaseSite") + .add_mapping_field("commonSymptom", "commonSymptom") + .add_mapping_field("applicableDrug", "applicableDrug") + .add_mapping_field("department", "department") + .add_mapping_field("diseaseSite", "diseaseSite") ) """ diff --git a/python/knext/knext/examples/medical/builder/model/openai_infer.json b/python/knext/knext/examples/medical/builder/model/openai_infer.json index 76620cf9..021b44fa 100644 --- a/python/knext/knext/examples/medical/builder/model/openai_infer.json +++ b/python/knext/knext/examples/medical/builder/model/openai_infer.json @@ -1,7 +1,7 @@ { "invoker_type": "OpenAI", "openai_api_key": "EMPTY", - "openai_api_base": "http://localhost:38080/v1", + "openai_api_base": "http://127.0.0.1:38080/v1", "openai_model_name": "gpt-3.5-turbo", "openai_max_tokens": 1000 } \ No newline at end of file diff --git a/python/knext/knext/examples/medical/schema/medical_schema_helper.py b/python/knext/knext/examples/medical/schema/medical_schema_helper.py index f71f1b4b..8df78b31 100644 --- a/python/knext/knext/examples/medical/schema/medical_schema_helper.py +++ b/python/knext/knext/examples/medical/schema/medical_schema_helper.py @@ -24,18 +24,18 @@ class Medical: description = PropertyHelper("description") id = PropertyHelper("id") name = PropertyHelper("name") - alias = PropertyHelper("alias") stdId = PropertyHelper("stdId") + alias = PropertyHelper("alias") class Disease(SPGTypeHelper): description = PropertyHelper("description") id = PropertyHelper("id") name = PropertyHelper("name") - department = PropertyHelper("department") - complication = PropertyHelper("complication") applicableDrug = PropertyHelper("applicableDrug") - diseaseSite = PropertyHelper("diseaseSite") + department = PropertyHelper("department") commonSymptom = PropertyHelper("commonSymptom") + diseaseSite = PropertyHelper("diseaseSite") + complication = PropertyHelper("complication") class Drug(SPGTypeHelper): description = PropertyHelper("description") @@ -46,8 +46,8 @@ class Medical: description = PropertyHelper("description") id = PropertyHelper("id") name = PropertyHelper("name") - alias = PropertyHelper("alias") stdId = PropertyHelper("stdId") + alias = PropertyHelper("alias") class Indicator(SPGTypeHelper): description = PropertyHelper("description") diff --git a/python/knext/knext/lib/__init__.py b/python/knext/knext/lib/__init__.py index f15b67da..78100a44 100644 --- a/python/knext/knext/lib/__init__.py +++ b/python/knext/knext/lib/__init__.py @@ -1,15 +1,12 @@ -GRAPH_STORE_PARAM = "-Dcloudext.graphstore.drivers=com.antgroup.openspg.cloudext.impl.graphstore.tugraph" \ - ".TuGraphStoreClientDriver" - -SEARCH_CLIENT_PARAM = "-Dcloudext.searchengine.drivers=com.antgroup.openspg.cloudext.impl.searchengine.elasticsearch" \ - ".ElasticSearchEngineClientDriver" LOCAL_BUILDER_JAR = "builder-runner-local-0.0.1-SNAPSHOT-jar-with-dependencies.jar" -LOCAL_REASONER_JAR = "" +LOCAL_REASONER_JAR = "reasoner-local-runner-0.0.1-SNAPSHOT-jar-with-dependencies.jar" LOCAL_SCHEMA_URL = "http://localhost:8887" LOCAL_GRAPH_STORE_URL = "tugraph://127.0.0.1:9090?graphName=default&timeout=50000&accessId=admin&accessKey=73@TuGraph" LOCAL_SEARCH_ENGINE_URL = "elasticsearch://127.0.0.1:9200?scheme=http" + +LOCAL_GRAPH_STATE_CLASS = "com.antgroup.openspg.reasoner.warehouse.cloudext.CloudExtGraphState" diff --git a/python/knext/knext/operator/builtin/auto_prompt.py b/python/knext/knext/operator/builtin/auto_prompt.py index 4353ba2d..c4a3d02b 100644 --- a/python/knext/knext/operator/builtin/auto_prompt.py +++ b/python/knext/knext/operator/builtin/auto_prompt.py @@ -54,6 +54,8 @@ input:${input} def parse_response(self, response: str) -> List[SPGRecord]: result = [] subject = {} + if isinstance(response, list) and len(response) > 0: + response = response[0] re_obj = json.loads(response) if "spo" not in re_obj.keys(): raise ValueError("SPO format error.") @@ -89,21 +91,10 @@ input:${input} result.append(subject_entity) return result - def build_next_variables( - self, variables: Dict[str, str], response: str - ) -> List[Dict[str, str]]: - re_obj = json.loads(response) - if "spo" not in re_obj.keys(): - raise ValueError("SPO format error.") - re = re_obj.get("spo", []) - return [{"input": variables.get("input"), "spo": str(i)} for i in re] - def _render(self, spg_type: BaseSpgType, property_names: List[str]): spos = [] repeat_desc = [] for property_name in property_names: - if property_name in ["id", "name", "description"]: - continue prop = spg_type.properties.get(property_name) object_desc = "" object_type = self.schema_client.query_spg_type(prop.object_type_name) @@ -117,101 +108,3 @@ input:${input} repeat_desc.extend([spg_type.name_zh, prop.name_zh, prop.object_type_name_zh]) schema_text = "\n[" + ",\n".join(spos) + "]" self.template = self.template.replace("${schema}", schema_text) - - -class EEPrompt(AutoPrompt): - template: str = """ -已知如下的事件schema定义:${schema}。从下列句子中抽取所定义的事件,如果存在以JSON格式返回,如果不存在返回空字符串。 -input:${input} -${example} -输出格式为:{"event":[{"event_type":,"arguments":[{},]}]} -"output": - """ - - def __init__( - self, - event_type_name: Union[str, SPGTypeHelper], - property_names: List[Union[str, PropertyHelper]], - custom_prompt: str = None, - ): - super().__init__() - - if custom_prompt: - self.template = custom_prompt - schema_client = SchemaClient() - spg_type = schema_client.query_spg_type(spg_type_name=event_type_name) - self.spg_type_name = event_type_name - self.predicate_zh_to_en_name = {} - self.predicate_type_zh_to_en_name = {} - for k, v in spg_type.properties.items(): - self.predicate_zh_to_en_name[v.name_zh] = k - self.predicate_type_zh_to_en_name[v.name_zh] = v.object_type_name - self._render(spg_type, property_names) - self.params = { - "spg_type_name": event_type_name, - "property_names": property_names, - "custom_prompt": custom_prompt, - } - - def build_prompt(self, variables: Dict[str, str]) -> str: - return self.template.replace("${input}", variables.get("input")) - - def parse_response(self, response: str) -> List[SPGRecord]: - response = "{\"event\":[{\"event_type\":\"区域经济指标事件(区域指标)\",\"arguments\":[{\"日期\":\"2022年\",\"区域\":\"济南市\",\"来源\":\"政府公告\",\"主体\":\"山东省财政局\"}]}]}" - result = [] - subject = {} - re_obj = json.loads(response) - if "event" not in re_obj.keys(): - raise ValueError("Event format error.") - subject_properties = {} - for spo_item in re_obj.get("event", []): - if spo_item["predicate"] not in self.predicate_zh_to_en_name: - continue - subject_properties = { - "id": spo_item["subject"], - "name": spo_item["subject"], - } - if spo_item["subject"] not in subject: - subject[spo_item["subject"]] = subject_properties - else: - subject_properties = subject[spo_item["subject"]] - - spo_en_name = self.predicate_zh_to_en_name[spo_item["predicate"]] - - if spo_en_name in subject_properties and len( - subject_properties[spo_en_name] - ): - subject_properties[spo_en_name] = ( - subject_properties[spo_en_name] + "," + spo_item["object"] - ) - else: - subject_properties[spo_en_name] = spo_item["object"] - - # for k, val in subject.items(): - subject_entity = SPGRecord( - spg_type_name=self.spg_type_name, properties=subject_properties - ) - result.append(subject_entity) - return result - - def build_next_variables( - self, variables: Dict[str, str], response: str - ) -> List[Dict[str, str]]: - re_obj = json.loads(response) - if "event" not in re_obj.keys(): - raise ValueError("Event format error.") - re = re_obj.get("event", []) - return [{"input": variables.get("input"), "event": str(i)} for i in re] - - def _render(self, spg_type: BaseSpgType, property_names: List[str]): - arguments = [] - for property_name in property_names: - if property_name in ["id", "name", "description"]: - continue - prop = spg_type.properties.get(property_name) - arguments.append( - f"{prop.name_zh}({prop.object_type_name_zh})" - ) - - schema_text = f"{{event_type:{spg_type.name_zh}({spg_type.desc or spg_type.name_zh}),arguments:[{','.join(arguments)}]" - self.template = self.template.replace("${schema}", schema_text) diff --git a/python/knext/knext/operator/builtin/online_runner.py b/python/knext/knext/operator/builtin/online_runner.py index 3bfd6e58..d575562b 100644 --- a/python/knext/knext/operator/builtin/online_runner.py +++ b/python/knext/knext/operator/builtin/online_runner.py @@ -60,8 +60,9 @@ class _BuiltInOnlineExtractor(ExtractOp): elif op_name == "IndicatorLOGIC": response = '[{"subject": "土地出让收入大幅下降", "predicate": "顺承", "object": ["综合财力明显下滑"]}]' else: - print(query) + print(repr(query)) response = self.model.remote_inference(query) + print(response) collector.extend(op.parse_response(response)) next_params.extend( op.build_next_variables(input_param, response) diff --git a/python/knext/knext/operator/op.py b/python/knext/knext/operator/op.py index 653f59d6..b1417b64 100644 --- a/python/knext/knext/operator/op.py +++ b/python/knext/knext/operator/op.py @@ -89,19 +89,22 @@ class FuseOp(BaseOp, ABC): def __init__(self, params: Dict[str, str] = None): super().__init__(params) - def link(self, subject_records: List[SPGRecord]) -> List[SPGRecord]: + def link(self, subject_record: SPGRecord) -> List[SPGRecord]: raise NotImplementedError( f"{self.__class__.__name__} need to implement `link` method." ) - def merge(self, subject_records: List[SPGRecord], target_records: List[SPGRecord]) -> List[SPGRecord]: + def merge(self, subject_record: SPGRecord, linked_records: List[SPGRecord]) -> List[SPGRecord]: raise NotImplementedError( f"{self.__class__.__name__} need to implement `merge` method." ) - def invoke(self, records: List[SPGRecord]) -> List[SPGRecord]: - linked_records = self.link(records) - return self.merge(records, linked_records) + def invoke(self, subject_records: List[SPGRecord]) -> List[SPGRecord]: + for record in subject_records: + linked_records = self.link(record) + merged_records = self.merge(record, linked_records) + return merged_records + return [] @staticmethod def _pre_process(*inputs): @@ -136,6 +139,8 @@ class PromptOp(BaseOp, ABC): def build_next_variables( self, variables: Dict[str, str], response: str ) -> List[Dict[str, str]]: + if isinstance(response, list) and len(response) > 0: + response = response[0] variables.update({f"{self.__class__.__name__}": response}) print("LLM(Output): ") print("----------------------")