mirror of
https://github.com/OpenSPG/KAG.git
synced 2025-06-27 03:20:08 +00:00

* x * x (#280) * bridge add solver * x * feat(bridge): spg server bridge (#283) * x * bridge add solver * x * add invoke * llm client catch error
174 lines
5.8 KiB
Python
174 lines
5.8 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Copyright 2023 OpenSPG Authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
|
# in compliance with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
|
# or implied.
|
|
|
|
import json
|
|
from typing import Union, Dict, List, Any
|
|
import logging
|
|
import traceback
|
|
from tenacity import retry, stop_after_attempt
|
|
from kag.interface import PromptABC
|
|
from kag.common.registry import Registrable
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LLMClient(Registrable):
|
|
"""
|
|
A class that provides methods for performing inference using large language model.
|
|
|
|
This class includes methods to call the model with a prompt, parse the response, and handle batch processing of prompts.
|
|
"""
|
|
|
|
@retry(stop=stop_after_attempt(3))
|
|
def __call__(self, prompt: Union[str, dict, list]) -> str:
|
|
"""
|
|
Perform inference on the given prompt and return the result.
|
|
|
|
Args:
|
|
prompt (Union[str, dict, list]): Input prompt for inference.
|
|
|
|
Returns:
|
|
str: Inference result.
|
|
|
|
Raises:
|
|
NotImplementedError: If the subclass has not implemented this method.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@retry(stop=stop_after_attempt(3))
|
|
def call_with_json_parse(self, prompt: Union[str, dict, list]):
|
|
"""
|
|
Perform inference on the given prompt and attempt to parse the result as JSON.
|
|
|
|
Args:
|
|
prompt (Union[str, dict, list]): Input prompt for inference.
|
|
|
|
Returns:
|
|
Any: Parsed result.
|
|
|
|
Raises:
|
|
NotImplementedError: If the subclass has not implemented this method.
|
|
"""
|
|
res = self(prompt)
|
|
_end = res.rfind("```")
|
|
_start = res.find("```json")
|
|
if _end != -1 and _start != -1:
|
|
json_str = res[_start + len("```json") : _end].strip()
|
|
else:
|
|
json_str = res
|
|
try:
|
|
json_result = json.loads(json_str)
|
|
except:
|
|
return res
|
|
return json_result
|
|
|
|
def invoke(
|
|
self,
|
|
variables: Dict[str, Any],
|
|
prompt_op: PromptABC,
|
|
with_json_parse: bool = True,
|
|
with_except: bool = False,
|
|
):
|
|
"""
|
|
Call the model and process the result.
|
|
|
|
Args:
|
|
variables (Dict[str, Any]): Variables used to build the prompt.
|
|
prompt_op (PromptABC): Prompt operation object for building and parsing prompts.
|
|
with_json_parse (bool, optional): Whether to attempt parsing the response as JSON. Defaults to True.
|
|
with_except (bool, optional): Whether to raise an exception if an error occurs. Defaults to False.
|
|
|
|
Returns:
|
|
List: Processed result list.
|
|
"""
|
|
result = []
|
|
prompt = prompt_op.build_prompt(variables)
|
|
logger.debug(f"Prompt: {prompt}")
|
|
if not prompt:
|
|
return result
|
|
response = ""
|
|
try:
|
|
response = (
|
|
self.call_with_json_parse(prompt=prompt)
|
|
if with_json_parse
|
|
else self(prompt)
|
|
)
|
|
logger.debug(f"Response: {response}")
|
|
result = prompt_op.parse_response(response, model=self.model, **variables)
|
|
logger.debug(f"Result: {result}")
|
|
except Exception as e:
|
|
import traceback
|
|
|
|
logger.debug(f"Error {e} during invocation: {traceback.format_exc()}")
|
|
if with_except:
|
|
raise RuntimeError(
|
|
f"LLM invoke exception, info: {e}\nllm input: \n{prompt}\nllm output: \n{response}"
|
|
)
|
|
return result
|
|
|
|
def batch(
|
|
self,
|
|
variables: Dict[str, Any],
|
|
prompt_op: PromptABC,
|
|
with_json_parse: bool = True,
|
|
) -> List:
|
|
"""
|
|
Batch process prompts.
|
|
|
|
Args:
|
|
variables (Dict[str, Any]): Variables used to build the prompts.
|
|
prompt_op (PromptABC): Prompt operation object for building and parsing prompts.
|
|
with_json_parse (bool, optional): Whether to attempt parsing the response as JSON. Defaults to True.
|
|
|
|
Returns:
|
|
List: List of all processed results.
|
|
"""
|
|
results = []
|
|
prompts = prompt_op.build_prompt(variables)
|
|
# If there is only one prompt, call the `invoke` method directly
|
|
if isinstance(prompts, str):
|
|
return self.invoke(variables, prompt_op, with_json_parse=with_json_parse)
|
|
|
|
for idx, prompt in enumerate(prompts, start=0):
|
|
logger.debug(f"Prompt_{idx}: {prompt}")
|
|
try:
|
|
response = (
|
|
self.call_with_json_parse(prompt=prompt)
|
|
if with_json_parse
|
|
else self(prompt)
|
|
)
|
|
logger.debug(f"Response_{idx}: {response}")
|
|
result = prompt_op.parse_response(
|
|
response, idx=idx, model=self.model, **variables
|
|
)
|
|
logger.debug(f"Result_{idx}: {result}")
|
|
results.extend(result)
|
|
except Exception as e:
|
|
logger.error(f"Error processing prompt {idx}: {e}")
|
|
logger.debug(traceback.format_exc())
|
|
continue
|
|
return results
|
|
|
|
def check(self):
|
|
from kag.common.conf import KAG_PROJECT_CONF
|
|
|
|
if (
|
|
hasattr(KAG_PROJECT_CONF, "llm_config_check")
|
|
and KAG_PROJECT_CONF.llm_config_check
|
|
):
|
|
try:
|
|
self.__call__("Are you OK?")
|
|
except Exception as e:
|
|
logger.error("LLM config check failed!")
|
|
raise e
|