openspg/python/knext/operator/builtin/online_runner.py

76 lines
3.3 KiB
Python
Raw Normal View History

2023-12-15 17:33:54 +08:00
import json
import sys
2023-12-06 17:26:39 +08:00
from typing import Dict, List
2023-12-15 17:33:54 +08:00
from knext.common.class_register import register_from_package
2023-12-06 17:26:39 +08:00
from knext.api.operator import ExtractOp
2023-12-15 17:33:54 +08:00
from knext.operator.base import BaseOp
from knext.operator.spg_record import SPGRecord
from nn4k.invoker import LLMInvoker
2023-12-06 17:26:39 +08:00
class BuiltInOnlineLLMBasedExtractOp(ExtractOp):
def __init__(self, params: Dict[str, str] = None):
"""
Args:
params: {"model_name": "openai", "token": "**"}
"""
super().__init__(params)
2023-12-15 17:33:54 +08:00
model_config = json.loads(params["model_config"])
prompt_config = json.loads(params["prompt_config"])
register_from_package(params["operator_dir"], BaseOp)
self.model = LLMInvoker.from_config(model_config)
self.prompt_ops = [BaseOp.by_name(config["className"])(**config["params"]) for config in prompt_config]
2023-12-06 17:26:39 +08:00
2023-12-15 17:33:54 +08:00
def eval(self, record: Dict[str, str]) -> List[SPGRecord]:
2023-12-06 17:26:39 +08:00
# 对于单条数据【record】执行多层抽取
# 每次抽取都需要执行op.build_prompt()->model.predict()->op.parse_response()流程
# 且每次抽取后可能得到多条结果,下次抽取需要对多条结果分别进行抽取。
record_list = [record]
# 循环所有prompt算子算子数量决定对单条数据执行几层抽取
for index, op in enumerate(self.prompt_ops):
extract_result_list = []
# record_list可能有多条数据对多条数据都要进行抽取
while record_list:
_record = record_list.pop()
# 生成完整query
query = op.build_prompt(_record)
# 模型预测,生成模型输出结果
2023-12-15 17:33:54 +08:00
response = self.model.remote_inference(query)
2023-12-06 17:26:39 +08:00
# response = self.model[op.name]
# 模型结果的后置处理,可能会拆分成多条数据 List[dict[str, str]]
result_list = op.parse_response(response)
# 把输入的record和模型输出的result拼成一个新的dict作为这次抽取最终结果
for result in result_list:
_ = _record.copy()
_.update(result)
extract_result_list.append(_)
# record_list为空时执行下一层抽取
if index == len(self.prompt_ops) - 1:
return extract_result_list
else:
record_list.extend(extract_result_list)
2023-12-15 17:33:54 +08:00
if __name__ == '__main__':
config = {
"invoker_type": "OpenAI",
"openai_api_key": "EMPTY",
"openai_api_base": "http://localhost:38000/v1",
"openai_model_name": "vicuna-7b-v1.5",
"openai_max_tokens": 1000
}
model = LLMInvoker.from_config(config)
query = """
已知SPO关系包括:[录音室专辑(录音室专辑)-发行年份-文本]从下列句子中提取定义的这些关系最终抽取结果以json格式输出
input:范特西是周杰伦的第二张音乐专辑由周杰伦担任制作人于2001年9月14日发行共收录爱在西元前威廉古堡双截棍等10首歌曲 [1]
输出格式为:{"spo":[{"subject":,"predicate":,"object":},]}
"output":
"""
response = model.remote_inference(query)
print(response)