feat(nn4k): support huggingface decode only model local inference (#128)

Co-authored-by: xionghuaidong <huaidong.xhd@antgroup.com>
This commit is contained in:
chenbin11200 2024-03-08 13:54:15 +08:00 committed by FishJoy
parent ee5089ef54
commit 210cf7be9a
12 changed files with 351 additions and 73 deletions

View File

@ -71,7 +71,7 @@ class NNExecutor(ABC):
f"{self.__class__.__name__} does not support batch inference." f"{self.__class__.__name__} does not support batch inference."
) )
def inference(self, data, args=None, **kwargs): def inference(self, inputs, **kwargs):
""" """
The entry point of inference. Usually for local invokers or model services. The entry point of inference. Usually for local invokers or model services.
""" """
@ -248,3 +248,103 @@ class NNAdapterModelArgs(NNModelArgs):
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
@dataclass
class NNInferenceArgs:
max_input_length: Optional[int] = field(
default=None,
metadata={
"help": "Controls the maximum length to use by one of the truncation/padding parameters. "
"In HuggingFace executors, known as max_length in tokenize callable function config."
},
)
max_output_length: Optional[int] = field(
default=None,
metadata={
"help": "The maximum numbers of tokens to generate. In HuggingFace executors, this arg will be tread as "
"max_new_tokens."
},
)
return_input_text: Optional[bool] = field(
default=False,
metadata={"help": "Whether return input texts together with output texts."},
)
stop_sequence: Optional[str] = field(
default=None,
metadata={
"help": "Generation will stop when stop sequence encountered in the output."
},
)
do_sample: bool = field(
default=False,
metadata={
"help": "If false, generation will be in greedy search mode, otherwise will sampling the probable tokens."
},
)
temperature: float = field(
default=1.0,
metadata={"help": "The creativity and diversity of the text generated."},
)
top_k: Optional[int] = field(
default=50,
metadata={
"help": "In nucleus sampling, model will only sampling the tokens with the highest top_p(percentage) "
"probability"
},
)
top_p: Optional[float] = field(
default=1.0,
metadata={
"help": "In nucleus sampling, model will only sampling the tokens with the highest top_k(count) probability"
},
)
repetition_penalty: Optional[float] = field(
default=1.0,
metadata={"help": "By default 1.0 means no penalty."},
)
generate_config: dict = field(
default_factory=lambda: {},
metadata={"help": "Config dict that will be use in model generation"},
)
tokenize_return_tensors: str = field(
default="pt",
metadata={
"help": "Tokenizer return type, will be merged into tokenize_config and pass into tokenize function"
},
)
tokenize_config: dict = field(
default_factory=lambda: {},
metadata={
"help": "Tokenize function config, will be pass into tokenize function"
},
)
decode_config: dict = field(
default_factory=lambda: {},
metadata={"help": "Configs to be pass into tokenizer.decode fucntion"},
)
def update_if_not_none(self, from_key, to_dict, to_key=None):
to_key = to_key or from_key
from_value = self.__getattribute__(from_key)
value_in_to_dict = self.__getattribute__(to_dict).get(to_key, None)
if value_in_to_dict is None and from_value is not None:
self.__getattribute__(to_dict)[to_key] = from_value
def __post_init__(self):
# merging generation args
self.update_if_not_none("max_output_length", "generate_config")
self.update_if_not_none("do_sample", "generate_config")
self.update_if_not_none("temperature", "generate_config")
self.update_if_not_none("top_k", "generate_config")
self.update_if_not_none("top_p", "generate_config")
self.update_if_not_none("repetition_penalty", "generate_config")
# merging tokenize args
self.update_if_not_none("max_input_length", "tokenize_config", "max_length")
self.update_if_not_none(
"tokenize_return_tensors", "tokenize_config", "return_tensors"
)

View File

@ -15,6 +15,7 @@ from typing import Optional
from transformers import TrainingArguments from transformers import TrainingArguments
from nn4k.executor import NNAdapterModelArgs from nn4k.executor import NNAdapterModelArgs
from nn4k.executor.base import NNInferenceArgs
@dataclass @dataclass
@ -52,6 +53,13 @@ class HFModelArgs(NNAdapterModelArgs):
"help": " Load the model weights from a TensorFlow checkpoint save file, default to False" "help": " Load the model weights from a TensorFlow checkpoint save file, default to False"
}, },
) )
padding_side: Optional[str] = field(
default=None,
metadata={
"help": "Padding side of the tokenizer when padding batch inputs",
"choices": [None, "left", "right"],
},
)
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
@ -105,3 +113,45 @@ class HFSftArgs(HFModelArgs, TrainingArguments):
print( print(
f"a eval_dataset_path is set but do_eval flag is not set, automatically set do_eval to True" f"a eval_dataset_path is set but do_eval flag is not set, automatically set do_eval to True"
) )
@dataclass
class HFInferArgs(NNInferenceArgs):
delete_heading_new_lines: bool = field(
default=False,
metadata={
"help": "an additional question mark or new line marks sometimes occurs at the beginning of output."
"Try to get rid of these marks by setting this parameter to True. Different model may have different "
"behavior, please check the result carefully."
},
)
tokenize_config: dict = field(
default_factory=lambda: {
"add_special_tokens": False,
"padding": False,
"truncation": False,
},
metadata={
"help": "padding: https://huggingface.co/docs/transformers/pad_truncation#padding-and-truncation"
},
)
decode_config: dict = field(
default_factory=lambda: {
"skip_special_tokens": True,
"clean_up_tokenization_spaces": True,
},
metadata={
"help": "check https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__"
},
)
def __post_init__(self):
super().__post_init__()
# HF specific map
self.update_if_not_none(
"max_output_length", "generate_config", "max_new_tokens"
)
self.update_if_not_none("max_input_length", "tokenize_config", "max_length")

View File

@ -18,8 +18,9 @@ from torch.utils.data import Dataset
from transformers import AutoConfig, AutoTokenizer, Trainer from transformers import AutoConfig, AutoTokenizer, Trainer
from nn4k.executor import LLMExecutor from nn4k.executor import LLMExecutor
from .hf_args import HFSftArgs, HFModelArgs from .hf_args import HFInferArgs, HFSftArgs, HFModelArgs
from nn4k.executor.huggingface.nn_hf_trainer import NNHFTrainer from nn4k.executor.huggingface.nn_hf_trainer import NNHFTrainer
from nn4k.utils.args_utils import ArgsUtils
class HFLLMExecutor(LLMExecutor): class HFLLMExecutor(LLMExecutor):
@ -188,57 +189,85 @@ class HFLLMExecutor(LLMExecutor):
if self.model_mode == mode and self._model is not None: if self.model_mode == mode and self._model is not None:
return return
args = args or self._init_args
from transformers import HfArgumentParser from transformers import HfArgumentParser
from nn4k.executor.huggingface import HFModelArgs from nn4k.executor.huggingface import HFModelArgs
parser = HfArgumentParser(HFModelArgs) parser = HfArgumentParser(HFModelArgs)
hf_model_args: HFModelArgs
hf_model_args, *_ = parser.parse_dict(args, allow_extra_keys=True) hf_model_args, *_ = parser.parse_dict(args, allow_extra_keys=True)
self.model_mode = mode self.model_mode = mode
self._tokenizer = self._hf_tokenizer_loader(hf_model_args) self._tokenizer = self._hf_tokenizer_loader(hf_model_args)
self._model = self._hf_model_loader( self._model = self._hf_model_loader(
hf_model_args, mode, hf_model_args.nn_device args=hf_model_args, mode=mode, device=hf_model_args.nn_device
) )
if self.tokenizer.eos_token_id is None: if self.tokenizer.eos_token_id is None:
self.tokenizer.eos_token_id = self.model.config.eos_token_id self.tokenizer.eos_token_id = self.model.config.eos_token_id
if self.tokenizer.pad_token_id is None: if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
if hf_model_args.padding_side is not None:
self.tokenizer.padding_side = hf_model_args.padding_side
def inference(self, inputs, **kwargs):
infer_args = ArgsUtils.update_args(self.init_args, kwargs)
from transformers import HfArgumentParser
parser = HfArgumentParser(HFInferArgs)
hf_infer_args: HFInferArgs
hf_infer_args, *_ = parser.parse_dict(infer_args, allow_extra_keys=True)
def inference(
self,
data,
max_input_length: int = 1024,
max_output_length: int = 1024,
do_sample: bool = False,
**kwargs,
):
model = self.model model = self.model
tokenizer = self.tokenizer tokenizer = self.tokenizer
input_ids = tokenizer( input_ids = tokenizer(
data, inputs,
padding=True, **hf_infer_args.tokenize_config,
return_token_type_ids=False,
return_tensors="pt",
truncation=True,
max_length=max_input_length,
).to(model.device) ).to(model.device)
if hf_infer_args.stop_sequence is not None:
stop_sequence = hf_infer_args.stop_sequence
stop_sequence_ids = self.tokenizer.encode(
stop_sequence, add_special_tokens=False
)
if len(stop_sequence_ids) > 1:
print( # TODO: use logger instead
"Warning: Stopping on a multiple token sequence is not yet supported on transformers. "
"The first token of the stop sequence will be used as the stop sequence string in the interim."
)
hf_infer_args.generate_config["eos_token_id"] = stop_sequence_ids[0]
output_ids = model.generate( output_ids = model.generate(
**input_ids, **input_ids,
max_new_tokens=max_output_length, **hf_infer_args.generate_config,
do_sample=do_sample,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
**kwargs,
) )
outputs = [ output_texts = []
tokenizer.decode( for idx, output_id in enumerate(output_ids):
output_id[len(input_ids["input_ids"][idx]) :], skip_special_tokens=True if not hf_infer_args.return_input_text:
output_id = output_id[len(input_ids["input_ids"][idx]) :]
output_text = self.tokenizer.decode(
output_id, **hf_infer_args.decode_config
) )
for idx, output_id in enumerate(output_ids)
] if (
return outputs not hf_infer_args.return_input_text
and hf_infer_args.delete_heading_new_lines
):
import re
match = re.search("(\\n)+", output_text)
if match is not None:
start_index = match.end()
if start_index < len(output_text) - 1:
output_text = output_text[start_index:]
output_texts.append(output_text)
return output_texts
@abstractmethod @abstractmethod
def _hf_model_loader( def _hf_model_loader(

View File

@ -0,0 +1,14 @@
{
// -- base model info
"nn_model_path": "/Path/to/model_dir", // local model path
"nn_invoker": "nn4k.invoker.base.LLMInvoker", // invoker to use
"nn_executor": "nn4k.executor.huggingface.hf_decode_only_executor.HFDecodeOnlyExecutor", // executor to use
// the following are optional
"adapter_name": "adapter_name", // adapter_name must be given to enable adapter; with adapter_path along has no effect!
"adapter_path": "/path/to/adapter",
"generate_config":{
"temperature": 0.2,
"do_sample": true
}
}

View File

@ -13,9 +13,22 @@ from nn4k.invoker.base import NNInvoker
def main(): def main():
NNInvoker.from_config("local_sft.json5").local_sft() # example for local sft
# Inference example, not implemented yet. # NNInvoker.from_config("local_sft.json5").local_sft()
# NNInvoker.from_config("inferece_args.json").local_inference("你是谁")
# example for local inference
invoker = NNInvoker.from_config("local_infer.json5")
answer = invoker.local_inference(
"What could LLM do for human?",
tokenize_config={"padding": True},
delete_heading_new_lines=True,
)
# doing so to avoid load model everytime. You could hold a invoker, which has alreday load the model at the first time.
answer2 = invoker.local_inference(
"What could LLM do for a programmer",
tokenize_config={"padding": True},
delete_heading_new_lines=True,
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -53,7 +53,7 @@ class HFEmbeddingExecutor(LLMExecutor):
) )
self._model = model self._model = model
def inference(self, data, args=None, **kwargs): def inference(self, inputs, **kwargs):
model = self.model model = self.model
embeddings = model.encode(data) embeddings = model.encode(inputs)
return embeddings return embeddings

View File

@ -14,7 +14,10 @@ from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from typing import Union from typing import Union
from nn4k.executor import NNExecutor from nn4k.utils.class_importing import dynamic_import_class
from nn4k.executor import LLMExecutor
from nn4k.utils.args_utils import ArgsUtils
class SubmitMode(Enum): class SubmitMode(Enum):
@ -36,6 +39,7 @@ class NNInvoker(ABC):
def __init__(self, init_args: dict, **kwargs): def __init__(self, init_args: dict, **kwargs):
self._init_args = init_args self._init_args = init_args
self._kwargs = kwargs self._kwargs = kwargs
self.inference_warmed_up = False
@property @property
def init_args(self): def init_args(self):
@ -71,7 +75,6 @@ class NNInvoker(ABC):
from nn4k.consts import NN_INVOKER_KEY, NN_INVOKER_TEXT from nn4k.consts import NN_INVOKER_KEY, NN_INVOKER_TEXT
from nn4k.utils.config_parsing import preprocess_config from nn4k.utils.config_parsing import preprocess_config
from nn4k.utils.config_parsing import get_string_field from nn4k.utils.config_parsing import get_string_field
from nn4k.utils.class_importing import dynamic_import_class
nn_config = preprocess_config(nn_config) nn_config = preprocess_config(nn_config)
nn_invoker = nn_config.get(NN_INVOKER_KEY) nn_invoker = nn_config.get(NN_INVOKER_KEY)
@ -141,9 +144,7 @@ class LLMInvoker(NNInvoker):
raise NotImplementedError(f"{self.__class__.__name__} does not support SFT.") raise NotImplementedError(f"{self.__class__.__name__} does not support SFT.")
def local_sft(self, args: dict = None): def local_sft(self, args: dict = None):
sft_args = copy.deepcopy(self.init_args) sft_args = ArgsUtils.update_args(self.init_args, args)
args = args or {}
sft_args.update(args)
from nn4k.executor import LLMExecutor from nn4k.executor import LLMExecutor
@ -161,29 +162,46 @@ class LLMInvoker(NNInvoker):
""" """
Implement local inference for local invoker. Implement local inference for local invoker.
""" """
return self._nn_executor.inference(data, **kwargs) args = ArgsUtils.handle_dict_config(kwargs)
if not self.inference_warmed_up:
print(
"warming up the model for inference, only happen for the first time..."
)
self.warmup_local_model()
self.inference_warmed_up = True
print("inference model is warmed up")
return self._nn_executor.inference(inputs=data, **args)
def warmup_local_model(self): def warmup_local_model(self):
""" """
Implement local model warming up logic for local invoker. Implement local model warming up logic for local invoker.
""" """
nn_config = self.init_args
from nn4k.nnhub import NNHub from nn4k.nnhub import NNHub
from nn4k.consts import NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
from nn4k.utils.config_parsing import get_string_field from nn4k.utils.config_parsing import get_string_field
from nn4k.utils.class_importing import dynamic_import_class from transformers import HfArgumentParser
from nn4k.executor import NNModelArgs
nn_executor = self.init_args.get(NN_EXECUTOR_KEY) parser = HfArgumentParser(NNModelArgs)
model_args: NNModelArgs
model_args, *_ = parser.parse_dict(self.init_args, allow_extra_keys=True)
from nn4k.consts import NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT
nn_executor = nn_config.get(NN_EXECUTOR_KEY)
if nn_executor is not None: if nn_executor is not None:
nn_executor = get_string_field( from nn4k.executor import NNExecutor
self.init_args, NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT
)
executor_class = dynamic_import_class(nn_executor, NN_EXECUTOR_TEXT) executor_class = dynamic_import_class(nn_executor, NN_EXECUTOR_TEXT)
if not issubclass(executor_class, NNExecutor): if not issubclass(executor_class, NNExecutor):
message = "%r is not an %s class" % (nn_executor, NN_EXECUTOR_TEXT) message = "%r is not an %s class" % (nn_executor, NN_EXECUTOR_TEXT)
raise RuntimeError(message) raise RuntimeError(message)
executor = executor_class.from_config(self.init_args) executor = executor_class.from_config(nn_config)
else: else:
nn_name = get_string_field(self.init_args, NN_NAME_KEY, NN_NAME_TEXT) nn_name = get_string_field(self.init_args, NN_NAME_KEY, NN_NAME_TEXT)
nn_version = self.init_args.get(NN_VERSION_KEY) nn_version = self.init_args.get(NN_VERSION_KEY)
@ -193,13 +211,18 @@ class LLMInvoker(NNInvoker):
) )
hub = NNHub.get_instance() hub = NNHub.get_instance()
executor = hub.get_model_executor(nn_name, nn_version) executor = hub.get_model_executor(nn_name, nn_version)
if executor is None:
message = "model %r version %r " % (nn_name, nn_version) if executor is None:
message += "is not found in the model hub" message = "model %r version %r " % (
raise RuntimeError(message) model_args.nn_name,
self._nn_executor: NNExecutor = executor model_args.nn_version,
)
message += "is not found in the model hub, you should provide a valid nn_executor class path"
raise RuntimeError(message)
self._nn_executor: LLMExecutor = executor
self._nn_executor.load_model(mode="inference") self._nn_executor.load_model(mode="inference")
self._nn_executor.warmup_inference() self._nn_executor.warmup_inference()
self.inference_warmed_up = True
@classmethod @classmethod
def from_config(cls, nn_config: dict) -> "LLMInvoker": def from_config(cls, nn_config: dict) -> "LLMInvoker":

View File

@ -0,0 +1,63 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
class ArgsUtils:
CONFIG_FILE_KEY = "config_file"
@staticmethod
def update_args(base_args: dict, new_args: dict) -> dict:
"""
update an existing args with a new set of args
:param base_args: args to get updated. Will be copied before get updated.
:param new_args: args to update the base args.
:rtype: dict
"""
import copy
copy_base_args = copy.deepcopy(base_args)
new_args = new_args or {}
copy_base_args.update(new_args)
return copy_base_args
@staticmethod
def handle_dict_config(kwargs: dict) -> dict:
if "config_file" in kwargs:
configs = ArgsUtils.load_config_dict_from_file(kwargs.get("config_file"))
else:
configs = kwargs
return configs
@staticmethod
def load_config_dict_from_file(file_path: str) -> dict:
from pathlib import Path
if file_path.endswith(".json"):
import json
with open(Path(file_path), "r", encoding="utf-8") as open_json_file:
data = json.load(open_json_file)
nn_config = data
return nn_config
if file_path.endswith(".json5"):
import json5
with open(Path(file_path), "r", encoding="utf-8") as open_json5_file:
data = json5.load(open_json5_file)
nn_config = data
return nn_config
from nn4k.utils.io.file_utils import FileUtils
raise ValueError(
f"Config file with extension type {FileUtils.get_extension(file_path)} is not supported."
f"use json or json5 instead."
)

View File

@ -32,26 +32,9 @@ def preprocess_config(nn_config: Union[str, dict]) -> dict:
if isinstance(nn_config, dict): if isinstance(nn_config, dict):
return nn_config return nn_config
elif isinstance(nn_config, str): elif isinstance(nn_config, str):
if nn_config.endswith(".json"): from nn4k.utils.args_utils import ArgsUtils
import json
with open(Path(nn_config), "r", encoding="utf-8") as open_json_file: return ArgsUtils.load_config_dict_from_file(nn_config)
data = json.load(open_json_file)
nn_config = data
return nn_config
if nn_config.endswith(".json5"):
import json5
with open(Path(nn_config), "r", encoding="utf-8") as open_json5_file:
data = json5.load(open_json5_file)
nn_config = data
return nn_config
from nn4k.utils.io.file_utils import FileUtils
raise ValueError(
f"Config file with extension type {FileUtils.get_extension(nn_config)} is not supported."
f"use json or json5 instead."
)
else: else:
raise ValueError( raise ValueError(
f"nn_config could be dict or str, {type(nn_config)} is not yet supported." f"nn_config could be dict or str, {type(nn_config)} is not yet supported."

View File

@ -22,7 +22,7 @@ class StubExecutor(LLMExecutor):
def warmup_inference(self, args=None, **kwargs): def warmup_inference(self, args=None, **kwargs):
pass pass
def inference(self, data, args=None, **kwargs): def inference(self, inputs, args=None, **kwargs):
pass pass
@classmethod @classmethod

View File

@ -37,7 +37,7 @@ class StubExecutor(NNExecutor):
def warmup_inference(self, args=None, **kwargs): def warmup_inference(self, args=None, **kwargs):
self.warmup_inference_called = True self.warmup_inference_called = True
def inference(self, data, args=None, **kwargs): def inference(self, inputs, args=None, **kwargs):
return self.inference_result return self.inference_result
@classmethod @classmethod

View File

@ -89,7 +89,10 @@ class TestBaseInvoker(unittest.TestCase):
def testLocalLLMInvokerWithCustomExecutor(self): def testLocalLLMInvokerWithCustomExecutor(self):
from nn4k.invoker import LLMInvoker from nn4k.invoker import LLMInvoker
nn_config = {"nn_executor": "invoker_test_stub.StubExecutor"} nn_config = {
"nn_model_path": "/path/to/model",
"nn_executor": "invoker_test_stub.StubExecutor",
}
invoker = LLMInvoker.from_config(nn_config) invoker = LLMInvoker.from_config(nn_config)
self.assertTrue(isinstance(invoker, LLMInvoker)) self.assertTrue(isinstance(invoker, LLMInvoker))
self.assertEqual(invoker.init_args, nn_config) self.assertEqual(invoker.init_args, nn_config)