This commit is contained in:
Qu 2023-12-28 21:01:04 +08:00
parent 6a8e597b26
commit d205c894d3
17 changed files with 227 additions and 184 deletions

View File

@ -10,10 +10,9 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from knext.operator.builtin.auto_prompt import REPrompt, EEPrompt
from knext.operator.builtin.auto_prompt import REPrompt
__all__ = [
"REPrompt",
"EEPrompt",
]

View File

@ -9,7 +9,7 @@
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
from enum import Enum
from knext import rest
@ -47,6 +47,38 @@ class ReasonerClient(Client):
reasoner_job_submit_request=request
)
def execute(self, dsl_content: str, output_file: str = None):
"""
--projectId 2 \ --query "MATCH (s:`RiskMining.TaxOfRiskUser`/`赌博App开发者`) RETURN s.id,s.name " \ --output ./reasoner.csv \ --schemaUrl "http://localhost:8887" \ --graphStateClass "com.antgroup.openspg.reasoner.warehouse.cloudext.CloudExtGraphState" \ --graphStoreUrl "tugraph://127.0.0.1:9090?graphName=default&timeout=60000&accessId=admin&accessKey=73@TuGraph" \
"""
import subprocess
import datetime
from knext import lib
jar_path = os.path.join(lib.__path__[0], lib.LOCAL_REASONER_JAR)
default_output_file = f"./{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.csv"
java_cmd = ['java', '-jar',
jar_path,
"--projectId", self._project_id,
"--query", dsl_content,
"--output", output_file or default_output_file,
"--schemaUrl", os.environ.get("KNEXT_HOST_ADDR") or lib.LOCAL_SCHEMA_URL,
"--graphStateClass", os.environ.get("KNEXT_GRAPH_STATE_CLASS") or lib.LOCAL_GRAPH_STATE_CLASS,
"--graphStoreUrl", os.environ.get("KNEXT_GRAPH_STORE_URL") or lib.LOCAL_GRAPH_STORE_URL,
]
print_java_cmd = [
cmd if not cmd.startswith("{") else f"'{cmd}'" for cmd in java_cmd
]
print_java_cmd = [
cmd if not cmd.count(";") > 0 else f"'{cmd}'" for cmd in print_java_cmd
]
import json
print(json.dumps(" ".join(print_java_cmd))[1:-1].replace("'", '"'))
subprocess.call(java_cmd)
def run_dsl(self, dsl_content: str):
"""Submit a synchronization reasoner job by providing DSL content."""
content = rest.KgdslReasonerContent(kgdsl=dsl_content)

View File

@ -26,6 +26,7 @@ from knext.command.sub_command.project import list_project
from knext.command.sub_command.reasoner import query_reasoner_job
from knext.command.sub_command.reasoner import run_dsl
from knext.command.sub_command.reasoner import submit_reasoner_job
from knext.command.sub_command.reasoner import execute_reasoner_job
from knext.command.sub_command.schema import commit_schema
from knext.command.sub_command.schema import diff_schema
from knext.command.sub_command.schema import list_schema
@ -110,6 +111,7 @@ def reasoner() -> None:
reasoner.command("submit")(submit_reasoner_job)
reasoner.command("query")(run_dsl)
reasoner.command("get")(query_reasoner_job)
reasoner.command("execute")(execute_reasoner_job)
if __name__ == "__main__":
_main()

View File

@ -121,3 +121,22 @@ def query_reasoner_job(id):
)
else:
sys.exit()
@click.option("--file", help="Path of DSL file.")
@click.option("--dsl", help="DSL string enclosed in double quotes.")
@click.option("--output", help="Output file.")
def execute_reasoner_job(file, dsl, output=None):
"""
Submit asynchronous reasoner jobs to server by providing DSL file or string.
"""
client = ReasonerClient()
if file and not dsl:
with open(file, "r") as f:
dsl_content = f.read()
elif not file and dsl:
dsl_content = dsl
else:
click.secho("ERROR: Please choose either --file or --dsl.", fg="bright_red")
sys.exit()
client.execute(dsl_content, output)

View File

@ -0,0 +1,64 @@
# -*- coding: utf-8 -*-
from knext.examples.financial.schema.financial_schema_helper import Financial
from knext.api.component import CSVReader, LLMBasedExtractor, KGWriter, SubGraphMapping
from knext.client.model.builder_job import BuilderJob
from nn4k.invoker import LLMInvoker
class Company(BuilderJob):
def build(self):
source = CSVReader(
local_path="builder/job/data/company.csv", columns=["input"], start_row=2
)
from knext.api.auto_prompt import REPrompt
prompt = REPrompt(
spg_type_name=Financial.Company,
property_names=[
Financial.Company.name,
Financial.Company.orgCertNo,
Financial.Company.regArea,
Financial.Company.businessScope,
Financial.Company.establishDate,
Financial.Company.legalPerson,
Financial.Company.regCapital
],
)
extract = LLMBasedExtractor(
llm=LLMInvoker.from_config("builder/model/openai_infer.json"),
prompt_ops=[prompt],
)
mapping = (
SubGraphMapping(spg_type_name=Financial.Company)
.add_mapping_field("name", Financial.Company.id)
.add_mapping_field("name", Financial.Company.name)
.add_mapping_field("regArea", Financial.Company.regArea)
.add_mapping_field("businessScope", Financial.Company.businessScope)
.add_mapping_field("establishDate", Financial.Company.establishDate)
.add_mapping_field("legalPerson", Financial.Company.legalPerson)
.add_mapping_field("regCapital", Financial.Company.regCapital)
)
sink = KGWriter()
return source >> extract >> mapping >> sink
if __name__ == '__main__':
from knext.api.auto_prompt import REPrompt
prompt = REPrompt(
spg_type_name=Financial.Company,
property_names=[
Financial.Company.orgCertNo,
Financial.Company.regArea,
Financial.Company.businessScope,
Financial.Company.establishDate,
Financial.Company.legalPerson,
Financial.Company.regCapital
],
)
print(prompt.template)

View File

@ -0,0 +1,2 @@
input
阿里巴巴中国有限公司是一家从事企业管理计算机系统服务电脑动画设计等业务的公司成立于2007年03月26日公司坐落在浙江省经营有阿里邮箱、浙烟邮箱师生家校、点淘-淘宝直播官方平台、云上会展等产品,经国家企业信用信息公示系统查询得知,阿里巴巴(中国)有限公司的信用代码/税号为91330100799655058B法人是蒋芳注册资本为15412.764910万美元,企业的经营范围为:服务:企业管理,计算机系统服务,电脑动画设计,经济信息咨询服务(除商品中介),成年人的非证书劳动职业技能培训和成年人的非文化教育培训(涉及前置审批的项目除外);生产:计算机软件;销售自产产品。(国家禁止和限制的除外,凡涉及许可证制度的凭证经营)
1 input
2 阿里巴巴(中国)有限公司是一家从事企业管理,计算机系统服务,电脑动画设计等业务的公司,成立于2007年03月26日,公司坐落在浙江省;经营有阿里邮箱、浙烟邮箱,师生家校、点淘-淘宝直播官方平台、云上会展等产品,经国家企业信用信息公示系统查询得知,阿里巴巴(中国)有限公司的信用代码/税号为91330100799655058B,法人是蒋芳,注册资本为15412.764910万美元,企业的经营范围为:服务:企业管理,计算机系统服务,电脑动画设计,经济信息咨询服务(除商品中介),成年人的非证书劳动职业技能培训和成年人的非文化教育培训(涉及前置审批的项目除外);生产:计算机软件;销售自产产品。(国家禁止和限制的除外,凡涉及许可证制度的凭证经营)

View File

@ -1,7 +1,7 @@
{
"invoker_type": "OpenAI",
"openai_api_key": "EMPTY",
"openai_api_base": "http://localhost:38000/v1",
"openai_model_name": "vicuna-7b-v1.5",
"openai_max_tokens": 1000
"openai_api_base": "http://127.0.0.1:38080/v1",
"openai_model_name": "gpt-3.5-turbo",
"openai_max_tokens": 2000
}

View File

@ -13,31 +13,28 @@ class IndicatorFuse(FuseOp):
super().__init__()
self.search_client = SearchClient("Financial.Indicator")
def link(self, subject_records: List[SPGRecord]) -> List[SPGRecord]:
print("####################IndicatorFuse#####################")
def link(self, subject_record: SPGRecord) -> List[SPGRecord]:
print("####################IndicatorFuse(指标融合)#####################")
print("IndicatorFuse(Input): ")
print("----------------------")
[print(r) for r in subject_records]
print(subject_record)
linked_records = []
for record in subject_records:
query = {"match": {"name": record.get_property("name", "")}}
recall_records = self.search_client.search(query, start=0, size=10)
if recall_records is not None and len(recall_records) > 0:
linked_records.append(SPGRecord(
"Financial.Indicator",
{
"id": recall_records[0].doc_id,
"name": recall_records[0].properties.get("name", ""),
},
))
query = {"match": {"name": subject_record.get_property("name", "")}}
recall_records = self.search_client.search(query, start=0, size=10)
if recall_records is not None and len(recall_records) > 0:
linked_records.append(SPGRecord(
"Financial.Indicator",
{
"id": recall_records[0].doc_id,
"name": recall_records[0].properties.get("name", ""),
},
))
return linked_records
def merge(self, subject_records: List[SPGRecord], target_records: List[SPGRecord]) -> List[SPGRecord]:
def merge(self, subject_record: SPGRecord, linked_records: List[SPGRecord]) -> List[SPGRecord]:
merged_records = []
for s in subject_records:
if s in target_records:
continue
merged_records.append(s)
if not linked_records:
merged_records.append(subject_record)
print("IndicatorFuse(Output): ")
print("----------------------")
[print(r) for r in merged_records]

View File

@ -13,37 +13,29 @@ class StateFuse(FuseOp):
super().__init__()
self.search_client = SearchClient("Financial.State")
def link(self, subject_records: List[SPGRecord]) -> List[SPGRecord]:
def link(self, subject_record: SPGRecord) -> List[SPGRecord]:
print("####################StateFuse(状态融合)#####################")
print("StateFuse(Input): ")
print("----------------------")
[print(r) for r in subject_records]
print(subject_record)
linked_records = []
for record in subject_records:
query = {"match": {"name": record.get_property("name", "")}}
recall_records = self.search_client.search(query, start=0, size=10)
if recall_records is not None and len(recall_records) > 0:
linked_records.append(SPGRecord(
"Financial.State",
{
"id": recall_records[0].doc_id,
"name": recall_records[0].properties.get("name", ""),
},
)
)
return linked_records
def merge(self, subject_records: List[SPGRecord], target_records: List[SPGRecord]) -> List[SPGRecord]:
merged_records = []
for s in subject_records:
# for t in target_records:
merged_records.append(SPGRecord(
query = {"match": {"name": subject_record.get_property("name", "")}}
recall_records = self.search_client.search(query, start=0, size=10)
if recall_records is not None and len(recall_records) > 0:
linked_records.append(SPGRecord(
"Financial.State",
{
"id": s.get_property("id"),
"name": s.get_property("name", ""),
})
"id": recall_records[0].doc_id,
"name": recall_records[0].properties.get("name", ""),
},
)
)
return linked_records
def merge(self, subject_record: SPGRecord, linked_records: List[SPGRecord]) -> List[SPGRecord]:
merged_records = []
if not linked_records:
merged_records.append(subject_record)
print("StateFuse(Output): ")
print("----------------------")
[print(r) for r in merged_records]

View File

@ -20,6 +20,42 @@ from knext.common.schema_helper import SPGTypeHelper, PropertyHelper
class Financial:
class AdministrativeArea(SPGTypeHelper):
description = PropertyHelper("description")
id = PropertyHelper("id")
name = PropertyHelper("name")
stdId = PropertyHelper("stdId")
alias = PropertyHelper("alias")
class AreaRiskEvent(SPGTypeHelper):
description = PropertyHelper("description")
id = PropertyHelper("id")
name = PropertyHelper("name")
eventTime = PropertyHelper("eventTime")
object = PropertyHelper("object")
subject = PropertyHelper("subject")
class Company(SPGTypeHelper):
description = PropertyHelper("description")
id = PropertyHelper("id")
name = PropertyHelper("name")
orgCertNo = PropertyHelper("orgCertNo")
establishDate = PropertyHelper("establishDate")
regArea = PropertyHelper("regArea")
regCapital = PropertyHelper("regCapital")
businessScope = PropertyHelper("businessScope")
legalPerson = PropertyHelper("legalPerson")
class CompanyEvent(SPGTypeHelper):
description = PropertyHelper("description")
id = PropertyHelper("id")
name = PropertyHelper("name")
location = PropertyHelper("location")
eventTime = PropertyHelper("eventTime")
happenedTime = PropertyHelper("happenedTime")
subject = PropertyHelper("subject")
object = PropertyHelper("object")
class Indicator(SPGTypeHelper):
description = PropertyHelper("description")
id = PropertyHelper("id")
@ -36,6 +72,10 @@ class Financial:
derivedFrom = PropertyHelper("derivedFrom")
stdId = PropertyHelper("stdId")
AdministrativeArea = AdministrativeArea("Financial.AdministrativeArea")
AreaRiskEvent = AreaRiskEvent("Financial.AreaRiskEvent")
Company = Company("Financial.Company")
CompanyEvent = CompanyEvent("Financial.CompanyEvent")
Indicator = Indicator("Financial.Indicator")
State = State("Financial.State")

View File

@ -37,10 +37,10 @@ class Disease(BuilderJob):
spg_type_name="Medical.Disease",
property_names=[
"complication",
# "commonSymptom",
# "applicableDrug",
# "department",
# "diseaseSite",
"commonSymptom",
"applicableDrug",
"department",
"diseaseSite",
],
)
],
@ -54,10 +54,10 @@ class Disease(BuilderJob):
.add_mapping_field("id", "id")
.add_mapping_field("name", "name")
.add_mapping_field("complication", "complication")
# .add_mapping_field("commonSymptom", "commonSymptom")
# .add_mapping_field("applicableDrug", "applicableDrug")
# .add_mapping_field("department", "department")
# .add_mapping_field("diseaseSite", "diseaseSite")
.add_mapping_field("commonSymptom", "commonSymptom")
.add_mapping_field("applicableDrug", "applicableDrug")
.add_mapping_field("department", "department")
.add_mapping_field("diseaseSite", "diseaseSite")
)
"""

View File

@ -1,7 +1,7 @@
{
"invoker_type": "OpenAI",
"openai_api_key": "EMPTY",
"openai_api_base": "http://localhost:38080/v1",
"openai_api_base": "http://127.0.0.1:38080/v1",
"openai_model_name": "gpt-3.5-turbo",
"openai_max_tokens": 1000
}

View File

@ -24,18 +24,18 @@ class Medical:
description = PropertyHelper("description")
id = PropertyHelper("id")
name = PropertyHelper("name")
alias = PropertyHelper("alias")
stdId = PropertyHelper("stdId")
alias = PropertyHelper("alias")
class Disease(SPGTypeHelper):
description = PropertyHelper("description")
id = PropertyHelper("id")
name = PropertyHelper("name")
department = PropertyHelper("department")
complication = PropertyHelper("complication")
applicableDrug = PropertyHelper("applicableDrug")
diseaseSite = PropertyHelper("diseaseSite")
department = PropertyHelper("department")
commonSymptom = PropertyHelper("commonSymptom")
diseaseSite = PropertyHelper("diseaseSite")
complication = PropertyHelper("complication")
class Drug(SPGTypeHelper):
description = PropertyHelper("description")
@ -46,8 +46,8 @@ class Medical:
description = PropertyHelper("description")
id = PropertyHelper("id")
name = PropertyHelper("name")
alias = PropertyHelper("alias")
stdId = PropertyHelper("stdId")
alias = PropertyHelper("alias")
class Indicator(SPGTypeHelper):
description = PropertyHelper("description")

View File

@ -1,15 +1,12 @@
GRAPH_STORE_PARAM = "-Dcloudext.graphstore.drivers=com.antgroup.openspg.cloudext.impl.graphstore.tugraph" \
".TuGraphStoreClientDriver"
SEARCH_CLIENT_PARAM = "-Dcloudext.searchengine.drivers=com.antgroup.openspg.cloudext.impl.searchengine.elasticsearch" \
".ElasticSearchEngineClientDriver"
LOCAL_BUILDER_JAR = "builder-runner-local-0.0.1-SNAPSHOT-jar-with-dependencies.jar"
LOCAL_REASONER_JAR = ""
LOCAL_REASONER_JAR = "reasoner-local-runner-0.0.1-SNAPSHOT-jar-with-dependencies.jar"
LOCAL_SCHEMA_URL = "http://localhost:8887"
LOCAL_GRAPH_STORE_URL = "tugraph://127.0.0.1:9090?graphName=default&timeout=50000&accessId=admin&accessKey=73@TuGraph"
LOCAL_SEARCH_ENGINE_URL = "elasticsearch://127.0.0.1:9200?scheme=http"
LOCAL_GRAPH_STATE_CLASS = "com.antgroup.openspg.reasoner.warehouse.cloudext.CloudExtGraphState"

View File

@ -54,6 +54,8 @@ input:${input}
def parse_response(self, response: str) -> List[SPGRecord]:
result = []
subject = {}
if isinstance(response, list) and len(response) > 0:
response = response[0]
re_obj = json.loads(response)
if "spo" not in re_obj.keys():
raise ValueError("SPO format error.")
@ -89,21 +91,10 @@ input:${input}
result.append(subject_entity)
return result
def build_next_variables(
self, variables: Dict[str, str], response: str
) -> List[Dict[str, str]]:
re_obj = json.loads(response)
if "spo" not in re_obj.keys():
raise ValueError("SPO format error.")
re = re_obj.get("spo", [])
return [{"input": variables.get("input"), "spo": str(i)} for i in re]
def _render(self, spg_type: BaseSpgType, property_names: List[str]):
spos = []
repeat_desc = []
for property_name in property_names:
if property_name in ["id", "name", "description"]:
continue
prop = spg_type.properties.get(property_name)
object_desc = ""
object_type = self.schema_client.query_spg_type(prop.object_type_name)
@ -117,101 +108,3 @@ input:${input}
repeat_desc.extend([spg_type.name_zh, prop.name_zh, prop.object_type_name_zh])
schema_text = "\n[" + ",\n".join(spos) + "]"
self.template = self.template.replace("${schema}", schema_text)
class EEPrompt(AutoPrompt):
template: str = """
已知如下的事件schema定义:${schema}从下列句子中抽取所定义的事件如果存在以JSON格式返回如果不存在返回空字符串
input:${input}
${example}
输出格式为:{"event":[{"event_type":,"arguments":[{},]}]}
"output":
"""
def __init__(
self,
event_type_name: Union[str, SPGTypeHelper],
property_names: List[Union[str, PropertyHelper]],
custom_prompt: str = None,
):
super().__init__()
if custom_prompt:
self.template = custom_prompt
schema_client = SchemaClient()
spg_type = schema_client.query_spg_type(spg_type_name=event_type_name)
self.spg_type_name = event_type_name
self.predicate_zh_to_en_name = {}
self.predicate_type_zh_to_en_name = {}
for k, v in spg_type.properties.items():
self.predicate_zh_to_en_name[v.name_zh] = k
self.predicate_type_zh_to_en_name[v.name_zh] = v.object_type_name
self._render(spg_type, property_names)
self.params = {
"spg_type_name": event_type_name,
"property_names": property_names,
"custom_prompt": custom_prompt,
}
def build_prompt(self, variables: Dict[str, str]) -> str:
return self.template.replace("${input}", variables.get("input"))
def parse_response(self, response: str) -> List[SPGRecord]:
response = "{\"event\":[{\"event_type\":\"区域经济指标事件(区域指标)\",\"arguments\":[{\"日期\":\"2022年\",\"区域\":\"济南市\",\"来源\":\"政府公告\",\"主体\":\"山东省财政局\"}]}]}"
result = []
subject = {}
re_obj = json.loads(response)
if "event" not in re_obj.keys():
raise ValueError("Event format error.")
subject_properties = {}
for spo_item in re_obj.get("event", []):
if spo_item["predicate"] not in self.predicate_zh_to_en_name:
continue
subject_properties = {
"id": spo_item["subject"],
"name": spo_item["subject"],
}
if spo_item["subject"] not in subject:
subject[spo_item["subject"]] = subject_properties
else:
subject_properties = subject[spo_item["subject"]]
spo_en_name = self.predicate_zh_to_en_name[spo_item["predicate"]]
if spo_en_name in subject_properties and len(
subject_properties[spo_en_name]
):
subject_properties[spo_en_name] = (
subject_properties[spo_en_name] + "," + spo_item["object"]
)
else:
subject_properties[spo_en_name] = spo_item["object"]
# for k, val in subject.items():
subject_entity = SPGRecord(
spg_type_name=self.spg_type_name, properties=subject_properties
)
result.append(subject_entity)
return result
def build_next_variables(
self, variables: Dict[str, str], response: str
) -> List[Dict[str, str]]:
re_obj = json.loads(response)
if "event" not in re_obj.keys():
raise ValueError("Event format error.")
re = re_obj.get("event", [])
return [{"input": variables.get("input"), "event": str(i)} for i in re]
def _render(self, spg_type: BaseSpgType, property_names: List[str]):
arguments = []
for property_name in property_names:
if property_name in ["id", "name", "description"]:
continue
prop = spg_type.properties.get(property_name)
arguments.append(
f"{prop.name_zh}({prop.object_type_name_zh})"
)
schema_text = f"{{event_type:{spg_type.name_zh}({spg_type.desc or spg_type.name_zh}),arguments:[{','.join(arguments)}]"
self.template = self.template.replace("${schema}", schema_text)

View File

@ -60,8 +60,9 @@ class _BuiltInOnlineExtractor(ExtractOp):
elif op_name == "IndicatorLOGIC":
response = '[{"subject": "土地出让收入大幅下降", "predicate": "顺承", "object": ["综合财力明显下滑"]}]'
else:
print(query)
print(repr(query))
response = self.model.remote_inference(query)
print(response)
collector.extend(op.parse_response(response))
next_params.extend(
op.build_next_variables(input_param, response)

View File

@ -89,19 +89,22 @@ class FuseOp(BaseOp, ABC):
def __init__(self, params: Dict[str, str] = None):
super().__init__(params)
def link(self, subject_records: List[SPGRecord]) -> List[SPGRecord]:
def link(self, subject_record: SPGRecord) -> List[SPGRecord]:
raise NotImplementedError(
f"{self.__class__.__name__} need to implement `link` method."
)
def merge(self, subject_records: List[SPGRecord], target_records: List[SPGRecord]) -> List[SPGRecord]:
def merge(self, subject_record: SPGRecord, linked_records: List[SPGRecord]) -> List[SPGRecord]:
raise NotImplementedError(
f"{self.__class__.__name__} need to implement `merge` method."
)
def invoke(self, records: List[SPGRecord]) -> List[SPGRecord]:
linked_records = self.link(records)
return self.merge(records, linked_records)
def invoke(self, subject_records: List[SPGRecord]) -> List[SPGRecord]:
for record in subject_records:
linked_records = self.link(record)
merged_records = self.merge(record, linked_records)
return merged_records
return []
@staticmethod
def _pre_process(*inputs):
@ -136,6 +139,8 @@ class PromptOp(BaseOp, ABC):
def build_next_variables(
self, variables: Dict[str, str], response: str
) -> List[Dict[str, str]]:
if isinstance(response, list) and len(response) > 0:
response = response[0]
variables.update({f"{self.__class__.__name__}": response})
print("LLM(Output): ")
print("----------------------")