KAG/kag/interface/common/llm_client.py
zhuzhongshu123 deae277510
feat(bridge): spg server bridge supports config check and run solver (#287)
* x

* x (#280)

* bridge add solver

* x

* feat(bridge): spg server bridge (#283)

* x

* bridge add solver

* x

* add invoke

* llm client catch error
2025-01-17 13:52:00 +08:00

174 lines
5.8 KiB
Python

# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
from typing import Union, Dict, List, Any
import logging
import traceback
from tenacity import retry, stop_after_attempt
from kag.interface import PromptABC
from kag.common.registry import Registrable
logger = logging.getLogger(__name__)
class LLMClient(Registrable):
"""
A class that provides methods for performing inference using large language model.
This class includes methods to call the model with a prompt, parse the response, and handle batch processing of prompts.
"""
@retry(stop=stop_after_attempt(3))
def __call__(self, prompt: Union[str, dict, list]) -> str:
"""
Perform inference on the given prompt and return the result.
Args:
prompt (Union[str, dict, list]): Input prompt for inference.
Returns:
str: Inference result.
Raises:
NotImplementedError: If the subclass has not implemented this method.
"""
raise NotImplementedError
@retry(stop=stop_after_attempt(3))
def call_with_json_parse(self, prompt: Union[str, dict, list]):
"""
Perform inference on the given prompt and attempt to parse the result as JSON.
Args:
prompt (Union[str, dict, list]): Input prompt for inference.
Returns:
Any: Parsed result.
Raises:
NotImplementedError: If the subclass has not implemented this method.
"""
res = self(prompt)
_end = res.rfind("```")
_start = res.find("```json")
if _end != -1 and _start != -1:
json_str = res[_start + len("```json") : _end].strip()
else:
json_str = res
try:
json_result = json.loads(json_str)
except:
return res
return json_result
def invoke(
self,
variables: Dict[str, Any],
prompt_op: PromptABC,
with_json_parse: bool = True,
with_except: bool = False,
):
"""
Call the model and process the result.
Args:
variables (Dict[str, Any]): Variables used to build the prompt.
prompt_op (PromptABC): Prompt operation object for building and parsing prompts.
with_json_parse (bool, optional): Whether to attempt parsing the response as JSON. Defaults to True.
with_except (bool, optional): Whether to raise an exception if an error occurs. Defaults to False.
Returns:
List: Processed result list.
"""
result = []
prompt = prompt_op.build_prompt(variables)
logger.debug(f"Prompt: {prompt}")
if not prompt:
return result
response = ""
try:
response = (
self.call_with_json_parse(prompt=prompt)
if with_json_parse
else self(prompt)
)
logger.debug(f"Response: {response}")
result = prompt_op.parse_response(response, model=self.model, **variables)
logger.debug(f"Result: {result}")
except Exception as e:
import traceback
logger.debug(f"Error {e} during invocation: {traceback.format_exc()}")
if with_except:
raise RuntimeError(
f"LLM invoke exception, info: {e}\nllm input: \n{prompt}\nllm output: \n{response}"
)
return result
def batch(
self,
variables: Dict[str, Any],
prompt_op: PromptABC,
with_json_parse: bool = True,
) -> List:
"""
Batch process prompts.
Args:
variables (Dict[str, Any]): Variables used to build the prompts.
prompt_op (PromptABC): Prompt operation object for building and parsing prompts.
with_json_parse (bool, optional): Whether to attempt parsing the response as JSON. Defaults to True.
Returns:
List: List of all processed results.
"""
results = []
prompts = prompt_op.build_prompt(variables)
# If there is only one prompt, call the `invoke` method directly
if isinstance(prompts, str):
return self.invoke(variables, prompt_op, with_json_parse=with_json_parse)
for idx, prompt in enumerate(prompts, start=0):
logger.debug(f"Prompt_{idx}: {prompt}")
try:
response = (
self.call_with_json_parse(prompt=prompt)
if with_json_parse
else self(prompt)
)
logger.debug(f"Response_{idx}: {response}")
result = prompt_op.parse_response(
response, idx=idx, model=self.model, **variables
)
logger.debug(f"Result_{idx}: {result}")
results.extend(result)
except Exception as e:
logger.error(f"Error processing prompt {idx}: {e}")
logger.debug(traceback.format_exc())
continue
return results
def check(self):
from kag.common.conf import KAG_PROJECT_CONF
if (
hasattr(KAG_PROJECT_CONF, "llm_config_check")
and KAG_PROJECT_CONF.llm_config_check
):
try:
self.__call__("Are you OK?")
except Exception as e:
logger.error("LLM config check failed!")
raise e