mirror of
https://github.com/OpenSPG/openspg.git
synced 2025-12-20 19:56:08 +00:00
feat(nn4k): support huggingface decode only model local inference (#128)
Co-authored-by: xionghuaidong <huaidong.xhd@antgroup.com>
This commit is contained in:
parent
ee5089ef54
commit
210cf7be9a
@ -71,7 +71,7 @@ class NNExecutor(ABC):
|
||||
f"{self.__class__.__name__} does not support batch inference."
|
||||
)
|
||||
|
||||
def inference(self, data, args=None, **kwargs):
|
||||
def inference(self, inputs, **kwargs):
|
||||
"""
|
||||
The entry point of inference. Usually for local invokers or model services.
|
||||
"""
|
||||
@ -248,3 +248,103 @@ class NNAdapterModelArgs(NNModelArgs):
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class NNInferenceArgs:
|
||||
max_input_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Controls the maximum length to use by one of the truncation/padding parameters. "
|
||||
"In HuggingFace executors, known as max_length in tokenize callable function config."
|
||||
},
|
||||
)
|
||||
max_output_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The maximum numbers of tokens to generate. In HuggingFace executors, this arg will be tread as "
|
||||
"max_new_tokens."
|
||||
},
|
||||
)
|
||||
return_input_text: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether return input texts together with output texts."},
|
||||
)
|
||||
stop_sequence: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Generation will stop when stop sequence encountered in the output."
|
||||
},
|
||||
)
|
||||
do_sample: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "If false, generation will be in greedy search mode, otherwise will sampling the probable tokens."
|
||||
},
|
||||
)
|
||||
temperature: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "The creativity and diversity of the text generated."},
|
||||
)
|
||||
top_k: Optional[int] = field(
|
||||
default=50,
|
||||
metadata={
|
||||
"help": "In nucleus sampling, model will only sampling the tokens with the highest top_p(percentage) "
|
||||
"probability"
|
||||
},
|
||||
)
|
||||
top_p: Optional[float] = field(
|
||||
default=1.0,
|
||||
metadata={
|
||||
"help": "In nucleus sampling, model will only sampling the tokens with the highest top_k(count) probability"
|
||||
},
|
||||
)
|
||||
repetition_penalty: Optional[float] = field(
|
||||
default=1.0,
|
||||
metadata={"help": "By default 1.0 means no penalty."},
|
||||
)
|
||||
|
||||
generate_config: dict = field(
|
||||
default_factory=lambda: {},
|
||||
metadata={"help": "Config dict that will be use in model generation"},
|
||||
)
|
||||
|
||||
tokenize_return_tensors: str = field(
|
||||
default="pt",
|
||||
metadata={
|
||||
"help": "Tokenizer return type, will be merged into tokenize_config and pass into tokenize function"
|
||||
},
|
||||
)
|
||||
tokenize_config: dict = field(
|
||||
default_factory=lambda: {},
|
||||
metadata={
|
||||
"help": "Tokenize function config, will be pass into tokenize function"
|
||||
},
|
||||
)
|
||||
|
||||
decode_config: dict = field(
|
||||
default_factory=lambda: {},
|
||||
metadata={"help": "Configs to be pass into tokenizer.decode fucntion"},
|
||||
)
|
||||
|
||||
def update_if_not_none(self, from_key, to_dict, to_key=None):
|
||||
to_key = to_key or from_key
|
||||
from_value = self.__getattribute__(from_key)
|
||||
value_in_to_dict = self.__getattribute__(to_dict).get(to_key, None)
|
||||
if value_in_to_dict is None and from_value is not None:
|
||||
self.__getattribute__(to_dict)[to_key] = from_value
|
||||
|
||||
def __post_init__(self):
|
||||
# merging generation args
|
||||
self.update_if_not_none("max_output_length", "generate_config")
|
||||
self.update_if_not_none("do_sample", "generate_config")
|
||||
self.update_if_not_none("temperature", "generate_config")
|
||||
self.update_if_not_none("top_k", "generate_config")
|
||||
self.update_if_not_none("top_p", "generate_config")
|
||||
self.update_if_not_none("repetition_penalty", "generate_config")
|
||||
|
||||
# merging tokenize args
|
||||
self.update_if_not_none("max_input_length", "tokenize_config", "max_length")
|
||||
self.update_if_not_none(
|
||||
"tokenize_return_tensors", "tokenize_config", "return_tensors"
|
||||
)
|
||||
|
||||
@ -15,6 +15,7 @@ from typing import Optional
|
||||
from transformers import TrainingArguments
|
||||
|
||||
from nn4k.executor import NNAdapterModelArgs
|
||||
from nn4k.executor.base import NNInferenceArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -52,6 +53,13 @@ class HFModelArgs(NNAdapterModelArgs):
|
||||
"help": " Load the model weights from a TensorFlow checkpoint save file, default to False"
|
||||
},
|
||||
)
|
||||
padding_side: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Padding side of the tokenizer when padding batch inputs",
|
||||
"choices": [None, "left", "right"],
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
@ -105,3 +113,45 @@ class HFSftArgs(HFModelArgs, TrainingArguments):
|
||||
print(
|
||||
f"a eval_dataset_path is set but do_eval flag is not set, automatically set do_eval to True"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HFInferArgs(NNInferenceArgs):
|
||||
delete_heading_new_lines: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "an additional question mark or new line marks sometimes occurs at the beginning of output."
|
||||
"Try to get rid of these marks by setting this parameter to True. Different model may have different "
|
||||
"behavior, please check the result carefully."
|
||||
},
|
||||
)
|
||||
|
||||
tokenize_config: dict = field(
|
||||
default_factory=lambda: {
|
||||
"add_special_tokens": False,
|
||||
"padding": False,
|
||||
"truncation": False,
|
||||
},
|
||||
metadata={
|
||||
"help": "padding: https://huggingface.co/docs/transformers/pad_truncation#padding-and-truncation"
|
||||
},
|
||||
)
|
||||
|
||||
decode_config: dict = field(
|
||||
default_factory=lambda: {
|
||||
"skip_special_tokens": True,
|
||||
"clean_up_tokenization_spaces": True,
|
||||
},
|
||||
metadata={
|
||||
"help": "check https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__"
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
# HF specific map
|
||||
self.update_if_not_none(
|
||||
"max_output_length", "generate_config", "max_new_tokens"
|
||||
)
|
||||
self.update_if_not_none("max_input_length", "tokenize_config", "max_length")
|
||||
|
||||
@ -18,8 +18,9 @@ from torch.utils.data import Dataset
|
||||
from transformers import AutoConfig, AutoTokenizer, Trainer
|
||||
|
||||
from nn4k.executor import LLMExecutor
|
||||
from .hf_args import HFSftArgs, HFModelArgs
|
||||
from .hf_args import HFInferArgs, HFSftArgs, HFModelArgs
|
||||
from nn4k.executor.huggingface.nn_hf_trainer import NNHFTrainer
|
||||
from nn4k.utils.args_utils import ArgsUtils
|
||||
|
||||
|
||||
class HFLLMExecutor(LLMExecutor):
|
||||
@ -188,57 +189,85 @@ class HFLLMExecutor(LLMExecutor):
|
||||
if self.model_mode == mode and self._model is not None:
|
||||
return
|
||||
|
||||
args = args or self._init_args
|
||||
|
||||
from transformers import HfArgumentParser
|
||||
from nn4k.executor.huggingface import HFModelArgs
|
||||
|
||||
parser = HfArgumentParser(HFModelArgs)
|
||||
hf_model_args: HFModelArgs
|
||||
hf_model_args, *_ = parser.parse_dict(args, allow_extra_keys=True)
|
||||
|
||||
self.model_mode = mode
|
||||
self._tokenizer = self._hf_tokenizer_loader(hf_model_args)
|
||||
self._model = self._hf_model_loader(
|
||||
hf_model_args, mode, hf_model_args.nn_device
|
||||
args=hf_model_args, mode=mode, device=hf_model_args.nn_device
|
||||
)
|
||||
|
||||
if self.tokenizer.eos_token_id is None:
|
||||
self.tokenizer.eos_token_id = self.model.config.eos_token_id
|
||||
if self.tokenizer.pad_token_id is None:
|
||||
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
||||
if hf_model_args.padding_side is not None:
|
||||
self.tokenizer.padding_side = hf_model_args.padding_side
|
||||
|
||||
def inference(self, inputs, **kwargs):
|
||||
infer_args = ArgsUtils.update_args(self.init_args, kwargs)
|
||||
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
parser = HfArgumentParser(HFInferArgs)
|
||||
hf_infer_args: HFInferArgs
|
||||
hf_infer_args, *_ = parser.parse_dict(infer_args, allow_extra_keys=True)
|
||||
|
||||
def inference(
|
||||
self,
|
||||
data,
|
||||
max_input_length: int = 1024,
|
||||
max_output_length: int = 1024,
|
||||
do_sample: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
model = self.model
|
||||
tokenizer = self.tokenizer
|
||||
|
||||
input_ids = tokenizer(
|
||||
data,
|
||||
padding=True,
|
||||
return_token_type_ids=False,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
max_length=max_input_length,
|
||||
inputs,
|
||||
**hf_infer_args.tokenize_config,
|
||||
).to(model.device)
|
||||
|
||||
if hf_infer_args.stop_sequence is not None:
|
||||
stop_sequence = hf_infer_args.stop_sequence
|
||||
stop_sequence_ids = self.tokenizer.encode(
|
||||
stop_sequence, add_special_tokens=False
|
||||
)
|
||||
if len(stop_sequence_ids) > 1:
|
||||
print( # TODO: use logger instead
|
||||
"Warning: Stopping on a multiple token sequence is not yet supported on transformers. "
|
||||
"The first token of the stop sequence will be used as the stop sequence string in the interim."
|
||||
)
|
||||
hf_infer_args.generate_config["eos_token_id"] = stop_sequence_ids[0]
|
||||
|
||||
output_ids = model.generate(
|
||||
**input_ids,
|
||||
max_new_tokens=max_output_length,
|
||||
do_sample=do_sample,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
**kwargs,
|
||||
**hf_infer_args.generate_config,
|
||||
)
|
||||
|
||||
outputs = [
|
||||
tokenizer.decode(
|
||||
output_id[len(input_ids["input_ids"][idx]) :], skip_special_tokens=True
|
||||
output_texts = []
|
||||
for idx, output_id in enumerate(output_ids):
|
||||
if not hf_infer_args.return_input_text:
|
||||
output_id = output_id[len(input_ids["input_ids"][idx]) :]
|
||||
output_text = self.tokenizer.decode(
|
||||
output_id, **hf_infer_args.decode_config
|
||||
)
|
||||
for idx, output_id in enumerate(output_ids)
|
||||
]
|
||||
return outputs
|
||||
|
||||
if (
|
||||
not hf_infer_args.return_input_text
|
||||
and hf_infer_args.delete_heading_new_lines
|
||||
):
|
||||
import re
|
||||
|
||||
match = re.search("(\\n)+", output_text)
|
||||
if match is not None:
|
||||
start_index = match.end()
|
||||
if start_index < len(output_text) - 1:
|
||||
output_text = output_text[start_index:]
|
||||
|
||||
output_texts.append(output_text)
|
||||
|
||||
return output_texts
|
||||
|
||||
@abstractmethod
|
||||
def _hf_model_loader(
|
||||
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
{
|
||||
// -- base model info
|
||||
"nn_model_path": "/Path/to/model_dir", // local model path
|
||||
"nn_invoker": "nn4k.invoker.base.LLMInvoker", // invoker to use
|
||||
"nn_executor": "nn4k.executor.huggingface.hf_decode_only_executor.HFDecodeOnlyExecutor", // executor to use
|
||||
// the following are optional
|
||||
"adapter_name": "adapter_name", // adapter_name must be given to enable adapter; with adapter_path along has no effect!
|
||||
"adapter_path": "/path/to/adapter",
|
||||
"generate_config":{
|
||||
"temperature": 0.2,
|
||||
"do_sample": true
|
||||
}
|
||||
}
|
||||
@ -13,9 +13,22 @@ from nn4k.invoker.base import NNInvoker
|
||||
|
||||
|
||||
def main():
|
||||
NNInvoker.from_config("local_sft.json5").local_sft()
|
||||
# Inference example, not implemented yet.
|
||||
# NNInvoker.from_config("inferece_args.json").local_inference("你是谁")
|
||||
# example for local sft
|
||||
# NNInvoker.from_config("local_sft.json5").local_sft()
|
||||
|
||||
# example for local inference
|
||||
invoker = NNInvoker.from_config("local_infer.json5")
|
||||
answer = invoker.local_inference(
|
||||
"What could LLM do for human?",
|
||||
tokenize_config={"padding": True},
|
||||
delete_heading_new_lines=True,
|
||||
)
|
||||
# doing so to avoid load model everytime. You could hold a invoker, which has alreday load the model at the first time.
|
||||
answer2 = invoker.local_inference(
|
||||
"What could LLM do for a programmer",
|
||||
tokenize_config={"padding": True},
|
||||
delete_heading_new_lines=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -53,7 +53,7 @@ class HFEmbeddingExecutor(LLMExecutor):
|
||||
)
|
||||
self._model = model
|
||||
|
||||
def inference(self, data, args=None, **kwargs):
|
||||
def inference(self, inputs, **kwargs):
|
||||
model = self.model
|
||||
embeddings = model.encode(data)
|
||||
embeddings = model.encode(inputs)
|
||||
return embeddings
|
||||
|
||||
@ -14,7 +14,10 @@ from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Union
|
||||
|
||||
from nn4k.executor import NNExecutor
|
||||
from nn4k.utils.class_importing import dynamic_import_class
|
||||
|
||||
from nn4k.executor import LLMExecutor
|
||||
from nn4k.utils.args_utils import ArgsUtils
|
||||
|
||||
|
||||
class SubmitMode(Enum):
|
||||
@ -36,6 +39,7 @@ class NNInvoker(ABC):
|
||||
def __init__(self, init_args: dict, **kwargs):
|
||||
self._init_args = init_args
|
||||
self._kwargs = kwargs
|
||||
self.inference_warmed_up = False
|
||||
|
||||
@property
|
||||
def init_args(self):
|
||||
@ -71,7 +75,6 @@ class NNInvoker(ABC):
|
||||
from nn4k.consts import NN_INVOKER_KEY, NN_INVOKER_TEXT
|
||||
from nn4k.utils.config_parsing import preprocess_config
|
||||
from nn4k.utils.config_parsing import get_string_field
|
||||
from nn4k.utils.class_importing import dynamic_import_class
|
||||
|
||||
nn_config = preprocess_config(nn_config)
|
||||
nn_invoker = nn_config.get(NN_INVOKER_KEY)
|
||||
@ -141,9 +144,7 @@ class LLMInvoker(NNInvoker):
|
||||
raise NotImplementedError(f"{self.__class__.__name__} does not support SFT.")
|
||||
|
||||
def local_sft(self, args: dict = None):
|
||||
sft_args = copy.deepcopy(self.init_args)
|
||||
args = args or {}
|
||||
sft_args.update(args)
|
||||
sft_args = ArgsUtils.update_args(self.init_args, args)
|
||||
|
||||
from nn4k.executor import LLMExecutor
|
||||
|
||||
@ -161,29 +162,46 @@ class LLMInvoker(NNInvoker):
|
||||
"""
|
||||
Implement local inference for local invoker.
|
||||
"""
|
||||
return self._nn_executor.inference(data, **kwargs)
|
||||
args = ArgsUtils.handle_dict_config(kwargs)
|
||||
|
||||
if not self.inference_warmed_up:
|
||||
print(
|
||||
"warming up the model for inference, only happen for the first time..."
|
||||
)
|
||||
self.warmup_local_model()
|
||||
self.inference_warmed_up = True
|
||||
print("inference model is warmed up")
|
||||
|
||||
return self._nn_executor.inference(inputs=data, **args)
|
||||
|
||||
def warmup_local_model(self):
|
||||
"""
|
||||
Implement local model warming up logic for local invoker.
|
||||
"""
|
||||
nn_config = self.init_args
|
||||
|
||||
from nn4k.nnhub import NNHub
|
||||
from nn4k.consts import NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT
|
||||
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
|
||||
from nn4k.utils.config_parsing import get_string_field
|
||||
from nn4k.utils.class_importing import dynamic_import_class
|
||||
from transformers import HfArgumentParser
|
||||
from nn4k.executor import NNModelArgs
|
||||
|
||||
nn_executor = self.init_args.get(NN_EXECUTOR_KEY)
|
||||
parser = HfArgumentParser(NNModelArgs)
|
||||
model_args: NNModelArgs
|
||||
model_args, *_ = parser.parse_dict(self.init_args, allow_extra_keys=True)
|
||||
|
||||
from nn4k.consts import NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT
|
||||
|
||||
nn_executor = nn_config.get(NN_EXECUTOR_KEY)
|
||||
if nn_executor is not None:
|
||||
nn_executor = get_string_field(
|
||||
self.init_args, NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT
|
||||
)
|
||||
from nn4k.executor import NNExecutor
|
||||
|
||||
executor_class = dynamic_import_class(nn_executor, NN_EXECUTOR_TEXT)
|
||||
if not issubclass(executor_class, NNExecutor):
|
||||
message = "%r is not an %s class" % (nn_executor, NN_EXECUTOR_TEXT)
|
||||
raise RuntimeError(message)
|
||||
executor = executor_class.from_config(self.init_args)
|
||||
executor = executor_class.from_config(nn_config)
|
||||
else:
|
||||
nn_name = get_string_field(self.init_args, NN_NAME_KEY, NN_NAME_TEXT)
|
||||
nn_version = self.init_args.get(NN_VERSION_KEY)
|
||||
@ -193,13 +211,18 @@ class LLMInvoker(NNInvoker):
|
||||
)
|
||||
hub = NNHub.get_instance()
|
||||
executor = hub.get_model_executor(nn_name, nn_version)
|
||||
if executor is None:
|
||||
message = "model %r version %r " % (nn_name, nn_version)
|
||||
message += "is not found in the model hub"
|
||||
raise RuntimeError(message)
|
||||
self._nn_executor: NNExecutor = executor
|
||||
|
||||
if executor is None:
|
||||
message = "model %r version %r " % (
|
||||
model_args.nn_name,
|
||||
model_args.nn_version,
|
||||
)
|
||||
message += "is not found in the model hub, you should provide a valid nn_executor class path"
|
||||
raise RuntimeError(message)
|
||||
self._nn_executor: LLMExecutor = executor
|
||||
self._nn_executor.load_model(mode="inference")
|
||||
self._nn_executor.warmup_inference()
|
||||
self.inference_warmed_up = True
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, nn_config: dict) -> "LLMInvoker":
|
||||
|
||||
63
python/nn4k/nn4k/utils/args_utils.py
Normal file
63
python/nn4k/nn4k/utils/args_utils.py
Normal file
@ -0,0 +1,63 @@
|
||||
# Copyright 2023 OpenSPG Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
|
||||
class ArgsUtils:
|
||||
CONFIG_FILE_KEY = "config_file"
|
||||
|
||||
@staticmethod
|
||||
def update_args(base_args: dict, new_args: dict) -> dict:
|
||||
"""
|
||||
update an existing args with a new set of args
|
||||
:param base_args: args to get updated. Will be copied before get updated.
|
||||
:param new_args: args to update the base args.
|
||||
:rtype: dict
|
||||
"""
|
||||
import copy
|
||||
|
||||
copy_base_args = copy.deepcopy(base_args)
|
||||
new_args = new_args or {}
|
||||
copy_base_args.update(new_args)
|
||||
return copy_base_args
|
||||
|
||||
@staticmethod
|
||||
def handle_dict_config(kwargs: dict) -> dict:
|
||||
if "config_file" in kwargs:
|
||||
configs = ArgsUtils.load_config_dict_from_file(kwargs.get("config_file"))
|
||||
else:
|
||||
configs = kwargs
|
||||
|
||||
return configs
|
||||
|
||||
@staticmethod
|
||||
def load_config_dict_from_file(file_path: str) -> dict:
|
||||
from pathlib import Path
|
||||
|
||||
if file_path.endswith(".json"):
|
||||
import json
|
||||
|
||||
with open(Path(file_path), "r", encoding="utf-8") as open_json_file:
|
||||
data = json.load(open_json_file)
|
||||
nn_config = data
|
||||
return nn_config
|
||||
if file_path.endswith(".json5"):
|
||||
import json5
|
||||
|
||||
with open(Path(file_path), "r", encoding="utf-8") as open_json5_file:
|
||||
data = json5.load(open_json5_file)
|
||||
nn_config = data
|
||||
return nn_config
|
||||
from nn4k.utils.io.file_utils import FileUtils
|
||||
|
||||
raise ValueError(
|
||||
f"Config file with extension type {FileUtils.get_extension(file_path)} is not supported."
|
||||
f"use json or json5 instead."
|
||||
)
|
||||
@ -32,26 +32,9 @@ def preprocess_config(nn_config: Union[str, dict]) -> dict:
|
||||
if isinstance(nn_config, dict):
|
||||
return nn_config
|
||||
elif isinstance(nn_config, str):
|
||||
if nn_config.endswith(".json"):
|
||||
import json
|
||||
from nn4k.utils.args_utils import ArgsUtils
|
||||
|
||||
with open(Path(nn_config), "r", encoding="utf-8") as open_json_file:
|
||||
data = json.load(open_json_file)
|
||||
nn_config = data
|
||||
return nn_config
|
||||
if nn_config.endswith(".json5"):
|
||||
import json5
|
||||
|
||||
with open(Path(nn_config), "r", encoding="utf-8") as open_json5_file:
|
||||
data = json5.load(open_json5_file)
|
||||
nn_config = data
|
||||
return nn_config
|
||||
from nn4k.utils.io.file_utils import FileUtils
|
||||
|
||||
raise ValueError(
|
||||
f"Config file with extension type {FileUtils.get_extension(nn_config)} is not supported."
|
||||
f"use json or json5 instead."
|
||||
)
|
||||
return ArgsUtils.load_config_dict_from_file(nn_config)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"nn_config could be dict or str, {type(nn_config)} is not yet supported."
|
||||
|
||||
@ -22,7 +22,7 @@ class StubExecutor(LLMExecutor):
|
||||
def warmup_inference(self, args=None, **kwargs):
|
||||
pass
|
||||
|
||||
def inference(self, data, args=None, **kwargs):
|
||||
def inference(self, inputs, args=None, **kwargs):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -37,7 +37,7 @@ class StubExecutor(NNExecutor):
|
||||
def warmup_inference(self, args=None, **kwargs):
|
||||
self.warmup_inference_called = True
|
||||
|
||||
def inference(self, data, args=None, **kwargs):
|
||||
def inference(self, inputs, args=None, **kwargs):
|
||||
return self.inference_result
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -89,7 +89,10 @@ class TestBaseInvoker(unittest.TestCase):
|
||||
def testLocalLLMInvokerWithCustomExecutor(self):
|
||||
from nn4k.invoker import LLMInvoker
|
||||
|
||||
nn_config = {"nn_executor": "invoker_test_stub.StubExecutor"}
|
||||
nn_config = {
|
||||
"nn_model_path": "/path/to/model",
|
||||
"nn_executor": "invoker_test_stub.StubExecutor",
|
||||
}
|
||||
invoker = LLMInvoker.from_config(nn_config)
|
||||
self.assertTrue(isinstance(invoker, LLMInvoker))
|
||||
self.assertEqual(invoker.init_args, nn_config)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user