2024-01-08 11:13:49 +08:00

218 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
import re
from abc import ABC
from typing import List, Dict, Tuple
from knext.client.schema import SchemaClient
from knext.common.schema_helper import SPGTypeName, PropertyName, RelationName
from knext.operator.op import PromptOp
from knext.operator.spg_record import SPGRecord
class AutoPrompt(PromptOp, ABC):
pass
class REPrompt(AutoPrompt):
template: str = """
已知SPO关系包括:${schema}
从下列句子中提取定义的这些关系。最终抽取结果以json格式输出且predicate必须在[${predicate}]内。
input:${input}
输出格式为:{"spo":[{"subject":,"predicate":,"object":},]}
"output":
"""
def __init__(
self,
spg_type_name: SPGTypeName,
property_names: List[PropertyName] = None,
relation_names: List[Tuple[RelationName, SPGTypeName]] = None,
custom_prompt: str = None,
):
super().__init__()
self.spg_type_name = spg_type_name
if custom_prompt:
self.template = custom_prompt
if not property_names:
property_names = []
if not relation_names:
relation_names = []
self.property_names = property_names
self.relation_names = relation_names
self._init_render_variables()
self._render()
self.params = {
"spg_type_name": spg_type_name,
"property_names": property_names,
"relation_names": relation_names,
"custom_prompt": custom_prompt,
}
def build_prompt(self, variables: Dict[str, str]) -> str:
return self.template.replace("${input}", variables.get("input"))
def parse_response(self, response: str) -> List[SPGRecord]:
if isinstance(response, list) and len(response) > 0:
response = response[0]
re_obj = json.loads(response)
if "spo" not in re_obj.keys():
raise ValueError("SPO format error.")
subject_records = {}
result = []
for spo_item in re_obj.get("spo", []):
if any(k not in spo_item for k in ["subject", "predicate", "object"]):
continue
s, p_zh, o = spo_item["subject"], spo_item["predicate"], spo_item["object"]
if s not in subject_records:
subject_records[s] = (
SPGRecord(spg_type_name=self.spg_type_name)
.upsert_property("id", s)
.upsert_property("name", s)
)
if p_zh in self.property_info_zh:
p, _, o_type = self.property_info_zh[p_zh]
o_list = re.split("[,,、;]", o)
result.extend(
[
SPGRecord(o_type)
.upsert_property("id", _o)
.upsert_property("name", _o)
for _o in o_list
]
)
o = subject_records[s].get_property(p)
o = ",".join([o] + o_list) if o else ",".join(o_list)
subject_records[s].upsert_property(p, o)
elif p_zh in self.relation_info_zh:
p, _, o_type = self.relation_info_zh[p_zh]
o_list = re.split("[,,、;]", o)
result.extend(
[
SPGRecord(o_type)
.upsert_property("id", _o)
.upsert_property("name", _o)
for _o in o_list
]
)
o = subject_records[s].get_relation(p, o_type, "")
o = ",".join([o] + o_list) if o else ",".join(o_list)
subject_records[s].upsert_relation(p, o_type, o)
else:
continue
for subject_record in subject_records.values():
result.append(subject_record)
return result
def _render(self):
spo_infos = []
predicates = []
duplicate_types = set()
duplicate_predicates = set()
for _prop in self.property_names:
s_name_zh, s_desc = self.spg_type_schema_info_en.get(self.spg_type_name)
s_desc = (
(s_desc or s_name_zh)
if self.spg_type_name not in duplicate_types
else None
)
s_info = (s_name_zh or "") + (f"({s_desc})" if s_desc else "")
p_name_zh, p_desc, o_type = self.property_info_en.get(_prop)
p_desc = (
(p_desc or p_name_zh) if _prop not in duplicate_predicates else None
)
p_info = (p_name_zh or "") + (f"({p_desc})" if p_desc else "")
o_name_zh, o_desc = self.spg_type_schema_info_en.get(o_type)
o_desc = (o_desc or o_name_zh) if o_type not in duplicate_types else None
o_info = (o_name_zh or "") + (f"({o_desc})" if o_desc else "")
spo_infos.append(f"{s_info}-{p_info}-{o_info}")
duplicate_predicates.add(_prop)
duplicate_types.update([self.spg_type_name, o_type])
predicates.append(p_name_zh)
for _rel, o_type in self.relation_names:
s_name_zh, s_desc = self.spg_type_schema_info_en.get(self.spg_type_name)
s_desc = (
(s_desc or s_name_zh)
if self.spg_type_name not in duplicate_types
else None
)
s_info = (s_name_zh or "") + (f"({s_desc})" if s_desc else "")
p_name_zh, p_desc, _ = self.relation_info_en.get(_rel)
p_desc = (p_desc or p_name_zh) if _rel not in duplicate_predicates else None
p_info = (p_name_zh or "") + (f"({p_desc})" if p_desc else "")
o_name_zh, o_desc = self.spg_type_schema_info_en.get(o_type)
o_desc = (o_desc or o_name_zh) if o_type not in duplicate_types else None
o_info = (o_name_zh or "") + (f"({o_desc})" if o_desc else "")
spo_infos.append(f"{s_info}-{p_info}-{o_info}")
duplicate_predicates.add(_rel)
duplicate_types.update([self.spg_type_name, o_type])
predicates.append(p_name_zh)
schema_text = "\n[" + ",\n".join(spo_infos) + "]"
predicate_text = ",".join(predicates)
self.template = self.template.replace("${schema}", schema_text).replace(
"${predicate}", predicate_text
)
def _init_render_variables(self):
schema_session = SchemaClient().create_session()
spg_type = schema_session.get(spg_type_name=self.spg_type_name)
self.property_info_en = {}
self.property_info_zh = {}
self.relation_info_en = {}
self.relation_info_zh = {}
self.spg_type_schema_info_en = {
"Text": ("文本", None),
"Integer": ("整型", None),
"Float": ("浮点型", None),
}
self.spg_type_schema_info_zh = {
"文本": ("Text", None),
"整型": ("Integer", None),
"浮点型": ("Float", None),
}
for _rel in spg_type.relations.values():
if _rel.is_dynamic:
continue
self.relation_info_zh[_rel.name_zh] = (
_rel.name,
_rel.desc,
_rel.object_type_name,
)
self.relation_info_en[_rel.name] = (
_rel.name_zh,
_rel.desc,
_rel.object_type_name,
)
for _prop in spg_type.properties.values():
self.property_info_zh[_prop.name_zh] = (
_prop.name,
_prop.desc,
_prop.object_type_name,
)
self.property_info_en[_prop.name] = (
_prop.name_zh,
_prop.desc,
_prop.object_type_name,
)
for _type in schema_session.spg_types.values():
if _type.name in ["Text", "Integer", "Float"]:
continue
self.spg_type_schema_info_zh[_type.name_zh] = (_type.name, _type.desc)
self.spg_type_schema_info_en[_type.name] = (_type.name_zh, _type.desc)