mirror of
https://github.com/OpenSPG/openspg.git
synced 2025-12-21 04:07:59 +00:00
feat(nn4k): support huggingface decode only model local inference (#128)
Co-authored-by: xionghuaidong <huaidong.xhd@antgroup.com>
This commit is contained in:
parent
ee5089ef54
commit
210cf7be9a
@ -71,7 +71,7 @@ class NNExecutor(ABC):
|
|||||||
f"{self.__class__.__name__} does not support batch inference."
|
f"{self.__class__.__name__} does not support batch inference."
|
||||||
)
|
)
|
||||||
|
|
||||||
def inference(self, data, args=None, **kwargs):
|
def inference(self, inputs, **kwargs):
|
||||||
"""
|
"""
|
||||||
The entry point of inference. Usually for local invokers or model services.
|
The entry point of inference. Usually for local invokers or model services.
|
||||||
"""
|
"""
|
||||||
@ -248,3 +248,103 @@ class NNAdapterModelArgs(NNModelArgs):
|
|||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NNInferenceArgs:
|
||||||
|
max_input_length: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Controls the maximum length to use by one of the truncation/padding parameters. "
|
||||||
|
"In HuggingFace executors, known as max_length in tokenize callable function config."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
max_output_length: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "The maximum numbers of tokens to generate. In HuggingFace executors, this arg will be tread as "
|
||||||
|
"max_new_tokens."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return_input_text: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether return input texts together with output texts."},
|
||||||
|
)
|
||||||
|
stop_sequence: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Generation will stop when stop sequence encountered in the output."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
do_sample: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "If false, generation will be in greedy search mode, otherwise will sampling the probable tokens."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
temperature: float = field(
|
||||||
|
default=1.0,
|
||||||
|
metadata={"help": "The creativity and diversity of the text generated."},
|
||||||
|
)
|
||||||
|
top_k: Optional[int] = field(
|
||||||
|
default=50,
|
||||||
|
metadata={
|
||||||
|
"help": "In nucleus sampling, model will only sampling the tokens with the highest top_p(percentage) "
|
||||||
|
"probability"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
top_p: Optional[float] = field(
|
||||||
|
default=1.0,
|
||||||
|
metadata={
|
||||||
|
"help": "In nucleus sampling, model will only sampling the tokens with the highest top_k(count) probability"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
repetition_penalty: Optional[float] = field(
|
||||||
|
default=1.0,
|
||||||
|
metadata={"help": "By default 1.0 means no penalty."},
|
||||||
|
)
|
||||||
|
|
||||||
|
generate_config: dict = field(
|
||||||
|
default_factory=lambda: {},
|
||||||
|
metadata={"help": "Config dict that will be use in model generation"},
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenize_return_tensors: str = field(
|
||||||
|
default="pt",
|
||||||
|
metadata={
|
||||||
|
"help": "Tokenizer return type, will be merged into tokenize_config and pass into tokenize function"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
tokenize_config: dict = field(
|
||||||
|
default_factory=lambda: {},
|
||||||
|
metadata={
|
||||||
|
"help": "Tokenize function config, will be pass into tokenize function"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
decode_config: dict = field(
|
||||||
|
default_factory=lambda: {},
|
||||||
|
metadata={"help": "Configs to be pass into tokenizer.decode fucntion"},
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_if_not_none(self, from_key, to_dict, to_key=None):
|
||||||
|
to_key = to_key or from_key
|
||||||
|
from_value = self.__getattribute__(from_key)
|
||||||
|
value_in_to_dict = self.__getattribute__(to_dict).get(to_key, None)
|
||||||
|
if value_in_to_dict is None and from_value is not None:
|
||||||
|
self.__getattribute__(to_dict)[to_key] = from_value
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# merging generation args
|
||||||
|
self.update_if_not_none("max_output_length", "generate_config")
|
||||||
|
self.update_if_not_none("do_sample", "generate_config")
|
||||||
|
self.update_if_not_none("temperature", "generate_config")
|
||||||
|
self.update_if_not_none("top_k", "generate_config")
|
||||||
|
self.update_if_not_none("top_p", "generate_config")
|
||||||
|
self.update_if_not_none("repetition_penalty", "generate_config")
|
||||||
|
|
||||||
|
# merging tokenize args
|
||||||
|
self.update_if_not_none("max_input_length", "tokenize_config", "max_length")
|
||||||
|
self.update_if_not_none(
|
||||||
|
"tokenize_return_tensors", "tokenize_config", "return_tensors"
|
||||||
|
)
|
||||||
|
|||||||
@ -15,6 +15,7 @@ from typing import Optional
|
|||||||
from transformers import TrainingArguments
|
from transformers import TrainingArguments
|
||||||
|
|
||||||
from nn4k.executor import NNAdapterModelArgs
|
from nn4k.executor import NNAdapterModelArgs
|
||||||
|
from nn4k.executor.base import NNInferenceArgs
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -52,6 +53,13 @@ class HFModelArgs(NNAdapterModelArgs):
|
|||||||
"help": " Load the model weights from a TensorFlow checkpoint save file, default to False"
|
"help": " Load the model weights from a TensorFlow checkpoint save file, default to False"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
padding_side: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Padding side of the tokenizer when padding batch inputs",
|
||||||
|
"choices": [None, "left", "right"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
@ -105,3 +113,45 @@ class HFSftArgs(HFModelArgs, TrainingArguments):
|
|||||||
print(
|
print(
|
||||||
f"a eval_dataset_path is set but do_eval flag is not set, automatically set do_eval to True"
|
f"a eval_dataset_path is set but do_eval flag is not set, automatically set do_eval to True"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HFInferArgs(NNInferenceArgs):
|
||||||
|
delete_heading_new_lines: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "an additional question mark or new line marks sometimes occurs at the beginning of output."
|
||||||
|
"Try to get rid of these marks by setting this parameter to True. Different model may have different "
|
||||||
|
"behavior, please check the result carefully."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenize_config: dict = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"add_special_tokens": False,
|
||||||
|
"padding": False,
|
||||||
|
"truncation": False,
|
||||||
|
},
|
||||||
|
metadata={
|
||||||
|
"help": "padding: https://huggingface.co/docs/transformers/pad_truncation#padding-and-truncation"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
decode_config: dict = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"skip_special_tokens": True,
|
||||||
|
"clean_up_tokenization_spaces": True,
|
||||||
|
},
|
||||||
|
metadata={
|
||||||
|
"help": "check https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
|
||||||
|
# HF specific map
|
||||||
|
self.update_if_not_none(
|
||||||
|
"max_output_length", "generate_config", "max_new_tokens"
|
||||||
|
)
|
||||||
|
self.update_if_not_none("max_input_length", "tokenize_config", "max_length")
|
||||||
|
|||||||
@ -18,8 +18,9 @@ from torch.utils.data import Dataset
|
|||||||
from transformers import AutoConfig, AutoTokenizer, Trainer
|
from transformers import AutoConfig, AutoTokenizer, Trainer
|
||||||
|
|
||||||
from nn4k.executor import LLMExecutor
|
from nn4k.executor import LLMExecutor
|
||||||
from .hf_args import HFSftArgs, HFModelArgs
|
from .hf_args import HFInferArgs, HFSftArgs, HFModelArgs
|
||||||
from nn4k.executor.huggingface.nn_hf_trainer import NNHFTrainer
|
from nn4k.executor.huggingface.nn_hf_trainer import NNHFTrainer
|
||||||
|
from nn4k.utils.args_utils import ArgsUtils
|
||||||
|
|
||||||
|
|
||||||
class HFLLMExecutor(LLMExecutor):
|
class HFLLMExecutor(LLMExecutor):
|
||||||
@ -188,57 +189,85 @@ class HFLLMExecutor(LLMExecutor):
|
|||||||
if self.model_mode == mode and self._model is not None:
|
if self.model_mode == mode and self._model is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
args = args or self._init_args
|
||||||
|
|
||||||
from transformers import HfArgumentParser
|
from transformers import HfArgumentParser
|
||||||
from nn4k.executor.huggingface import HFModelArgs
|
from nn4k.executor.huggingface import HFModelArgs
|
||||||
|
|
||||||
parser = HfArgumentParser(HFModelArgs)
|
parser = HfArgumentParser(HFModelArgs)
|
||||||
|
hf_model_args: HFModelArgs
|
||||||
hf_model_args, *_ = parser.parse_dict(args, allow_extra_keys=True)
|
hf_model_args, *_ = parser.parse_dict(args, allow_extra_keys=True)
|
||||||
|
|
||||||
self.model_mode = mode
|
self.model_mode = mode
|
||||||
self._tokenizer = self._hf_tokenizer_loader(hf_model_args)
|
self._tokenizer = self._hf_tokenizer_loader(hf_model_args)
|
||||||
self._model = self._hf_model_loader(
|
self._model = self._hf_model_loader(
|
||||||
hf_model_args, mode, hf_model_args.nn_device
|
args=hf_model_args, mode=mode, device=hf_model_args.nn_device
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.tokenizer.eos_token_id is None:
|
if self.tokenizer.eos_token_id is None:
|
||||||
self.tokenizer.eos_token_id = self.model.config.eos_token_id
|
self.tokenizer.eos_token_id = self.model.config.eos_token_id
|
||||||
if self.tokenizer.pad_token_id is None:
|
if self.tokenizer.pad_token_id is None:
|
||||||
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
||||||
|
if hf_model_args.padding_side is not None:
|
||||||
|
self.tokenizer.padding_side = hf_model_args.padding_side
|
||||||
|
|
||||||
|
def inference(self, inputs, **kwargs):
|
||||||
|
infer_args = ArgsUtils.update_args(self.init_args, kwargs)
|
||||||
|
|
||||||
|
from transformers import HfArgumentParser
|
||||||
|
|
||||||
|
parser = HfArgumentParser(HFInferArgs)
|
||||||
|
hf_infer_args: HFInferArgs
|
||||||
|
hf_infer_args, *_ = parser.parse_dict(infer_args, allow_extra_keys=True)
|
||||||
|
|
||||||
def inference(
|
|
||||||
self,
|
|
||||||
data,
|
|
||||||
max_input_length: int = 1024,
|
|
||||||
max_output_length: int = 1024,
|
|
||||||
do_sample: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
model = self.model
|
model = self.model
|
||||||
tokenizer = self.tokenizer
|
tokenizer = self.tokenizer
|
||||||
|
|
||||||
input_ids = tokenizer(
|
input_ids = tokenizer(
|
||||||
data,
|
inputs,
|
||||||
padding=True,
|
**hf_infer_args.tokenize_config,
|
||||||
return_token_type_ids=False,
|
|
||||||
return_tensors="pt",
|
|
||||||
truncation=True,
|
|
||||||
max_length=max_input_length,
|
|
||||||
).to(model.device)
|
).to(model.device)
|
||||||
|
|
||||||
|
if hf_infer_args.stop_sequence is not None:
|
||||||
|
stop_sequence = hf_infer_args.stop_sequence
|
||||||
|
stop_sequence_ids = self.tokenizer.encode(
|
||||||
|
stop_sequence, add_special_tokens=False
|
||||||
|
)
|
||||||
|
if len(stop_sequence_ids) > 1:
|
||||||
|
print( # TODO: use logger instead
|
||||||
|
"Warning: Stopping on a multiple token sequence is not yet supported on transformers. "
|
||||||
|
"The first token of the stop sequence will be used as the stop sequence string in the interim."
|
||||||
|
)
|
||||||
|
hf_infer_args.generate_config["eos_token_id"] = stop_sequence_ids[0]
|
||||||
|
|
||||||
output_ids = model.generate(
|
output_ids = model.generate(
|
||||||
**input_ids,
|
**input_ids,
|
||||||
max_new_tokens=max_output_length,
|
**hf_infer_args.generate_config,
|
||||||
do_sample=do_sample,
|
|
||||||
eos_token_id=tokenizer.eos_token_id,
|
|
||||||
pad_token_id=tokenizer.pad_token_id,
|
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = [
|
output_texts = []
|
||||||
tokenizer.decode(
|
for idx, output_id in enumerate(output_ids):
|
||||||
output_id[len(input_ids["input_ids"][idx]) :], skip_special_tokens=True
|
if not hf_infer_args.return_input_text:
|
||||||
|
output_id = output_id[len(input_ids["input_ids"][idx]) :]
|
||||||
|
output_text = self.tokenizer.decode(
|
||||||
|
output_id, **hf_infer_args.decode_config
|
||||||
)
|
)
|
||||||
for idx, output_id in enumerate(output_ids)
|
|
||||||
]
|
if (
|
||||||
return outputs
|
not hf_infer_args.return_input_text
|
||||||
|
and hf_infer_args.delete_heading_new_lines
|
||||||
|
):
|
||||||
|
import re
|
||||||
|
|
||||||
|
match = re.search("(\\n)+", output_text)
|
||||||
|
if match is not None:
|
||||||
|
start_index = match.end()
|
||||||
|
if start_index < len(output_text) - 1:
|
||||||
|
output_text = output_text[start_index:]
|
||||||
|
|
||||||
|
output_texts.append(output_text)
|
||||||
|
|
||||||
|
return output_texts
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _hf_model_loader(
|
def _hf_model_loader(
|
||||||
|
|||||||
@ -0,0 +1,14 @@
|
|||||||
|
|
||||||
|
{
|
||||||
|
// -- base model info
|
||||||
|
"nn_model_path": "/Path/to/model_dir", // local model path
|
||||||
|
"nn_invoker": "nn4k.invoker.base.LLMInvoker", // invoker to use
|
||||||
|
"nn_executor": "nn4k.executor.huggingface.hf_decode_only_executor.HFDecodeOnlyExecutor", // executor to use
|
||||||
|
// the following are optional
|
||||||
|
"adapter_name": "adapter_name", // adapter_name must be given to enable adapter; with adapter_path along has no effect!
|
||||||
|
"adapter_path": "/path/to/adapter",
|
||||||
|
"generate_config":{
|
||||||
|
"temperature": 0.2,
|
||||||
|
"do_sample": true
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -13,9 +13,22 @@ from nn4k.invoker.base import NNInvoker
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
NNInvoker.from_config("local_sft.json5").local_sft()
|
# example for local sft
|
||||||
# Inference example, not implemented yet.
|
# NNInvoker.from_config("local_sft.json5").local_sft()
|
||||||
# NNInvoker.from_config("inferece_args.json").local_inference("你是谁")
|
|
||||||
|
# example for local inference
|
||||||
|
invoker = NNInvoker.from_config("local_infer.json5")
|
||||||
|
answer = invoker.local_inference(
|
||||||
|
"What could LLM do for human?",
|
||||||
|
tokenize_config={"padding": True},
|
||||||
|
delete_heading_new_lines=True,
|
||||||
|
)
|
||||||
|
# doing so to avoid load model everytime. You could hold a invoker, which has alreday load the model at the first time.
|
||||||
|
answer2 = invoker.local_inference(
|
||||||
|
"What could LLM do for a programmer",
|
||||||
|
tokenize_config={"padding": True},
|
||||||
|
delete_heading_new_lines=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -53,7 +53,7 @@ class HFEmbeddingExecutor(LLMExecutor):
|
|||||||
)
|
)
|
||||||
self._model = model
|
self._model = model
|
||||||
|
|
||||||
def inference(self, data, args=None, **kwargs):
|
def inference(self, inputs, **kwargs):
|
||||||
model = self.model
|
model = self.model
|
||||||
embeddings = model.encode(data)
|
embeddings = model.encode(inputs)
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|||||||
@ -14,7 +14,10 @@ from abc import ABC, abstractmethod
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from nn4k.executor import NNExecutor
|
from nn4k.utils.class_importing import dynamic_import_class
|
||||||
|
|
||||||
|
from nn4k.executor import LLMExecutor
|
||||||
|
from nn4k.utils.args_utils import ArgsUtils
|
||||||
|
|
||||||
|
|
||||||
class SubmitMode(Enum):
|
class SubmitMode(Enum):
|
||||||
@ -36,6 +39,7 @@ class NNInvoker(ABC):
|
|||||||
def __init__(self, init_args: dict, **kwargs):
|
def __init__(self, init_args: dict, **kwargs):
|
||||||
self._init_args = init_args
|
self._init_args = init_args
|
||||||
self._kwargs = kwargs
|
self._kwargs = kwargs
|
||||||
|
self.inference_warmed_up = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def init_args(self):
|
def init_args(self):
|
||||||
@ -71,7 +75,6 @@ class NNInvoker(ABC):
|
|||||||
from nn4k.consts import NN_INVOKER_KEY, NN_INVOKER_TEXT
|
from nn4k.consts import NN_INVOKER_KEY, NN_INVOKER_TEXT
|
||||||
from nn4k.utils.config_parsing import preprocess_config
|
from nn4k.utils.config_parsing import preprocess_config
|
||||||
from nn4k.utils.config_parsing import get_string_field
|
from nn4k.utils.config_parsing import get_string_field
|
||||||
from nn4k.utils.class_importing import dynamic_import_class
|
|
||||||
|
|
||||||
nn_config = preprocess_config(nn_config)
|
nn_config = preprocess_config(nn_config)
|
||||||
nn_invoker = nn_config.get(NN_INVOKER_KEY)
|
nn_invoker = nn_config.get(NN_INVOKER_KEY)
|
||||||
@ -141,9 +144,7 @@ class LLMInvoker(NNInvoker):
|
|||||||
raise NotImplementedError(f"{self.__class__.__name__} does not support SFT.")
|
raise NotImplementedError(f"{self.__class__.__name__} does not support SFT.")
|
||||||
|
|
||||||
def local_sft(self, args: dict = None):
|
def local_sft(self, args: dict = None):
|
||||||
sft_args = copy.deepcopy(self.init_args)
|
sft_args = ArgsUtils.update_args(self.init_args, args)
|
||||||
args = args or {}
|
|
||||||
sft_args.update(args)
|
|
||||||
|
|
||||||
from nn4k.executor import LLMExecutor
|
from nn4k.executor import LLMExecutor
|
||||||
|
|
||||||
@ -161,29 +162,46 @@ class LLMInvoker(NNInvoker):
|
|||||||
"""
|
"""
|
||||||
Implement local inference for local invoker.
|
Implement local inference for local invoker.
|
||||||
"""
|
"""
|
||||||
return self._nn_executor.inference(data, **kwargs)
|
args = ArgsUtils.handle_dict_config(kwargs)
|
||||||
|
|
||||||
|
if not self.inference_warmed_up:
|
||||||
|
print(
|
||||||
|
"warming up the model for inference, only happen for the first time..."
|
||||||
|
)
|
||||||
|
self.warmup_local_model()
|
||||||
|
self.inference_warmed_up = True
|
||||||
|
print("inference model is warmed up")
|
||||||
|
|
||||||
|
return self._nn_executor.inference(inputs=data, **args)
|
||||||
|
|
||||||
def warmup_local_model(self):
|
def warmup_local_model(self):
|
||||||
"""
|
"""
|
||||||
Implement local model warming up logic for local invoker.
|
Implement local model warming up logic for local invoker.
|
||||||
"""
|
"""
|
||||||
|
nn_config = self.init_args
|
||||||
|
|
||||||
from nn4k.nnhub import NNHub
|
from nn4k.nnhub import NNHub
|
||||||
from nn4k.consts import NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT
|
|
||||||
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||||
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
|
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
|
||||||
from nn4k.utils.config_parsing import get_string_field
|
from nn4k.utils.config_parsing import get_string_field
|
||||||
from nn4k.utils.class_importing import dynamic_import_class
|
from transformers import HfArgumentParser
|
||||||
|
from nn4k.executor import NNModelArgs
|
||||||
|
|
||||||
nn_executor = self.init_args.get(NN_EXECUTOR_KEY)
|
parser = HfArgumentParser(NNModelArgs)
|
||||||
|
model_args: NNModelArgs
|
||||||
|
model_args, *_ = parser.parse_dict(self.init_args, allow_extra_keys=True)
|
||||||
|
|
||||||
|
from nn4k.consts import NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT
|
||||||
|
|
||||||
|
nn_executor = nn_config.get(NN_EXECUTOR_KEY)
|
||||||
if nn_executor is not None:
|
if nn_executor is not None:
|
||||||
nn_executor = get_string_field(
|
from nn4k.executor import NNExecutor
|
||||||
self.init_args, NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT
|
|
||||||
)
|
|
||||||
executor_class = dynamic_import_class(nn_executor, NN_EXECUTOR_TEXT)
|
executor_class = dynamic_import_class(nn_executor, NN_EXECUTOR_TEXT)
|
||||||
if not issubclass(executor_class, NNExecutor):
|
if not issubclass(executor_class, NNExecutor):
|
||||||
message = "%r is not an %s class" % (nn_executor, NN_EXECUTOR_TEXT)
|
message = "%r is not an %s class" % (nn_executor, NN_EXECUTOR_TEXT)
|
||||||
raise RuntimeError(message)
|
raise RuntimeError(message)
|
||||||
executor = executor_class.from_config(self.init_args)
|
executor = executor_class.from_config(nn_config)
|
||||||
else:
|
else:
|
||||||
nn_name = get_string_field(self.init_args, NN_NAME_KEY, NN_NAME_TEXT)
|
nn_name = get_string_field(self.init_args, NN_NAME_KEY, NN_NAME_TEXT)
|
||||||
nn_version = self.init_args.get(NN_VERSION_KEY)
|
nn_version = self.init_args.get(NN_VERSION_KEY)
|
||||||
@ -193,13 +211,18 @@ class LLMInvoker(NNInvoker):
|
|||||||
)
|
)
|
||||||
hub = NNHub.get_instance()
|
hub = NNHub.get_instance()
|
||||||
executor = hub.get_model_executor(nn_name, nn_version)
|
executor = hub.get_model_executor(nn_name, nn_version)
|
||||||
if executor is None:
|
|
||||||
message = "model %r version %r " % (nn_name, nn_version)
|
if executor is None:
|
||||||
message += "is not found in the model hub"
|
message = "model %r version %r " % (
|
||||||
raise RuntimeError(message)
|
model_args.nn_name,
|
||||||
self._nn_executor: NNExecutor = executor
|
model_args.nn_version,
|
||||||
|
)
|
||||||
|
message += "is not found in the model hub, you should provide a valid nn_executor class path"
|
||||||
|
raise RuntimeError(message)
|
||||||
|
self._nn_executor: LLMExecutor = executor
|
||||||
self._nn_executor.load_model(mode="inference")
|
self._nn_executor.load_model(mode="inference")
|
||||||
self._nn_executor.warmup_inference()
|
self._nn_executor.warmup_inference()
|
||||||
|
self.inference_warmed_up = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, nn_config: dict) -> "LLMInvoker":
|
def from_config(cls, nn_config: dict) -> "LLMInvoker":
|
||||||
|
|||||||
63
python/nn4k/nn4k/utils/args_utils.py
Normal file
63
python/nn4k/nn4k/utils/args_utils.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
# Copyright 2023 OpenSPG Authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
|
||||||
|
class ArgsUtils:
|
||||||
|
CONFIG_FILE_KEY = "config_file"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_args(base_args: dict, new_args: dict) -> dict:
|
||||||
|
"""
|
||||||
|
update an existing args with a new set of args
|
||||||
|
:param base_args: args to get updated. Will be copied before get updated.
|
||||||
|
:param new_args: args to update the base args.
|
||||||
|
:rtype: dict
|
||||||
|
"""
|
||||||
|
import copy
|
||||||
|
|
||||||
|
copy_base_args = copy.deepcopy(base_args)
|
||||||
|
new_args = new_args or {}
|
||||||
|
copy_base_args.update(new_args)
|
||||||
|
return copy_base_args
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def handle_dict_config(kwargs: dict) -> dict:
|
||||||
|
if "config_file" in kwargs:
|
||||||
|
configs = ArgsUtils.load_config_dict_from_file(kwargs.get("config_file"))
|
||||||
|
else:
|
||||||
|
configs = kwargs
|
||||||
|
|
||||||
|
return configs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_config_dict_from_file(file_path: str) -> dict:
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
if file_path.endswith(".json"):
|
||||||
|
import json
|
||||||
|
|
||||||
|
with open(Path(file_path), "r", encoding="utf-8") as open_json_file:
|
||||||
|
data = json.load(open_json_file)
|
||||||
|
nn_config = data
|
||||||
|
return nn_config
|
||||||
|
if file_path.endswith(".json5"):
|
||||||
|
import json5
|
||||||
|
|
||||||
|
with open(Path(file_path), "r", encoding="utf-8") as open_json5_file:
|
||||||
|
data = json5.load(open_json5_file)
|
||||||
|
nn_config = data
|
||||||
|
return nn_config
|
||||||
|
from nn4k.utils.io.file_utils import FileUtils
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Config file with extension type {FileUtils.get_extension(file_path)} is not supported."
|
||||||
|
f"use json or json5 instead."
|
||||||
|
)
|
||||||
@ -32,26 +32,9 @@ def preprocess_config(nn_config: Union[str, dict]) -> dict:
|
|||||||
if isinstance(nn_config, dict):
|
if isinstance(nn_config, dict):
|
||||||
return nn_config
|
return nn_config
|
||||||
elif isinstance(nn_config, str):
|
elif isinstance(nn_config, str):
|
||||||
if nn_config.endswith(".json"):
|
from nn4k.utils.args_utils import ArgsUtils
|
||||||
import json
|
|
||||||
|
|
||||||
with open(Path(nn_config), "r", encoding="utf-8") as open_json_file:
|
return ArgsUtils.load_config_dict_from_file(nn_config)
|
||||||
data = json.load(open_json_file)
|
|
||||||
nn_config = data
|
|
||||||
return nn_config
|
|
||||||
if nn_config.endswith(".json5"):
|
|
||||||
import json5
|
|
||||||
|
|
||||||
with open(Path(nn_config), "r", encoding="utf-8") as open_json5_file:
|
|
||||||
data = json5.load(open_json5_file)
|
|
||||||
nn_config = data
|
|
||||||
return nn_config
|
|
||||||
from nn4k.utils.io.file_utils import FileUtils
|
|
||||||
|
|
||||||
raise ValueError(
|
|
||||||
f"Config file with extension type {FileUtils.get_extension(nn_config)} is not supported."
|
|
||||||
f"use json or json5 instead."
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"nn_config could be dict or str, {type(nn_config)} is not yet supported."
|
f"nn_config could be dict or str, {type(nn_config)} is not yet supported."
|
||||||
|
|||||||
@ -22,7 +22,7 @@ class StubExecutor(LLMExecutor):
|
|||||||
def warmup_inference(self, args=None, **kwargs):
|
def warmup_inference(self, args=None, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def inference(self, data, args=None, **kwargs):
|
def inference(self, inputs, args=None, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -37,7 +37,7 @@ class StubExecutor(NNExecutor):
|
|||||||
def warmup_inference(self, args=None, **kwargs):
|
def warmup_inference(self, args=None, **kwargs):
|
||||||
self.warmup_inference_called = True
|
self.warmup_inference_called = True
|
||||||
|
|
||||||
def inference(self, data, args=None, **kwargs):
|
def inference(self, inputs, args=None, **kwargs):
|
||||||
return self.inference_result
|
return self.inference_result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -89,7 +89,10 @@ class TestBaseInvoker(unittest.TestCase):
|
|||||||
def testLocalLLMInvokerWithCustomExecutor(self):
|
def testLocalLLMInvokerWithCustomExecutor(self):
|
||||||
from nn4k.invoker import LLMInvoker
|
from nn4k.invoker import LLMInvoker
|
||||||
|
|
||||||
nn_config = {"nn_executor": "invoker_test_stub.StubExecutor"}
|
nn_config = {
|
||||||
|
"nn_model_path": "/path/to/model",
|
||||||
|
"nn_executor": "invoker_test_stub.StubExecutor",
|
||||||
|
}
|
||||||
invoker = LLMInvoker.from_config(nn_config)
|
invoker = LLMInvoker.from_config(nn_config)
|
||||||
self.assertTrue(isinstance(invoker, LLMInvoker))
|
self.assertTrue(isinstance(invoker, LLMInvoker))
|
||||||
self.assertEqual(invoker.init_args, nn_config)
|
self.assertEqual(invoker.init_args, nn_config)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user