mirror of
https://github.com/OpenSPG/openspg.git
synced 2025-09-16 12:07:58 +00:00
fix
This commit is contained in:
parent
6a8e597b26
commit
d205c894d3
@ -10,10 +10,9 @@
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
from knext.operator.builtin.auto_prompt import REPrompt, EEPrompt
|
||||
from knext.operator.builtin.auto_prompt import REPrompt
|
||||
|
||||
|
||||
__all__ = [
|
||||
"REPrompt",
|
||||
"EEPrompt",
|
||||
]
|
||||
|
@ -9,7 +9,7 @@
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
import os
|
||||
from enum import Enum
|
||||
|
||||
from knext import rest
|
||||
@ -47,6 +47,38 @@ class ReasonerClient(Client):
|
||||
reasoner_job_submit_request=request
|
||||
)
|
||||
|
||||
def execute(self, dsl_content: str, output_file: str = None):
|
||||
"""
|
||||
--projectId 2 \ --query "MATCH (s:`RiskMining.TaxOfRiskUser`/`赌博App开发者`) RETURN s.id,s.name " \ --output ./reasoner.csv \ --schemaUrl "http://localhost:8887" \ --graphStateClass "com.antgroup.openspg.reasoner.warehouse.cloudext.CloudExtGraphState" \ --graphStoreUrl "tugraph://127.0.0.1:9090?graphName=default&timeout=60000&accessId=admin&accessKey=73@TuGraph" \
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import datetime
|
||||
from knext import lib
|
||||
jar_path = os.path.join(lib.__path__[0], lib.LOCAL_REASONER_JAR)
|
||||
default_output_file = f"./{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.csv"
|
||||
|
||||
java_cmd = ['java', '-jar',
|
||||
jar_path,
|
||||
"--projectId", self._project_id,
|
||||
"--query", dsl_content,
|
||||
"--output", output_file or default_output_file,
|
||||
"--schemaUrl", os.environ.get("KNEXT_HOST_ADDR") or lib.LOCAL_SCHEMA_URL,
|
||||
"--graphStateClass", os.environ.get("KNEXT_GRAPH_STATE_CLASS") or lib.LOCAL_GRAPH_STATE_CLASS,
|
||||
"--graphStoreUrl", os.environ.get("KNEXT_GRAPH_STORE_URL") or lib.LOCAL_GRAPH_STORE_URL,
|
||||
]
|
||||
|
||||
print_java_cmd = [
|
||||
cmd if not cmd.startswith("{") else f"'{cmd}'" for cmd in java_cmd
|
||||
]
|
||||
print_java_cmd = [
|
||||
cmd if not cmd.count(";") > 0 else f"'{cmd}'" for cmd in print_java_cmd
|
||||
]
|
||||
import json
|
||||
print(json.dumps(" ".join(print_java_cmd))[1:-1].replace("'", '"'))
|
||||
|
||||
subprocess.call(java_cmd)
|
||||
|
||||
def run_dsl(self, dsl_content: str):
|
||||
"""Submit a synchronization reasoner job by providing DSL content."""
|
||||
content = rest.KgdslReasonerContent(kgdsl=dsl_content)
|
||||
|
@ -26,6 +26,7 @@ from knext.command.sub_command.project import list_project
|
||||
from knext.command.sub_command.reasoner import query_reasoner_job
|
||||
from knext.command.sub_command.reasoner import run_dsl
|
||||
from knext.command.sub_command.reasoner import submit_reasoner_job
|
||||
from knext.command.sub_command.reasoner import execute_reasoner_job
|
||||
from knext.command.sub_command.schema import commit_schema
|
||||
from knext.command.sub_command.schema import diff_schema
|
||||
from knext.command.sub_command.schema import list_schema
|
||||
@ -110,6 +111,7 @@ def reasoner() -> None:
|
||||
reasoner.command("submit")(submit_reasoner_job)
|
||||
reasoner.command("query")(run_dsl)
|
||||
reasoner.command("get")(query_reasoner_job)
|
||||
reasoner.command("execute")(execute_reasoner_job)
|
||||
|
||||
if __name__ == "__main__":
|
||||
_main()
|
||||
|
@ -121,3 +121,22 @@ def query_reasoner_job(id):
|
||||
)
|
||||
else:
|
||||
sys.exit()
|
||||
|
||||
|
||||
@click.option("--file", help="Path of DSL file.")
|
||||
@click.option("--dsl", help="DSL string enclosed in double quotes.")
|
||||
@click.option("--output", help="Output file.")
|
||||
def execute_reasoner_job(file, dsl, output=None):
|
||||
"""
|
||||
Submit asynchronous reasoner jobs to server by providing DSL file or string.
|
||||
"""
|
||||
client = ReasonerClient()
|
||||
if file and not dsl:
|
||||
with open(file, "r") as f:
|
||||
dsl_content = f.read()
|
||||
elif not file and dsl:
|
||||
dsl_content = dsl
|
||||
else:
|
||||
click.secho("ERROR: Please choose either --file or --dsl.", fg="bright_red")
|
||||
sys.exit()
|
||||
client.execute(dsl_content, output)
|
||||
|
64
python/knext/knext/examples/financial/builder/job/company.py
Normal file
64
python/knext/knext/examples/financial/builder/job/company.py
Normal file
@ -0,0 +1,64 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from knext.examples.financial.schema.financial_schema_helper import Financial
|
||||
|
||||
from knext.api.component import CSVReader, LLMBasedExtractor, KGWriter, SubGraphMapping
|
||||
from knext.client.model.builder_job import BuilderJob
|
||||
from nn4k.invoker import LLMInvoker
|
||||
|
||||
|
||||
class Company(BuilderJob):
|
||||
def build(self):
|
||||
source = CSVReader(
|
||||
local_path="builder/job/data/company.csv", columns=["input"], start_row=2
|
||||
)
|
||||
|
||||
from knext.api.auto_prompt import REPrompt
|
||||
prompt = REPrompt(
|
||||
spg_type_name=Financial.Company,
|
||||
property_names=[
|
||||
Financial.Company.name,
|
||||
Financial.Company.orgCertNo,
|
||||
Financial.Company.regArea,
|
||||
Financial.Company.businessScope,
|
||||
Financial.Company.establishDate,
|
||||
Financial.Company.legalPerson,
|
||||
Financial.Company.regCapital
|
||||
],
|
||||
)
|
||||
|
||||
extract = LLMBasedExtractor(
|
||||
llm=LLMInvoker.from_config("builder/model/openai_infer.json"),
|
||||
prompt_ops=[prompt],
|
||||
)
|
||||
|
||||
mapping = (
|
||||
SubGraphMapping(spg_type_name=Financial.Company)
|
||||
.add_mapping_field("name", Financial.Company.id)
|
||||
.add_mapping_field("name", Financial.Company.name)
|
||||
.add_mapping_field("regArea", Financial.Company.regArea)
|
||||
.add_mapping_field("businessScope", Financial.Company.businessScope)
|
||||
.add_mapping_field("establishDate", Financial.Company.establishDate)
|
||||
.add_mapping_field("legalPerson", Financial.Company.legalPerson)
|
||||
.add_mapping_field("regCapital", Financial.Company.regCapital)
|
||||
)
|
||||
|
||||
sink = KGWriter()
|
||||
|
||||
return source >> extract >> mapping >> sink
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from knext.api.auto_prompt import REPrompt
|
||||
prompt = REPrompt(
|
||||
spg_type_name=Financial.Company,
|
||||
property_names=[
|
||||
Financial.Company.orgCertNo,
|
||||
Financial.Company.regArea,
|
||||
Financial.Company.businessScope,
|
||||
Financial.Company.establishDate,
|
||||
Financial.Company.legalPerson,
|
||||
Financial.Company.regCapital
|
||||
],
|
||||
)
|
||||
print(prompt.template)
|
@ -0,0 +1,2 @@
|
||||
input
|
||||
阿里巴巴(中国)有限公司是一家从事企业管理,计算机系统服务,电脑动画设计等业务的公司,成立于2007年03月26日,公司坐落在浙江省;经营有阿里邮箱、浙烟邮箱,师生家校、点淘-淘宝直播官方平台、云上会展等产品,经国家企业信用信息公示系统查询得知,阿里巴巴(中国)有限公司的信用代码/税号为91330100799655058B,法人是蒋芳,注册资本为15412.764910万美元,企业的经营范围为:服务:企业管理,计算机系统服务,电脑动画设计,经济信息咨询服务(除商品中介),成年人的非证书劳动职业技能培训和成年人的非文化教育培训(涉及前置审批的项目除外);生产:计算机软件;销售自产产品。(国家禁止和限制的除外,凡涉及许可证制度的凭证经营)
|
|
@ -1,7 +1,7 @@
|
||||
{
|
||||
"invoker_type": "OpenAI",
|
||||
"openai_api_key": "EMPTY",
|
||||
"openai_api_base": "http://localhost:38000/v1",
|
||||
"openai_model_name": "vicuna-7b-v1.5",
|
||||
"openai_max_tokens": 1000
|
||||
"openai_api_base": "http://127.0.0.1:38080/v1",
|
||||
"openai_model_name": "gpt-3.5-turbo",
|
||||
"openai_max_tokens": 2000
|
||||
}
|
@ -13,31 +13,28 @@ class IndicatorFuse(FuseOp):
|
||||
super().__init__()
|
||||
self.search_client = SearchClient("Financial.Indicator")
|
||||
|
||||
def link(self, subject_records: List[SPGRecord]) -> List[SPGRecord]:
|
||||
print("####################IndicatorFuse#####################")
|
||||
def link(self, subject_record: SPGRecord) -> List[SPGRecord]:
|
||||
print("####################IndicatorFuse(指标融合)#####################")
|
||||
print("IndicatorFuse(Input): ")
|
||||
print("----------------------")
|
||||
[print(r) for r in subject_records]
|
||||
print(subject_record)
|
||||
linked_records = []
|
||||
for record in subject_records:
|
||||
query = {"match": {"name": record.get_property("name", "")}}
|
||||
recall_records = self.search_client.search(query, start=0, size=10)
|
||||
if recall_records is not None and len(recall_records) > 0:
|
||||
linked_records.append(SPGRecord(
|
||||
"Financial.Indicator",
|
||||
{
|
||||
"id": recall_records[0].doc_id,
|
||||
"name": recall_records[0].properties.get("name", ""),
|
||||
},
|
||||
))
|
||||
query = {"match": {"name": subject_record.get_property("name", "")}}
|
||||
recall_records = self.search_client.search(query, start=0, size=10)
|
||||
if recall_records is not None and len(recall_records) > 0:
|
||||
linked_records.append(SPGRecord(
|
||||
"Financial.Indicator",
|
||||
{
|
||||
"id": recall_records[0].doc_id,
|
||||
"name": recall_records[0].properties.get("name", ""),
|
||||
},
|
||||
))
|
||||
return linked_records
|
||||
|
||||
def merge(self, subject_records: List[SPGRecord], target_records: List[SPGRecord]) -> List[SPGRecord]:
|
||||
def merge(self, subject_record: SPGRecord, linked_records: List[SPGRecord]) -> List[SPGRecord]:
|
||||
merged_records = []
|
||||
for s in subject_records:
|
||||
if s in target_records:
|
||||
continue
|
||||
merged_records.append(s)
|
||||
if not linked_records:
|
||||
merged_records.append(subject_record)
|
||||
print("IndicatorFuse(Output): ")
|
||||
print("----------------------")
|
||||
[print(r) for r in merged_records]
|
||||
|
@ -13,37 +13,29 @@ class StateFuse(FuseOp):
|
||||
super().__init__()
|
||||
self.search_client = SearchClient("Financial.State")
|
||||
|
||||
def link(self, subject_records: List[SPGRecord]) -> List[SPGRecord]:
|
||||
def link(self, subject_record: SPGRecord) -> List[SPGRecord]:
|
||||
print("####################StateFuse(状态融合)#####################")
|
||||
print("StateFuse(Input): ")
|
||||
print("----------------------")
|
||||
[print(r) for r in subject_records]
|
||||
print(subject_record)
|
||||
linked_records = []
|
||||
for record in subject_records:
|
||||
query = {"match": {"name": record.get_property("name", "")}}
|
||||
recall_records = self.search_client.search(query, start=0, size=10)
|
||||
if recall_records is not None and len(recall_records) > 0:
|
||||
linked_records.append(SPGRecord(
|
||||
"Financial.State",
|
||||
{
|
||||
"id": recall_records[0].doc_id,
|
||||
"name": recall_records[0].properties.get("name", ""),
|
||||
},
|
||||
)
|
||||
)
|
||||
return linked_records
|
||||
|
||||
def merge(self, subject_records: List[SPGRecord], target_records: List[SPGRecord]) -> List[SPGRecord]:
|
||||
merged_records = []
|
||||
for s in subject_records:
|
||||
# for t in target_records:
|
||||
merged_records.append(SPGRecord(
|
||||
query = {"match": {"name": subject_record.get_property("name", "")}}
|
||||
recall_records = self.search_client.search(query, start=0, size=10)
|
||||
if recall_records is not None and len(recall_records) > 0:
|
||||
linked_records.append(SPGRecord(
|
||||
"Financial.State",
|
||||
{
|
||||
"id": s.get_property("id"),
|
||||
"name": s.get_property("name", ""),
|
||||
})
|
||||
"id": recall_records[0].doc_id,
|
||||
"name": recall_records[0].properties.get("name", ""),
|
||||
},
|
||||
)
|
||||
)
|
||||
return linked_records
|
||||
|
||||
def merge(self, subject_record: SPGRecord, linked_records: List[SPGRecord]) -> List[SPGRecord]:
|
||||
merged_records = []
|
||||
if not linked_records:
|
||||
merged_records.append(subject_record)
|
||||
print("StateFuse(Output): ")
|
||||
print("----------------------")
|
||||
[print(r) for r in merged_records]
|
||||
|
@ -20,6 +20,42 @@ from knext.common.schema_helper import SPGTypeHelper, PropertyHelper
|
||||
|
||||
class Financial:
|
||||
|
||||
class AdministrativeArea(SPGTypeHelper):
|
||||
description = PropertyHelper("description")
|
||||
id = PropertyHelper("id")
|
||||
name = PropertyHelper("name")
|
||||
stdId = PropertyHelper("stdId")
|
||||
alias = PropertyHelper("alias")
|
||||
|
||||
class AreaRiskEvent(SPGTypeHelper):
|
||||
description = PropertyHelper("description")
|
||||
id = PropertyHelper("id")
|
||||
name = PropertyHelper("name")
|
||||
eventTime = PropertyHelper("eventTime")
|
||||
object = PropertyHelper("object")
|
||||
subject = PropertyHelper("subject")
|
||||
|
||||
class Company(SPGTypeHelper):
|
||||
description = PropertyHelper("description")
|
||||
id = PropertyHelper("id")
|
||||
name = PropertyHelper("name")
|
||||
orgCertNo = PropertyHelper("orgCertNo")
|
||||
establishDate = PropertyHelper("establishDate")
|
||||
regArea = PropertyHelper("regArea")
|
||||
regCapital = PropertyHelper("regCapital")
|
||||
businessScope = PropertyHelper("businessScope")
|
||||
legalPerson = PropertyHelper("legalPerson")
|
||||
|
||||
class CompanyEvent(SPGTypeHelper):
|
||||
description = PropertyHelper("description")
|
||||
id = PropertyHelper("id")
|
||||
name = PropertyHelper("name")
|
||||
location = PropertyHelper("location")
|
||||
eventTime = PropertyHelper("eventTime")
|
||||
happenedTime = PropertyHelper("happenedTime")
|
||||
subject = PropertyHelper("subject")
|
||||
object = PropertyHelper("object")
|
||||
|
||||
class Indicator(SPGTypeHelper):
|
||||
description = PropertyHelper("description")
|
||||
id = PropertyHelper("id")
|
||||
@ -36,6 +72,10 @@ class Financial:
|
||||
derivedFrom = PropertyHelper("derivedFrom")
|
||||
stdId = PropertyHelper("stdId")
|
||||
|
||||
AdministrativeArea = AdministrativeArea("Financial.AdministrativeArea")
|
||||
AreaRiskEvent = AreaRiskEvent("Financial.AreaRiskEvent")
|
||||
Company = Company("Financial.Company")
|
||||
CompanyEvent = CompanyEvent("Financial.CompanyEvent")
|
||||
Indicator = Indicator("Financial.Indicator")
|
||||
State = State("Financial.State")
|
||||
|
@ -37,10 +37,10 @@ class Disease(BuilderJob):
|
||||
spg_type_name="Medical.Disease",
|
||||
property_names=[
|
||||
"complication",
|
||||
# "commonSymptom",
|
||||
# "applicableDrug",
|
||||
# "department",
|
||||
# "diseaseSite",
|
||||
"commonSymptom",
|
||||
"applicableDrug",
|
||||
"department",
|
||||
"diseaseSite",
|
||||
],
|
||||
)
|
||||
],
|
||||
@ -54,10 +54,10 @@ class Disease(BuilderJob):
|
||||
.add_mapping_field("id", "id")
|
||||
.add_mapping_field("name", "name")
|
||||
.add_mapping_field("complication", "complication")
|
||||
# .add_mapping_field("commonSymptom", "commonSymptom")
|
||||
# .add_mapping_field("applicableDrug", "applicableDrug")
|
||||
# .add_mapping_field("department", "department")
|
||||
# .add_mapping_field("diseaseSite", "diseaseSite")
|
||||
.add_mapping_field("commonSymptom", "commonSymptom")
|
||||
.add_mapping_field("applicableDrug", "applicableDrug")
|
||||
.add_mapping_field("department", "department")
|
||||
.add_mapping_field("diseaseSite", "diseaseSite")
|
||||
)
|
||||
|
||||
"""
|
||||
|
@ -1,7 +1,7 @@
|
||||
{
|
||||
"invoker_type": "OpenAI",
|
||||
"openai_api_key": "EMPTY",
|
||||
"openai_api_base": "http://localhost:38080/v1",
|
||||
"openai_api_base": "http://127.0.0.1:38080/v1",
|
||||
"openai_model_name": "gpt-3.5-turbo",
|
||||
"openai_max_tokens": 1000
|
||||
}
|
@ -24,18 +24,18 @@ class Medical:
|
||||
description = PropertyHelper("description")
|
||||
id = PropertyHelper("id")
|
||||
name = PropertyHelper("name")
|
||||
alias = PropertyHelper("alias")
|
||||
stdId = PropertyHelper("stdId")
|
||||
alias = PropertyHelper("alias")
|
||||
|
||||
class Disease(SPGTypeHelper):
|
||||
description = PropertyHelper("description")
|
||||
id = PropertyHelper("id")
|
||||
name = PropertyHelper("name")
|
||||
department = PropertyHelper("department")
|
||||
complication = PropertyHelper("complication")
|
||||
applicableDrug = PropertyHelper("applicableDrug")
|
||||
diseaseSite = PropertyHelper("diseaseSite")
|
||||
department = PropertyHelper("department")
|
||||
commonSymptom = PropertyHelper("commonSymptom")
|
||||
diseaseSite = PropertyHelper("diseaseSite")
|
||||
complication = PropertyHelper("complication")
|
||||
|
||||
class Drug(SPGTypeHelper):
|
||||
description = PropertyHelper("description")
|
||||
@ -46,8 +46,8 @@ class Medical:
|
||||
description = PropertyHelper("description")
|
||||
id = PropertyHelper("id")
|
||||
name = PropertyHelper("name")
|
||||
alias = PropertyHelper("alias")
|
||||
stdId = PropertyHelper("stdId")
|
||||
alias = PropertyHelper("alias")
|
||||
|
||||
class Indicator(SPGTypeHelper):
|
||||
description = PropertyHelper("description")
|
||||
|
@ -1,15 +1,12 @@
|
||||
GRAPH_STORE_PARAM = "-Dcloudext.graphstore.drivers=com.antgroup.openspg.cloudext.impl.graphstore.tugraph" \
|
||||
".TuGraphStoreClientDriver"
|
||||
|
||||
SEARCH_CLIENT_PARAM = "-Dcloudext.searchengine.drivers=com.antgroup.openspg.cloudext.impl.searchengine.elasticsearch" \
|
||||
".ElasticSearchEngineClientDriver"
|
||||
|
||||
LOCAL_BUILDER_JAR = "builder-runner-local-0.0.1-SNAPSHOT-jar-with-dependencies.jar"
|
||||
|
||||
LOCAL_REASONER_JAR = ""
|
||||
LOCAL_REASONER_JAR = "reasoner-local-runner-0.0.1-SNAPSHOT-jar-with-dependencies.jar"
|
||||
|
||||
LOCAL_SCHEMA_URL = "http://localhost:8887"
|
||||
|
||||
LOCAL_GRAPH_STORE_URL = "tugraph://127.0.0.1:9090?graphName=default&timeout=50000&accessId=admin&accessKey=73@TuGraph"
|
||||
|
||||
LOCAL_SEARCH_ENGINE_URL = "elasticsearch://127.0.0.1:9200?scheme=http"
|
||||
|
||||
LOCAL_GRAPH_STATE_CLASS = "com.antgroup.openspg.reasoner.warehouse.cloudext.CloudExtGraphState"
|
||||
|
@ -54,6 +54,8 @@ input:${input}
|
||||
def parse_response(self, response: str) -> List[SPGRecord]:
|
||||
result = []
|
||||
subject = {}
|
||||
if isinstance(response, list) and len(response) > 0:
|
||||
response = response[0]
|
||||
re_obj = json.loads(response)
|
||||
if "spo" not in re_obj.keys():
|
||||
raise ValueError("SPO format error.")
|
||||
@ -89,21 +91,10 @@ input:${input}
|
||||
result.append(subject_entity)
|
||||
return result
|
||||
|
||||
def build_next_variables(
|
||||
self, variables: Dict[str, str], response: str
|
||||
) -> List[Dict[str, str]]:
|
||||
re_obj = json.loads(response)
|
||||
if "spo" not in re_obj.keys():
|
||||
raise ValueError("SPO format error.")
|
||||
re = re_obj.get("spo", [])
|
||||
return [{"input": variables.get("input"), "spo": str(i)} for i in re]
|
||||
|
||||
def _render(self, spg_type: BaseSpgType, property_names: List[str]):
|
||||
spos = []
|
||||
repeat_desc = []
|
||||
for property_name in property_names:
|
||||
if property_name in ["id", "name", "description"]:
|
||||
continue
|
||||
prop = spg_type.properties.get(property_name)
|
||||
object_desc = ""
|
||||
object_type = self.schema_client.query_spg_type(prop.object_type_name)
|
||||
@ -117,101 +108,3 @@ input:${input}
|
||||
repeat_desc.extend([spg_type.name_zh, prop.name_zh, prop.object_type_name_zh])
|
||||
schema_text = "\n[" + ",\n".join(spos) + "]"
|
||||
self.template = self.template.replace("${schema}", schema_text)
|
||||
|
||||
|
||||
class EEPrompt(AutoPrompt):
|
||||
template: str = """
|
||||
已知如下的事件schema定义:${schema}。从下列句子中抽取所定义的事件,如果存在以JSON格式返回,如果不存在返回空字符串。
|
||||
input:${input}
|
||||
${example}
|
||||
输出格式为:{"event":[{"event_type":,"arguments":[{},]}]}
|
||||
"output":
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_type_name: Union[str, SPGTypeHelper],
|
||||
property_names: List[Union[str, PropertyHelper]],
|
||||
custom_prompt: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if custom_prompt:
|
||||
self.template = custom_prompt
|
||||
schema_client = SchemaClient()
|
||||
spg_type = schema_client.query_spg_type(spg_type_name=event_type_name)
|
||||
self.spg_type_name = event_type_name
|
||||
self.predicate_zh_to_en_name = {}
|
||||
self.predicate_type_zh_to_en_name = {}
|
||||
for k, v in spg_type.properties.items():
|
||||
self.predicate_zh_to_en_name[v.name_zh] = k
|
||||
self.predicate_type_zh_to_en_name[v.name_zh] = v.object_type_name
|
||||
self._render(spg_type, property_names)
|
||||
self.params = {
|
||||
"spg_type_name": event_type_name,
|
||||
"property_names": property_names,
|
||||
"custom_prompt": custom_prompt,
|
||||
}
|
||||
|
||||
def build_prompt(self, variables: Dict[str, str]) -> str:
|
||||
return self.template.replace("${input}", variables.get("input"))
|
||||
|
||||
def parse_response(self, response: str) -> List[SPGRecord]:
|
||||
response = "{\"event\":[{\"event_type\":\"区域经济指标事件(区域指标)\",\"arguments\":[{\"日期\":\"2022年\",\"区域\":\"济南市\",\"来源\":\"政府公告\",\"主体\":\"山东省财政局\"}]}]}"
|
||||
result = []
|
||||
subject = {}
|
||||
re_obj = json.loads(response)
|
||||
if "event" not in re_obj.keys():
|
||||
raise ValueError("Event format error.")
|
||||
subject_properties = {}
|
||||
for spo_item in re_obj.get("event", []):
|
||||
if spo_item["predicate"] not in self.predicate_zh_to_en_name:
|
||||
continue
|
||||
subject_properties = {
|
||||
"id": spo_item["subject"],
|
||||
"name": spo_item["subject"],
|
||||
}
|
||||
if spo_item["subject"] not in subject:
|
||||
subject[spo_item["subject"]] = subject_properties
|
||||
else:
|
||||
subject_properties = subject[spo_item["subject"]]
|
||||
|
||||
spo_en_name = self.predicate_zh_to_en_name[spo_item["predicate"]]
|
||||
|
||||
if spo_en_name in subject_properties and len(
|
||||
subject_properties[spo_en_name]
|
||||
):
|
||||
subject_properties[spo_en_name] = (
|
||||
subject_properties[spo_en_name] + "," + spo_item["object"]
|
||||
)
|
||||
else:
|
||||
subject_properties[spo_en_name] = spo_item["object"]
|
||||
|
||||
# for k, val in subject.items():
|
||||
subject_entity = SPGRecord(
|
||||
spg_type_name=self.spg_type_name, properties=subject_properties
|
||||
)
|
||||
result.append(subject_entity)
|
||||
return result
|
||||
|
||||
def build_next_variables(
|
||||
self, variables: Dict[str, str], response: str
|
||||
) -> List[Dict[str, str]]:
|
||||
re_obj = json.loads(response)
|
||||
if "event" not in re_obj.keys():
|
||||
raise ValueError("Event format error.")
|
||||
re = re_obj.get("event", [])
|
||||
return [{"input": variables.get("input"), "event": str(i)} for i in re]
|
||||
|
||||
def _render(self, spg_type: BaseSpgType, property_names: List[str]):
|
||||
arguments = []
|
||||
for property_name in property_names:
|
||||
if property_name in ["id", "name", "description"]:
|
||||
continue
|
||||
prop = spg_type.properties.get(property_name)
|
||||
arguments.append(
|
||||
f"{prop.name_zh}({prop.object_type_name_zh})"
|
||||
)
|
||||
|
||||
schema_text = f"{{event_type:{spg_type.name_zh}({spg_type.desc or spg_type.name_zh}),arguments:[{','.join(arguments)}]"
|
||||
self.template = self.template.replace("${schema}", schema_text)
|
||||
|
@ -60,8 +60,9 @@ class _BuiltInOnlineExtractor(ExtractOp):
|
||||
elif op_name == "IndicatorLOGIC":
|
||||
response = '[{"subject": "土地出让收入大幅下降", "predicate": "顺承", "object": ["综合财力明显下滑"]}]'
|
||||
else:
|
||||
print(query)
|
||||
print(repr(query))
|
||||
response = self.model.remote_inference(query)
|
||||
print(response)
|
||||
collector.extend(op.parse_response(response))
|
||||
next_params.extend(
|
||||
op.build_next_variables(input_param, response)
|
||||
|
@ -89,19 +89,22 @@ class FuseOp(BaseOp, ABC):
|
||||
def __init__(self, params: Dict[str, str] = None):
|
||||
super().__init__(params)
|
||||
|
||||
def link(self, subject_records: List[SPGRecord]) -> List[SPGRecord]:
|
||||
def link(self, subject_record: SPGRecord) -> List[SPGRecord]:
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} need to implement `link` method."
|
||||
)
|
||||
|
||||
def merge(self, subject_records: List[SPGRecord], target_records: List[SPGRecord]) -> List[SPGRecord]:
|
||||
def merge(self, subject_record: SPGRecord, linked_records: List[SPGRecord]) -> List[SPGRecord]:
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} need to implement `merge` method."
|
||||
)
|
||||
|
||||
def invoke(self, records: List[SPGRecord]) -> List[SPGRecord]:
|
||||
linked_records = self.link(records)
|
||||
return self.merge(records, linked_records)
|
||||
def invoke(self, subject_records: List[SPGRecord]) -> List[SPGRecord]:
|
||||
for record in subject_records:
|
||||
linked_records = self.link(record)
|
||||
merged_records = self.merge(record, linked_records)
|
||||
return merged_records
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _pre_process(*inputs):
|
||||
@ -136,6 +139,8 @@ class PromptOp(BaseOp, ABC):
|
||||
def build_next_variables(
|
||||
self, variables: Dict[str, str], response: str
|
||||
) -> List[Dict[str, str]]:
|
||||
if isinstance(response, list) and len(response) > 0:
|
||||
response = response[0]
|
||||
variables.update({f"{self.__class__.__name__}": response})
|
||||
print("LLM(Output): ")
|
||||
print("----------------------")
|
||||
|
Loading…
x
Reference in New Issue
Block a user