mirror of
https://github.com/OpenSPG/openspg.git
synced 2025-09-16 20:18:44 +00:00
fix
This commit is contained in:
parent
6a8e597b26
commit
d205c894d3
@ -10,10 +10,9 @@
|
|||||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
# or implied.
|
# or implied.
|
||||||
|
|
||||||
from knext.operator.builtin.auto_prompt import REPrompt, EEPrompt
|
from knext.operator.builtin.auto_prompt import REPrompt
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"REPrompt",
|
"REPrompt",
|
||||||
"EEPrompt",
|
|
||||||
]
|
]
|
||||||
|
@ -9,7 +9,7 @@
|
|||||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
# or implied.
|
# or implied.
|
||||||
|
import os
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from knext import rest
|
from knext import rest
|
||||||
@ -47,6 +47,38 @@ class ReasonerClient(Client):
|
|||||||
reasoner_job_submit_request=request
|
reasoner_job_submit_request=request
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def execute(self, dsl_content: str, output_file: str = None):
|
||||||
|
"""
|
||||||
|
--projectId 2 \ --query "MATCH (s:`RiskMining.TaxOfRiskUser`/`赌博App开发者`) RETURN s.id,s.name " \ --output ./reasoner.csv \ --schemaUrl "http://localhost:8887" \ --graphStateClass "com.antgroup.openspg.reasoner.warehouse.cloudext.CloudExtGraphState" \ --graphStoreUrl "tugraph://127.0.0.1:9090?graphName=default&timeout=60000&accessId=admin&accessKey=73@TuGraph" \
|
||||||
|
"""
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
import datetime
|
||||||
|
from knext import lib
|
||||||
|
jar_path = os.path.join(lib.__path__[0], lib.LOCAL_REASONER_JAR)
|
||||||
|
default_output_file = f"./{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.csv"
|
||||||
|
|
||||||
|
java_cmd = ['java', '-jar',
|
||||||
|
jar_path,
|
||||||
|
"--projectId", self._project_id,
|
||||||
|
"--query", dsl_content,
|
||||||
|
"--output", output_file or default_output_file,
|
||||||
|
"--schemaUrl", os.environ.get("KNEXT_HOST_ADDR") or lib.LOCAL_SCHEMA_URL,
|
||||||
|
"--graphStateClass", os.environ.get("KNEXT_GRAPH_STATE_CLASS") or lib.LOCAL_GRAPH_STATE_CLASS,
|
||||||
|
"--graphStoreUrl", os.environ.get("KNEXT_GRAPH_STORE_URL") or lib.LOCAL_GRAPH_STORE_URL,
|
||||||
|
]
|
||||||
|
|
||||||
|
print_java_cmd = [
|
||||||
|
cmd if not cmd.startswith("{") else f"'{cmd}'" for cmd in java_cmd
|
||||||
|
]
|
||||||
|
print_java_cmd = [
|
||||||
|
cmd if not cmd.count(";") > 0 else f"'{cmd}'" for cmd in print_java_cmd
|
||||||
|
]
|
||||||
|
import json
|
||||||
|
print(json.dumps(" ".join(print_java_cmd))[1:-1].replace("'", '"'))
|
||||||
|
|
||||||
|
subprocess.call(java_cmd)
|
||||||
|
|
||||||
def run_dsl(self, dsl_content: str):
|
def run_dsl(self, dsl_content: str):
|
||||||
"""Submit a synchronization reasoner job by providing DSL content."""
|
"""Submit a synchronization reasoner job by providing DSL content."""
|
||||||
content = rest.KgdslReasonerContent(kgdsl=dsl_content)
|
content = rest.KgdslReasonerContent(kgdsl=dsl_content)
|
||||||
|
@ -26,6 +26,7 @@ from knext.command.sub_command.project import list_project
|
|||||||
from knext.command.sub_command.reasoner import query_reasoner_job
|
from knext.command.sub_command.reasoner import query_reasoner_job
|
||||||
from knext.command.sub_command.reasoner import run_dsl
|
from knext.command.sub_command.reasoner import run_dsl
|
||||||
from knext.command.sub_command.reasoner import submit_reasoner_job
|
from knext.command.sub_command.reasoner import submit_reasoner_job
|
||||||
|
from knext.command.sub_command.reasoner import execute_reasoner_job
|
||||||
from knext.command.sub_command.schema import commit_schema
|
from knext.command.sub_command.schema import commit_schema
|
||||||
from knext.command.sub_command.schema import diff_schema
|
from knext.command.sub_command.schema import diff_schema
|
||||||
from knext.command.sub_command.schema import list_schema
|
from knext.command.sub_command.schema import list_schema
|
||||||
@ -110,6 +111,7 @@ def reasoner() -> None:
|
|||||||
reasoner.command("submit")(submit_reasoner_job)
|
reasoner.command("submit")(submit_reasoner_job)
|
||||||
reasoner.command("query")(run_dsl)
|
reasoner.command("query")(run_dsl)
|
||||||
reasoner.command("get")(query_reasoner_job)
|
reasoner.command("get")(query_reasoner_job)
|
||||||
|
reasoner.command("execute")(execute_reasoner_job)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
_main()
|
_main()
|
||||||
|
@ -121,3 +121,22 @@ def query_reasoner_job(id):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
|
|
||||||
|
@click.option("--file", help="Path of DSL file.")
|
||||||
|
@click.option("--dsl", help="DSL string enclosed in double quotes.")
|
||||||
|
@click.option("--output", help="Output file.")
|
||||||
|
def execute_reasoner_job(file, dsl, output=None):
|
||||||
|
"""
|
||||||
|
Submit asynchronous reasoner jobs to server by providing DSL file or string.
|
||||||
|
"""
|
||||||
|
client = ReasonerClient()
|
||||||
|
if file and not dsl:
|
||||||
|
with open(file, "r") as f:
|
||||||
|
dsl_content = f.read()
|
||||||
|
elif not file and dsl:
|
||||||
|
dsl_content = dsl
|
||||||
|
else:
|
||||||
|
click.secho("ERROR: Please choose either --file or --dsl.", fg="bright_red")
|
||||||
|
sys.exit()
|
||||||
|
client.execute(dsl_content, output)
|
||||||
|
64
python/knext/knext/examples/financial/builder/job/company.py
Normal file
64
python/knext/knext/examples/financial/builder/job/company.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from knext.examples.financial.schema.financial_schema_helper import Financial
|
||||||
|
|
||||||
|
from knext.api.component import CSVReader, LLMBasedExtractor, KGWriter, SubGraphMapping
|
||||||
|
from knext.client.model.builder_job import BuilderJob
|
||||||
|
from nn4k.invoker import LLMInvoker
|
||||||
|
|
||||||
|
|
||||||
|
class Company(BuilderJob):
|
||||||
|
def build(self):
|
||||||
|
source = CSVReader(
|
||||||
|
local_path="builder/job/data/company.csv", columns=["input"], start_row=2
|
||||||
|
)
|
||||||
|
|
||||||
|
from knext.api.auto_prompt import REPrompt
|
||||||
|
prompt = REPrompt(
|
||||||
|
spg_type_name=Financial.Company,
|
||||||
|
property_names=[
|
||||||
|
Financial.Company.name,
|
||||||
|
Financial.Company.orgCertNo,
|
||||||
|
Financial.Company.regArea,
|
||||||
|
Financial.Company.businessScope,
|
||||||
|
Financial.Company.establishDate,
|
||||||
|
Financial.Company.legalPerson,
|
||||||
|
Financial.Company.regCapital
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
extract = LLMBasedExtractor(
|
||||||
|
llm=LLMInvoker.from_config("builder/model/openai_infer.json"),
|
||||||
|
prompt_ops=[prompt],
|
||||||
|
)
|
||||||
|
|
||||||
|
mapping = (
|
||||||
|
SubGraphMapping(spg_type_name=Financial.Company)
|
||||||
|
.add_mapping_field("name", Financial.Company.id)
|
||||||
|
.add_mapping_field("name", Financial.Company.name)
|
||||||
|
.add_mapping_field("regArea", Financial.Company.regArea)
|
||||||
|
.add_mapping_field("businessScope", Financial.Company.businessScope)
|
||||||
|
.add_mapping_field("establishDate", Financial.Company.establishDate)
|
||||||
|
.add_mapping_field("legalPerson", Financial.Company.legalPerson)
|
||||||
|
.add_mapping_field("regCapital", Financial.Company.regCapital)
|
||||||
|
)
|
||||||
|
|
||||||
|
sink = KGWriter()
|
||||||
|
|
||||||
|
return source >> extract >> mapping >> sink
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from knext.api.auto_prompt import REPrompt
|
||||||
|
prompt = REPrompt(
|
||||||
|
spg_type_name=Financial.Company,
|
||||||
|
property_names=[
|
||||||
|
Financial.Company.orgCertNo,
|
||||||
|
Financial.Company.regArea,
|
||||||
|
Financial.Company.businessScope,
|
||||||
|
Financial.Company.establishDate,
|
||||||
|
Financial.Company.legalPerson,
|
||||||
|
Financial.Company.regCapital
|
||||||
|
],
|
||||||
|
)
|
||||||
|
print(prompt.template)
|
@ -0,0 +1,2 @@
|
|||||||
|
input
|
||||||
|
阿里巴巴(中国)有限公司是一家从事企业管理,计算机系统服务,电脑动画设计等业务的公司,成立于2007年03月26日,公司坐落在浙江省;经营有阿里邮箱、浙烟邮箱,师生家校、点淘-淘宝直播官方平台、云上会展等产品,经国家企业信用信息公示系统查询得知,阿里巴巴(中国)有限公司的信用代码/税号为91330100799655058B,法人是蒋芳,注册资本为15412.764910万美元,企业的经营范围为:服务:企业管理,计算机系统服务,电脑动画设计,经济信息咨询服务(除商品中介),成年人的非证书劳动职业技能培训和成年人的非文化教育培训(涉及前置审批的项目除外);生产:计算机软件;销售自产产品。(国家禁止和限制的除外,凡涉及许可证制度的凭证经营)
|
|
@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"invoker_type": "OpenAI",
|
"invoker_type": "OpenAI",
|
||||||
"openai_api_key": "EMPTY",
|
"openai_api_key": "EMPTY",
|
||||||
"openai_api_base": "http://localhost:38000/v1",
|
"openai_api_base": "http://127.0.0.1:38080/v1",
|
||||||
"openai_model_name": "vicuna-7b-v1.5",
|
"openai_model_name": "gpt-3.5-turbo",
|
||||||
"openai_max_tokens": 1000
|
"openai_max_tokens": 2000
|
||||||
}
|
}
|
@ -13,14 +13,13 @@ class IndicatorFuse(FuseOp):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.search_client = SearchClient("Financial.Indicator")
|
self.search_client = SearchClient("Financial.Indicator")
|
||||||
|
|
||||||
def link(self, subject_records: List[SPGRecord]) -> List[SPGRecord]:
|
def link(self, subject_record: SPGRecord) -> List[SPGRecord]:
|
||||||
print("####################IndicatorFuse#####################")
|
print("####################IndicatorFuse(指标融合)#####################")
|
||||||
print("IndicatorFuse(Input): ")
|
print("IndicatorFuse(Input): ")
|
||||||
print("----------------------")
|
print("----------------------")
|
||||||
[print(r) for r in subject_records]
|
print(subject_record)
|
||||||
linked_records = []
|
linked_records = []
|
||||||
for record in subject_records:
|
query = {"match": {"name": subject_record.get_property("name", "")}}
|
||||||
query = {"match": {"name": record.get_property("name", "")}}
|
|
||||||
recall_records = self.search_client.search(query, start=0, size=10)
|
recall_records = self.search_client.search(query, start=0, size=10)
|
||||||
if recall_records is not None and len(recall_records) > 0:
|
if recall_records is not None and len(recall_records) > 0:
|
||||||
linked_records.append(SPGRecord(
|
linked_records.append(SPGRecord(
|
||||||
@ -32,12 +31,10 @@ class IndicatorFuse(FuseOp):
|
|||||||
))
|
))
|
||||||
return linked_records
|
return linked_records
|
||||||
|
|
||||||
def merge(self, subject_records: List[SPGRecord], target_records: List[SPGRecord]) -> List[SPGRecord]:
|
def merge(self, subject_record: SPGRecord, linked_records: List[SPGRecord]) -> List[SPGRecord]:
|
||||||
merged_records = []
|
merged_records = []
|
||||||
for s in subject_records:
|
if not linked_records:
|
||||||
if s in target_records:
|
merged_records.append(subject_record)
|
||||||
continue
|
|
||||||
merged_records.append(s)
|
|
||||||
print("IndicatorFuse(Output): ")
|
print("IndicatorFuse(Output): ")
|
||||||
print("----------------------")
|
print("----------------------")
|
||||||
[print(r) for r in merged_records]
|
[print(r) for r in merged_records]
|
||||||
|
@ -13,14 +13,13 @@ class StateFuse(FuseOp):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.search_client = SearchClient("Financial.State")
|
self.search_client = SearchClient("Financial.State")
|
||||||
|
|
||||||
def link(self, subject_records: List[SPGRecord]) -> List[SPGRecord]:
|
def link(self, subject_record: SPGRecord) -> List[SPGRecord]:
|
||||||
print("####################StateFuse(状态融合)#####################")
|
print("####################StateFuse(状态融合)#####################")
|
||||||
print("StateFuse(Input): ")
|
print("StateFuse(Input): ")
|
||||||
print("----------------------")
|
print("----------------------")
|
||||||
[print(r) for r in subject_records]
|
print(subject_record)
|
||||||
linked_records = []
|
linked_records = []
|
||||||
for record in subject_records:
|
query = {"match": {"name": subject_record.get_property("name", "")}}
|
||||||
query = {"match": {"name": record.get_property("name", "")}}
|
|
||||||
recall_records = self.search_client.search(query, start=0, size=10)
|
recall_records = self.search_client.search(query, start=0, size=10)
|
||||||
if recall_records is not None and len(recall_records) > 0:
|
if recall_records is not None and len(recall_records) > 0:
|
||||||
linked_records.append(SPGRecord(
|
linked_records.append(SPGRecord(
|
||||||
@ -33,17 +32,10 @@ class StateFuse(FuseOp):
|
|||||||
)
|
)
|
||||||
return linked_records
|
return linked_records
|
||||||
|
|
||||||
def merge(self, subject_records: List[SPGRecord], target_records: List[SPGRecord]) -> List[SPGRecord]:
|
def merge(self, subject_record: SPGRecord, linked_records: List[SPGRecord]) -> List[SPGRecord]:
|
||||||
merged_records = []
|
merged_records = []
|
||||||
for s in subject_records:
|
if not linked_records:
|
||||||
# for t in target_records:
|
merged_records.append(subject_record)
|
||||||
merged_records.append(SPGRecord(
|
|
||||||
"Financial.State",
|
|
||||||
{
|
|
||||||
"id": s.get_property("id"),
|
|
||||||
"name": s.get_property("name", ""),
|
|
||||||
})
|
|
||||||
)
|
|
||||||
print("StateFuse(Output): ")
|
print("StateFuse(Output): ")
|
||||||
print("----------------------")
|
print("----------------------")
|
||||||
[print(r) for r in merged_records]
|
[print(r) for r in merged_records]
|
||||||
|
@ -20,6 +20,42 @@ from knext.common.schema_helper import SPGTypeHelper, PropertyHelper
|
|||||||
|
|
||||||
class Financial:
|
class Financial:
|
||||||
|
|
||||||
|
class AdministrativeArea(SPGTypeHelper):
|
||||||
|
description = PropertyHelper("description")
|
||||||
|
id = PropertyHelper("id")
|
||||||
|
name = PropertyHelper("name")
|
||||||
|
stdId = PropertyHelper("stdId")
|
||||||
|
alias = PropertyHelper("alias")
|
||||||
|
|
||||||
|
class AreaRiskEvent(SPGTypeHelper):
|
||||||
|
description = PropertyHelper("description")
|
||||||
|
id = PropertyHelper("id")
|
||||||
|
name = PropertyHelper("name")
|
||||||
|
eventTime = PropertyHelper("eventTime")
|
||||||
|
object = PropertyHelper("object")
|
||||||
|
subject = PropertyHelper("subject")
|
||||||
|
|
||||||
|
class Company(SPGTypeHelper):
|
||||||
|
description = PropertyHelper("description")
|
||||||
|
id = PropertyHelper("id")
|
||||||
|
name = PropertyHelper("name")
|
||||||
|
orgCertNo = PropertyHelper("orgCertNo")
|
||||||
|
establishDate = PropertyHelper("establishDate")
|
||||||
|
regArea = PropertyHelper("regArea")
|
||||||
|
regCapital = PropertyHelper("regCapital")
|
||||||
|
businessScope = PropertyHelper("businessScope")
|
||||||
|
legalPerson = PropertyHelper("legalPerson")
|
||||||
|
|
||||||
|
class CompanyEvent(SPGTypeHelper):
|
||||||
|
description = PropertyHelper("description")
|
||||||
|
id = PropertyHelper("id")
|
||||||
|
name = PropertyHelper("name")
|
||||||
|
location = PropertyHelper("location")
|
||||||
|
eventTime = PropertyHelper("eventTime")
|
||||||
|
happenedTime = PropertyHelper("happenedTime")
|
||||||
|
subject = PropertyHelper("subject")
|
||||||
|
object = PropertyHelper("object")
|
||||||
|
|
||||||
class Indicator(SPGTypeHelper):
|
class Indicator(SPGTypeHelper):
|
||||||
description = PropertyHelper("description")
|
description = PropertyHelper("description")
|
||||||
id = PropertyHelper("id")
|
id = PropertyHelper("id")
|
||||||
@ -36,6 +72,10 @@ class Financial:
|
|||||||
derivedFrom = PropertyHelper("derivedFrom")
|
derivedFrom = PropertyHelper("derivedFrom")
|
||||||
stdId = PropertyHelper("stdId")
|
stdId = PropertyHelper("stdId")
|
||||||
|
|
||||||
|
AdministrativeArea = AdministrativeArea("Financial.AdministrativeArea")
|
||||||
|
AreaRiskEvent = AreaRiskEvent("Financial.AreaRiskEvent")
|
||||||
|
Company = Company("Financial.Company")
|
||||||
|
CompanyEvent = CompanyEvent("Financial.CompanyEvent")
|
||||||
Indicator = Indicator("Financial.Indicator")
|
Indicator = Indicator("Financial.Indicator")
|
||||||
State = State("Financial.State")
|
State = State("Financial.State")
|
||||||
|
|
@ -37,10 +37,10 @@ class Disease(BuilderJob):
|
|||||||
spg_type_name="Medical.Disease",
|
spg_type_name="Medical.Disease",
|
||||||
property_names=[
|
property_names=[
|
||||||
"complication",
|
"complication",
|
||||||
# "commonSymptom",
|
"commonSymptom",
|
||||||
# "applicableDrug",
|
"applicableDrug",
|
||||||
# "department",
|
"department",
|
||||||
# "diseaseSite",
|
"diseaseSite",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
@ -54,10 +54,10 @@ class Disease(BuilderJob):
|
|||||||
.add_mapping_field("id", "id")
|
.add_mapping_field("id", "id")
|
||||||
.add_mapping_field("name", "name")
|
.add_mapping_field("name", "name")
|
||||||
.add_mapping_field("complication", "complication")
|
.add_mapping_field("complication", "complication")
|
||||||
# .add_mapping_field("commonSymptom", "commonSymptom")
|
.add_mapping_field("commonSymptom", "commonSymptom")
|
||||||
# .add_mapping_field("applicableDrug", "applicableDrug")
|
.add_mapping_field("applicableDrug", "applicableDrug")
|
||||||
# .add_mapping_field("department", "department")
|
.add_mapping_field("department", "department")
|
||||||
# .add_mapping_field("diseaseSite", "diseaseSite")
|
.add_mapping_field("diseaseSite", "diseaseSite")
|
||||||
)
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"invoker_type": "OpenAI",
|
"invoker_type": "OpenAI",
|
||||||
"openai_api_key": "EMPTY",
|
"openai_api_key": "EMPTY",
|
||||||
"openai_api_base": "http://localhost:38080/v1",
|
"openai_api_base": "http://127.0.0.1:38080/v1",
|
||||||
"openai_model_name": "gpt-3.5-turbo",
|
"openai_model_name": "gpt-3.5-turbo",
|
||||||
"openai_max_tokens": 1000
|
"openai_max_tokens": 1000
|
||||||
}
|
}
|
@ -24,18 +24,18 @@ class Medical:
|
|||||||
description = PropertyHelper("description")
|
description = PropertyHelper("description")
|
||||||
id = PropertyHelper("id")
|
id = PropertyHelper("id")
|
||||||
name = PropertyHelper("name")
|
name = PropertyHelper("name")
|
||||||
alias = PropertyHelper("alias")
|
|
||||||
stdId = PropertyHelper("stdId")
|
stdId = PropertyHelper("stdId")
|
||||||
|
alias = PropertyHelper("alias")
|
||||||
|
|
||||||
class Disease(SPGTypeHelper):
|
class Disease(SPGTypeHelper):
|
||||||
description = PropertyHelper("description")
|
description = PropertyHelper("description")
|
||||||
id = PropertyHelper("id")
|
id = PropertyHelper("id")
|
||||||
name = PropertyHelper("name")
|
name = PropertyHelper("name")
|
||||||
department = PropertyHelper("department")
|
|
||||||
complication = PropertyHelper("complication")
|
|
||||||
applicableDrug = PropertyHelper("applicableDrug")
|
applicableDrug = PropertyHelper("applicableDrug")
|
||||||
diseaseSite = PropertyHelper("diseaseSite")
|
department = PropertyHelper("department")
|
||||||
commonSymptom = PropertyHelper("commonSymptom")
|
commonSymptom = PropertyHelper("commonSymptom")
|
||||||
|
diseaseSite = PropertyHelper("diseaseSite")
|
||||||
|
complication = PropertyHelper("complication")
|
||||||
|
|
||||||
class Drug(SPGTypeHelper):
|
class Drug(SPGTypeHelper):
|
||||||
description = PropertyHelper("description")
|
description = PropertyHelper("description")
|
||||||
@ -46,8 +46,8 @@ class Medical:
|
|||||||
description = PropertyHelper("description")
|
description = PropertyHelper("description")
|
||||||
id = PropertyHelper("id")
|
id = PropertyHelper("id")
|
||||||
name = PropertyHelper("name")
|
name = PropertyHelper("name")
|
||||||
alias = PropertyHelper("alias")
|
|
||||||
stdId = PropertyHelper("stdId")
|
stdId = PropertyHelper("stdId")
|
||||||
|
alias = PropertyHelper("alias")
|
||||||
|
|
||||||
class Indicator(SPGTypeHelper):
|
class Indicator(SPGTypeHelper):
|
||||||
description = PropertyHelper("description")
|
description = PropertyHelper("description")
|
||||||
|
@ -1,15 +1,12 @@
|
|||||||
GRAPH_STORE_PARAM = "-Dcloudext.graphstore.drivers=com.antgroup.openspg.cloudext.impl.graphstore.tugraph" \
|
|
||||||
".TuGraphStoreClientDriver"
|
|
||||||
|
|
||||||
SEARCH_CLIENT_PARAM = "-Dcloudext.searchengine.drivers=com.antgroup.openspg.cloudext.impl.searchengine.elasticsearch" \
|
|
||||||
".ElasticSearchEngineClientDriver"
|
|
||||||
|
|
||||||
LOCAL_BUILDER_JAR = "builder-runner-local-0.0.1-SNAPSHOT-jar-with-dependencies.jar"
|
LOCAL_BUILDER_JAR = "builder-runner-local-0.0.1-SNAPSHOT-jar-with-dependencies.jar"
|
||||||
|
|
||||||
LOCAL_REASONER_JAR = ""
|
LOCAL_REASONER_JAR = "reasoner-local-runner-0.0.1-SNAPSHOT-jar-with-dependencies.jar"
|
||||||
|
|
||||||
LOCAL_SCHEMA_URL = "http://localhost:8887"
|
LOCAL_SCHEMA_URL = "http://localhost:8887"
|
||||||
|
|
||||||
LOCAL_GRAPH_STORE_URL = "tugraph://127.0.0.1:9090?graphName=default&timeout=50000&accessId=admin&accessKey=73@TuGraph"
|
LOCAL_GRAPH_STORE_URL = "tugraph://127.0.0.1:9090?graphName=default&timeout=50000&accessId=admin&accessKey=73@TuGraph"
|
||||||
|
|
||||||
LOCAL_SEARCH_ENGINE_URL = "elasticsearch://127.0.0.1:9200?scheme=http"
|
LOCAL_SEARCH_ENGINE_URL = "elasticsearch://127.0.0.1:9200?scheme=http"
|
||||||
|
|
||||||
|
LOCAL_GRAPH_STATE_CLASS = "com.antgroup.openspg.reasoner.warehouse.cloudext.CloudExtGraphState"
|
||||||
|
@ -54,6 +54,8 @@ input:${input}
|
|||||||
def parse_response(self, response: str) -> List[SPGRecord]:
|
def parse_response(self, response: str) -> List[SPGRecord]:
|
||||||
result = []
|
result = []
|
||||||
subject = {}
|
subject = {}
|
||||||
|
if isinstance(response, list) and len(response) > 0:
|
||||||
|
response = response[0]
|
||||||
re_obj = json.loads(response)
|
re_obj = json.loads(response)
|
||||||
if "spo" not in re_obj.keys():
|
if "spo" not in re_obj.keys():
|
||||||
raise ValueError("SPO format error.")
|
raise ValueError("SPO format error.")
|
||||||
@ -89,21 +91,10 @@ input:${input}
|
|||||||
result.append(subject_entity)
|
result.append(subject_entity)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def build_next_variables(
|
|
||||||
self, variables: Dict[str, str], response: str
|
|
||||||
) -> List[Dict[str, str]]:
|
|
||||||
re_obj = json.loads(response)
|
|
||||||
if "spo" not in re_obj.keys():
|
|
||||||
raise ValueError("SPO format error.")
|
|
||||||
re = re_obj.get("spo", [])
|
|
||||||
return [{"input": variables.get("input"), "spo": str(i)} for i in re]
|
|
||||||
|
|
||||||
def _render(self, spg_type: BaseSpgType, property_names: List[str]):
|
def _render(self, spg_type: BaseSpgType, property_names: List[str]):
|
||||||
spos = []
|
spos = []
|
||||||
repeat_desc = []
|
repeat_desc = []
|
||||||
for property_name in property_names:
|
for property_name in property_names:
|
||||||
if property_name in ["id", "name", "description"]:
|
|
||||||
continue
|
|
||||||
prop = spg_type.properties.get(property_name)
|
prop = spg_type.properties.get(property_name)
|
||||||
object_desc = ""
|
object_desc = ""
|
||||||
object_type = self.schema_client.query_spg_type(prop.object_type_name)
|
object_type = self.schema_client.query_spg_type(prop.object_type_name)
|
||||||
@ -117,101 +108,3 @@ input:${input}
|
|||||||
repeat_desc.extend([spg_type.name_zh, prop.name_zh, prop.object_type_name_zh])
|
repeat_desc.extend([spg_type.name_zh, prop.name_zh, prop.object_type_name_zh])
|
||||||
schema_text = "\n[" + ",\n".join(spos) + "]"
|
schema_text = "\n[" + ",\n".join(spos) + "]"
|
||||||
self.template = self.template.replace("${schema}", schema_text)
|
self.template = self.template.replace("${schema}", schema_text)
|
||||||
|
|
||||||
|
|
||||||
class EEPrompt(AutoPrompt):
|
|
||||||
template: str = """
|
|
||||||
已知如下的事件schema定义:${schema}。从下列句子中抽取所定义的事件,如果存在以JSON格式返回,如果不存在返回空字符串。
|
|
||||||
input:${input}
|
|
||||||
${example}
|
|
||||||
输出格式为:{"event":[{"event_type":,"arguments":[{},]}]}
|
|
||||||
"output":
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
event_type_name: Union[str, SPGTypeHelper],
|
|
||||||
property_names: List[Union[str, PropertyHelper]],
|
|
||||||
custom_prompt: str = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if custom_prompt:
|
|
||||||
self.template = custom_prompt
|
|
||||||
schema_client = SchemaClient()
|
|
||||||
spg_type = schema_client.query_spg_type(spg_type_name=event_type_name)
|
|
||||||
self.spg_type_name = event_type_name
|
|
||||||
self.predicate_zh_to_en_name = {}
|
|
||||||
self.predicate_type_zh_to_en_name = {}
|
|
||||||
for k, v in spg_type.properties.items():
|
|
||||||
self.predicate_zh_to_en_name[v.name_zh] = k
|
|
||||||
self.predicate_type_zh_to_en_name[v.name_zh] = v.object_type_name
|
|
||||||
self._render(spg_type, property_names)
|
|
||||||
self.params = {
|
|
||||||
"spg_type_name": event_type_name,
|
|
||||||
"property_names": property_names,
|
|
||||||
"custom_prompt": custom_prompt,
|
|
||||||
}
|
|
||||||
|
|
||||||
def build_prompt(self, variables: Dict[str, str]) -> str:
|
|
||||||
return self.template.replace("${input}", variables.get("input"))
|
|
||||||
|
|
||||||
def parse_response(self, response: str) -> List[SPGRecord]:
|
|
||||||
response = "{\"event\":[{\"event_type\":\"区域经济指标事件(区域指标)\",\"arguments\":[{\"日期\":\"2022年\",\"区域\":\"济南市\",\"来源\":\"政府公告\",\"主体\":\"山东省财政局\"}]}]}"
|
|
||||||
result = []
|
|
||||||
subject = {}
|
|
||||||
re_obj = json.loads(response)
|
|
||||||
if "event" not in re_obj.keys():
|
|
||||||
raise ValueError("Event format error.")
|
|
||||||
subject_properties = {}
|
|
||||||
for spo_item in re_obj.get("event", []):
|
|
||||||
if spo_item["predicate"] not in self.predicate_zh_to_en_name:
|
|
||||||
continue
|
|
||||||
subject_properties = {
|
|
||||||
"id": spo_item["subject"],
|
|
||||||
"name": spo_item["subject"],
|
|
||||||
}
|
|
||||||
if spo_item["subject"] not in subject:
|
|
||||||
subject[spo_item["subject"]] = subject_properties
|
|
||||||
else:
|
|
||||||
subject_properties = subject[spo_item["subject"]]
|
|
||||||
|
|
||||||
spo_en_name = self.predicate_zh_to_en_name[spo_item["predicate"]]
|
|
||||||
|
|
||||||
if spo_en_name in subject_properties and len(
|
|
||||||
subject_properties[spo_en_name]
|
|
||||||
):
|
|
||||||
subject_properties[spo_en_name] = (
|
|
||||||
subject_properties[spo_en_name] + "," + spo_item["object"]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
subject_properties[spo_en_name] = spo_item["object"]
|
|
||||||
|
|
||||||
# for k, val in subject.items():
|
|
||||||
subject_entity = SPGRecord(
|
|
||||||
spg_type_name=self.spg_type_name, properties=subject_properties
|
|
||||||
)
|
|
||||||
result.append(subject_entity)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def build_next_variables(
|
|
||||||
self, variables: Dict[str, str], response: str
|
|
||||||
) -> List[Dict[str, str]]:
|
|
||||||
re_obj = json.loads(response)
|
|
||||||
if "event" not in re_obj.keys():
|
|
||||||
raise ValueError("Event format error.")
|
|
||||||
re = re_obj.get("event", [])
|
|
||||||
return [{"input": variables.get("input"), "event": str(i)} for i in re]
|
|
||||||
|
|
||||||
def _render(self, spg_type: BaseSpgType, property_names: List[str]):
|
|
||||||
arguments = []
|
|
||||||
for property_name in property_names:
|
|
||||||
if property_name in ["id", "name", "description"]:
|
|
||||||
continue
|
|
||||||
prop = spg_type.properties.get(property_name)
|
|
||||||
arguments.append(
|
|
||||||
f"{prop.name_zh}({prop.object_type_name_zh})"
|
|
||||||
)
|
|
||||||
|
|
||||||
schema_text = f"{{event_type:{spg_type.name_zh}({spg_type.desc or spg_type.name_zh}),arguments:[{','.join(arguments)}]"
|
|
||||||
self.template = self.template.replace("${schema}", schema_text)
|
|
||||||
|
@ -60,8 +60,9 @@ class _BuiltInOnlineExtractor(ExtractOp):
|
|||||||
elif op_name == "IndicatorLOGIC":
|
elif op_name == "IndicatorLOGIC":
|
||||||
response = '[{"subject": "土地出让收入大幅下降", "predicate": "顺承", "object": ["综合财力明显下滑"]}]'
|
response = '[{"subject": "土地出让收入大幅下降", "predicate": "顺承", "object": ["综合财力明显下滑"]}]'
|
||||||
else:
|
else:
|
||||||
print(query)
|
print(repr(query))
|
||||||
response = self.model.remote_inference(query)
|
response = self.model.remote_inference(query)
|
||||||
|
print(response)
|
||||||
collector.extend(op.parse_response(response))
|
collector.extend(op.parse_response(response))
|
||||||
next_params.extend(
|
next_params.extend(
|
||||||
op.build_next_variables(input_param, response)
|
op.build_next_variables(input_param, response)
|
||||||
|
@ -89,19 +89,22 @@ class FuseOp(BaseOp, ABC):
|
|||||||
def __init__(self, params: Dict[str, str] = None):
|
def __init__(self, params: Dict[str, str] = None):
|
||||||
super().__init__(params)
|
super().__init__(params)
|
||||||
|
|
||||||
def link(self, subject_records: List[SPGRecord]) -> List[SPGRecord]:
|
def link(self, subject_record: SPGRecord) -> List[SPGRecord]:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"{self.__class__.__name__} need to implement `link` method."
|
f"{self.__class__.__name__} need to implement `link` method."
|
||||||
)
|
)
|
||||||
|
|
||||||
def merge(self, subject_records: List[SPGRecord], target_records: List[SPGRecord]) -> List[SPGRecord]:
|
def merge(self, subject_record: SPGRecord, linked_records: List[SPGRecord]) -> List[SPGRecord]:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"{self.__class__.__name__} need to implement `merge` method."
|
f"{self.__class__.__name__} need to implement `merge` method."
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, records: List[SPGRecord]) -> List[SPGRecord]:
|
def invoke(self, subject_records: List[SPGRecord]) -> List[SPGRecord]:
|
||||||
linked_records = self.link(records)
|
for record in subject_records:
|
||||||
return self.merge(records, linked_records)
|
linked_records = self.link(record)
|
||||||
|
merged_records = self.merge(record, linked_records)
|
||||||
|
return merged_records
|
||||||
|
return []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _pre_process(*inputs):
|
def _pre_process(*inputs):
|
||||||
@ -136,6 +139,8 @@ class PromptOp(BaseOp, ABC):
|
|||||||
def build_next_variables(
|
def build_next_variables(
|
||||||
self, variables: Dict[str, str], response: str
|
self, variables: Dict[str, str], response: str
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
|
if isinstance(response, list) and len(response) > 0:
|
||||||
|
response = response[0]
|
||||||
variables.update({f"{self.__class__.__name__}": response})
|
variables.update({f"{self.__class__.__name__}": response})
|
||||||
print("LLM(Output): ")
|
print("LLM(Output): ")
|
||||||
print("----------------------")
|
print("----------------------")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user