mirror of
https://github.com/OpenSPG/openspg.git
synced 2025-07-03 23:25:43 +00:00
218 lines
8.4 KiB
Python
218 lines
8.4 KiB
Python
# Copyright 2023 Ant Group CO., Ltd.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||
# in compliance with the License. You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||
# or implied.
|
||
|
||
import json
|
||
import re
|
||
from abc import ABC
|
||
from typing import List, Dict, Tuple
|
||
|
||
from knext.client.schema import SchemaClient
|
||
from knext.common.schema_helper import SPGTypeName, PropertyName, RelationName
|
||
from knext.operator.op import PromptOp
|
||
from knext.operator.spg_record import SPGRecord
|
||
|
||
|
||
class AutoPrompt(PromptOp, ABC):
|
||
pass
|
||
|
||
|
||
class REPrompt(AutoPrompt):
|
||
|
||
template: str = """
|
||
已知SPO关系包括:${schema}
|
||
从下列句子中提取定义的这些关系。最终抽取结果以json格式输出,且predicate必须在[${predicate}]内。
|
||
input:${input}
|
||
输出格式为:{"spo":[{"subject":,"predicate":,"object":},]}
|
||
"output":
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
spg_type_name: SPGTypeName,
|
||
property_names: List[PropertyName] = None,
|
||
relation_names: List[Tuple[RelationName, SPGTypeName]] = None,
|
||
custom_prompt: str = None,
|
||
):
|
||
super().__init__()
|
||
|
||
self.spg_type_name = spg_type_name
|
||
if custom_prompt:
|
||
self.template = custom_prompt
|
||
if not property_names:
|
||
property_names = []
|
||
if not relation_names:
|
||
relation_names = []
|
||
|
||
self.property_names = property_names
|
||
self.relation_names = relation_names
|
||
|
||
self._init_render_variables()
|
||
self._render()
|
||
|
||
self.params = {
|
||
"spg_type_name": spg_type_name,
|
||
"property_names": property_names,
|
||
"relation_names": relation_names,
|
||
"custom_prompt": custom_prompt,
|
||
}
|
||
|
||
def build_prompt(self, variables: Dict[str, str]) -> str:
|
||
return self.template.replace("${input}", variables.get("input"))
|
||
|
||
def parse_response(self, response: str) -> List[SPGRecord]:
|
||
if isinstance(response, list) and len(response) > 0:
|
||
response = response[0]
|
||
re_obj = json.loads(response)
|
||
if "spo" not in re_obj.keys():
|
||
raise ValueError("SPO format error.")
|
||
subject_records = {}
|
||
result = []
|
||
for spo_item in re_obj.get("spo", []):
|
||
if any(k not in spo_item for k in ["subject", "predicate", "object"]):
|
||
continue
|
||
s, p_zh, o = spo_item["subject"], spo_item["predicate"], spo_item["object"]
|
||
if s not in subject_records:
|
||
subject_records[s] = (
|
||
SPGRecord(spg_type_name=self.spg_type_name)
|
||
.upsert_property("id", s)
|
||
.upsert_property("name", s)
|
||
)
|
||
if p_zh in self.property_info_zh:
|
||
p, _, o_type = self.property_info_zh[p_zh]
|
||
o_list = re.split("[,,、;;]", o)
|
||
result.extend(
|
||
[
|
||
SPGRecord(o_type)
|
||
.upsert_property("id", _o)
|
||
.upsert_property("name", _o)
|
||
for _o in o_list
|
||
]
|
||
)
|
||
o = subject_records[s].get_property(p)
|
||
o = ",".join([o] + o_list) if o else ",".join(o_list)
|
||
subject_records[s].upsert_property(p, o)
|
||
elif p_zh in self.relation_info_zh:
|
||
p, _, o_type = self.relation_info_zh[p_zh]
|
||
o_list = re.split("[,,、;;]", o)
|
||
result.extend(
|
||
[
|
||
SPGRecord(o_type)
|
||
.upsert_property("id", _o)
|
||
.upsert_property("name", _o)
|
||
for _o in o_list
|
||
]
|
||
)
|
||
o = subject_records[s].get_relation(p, o_type, "")
|
||
o = ",".join([o] + o_list) if o else ",".join(o_list)
|
||
subject_records[s].upsert_relation(p, o_type, o)
|
||
else:
|
||
continue
|
||
|
||
for subject_record in subject_records.values():
|
||
result.append(subject_record)
|
||
return result
|
||
|
||
def _render(self):
|
||
spo_infos = []
|
||
predicates = []
|
||
duplicate_types = set()
|
||
duplicate_predicates = set()
|
||
for _prop in self.property_names:
|
||
s_name_zh, s_desc = self.spg_type_schema_info_en.get(self.spg_type_name)
|
||
s_desc = (
|
||
(s_desc or s_name_zh)
|
||
if self.spg_type_name not in duplicate_types
|
||
else None
|
||
)
|
||
s_info = (s_name_zh or "") + (f"({s_desc})" if s_desc else "")
|
||
p_name_zh, p_desc, o_type = self.property_info_en.get(_prop)
|
||
p_desc = (
|
||
(p_desc or p_name_zh) if _prop not in duplicate_predicates else None
|
||
)
|
||
p_info = (p_name_zh or "") + (f"({p_desc})" if p_desc else "")
|
||
o_name_zh, o_desc = self.spg_type_schema_info_en.get(o_type)
|
||
o_desc = (o_desc or o_name_zh) if o_type not in duplicate_types else None
|
||
o_info = (o_name_zh or "") + (f"({o_desc})" if o_desc else "")
|
||
spo_infos.append(f"{s_info}-{p_info}-{o_info}")
|
||
duplicate_predicates.add(_prop)
|
||
duplicate_types.update([self.spg_type_name, o_type])
|
||
predicates.append(p_name_zh)
|
||
for _rel, o_type in self.relation_names:
|
||
s_name_zh, s_desc = self.spg_type_schema_info_en.get(self.spg_type_name)
|
||
s_desc = (
|
||
(s_desc or s_name_zh)
|
||
if self.spg_type_name not in duplicate_types
|
||
else None
|
||
)
|
||
s_info = (s_name_zh or "") + (f"({s_desc})" if s_desc else "")
|
||
p_name_zh, p_desc, _ = self.relation_info_en.get(_rel)
|
||
p_desc = (p_desc or p_name_zh) if _rel not in duplicate_predicates else None
|
||
p_info = (p_name_zh or "") + (f"({p_desc})" if p_desc else "")
|
||
o_name_zh, o_desc = self.spg_type_schema_info_en.get(o_type)
|
||
o_desc = (o_desc or o_name_zh) if o_type not in duplicate_types else None
|
||
o_info = (o_name_zh or "") + (f"({o_desc})" if o_desc else "")
|
||
spo_infos.append(f"{s_info}-{p_info}-{o_info}")
|
||
duplicate_predicates.add(_rel)
|
||
duplicate_types.update([self.spg_type_name, o_type])
|
||
predicates.append(p_name_zh)
|
||
schema_text = "\n[" + ",\n".join(spo_infos) + "]"
|
||
predicate_text = ",".join(predicates)
|
||
self.template = self.template.replace("${schema}", schema_text).replace(
|
||
"${predicate}", predicate_text
|
||
)
|
||
|
||
def _init_render_variables(self):
|
||
schema_session = SchemaClient().create_session()
|
||
spg_type = schema_session.get(spg_type_name=self.spg_type_name)
|
||
self.property_info_en = {}
|
||
self.property_info_zh = {}
|
||
self.relation_info_en = {}
|
||
self.relation_info_zh = {}
|
||
self.spg_type_schema_info_en = {
|
||
"Text": ("文本", None),
|
||
"Integer": ("整型", None),
|
||
"Float": ("浮点型", None),
|
||
}
|
||
self.spg_type_schema_info_zh = {
|
||
"文本": ("Text", None),
|
||
"整型": ("Integer", None),
|
||
"浮点型": ("Float", None),
|
||
}
|
||
for _rel in spg_type.relations.values():
|
||
if _rel.is_dynamic:
|
||
continue
|
||
self.relation_info_zh[_rel.name_zh] = (
|
||
_rel.name,
|
||
_rel.desc,
|
||
_rel.object_type_name,
|
||
)
|
||
self.relation_info_en[_rel.name] = (
|
||
_rel.name_zh,
|
||
_rel.desc,
|
||
_rel.object_type_name,
|
||
)
|
||
for _prop in spg_type.properties.values():
|
||
self.property_info_zh[_prop.name_zh] = (
|
||
_prop.name,
|
||
_prop.desc,
|
||
_prop.object_type_name,
|
||
)
|
||
self.property_info_en[_prop.name] = (
|
||
_prop.name_zh,
|
||
_prop.desc,
|
||
_prop.object_type_name,
|
||
)
|
||
for _type in schema_session.spg_types.values():
|
||
if _type.name in ["Text", "Integer", "Float"]:
|
||
continue
|
||
self.spg_type_schema_info_zh[_type.name_zh] = (_type.name, _type.desc)
|
||
self.spg_type_schema_info_en[_type.name] = (_type.name_zh, _type.desc)
|