feat(nn4k): support huggingface decode only model local inference (#128)

Co-authored-by: xionghuaidong <huaidong.xhd@antgroup.com>
This commit is contained in:
chenbin11200 2024-03-08 13:54:15 +08:00 committed by FishJoy
parent ee5089ef54
commit 210cf7be9a
12 changed files with 351 additions and 73 deletions

View File

@ -71,7 +71,7 @@ class NNExecutor(ABC):
f"{self.__class__.__name__} does not support batch inference."
)
def inference(self, data, args=None, **kwargs):
def inference(self, inputs, **kwargs):
"""
The entry point of inference. Usually for local invokers or model services.
"""
@ -248,3 +248,103 @@ class NNAdapterModelArgs(NNModelArgs):
def __post_init__(self):
super().__post_init__()
@dataclass
class NNInferenceArgs:
max_input_length: Optional[int] = field(
default=None,
metadata={
"help": "Controls the maximum length to use by one of the truncation/padding parameters. "
"In HuggingFace executors, known as max_length in tokenize callable function config."
},
)
max_output_length: Optional[int] = field(
default=None,
metadata={
"help": "The maximum numbers of tokens to generate. In HuggingFace executors, this arg will be tread as "
"max_new_tokens."
},
)
return_input_text: Optional[bool] = field(
default=False,
metadata={"help": "Whether return input texts together with output texts."},
)
stop_sequence: Optional[str] = field(
default=None,
metadata={
"help": "Generation will stop when stop sequence encountered in the output."
},
)
do_sample: bool = field(
default=False,
metadata={
"help": "If false, generation will be in greedy search mode, otherwise will sampling the probable tokens."
},
)
temperature: float = field(
default=1.0,
metadata={"help": "The creativity and diversity of the text generated."},
)
top_k: Optional[int] = field(
default=50,
metadata={
"help": "In nucleus sampling, model will only sampling the tokens with the highest top_p(percentage) "
"probability"
},
)
top_p: Optional[float] = field(
default=1.0,
metadata={
"help": "In nucleus sampling, model will only sampling the tokens with the highest top_k(count) probability"
},
)
repetition_penalty: Optional[float] = field(
default=1.0,
metadata={"help": "By default 1.0 means no penalty."},
)
generate_config: dict = field(
default_factory=lambda: {},
metadata={"help": "Config dict that will be use in model generation"},
)
tokenize_return_tensors: str = field(
default="pt",
metadata={
"help": "Tokenizer return type, will be merged into tokenize_config and pass into tokenize function"
},
)
tokenize_config: dict = field(
default_factory=lambda: {},
metadata={
"help": "Tokenize function config, will be pass into tokenize function"
},
)
decode_config: dict = field(
default_factory=lambda: {},
metadata={"help": "Configs to be pass into tokenizer.decode fucntion"},
)
def update_if_not_none(self, from_key, to_dict, to_key=None):
to_key = to_key or from_key
from_value = self.__getattribute__(from_key)
value_in_to_dict = self.__getattribute__(to_dict).get(to_key, None)
if value_in_to_dict is None and from_value is not None:
self.__getattribute__(to_dict)[to_key] = from_value
def __post_init__(self):
# merging generation args
self.update_if_not_none("max_output_length", "generate_config")
self.update_if_not_none("do_sample", "generate_config")
self.update_if_not_none("temperature", "generate_config")
self.update_if_not_none("top_k", "generate_config")
self.update_if_not_none("top_p", "generate_config")
self.update_if_not_none("repetition_penalty", "generate_config")
# merging tokenize args
self.update_if_not_none("max_input_length", "tokenize_config", "max_length")
self.update_if_not_none(
"tokenize_return_tensors", "tokenize_config", "return_tensors"
)

View File

@ -15,6 +15,7 @@ from typing import Optional
from transformers import TrainingArguments
from nn4k.executor import NNAdapterModelArgs
from nn4k.executor.base import NNInferenceArgs
@dataclass
@ -52,6 +53,13 @@ class HFModelArgs(NNAdapterModelArgs):
"help": " Load the model weights from a TensorFlow checkpoint save file, default to False"
},
)
padding_side: Optional[str] = field(
default=None,
metadata={
"help": "Padding side of the tokenizer when padding batch inputs",
"choices": [None, "left", "right"],
},
)
def __post_init__(self):
super().__post_init__()
@ -105,3 +113,45 @@ class HFSftArgs(HFModelArgs, TrainingArguments):
print(
f"a eval_dataset_path is set but do_eval flag is not set, automatically set do_eval to True"
)
@dataclass
class HFInferArgs(NNInferenceArgs):
delete_heading_new_lines: bool = field(
default=False,
metadata={
"help": "an additional question mark or new line marks sometimes occurs at the beginning of output."
"Try to get rid of these marks by setting this parameter to True. Different model may have different "
"behavior, please check the result carefully."
},
)
tokenize_config: dict = field(
default_factory=lambda: {
"add_special_tokens": False,
"padding": False,
"truncation": False,
},
metadata={
"help": "padding: https://huggingface.co/docs/transformers/pad_truncation#padding-and-truncation"
},
)
decode_config: dict = field(
default_factory=lambda: {
"skip_special_tokens": True,
"clean_up_tokenization_spaces": True,
},
metadata={
"help": "check https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__"
},
)
def __post_init__(self):
super().__post_init__()
# HF specific map
self.update_if_not_none(
"max_output_length", "generate_config", "max_new_tokens"
)
self.update_if_not_none("max_input_length", "tokenize_config", "max_length")

View File

@ -18,8 +18,9 @@ from torch.utils.data import Dataset
from transformers import AutoConfig, AutoTokenizer, Trainer
from nn4k.executor import LLMExecutor
from .hf_args import HFSftArgs, HFModelArgs
from .hf_args import HFInferArgs, HFSftArgs, HFModelArgs
from nn4k.executor.huggingface.nn_hf_trainer import NNHFTrainer
from nn4k.utils.args_utils import ArgsUtils
class HFLLMExecutor(LLMExecutor):
@ -188,57 +189,85 @@ class HFLLMExecutor(LLMExecutor):
if self.model_mode == mode and self._model is not None:
return
args = args or self._init_args
from transformers import HfArgumentParser
from nn4k.executor.huggingface import HFModelArgs
parser = HfArgumentParser(HFModelArgs)
hf_model_args: HFModelArgs
hf_model_args, *_ = parser.parse_dict(args, allow_extra_keys=True)
self.model_mode = mode
self._tokenizer = self._hf_tokenizer_loader(hf_model_args)
self._model = self._hf_model_loader(
hf_model_args, mode, hf_model_args.nn_device
args=hf_model_args, mode=mode, device=hf_model_args.nn_device
)
if self.tokenizer.eos_token_id is None:
self.tokenizer.eos_token_id = self.model.config.eos_token_id
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
if hf_model_args.padding_side is not None:
self.tokenizer.padding_side = hf_model_args.padding_side
def inference(self, inputs, **kwargs):
infer_args = ArgsUtils.update_args(self.init_args, kwargs)
from transformers import HfArgumentParser
parser = HfArgumentParser(HFInferArgs)
hf_infer_args: HFInferArgs
hf_infer_args, *_ = parser.parse_dict(infer_args, allow_extra_keys=True)
def inference(
self,
data,
max_input_length: int = 1024,
max_output_length: int = 1024,
do_sample: bool = False,
**kwargs,
):
model = self.model
tokenizer = self.tokenizer
input_ids = tokenizer(
data,
padding=True,
return_token_type_ids=False,
return_tensors="pt",
truncation=True,
max_length=max_input_length,
inputs,
**hf_infer_args.tokenize_config,
).to(model.device)
if hf_infer_args.stop_sequence is not None:
stop_sequence = hf_infer_args.stop_sequence
stop_sequence_ids = self.tokenizer.encode(
stop_sequence, add_special_tokens=False
)
if len(stop_sequence_ids) > 1:
print( # TODO: use logger instead
"Warning: Stopping on a multiple token sequence is not yet supported on transformers. "
"The first token of the stop sequence will be used as the stop sequence string in the interim."
)
hf_infer_args.generate_config["eos_token_id"] = stop_sequence_ids[0]
output_ids = model.generate(
**input_ids,
max_new_tokens=max_output_length,
do_sample=do_sample,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
**kwargs,
**hf_infer_args.generate_config,
)
outputs = [
tokenizer.decode(
output_id[len(input_ids["input_ids"][idx]) :], skip_special_tokens=True
output_texts = []
for idx, output_id in enumerate(output_ids):
if not hf_infer_args.return_input_text:
output_id = output_id[len(input_ids["input_ids"][idx]) :]
output_text = self.tokenizer.decode(
output_id, **hf_infer_args.decode_config
)
for idx, output_id in enumerate(output_ids)
]
return outputs
if (
not hf_infer_args.return_input_text
and hf_infer_args.delete_heading_new_lines
):
import re
match = re.search("(\\n)+", output_text)
if match is not None:
start_index = match.end()
if start_index < len(output_text) - 1:
output_text = output_text[start_index:]
output_texts.append(output_text)
return output_texts
@abstractmethod
def _hf_model_loader(

View File

@ -0,0 +1,14 @@
{
// -- base model info
"nn_model_path": "/Path/to/model_dir", // local model path
"nn_invoker": "nn4k.invoker.base.LLMInvoker", // invoker to use
"nn_executor": "nn4k.executor.huggingface.hf_decode_only_executor.HFDecodeOnlyExecutor", // executor to use
// the following are optional
"adapter_name": "adapter_name", // adapter_name must be given to enable adapter; with adapter_path along has no effect!
"adapter_path": "/path/to/adapter",
"generate_config":{
"temperature": 0.2,
"do_sample": true
}
}

View File

@ -13,9 +13,22 @@ from nn4k.invoker.base import NNInvoker
def main():
NNInvoker.from_config("local_sft.json5").local_sft()
# Inference example, not implemented yet.
# NNInvoker.from_config("inferece_args.json").local_inference("你是谁")
# example for local sft
# NNInvoker.from_config("local_sft.json5").local_sft()
# example for local inference
invoker = NNInvoker.from_config("local_infer.json5")
answer = invoker.local_inference(
"What could LLM do for human?",
tokenize_config={"padding": True},
delete_heading_new_lines=True,
)
# doing so to avoid load model everytime. You could hold a invoker, which has alreday load the model at the first time.
answer2 = invoker.local_inference(
"What could LLM do for a programmer",
tokenize_config={"padding": True},
delete_heading_new_lines=True,
)
if __name__ == "__main__":

View File

@ -53,7 +53,7 @@ class HFEmbeddingExecutor(LLMExecutor):
)
self._model = model
def inference(self, data, args=None, **kwargs):
def inference(self, inputs, **kwargs):
model = self.model
embeddings = model.encode(data)
embeddings = model.encode(inputs)
return embeddings

View File

@ -14,7 +14,10 @@ from abc import ABC, abstractmethod
from enum import Enum
from typing import Union
from nn4k.executor import NNExecutor
from nn4k.utils.class_importing import dynamic_import_class
from nn4k.executor import LLMExecutor
from nn4k.utils.args_utils import ArgsUtils
class SubmitMode(Enum):
@ -36,6 +39,7 @@ class NNInvoker(ABC):
def __init__(self, init_args: dict, **kwargs):
self._init_args = init_args
self._kwargs = kwargs
self.inference_warmed_up = False
@property
def init_args(self):
@ -71,7 +75,6 @@ class NNInvoker(ABC):
from nn4k.consts import NN_INVOKER_KEY, NN_INVOKER_TEXT
from nn4k.utils.config_parsing import preprocess_config
from nn4k.utils.config_parsing import get_string_field
from nn4k.utils.class_importing import dynamic_import_class
nn_config = preprocess_config(nn_config)
nn_invoker = nn_config.get(NN_INVOKER_KEY)
@ -141,9 +144,7 @@ class LLMInvoker(NNInvoker):
raise NotImplementedError(f"{self.__class__.__name__} does not support SFT.")
def local_sft(self, args: dict = None):
sft_args = copy.deepcopy(self.init_args)
args = args or {}
sft_args.update(args)
sft_args = ArgsUtils.update_args(self.init_args, args)
from nn4k.executor import LLMExecutor
@ -161,29 +162,46 @@ class LLMInvoker(NNInvoker):
"""
Implement local inference for local invoker.
"""
return self._nn_executor.inference(data, **kwargs)
args = ArgsUtils.handle_dict_config(kwargs)
if not self.inference_warmed_up:
print(
"warming up the model for inference, only happen for the first time..."
)
self.warmup_local_model()
self.inference_warmed_up = True
print("inference model is warmed up")
return self._nn_executor.inference(inputs=data, **args)
def warmup_local_model(self):
"""
Implement local model warming up logic for local invoker.
"""
nn_config = self.init_args
from nn4k.nnhub import NNHub
from nn4k.consts import NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
from nn4k.utils.config_parsing import get_string_field
from nn4k.utils.class_importing import dynamic_import_class
from transformers import HfArgumentParser
from nn4k.executor import NNModelArgs
nn_executor = self.init_args.get(NN_EXECUTOR_KEY)
parser = HfArgumentParser(NNModelArgs)
model_args: NNModelArgs
model_args, *_ = parser.parse_dict(self.init_args, allow_extra_keys=True)
from nn4k.consts import NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT
nn_executor = nn_config.get(NN_EXECUTOR_KEY)
if nn_executor is not None:
nn_executor = get_string_field(
self.init_args, NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT
)
from nn4k.executor import NNExecutor
executor_class = dynamic_import_class(nn_executor, NN_EXECUTOR_TEXT)
if not issubclass(executor_class, NNExecutor):
message = "%r is not an %s class" % (nn_executor, NN_EXECUTOR_TEXT)
raise RuntimeError(message)
executor = executor_class.from_config(self.init_args)
executor = executor_class.from_config(nn_config)
else:
nn_name = get_string_field(self.init_args, NN_NAME_KEY, NN_NAME_TEXT)
nn_version = self.init_args.get(NN_VERSION_KEY)
@ -193,13 +211,18 @@ class LLMInvoker(NNInvoker):
)
hub = NNHub.get_instance()
executor = hub.get_model_executor(nn_name, nn_version)
if executor is None:
message = "model %r version %r " % (nn_name, nn_version)
message += "is not found in the model hub"
raise RuntimeError(message)
self._nn_executor: NNExecutor = executor
if executor is None:
message = "model %r version %r " % (
model_args.nn_name,
model_args.nn_version,
)
message += "is not found in the model hub, you should provide a valid nn_executor class path"
raise RuntimeError(message)
self._nn_executor: LLMExecutor = executor
self._nn_executor.load_model(mode="inference")
self._nn_executor.warmup_inference()
self.inference_warmed_up = True
@classmethod
def from_config(cls, nn_config: dict) -> "LLMInvoker":

View File

@ -0,0 +1,63 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
class ArgsUtils:
CONFIG_FILE_KEY = "config_file"
@staticmethod
def update_args(base_args: dict, new_args: dict) -> dict:
"""
update an existing args with a new set of args
:param base_args: args to get updated. Will be copied before get updated.
:param new_args: args to update the base args.
:rtype: dict
"""
import copy
copy_base_args = copy.deepcopy(base_args)
new_args = new_args or {}
copy_base_args.update(new_args)
return copy_base_args
@staticmethod
def handle_dict_config(kwargs: dict) -> dict:
if "config_file" in kwargs:
configs = ArgsUtils.load_config_dict_from_file(kwargs.get("config_file"))
else:
configs = kwargs
return configs
@staticmethod
def load_config_dict_from_file(file_path: str) -> dict:
from pathlib import Path
if file_path.endswith(".json"):
import json
with open(Path(file_path), "r", encoding="utf-8") as open_json_file:
data = json.load(open_json_file)
nn_config = data
return nn_config
if file_path.endswith(".json5"):
import json5
with open(Path(file_path), "r", encoding="utf-8") as open_json5_file:
data = json5.load(open_json5_file)
nn_config = data
return nn_config
from nn4k.utils.io.file_utils import FileUtils
raise ValueError(
f"Config file with extension type {FileUtils.get_extension(file_path)} is not supported."
f"use json or json5 instead."
)

View File

@ -32,26 +32,9 @@ def preprocess_config(nn_config: Union[str, dict]) -> dict:
if isinstance(nn_config, dict):
return nn_config
elif isinstance(nn_config, str):
if nn_config.endswith(".json"):
import json
from nn4k.utils.args_utils import ArgsUtils
with open(Path(nn_config), "r", encoding="utf-8") as open_json_file:
data = json.load(open_json_file)
nn_config = data
return nn_config
if nn_config.endswith(".json5"):
import json5
with open(Path(nn_config), "r", encoding="utf-8") as open_json5_file:
data = json5.load(open_json5_file)
nn_config = data
return nn_config
from nn4k.utils.io.file_utils import FileUtils
raise ValueError(
f"Config file with extension type {FileUtils.get_extension(nn_config)} is not supported."
f"use json or json5 instead."
)
return ArgsUtils.load_config_dict_from_file(nn_config)
else:
raise ValueError(
f"nn_config could be dict or str, {type(nn_config)} is not yet supported."

View File

@ -22,7 +22,7 @@ class StubExecutor(LLMExecutor):
def warmup_inference(self, args=None, **kwargs):
pass
def inference(self, data, args=None, **kwargs):
def inference(self, inputs, args=None, **kwargs):
pass
@classmethod

View File

@ -37,7 +37,7 @@ class StubExecutor(NNExecutor):
def warmup_inference(self, args=None, **kwargs):
self.warmup_inference_called = True
def inference(self, data, args=None, **kwargs):
def inference(self, inputs, args=None, **kwargs):
return self.inference_result
@classmethod

View File

@ -89,7 +89,10 @@ class TestBaseInvoker(unittest.TestCase):
def testLocalLLMInvokerWithCustomExecutor(self):
from nn4k.invoker import LLMInvoker
nn_config = {"nn_executor": "invoker_test_stub.StubExecutor"}
nn_config = {
"nn_model_path": "/path/to/model",
"nn_executor": "invoker_test_stub.StubExecutor",
}
invoker = LLMInvoker.from_config(nn_config)
self.assertTrue(isinstance(invoker, LLMInvoker))
self.assertEqual(invoker.init_args, nn_config)