feat(nn4k): add huggingface decode only model local sft feature (#1) (#109)

Co-authored-by: xionghuaidong <huaidong.xhd@antgroup.com>
This commit is contained in:
chenbin11200 2024-02-22 14:08:21 +08:00 committed by GitHub
parent e95725d470
commit eb2590aada
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 1252 additions and 206 deletions

View File

@ -32,6 +32,7 @@ header:
- '**/*.schema'
- '**/*.rule'
- '**/*.json'
- '**/*.json5'
- '**/*.in'
- '**/META-INF/services/*'
- '**/*.conf'

View File

@ -22,8 +22,7 @@ NN_INVOKER_TEXT = "NN invoker"
NN_EXECUTOR_KEY = "nn_executor"
NN_EXECUTOR_TEXT = "NN executor"
NN_DEVICE_KEY = "device"
NN_TRUST_REMOTE_CODE_KEY = "trust_remote_code"
NN_DEVICE_KEY = "nn_device"
NN_OPENAI_MODEL_NAME_KEY = NN_NAME_KEY
NN_OPENAI_MODEL_NAME_TEXT = "openai model name"

View File

@ -9,4 +9,4 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from nn4k.executor.base import NNExecutor, LLMExecutor
from nn4k.executor.base import NNExecutor, LLMExecutor, NNModelArgs, NNAdapterModelArgs

View File

@ -10,7 +10,8 @@
# or implied.
from abc import ABC, abstractmethod
from typing import Union
from dataclasses import dataclass, field
from typing import Optional, Union
class NNExecutor(ABC):
@ -145,7 +146,23 @@ class NNExecutor(ABC):
raise RuntimeError(message)
class LLMExecutor(NNExecutor):
class LLMExecutor(NNExecutor, ABC):
"""
Base Executor for LLM.
"""
@classmethod
def from_config(cls, nn_config: Union[str, dict]) -> "LLMExecutor":
"""
Implement distribution logic for LLM, since we only support Huggingface Decode Only models for now,
it is directly point to HFDecodeOnlyExecutor. Will use the hub management functions later on.
"""
from nn4k.executor.huggingface.hf_decode_only_executor import (
HFDecodeOnlyExecutor,
)
return HFDecodeOnlyExecutor.from_config(nn_config)
def execute_sft(self, args=None, callbacks=None, **kwargs):
"""
The entry point of SFT execution in a certain pod.
@ -159,3 +176,75 @@ class LLMExecutor(NNExecutor):
raise NotImplementedError(
f"{self.__class__.__name__} does not support RL-Tuning."
)
@dataclass
class NNModelArgs:
"""
Base NN4K-supported model definition and load related args.
"""
nn_name: Optional[str] = field(
default=None,
metadata={"help": ("NN4K model name")},
)
nn_version: Optional[str] = field(
default="default",
metadata={"help": ("NN4K model version, by default is 'default'")},
)
nn_model_path: Optional[str] = field(
default=None,
metadata={
"help": (
"model path dir, could be delivered by user or get managed in Hub."
)
},
)
nn_device: Optional[str] = field(
default="auto", metadata={"help": ("device to use to load model")}
)
def __post_init__(self):
assert (
self.nn_name is not None or self.nn_model_path is not None
), "either nn_name or nn_model_path has to be provided"
@dataclass
class NNAdapterModelArgs(NNModelArgs):
"""
One should use this args dataclass to enable adapter models.
"""
adapter_name: str = field(
default=None,
metadata={
"help": "adapter name. Should be provided if you want to sft or load a adapter model."
},
)
adapter_version: str = field(
default="auto",
metadata={
"help": "adapter is designed to get managed by versions, by default is 'latest'"
},
)
adapter_type: str = field(
default="lora", metadata={"help": "adapter type, lora by default."}
)
adapter_path: str = field(
default=None,
metadata={
"help": "adapter weight and config path, could be delivered by user or get managed in Hub."
},
)
adapter_config: Optional[dict] = field(
default=None,
metadata={
"help": "Only necessary if you want to init a new adapter model and train from scratch or resume"
"from a checkpoint (in this case, should be the same as the previous adapter_config)."
"Values are the same as peft config init args."
},
)
def __post_init__(self):
super().__post_init__()

View File

@ -1,152 +0,0 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from typing import Union
from nn4k.executor import LLMExecutor
class HFLLMExecutor(LLMExecutor):
@classmethod
def from_config(cls, nn_config: dict) -> "HFLLMExecutor":
"""
Create an HFLLMExecutor instance from `nn_config`.
"""
executor = cls(nn_config)
return executor
def execute_sft(self, args=None, callbacks=None, **kwargs):
raise NotImplementedError(
f"{self.__class__.__name__} will support SFT in the next version."
)
def load_model(self, args=None, **kwargs):
import torch
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
from nn4k.consts import NN_DEVICE_KEY, NN_TRUST_REMOTE_CODE_KEY
from nn4k.utils.config_parsing import get_string_field
nn_config: dict = args or self.init_args
if self._model is None:
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
nn_version = nn_config.get(NN_VERSION_KEY)
if nn_version is not None:
nn_version = get_string_field(
nn_config, NN_VERSION_KEY, NN_VERSION_TEXT
)
model_path = nn_name
revision = nn_version
use_fast_tokenizer = False
device = nn_config.get(NN_DEVICE_KEY)
trust_remote_code = nn_config.get(NN_TRUST_REMOTE_CODE_KEY, False)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(
model_path,
use_fast=use_fast_tokenizer,
revision=revision,
trust_remote_code=trust_remote_code,
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
revision=revision,
trust_remote_code=trust_remote_code,
)
model.to(device)
self._tokenizer = tokenizer
self._model = model
def inference(
self,
data,
max_input_length: int = 1024,
max_output_length: int = 1024,
do_sample: bool = False,
**kwargs,
):
model = self.model
tokenizer = self.tokenizer
input_ids = tokenizer(
data,
padding=True,
return_token_type_ids=False,
return_tensors="pt",
truncation=True,
max_length=max_input_length,
).to(model.device)
output_ids = model.generate(
**input_ids,
max_new_tokens=max_output_length,
do_sample=do_sample,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
**kwargs,
)
outputs = [
tokenizer.decode(
output_id[len(input_ids["input_ids"][idx]) :], skip_special_tokens=True
)
for idx, output_id in enumerate(output_ids)
]
return outputs
class HFEmbeddingExecutor(LLMExecutor):
@classmethod
def from_config(cls, nn_config: dict) -> "HFEmbeddingExecutor":
"""
Create an HFEmbeddingExecutor instance from `nn_config`.
"""
executor = cls(nn_config)
return executor
def load_model(self, args=None, **kwargs):
import torch
from sentence_transformers import SentenceTransformer
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
from nn4k.consts import NN_DEVICE_KEY
from nn4k.utils.config_parsing import get_string_field
nn_config: dict = args or self.init_args
if self._model is None:
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
nn_version = nn_config.get(NN_VERSION_KEY)
if nn_version is not None:
nn_version = get_string_field(
nn_config, NN_VERSION_KEY, NN_VERSION_TEXT
)
model_path = nn_name
revision = nn_version
use_fast_tokenizer = False
device = nn_config.get(NN_DEVICE_KEY)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
#
# SentenceTransformer will support `revision` soon. See:
#
# https://github.com/UKPLab/sentence-transformers/pull/2419
#
model = SentenceTransformer(
model_path,
device=device,
)
self._model = model
def inference(self, data, args=None, **kwargs):
model = self.model
embeddings = model.encode(data)
return embeddings

View File

@ -0,0 +1,13 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from nn4k.executor.huggingface.base.hf_llm_executor import HFLLMExecutor
from nn4k.executor.huggingface.base.hf_args import HFModelArgs, HFSftArgs

View File

@ -0,0 +1,10 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

View File

@ -0,0 +1,107 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from dataclasses import dataclass, field
from typing import Optional
from transformers import TrainingArguments
from nn4k.executor import NNAdapterModelArgs
@dataclass
class HFModelArgs(NNAdapterModelArgs):
"""
Huggingface Model is designed to support adapter models such as lora, therefore should inherit from
NNAdapterModelArgs dataclass
"""
torch_dtype: Optional[str] = field(
default="auto",
metadata={
"help": (
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
"dtype will be automatically derived from the model's weights."
)
},
)
qlora_bits_and_bytes_config: Optional[dict] = field(
default=None,
metadata={
"help": "Quantization configs to load qlora, "
"same as :class:`transformers.utils.quantization_config.BitsAndBytesConfig`"
},
)
trust_remote_code: bool = field(
default=True,
metadata={
"help": "Whether or not to allow for custom models defined on the Hub in their own modeling files."
},
)
from_tf: bool = field(
default=False,
metadata={
"help": " Load the model weights from a TensorFlow checkpoint save file, default to False"
},
)
def __post_init__(self):
super().__post_init__()
# for hf models, if model path has higher priority then name, since you don't need to download the model(or
# from cache) again.
self.pretrained_model_name_or_path = self.nn_model_path or self.nn_name
@dataclass
class HFSftArgs(HFModelArgs, TrainingArguments):
"""
args to use for huggingface model sft task
"""
train_dataset_path: Optional[str] = field(
default=None,
metadata={
"help": "Should not be None. A file or dir path to train dataset, If a dir path, "
"all files inside should have the same file extension."
},
)
eval_dataset_path: Optional[str] = field(
default=None,
metadata={
"help": "A file or dir path to eval dataset. If a dir path, all files inside should have the same "
"file extension. If set, do_eval flag will be set to True"
},
)
max_input_length: int = field(
default=1024,
metadata={"help": "max length of input"},
)
resume_from_checkpoint: Optional[str] = field(
default=None,
metadata={
"help": "The path to a folder with a valid checkpoint for your model."
},
)
def __post_init__(self):
HFModelArgs.__post_init__(self)
TrainingArguments.__post_init__(self)
assert self.train_dataset_path is not None, "train_dataset_path must be set."
if self.train_dataset_path and not self.do_train:
self.do_train = True
print(
f"a train_dataset_path is set but do_train flag is not set, automatically set do_train to True"
)
if self.eval_dataset_path and not self.do_eval:
self.do_eval = True
print(
f"a eval_dataset_path is set but do_eval flag is not set, automatically set do_eval to True"
)

View File

@ -0,0 +1,328 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
import typing
from abc import abstractmethod
from typing import Optional, Union
from torch.utils.data import Dataset
from transformers import AutoConfig, AutoTokenizer, Trainer
from nn4k.executor import LLMExecutor
from .hf_args import HFSftArgs, HFModelArgs
from nn4k.executor.huggingface.nn_hf_trainer import NNHFTrainer
class HFLLMExecutor(LLMExecutor):
"""
Base Executor for huggingface models.
"""
def __init__(self, init_args: dict, **kwargs):
super().__init__(init_args=init_args, **kwargs)
# model_model could be either 'train' or 'inference' or model load
self.model_mode = None
@classmethod
def from_config(cls, nn_config: Union[dict]) -> "HFLLMExecutor":
"""
Create an HFLLMExecutor instance from `nn_config`.
"""
executor = cls(nn_config)
return executor
def execute_sft(self, args: dict = None, callbacks=None, **kwargs):
args = args or self.init_args
self.load_model(args=args, mode="train")
# parse args into HFSftArgs dataclass for more convenient features
from transformers import HfArgumentParser
parser = HfArgumentParser(HFSftArgs)
hf_sft_args: HFSftArgs
hf_sft_args, *_ = parser.parse_dict(args, allow_extra_keys=True)
# load checkpoint path if necessary.
resume_from_checkpoint_path = self._get_last_checkpoint(hf_sft_args)
# load and map dataset
train_dataset, eval_dataset = self._init_dataset(hf_sft_args)
# init trainer
trainer: Trainer = self._init_trainer(
train_dataset, eval_dataset, hf_sft_args, callbacks
)
# start training
train_result = trainer.train(resume_from_checkpoint=resume_from_checkpoint_path)
# save trained model after train complete
trainer.save_model(hf_sft_args.output_dir)
# save train metrics
train_metrics = train_result.metrics
train_metrics["train_samples_len"] = len(train_dataset)
trainer.log_metrics("train", train_metrics)
trainer.save_metrics("train", train_metrics)
trainer.save_state()
return self
def _get_last_checkpoint(self, sft_args: HFSftArgs) -> Optional[str]: # noqa
"""
try to find checkpoint in sft_args.output_dir.
If sft_args.resume_from_checkpoint in ['True', 'true', True, ''], try to return the checkpoint dir with the
largest checkpoint index. The largest checkpoint dir path will be returned.
If sft_args.resume_from_checkpoint in [None, 'False', 'false', False], means not necessary to resume from
checkpoint, None will be returned.
If sft_args.resume_from_checkpoint is the checkpoint subfolder dir name, the 'output_dir/resume_from_checkpoint'
path will be returned if exists. Be aware, if the dir does not exist, ValueError will be raised.
"""
output_dir_contains_file = (
os.path.isdir(sft_args.output_dir)
and len(os.listdir(sft_args.output_dir)) > 0
)
if sft_args.resume_from_checkpoint in ["True", "true", True, ""]:
resume_from_checkpoint_bool = True
if output_dir_contains_file:
from transformers.trainer_utils import get_last_checkpoint
resume_from_checkpoint_path = get_last_checkpoint(sft_args.output_dir)
else:
resume_from_checkpoint_path = None
assert (
resume_from_checkpoint_path is not None
), f"cannot find last checkpoint dir in {sft_args.output_dir}"
elif sft_args.resume_from_checkpoint in [None, "False", "false", False]:
resume_from_checkpoint_bool = False
resume_from_checkpoint_path = None
else:
resume_from_checkpoint_bool = True
resume_from_checkpoint_path = os.path.join(
sft_args.output_dir, sft_args.resume_from_checkpoint
)
assert os.path.isdir(
resume_from_checkpoint_path
), f"{resume_from_checkpoint_path} is not a dir."
if (
output_dir_contains_file
and not sft_args.overwrite_output_dir
and not resume_from_checkpoint_bool
):
raise ValueError(
f"Output_dir ({sft_args.output_dir}) is not empty. Maybe you mean --resume_from_checkpoint"
'="True" to resume a training or --overwrite_output_dir to overwrite output_dir.'
)
return resume_from_checkpoint_path
def map_fn(self, dataset, **kwargs):
"""
dataset map and template function. The default implement follows the BelleGroup/train_0.5M_CN format, means
'instruction', 'input' and 'output' are necessary. Since some other popular dataset like tatsu-lab/alpaca
provides these columns as well, it is also supported.
"""
args: HFSftArgs = kwargs.get("args", None)
instruction = dataset["instruction"]
input_text = dataset["input"]
output_text = dataset["output"]
bos_token = self.tokenizer.bos_token or ""
eos_token = self.tokenizer.eos_token
input_prompt = f"{bos_token}{instruction} {input_text}{eos_token}"
tokenized_full_prompt = self._tokenize_dataset(
input_prompt, args.max_input_length
)
return tokenized_full_prompt
def _init_dataset(
self, args: HFSftArgs
) -> typing.Tuple[Union[Dataset], Union[Dataset]]: # noqa
"""
init and map dataset, for train and eval
"""
with args.main_process_first(desc="initialize dataset"):
train_dataset = None
if args.train_dataset_path:
train_dataset = (
self._load_dataset(args.train_dataset_path, "train")
.shuffle()
.map(self.map_fn, fn_kwargs={"args": args})
)
eval_dataset = None
if args.eval_dataset_path:
eval_dataset = (
self._load_dataset(args.eval_dataset_path, "train")
.shuffle()
.map(self.map_fn, fn_kwargs={"args": args})
)
return train_dataset, eval_dataset
def _load_dataset(self, data_path, split="train"): # noqa
from nn4k.utils.io.dataset_utils import DatasetUtils
return DatasetUtils.auto_dataset(data_path, split)
def load_model(self, args: dict = None, mode=None, **kwargs):
"""
load model and tokenizer. If the model with the same mode is already loaded, will not load again.
"""
assert (
mode is not None
), f"mode should be either 'train' or 'inference' for HFLLMExecutor, {mode} is illegal."
if self.model_mode == mode and self._model is not None:
return
from transformers import HfArgumentParser
from nn4k.executor.huggingface import HFModelArgs
parser = HfArgumentParser(HFModelArgs)
hf_model_args, *_ = parser.parse_dict(args, allow_extra_keys=True)
self.model_mode = mode
self._tokenizer = self._hf_tokenizer_loader(hf_model_args)
self._model = self._hf_model_loader(
hf_model_args, mode, hf_model_args.nn_device
)
if self.tokenizer.eos_token_id is None:
self.tokenizer.eos_token_id = self.model.config.eos_token_id
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
def inference(
self,
data,
max_input_length: int = 1024,
max_output_length: int = 1024,
do_sample: bool = False,
**kwargs,
):
model = self.model
tokenizer = self.tokenizer
input_ids = tokenizer(
data,
padding=True,
return_token_type_ids=False,
return_tensors="pt",
truncation=True,
max_length=max_input_length,
).to(model.device)
output_ids = model.generate(
**input_ids,
max_new_tokens=max_output_length,
do_sample=do_sample,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
**kwargs,
)
outputs = [
tokenizer.decode(
output_id[len(input_ids["input_ids"][idx]) :], skip_special_tokens=True
)
for idx, output_id in enumerate(output_ids)
]
return outputs
@abstractmethod
def _hf_model_loader(
self,
args: HFModelArgs,
mode,
resume_from_checkpoint=False,
device=None,
**kwargs,
):
"""
load model into given device for hugging face.
"""
pass
def _hf_tokenizer_loader(self, args: HFModelArgs, **kwargs): # noqa
"""
hugging face tokenizer loader
"""
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=args.pretrained_model_name_or_path,
use_fast=False,
revision=args.nn_version,
trust_remote_code=args.trust_remote_code,
)
return tokenizer
def _hf_model_config_loader(self, args: HFModelArgs, **kwargs): # noqa
"""
hugging face model config loader
"""
model_config = AutoConfig.from_pretrained(
args.pretrained_model_name_or_path,
trust_remote_code=args.trust_remote_code,
**kwargs,
)
return model_config
def _init_trainer(
self, train_dataset, eval_dataset, sft_args: HFSftArgs, callbacks=None
) -> Trainer:
"""
hugging face model trainer initializer
"""
trainer = NNHFTrainer(
model=self.model,
args=sft_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=self.tokenizer,
data_collator=self._data_collator(),
callbacks=callbacks,
)
return trainer
@abstractmethod
def _data_collator(self, return_tensors="pt", **kwargs):
"""
data collator used in trainer
"""
pass
def _tokenize_dataset(self, prompt_text, max_length):
"""
tokenize dataset, by default will cut the input to the max_length
"""
tokenized_dataset = self.tokenizer(
prompt_text, truncation=True, max_length=max_length
)
input_ids = tokenized_dataset["input_ids"]
attention_mask = tokenized_dataset["attention_mask"]
# append eos token if necessary
# input length is shorter than max_length
if len(input_ids) < max_length:
if input_ids[-1] != self.tokenizer.eos_token_id:
input_ids.append(self.tokenizer.eos_token_id)
attention_mask.append(1)
else:
input_ids[max_length - 1] = self.tokenizer.eos_token_id
attention_mask[max_length - 1] = 1
# labels are copy of input_ids
tokenized_dataset["labels"] = tokenized_dataset["input_ids"].copy()
return tokenized_dataset

View File

@ -0,0 +1,10 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

View File

@ -0,0 +1,10 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

View File

@ -0,0 +1,24 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
# The scripts are tested by the following package installed
export WANDB_DISABLED=true
#Only if you have a cuda OOM, try this setting
#export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:32
pip install peft==0.5.0
pip install json5 # only necessary if you use json5 file as a config file
pip install numpy==1.23.1
pip install transformers==4.36.2
pip install accelerate>=0.21.0
pip install bitsandbytes>=0.39.0 #only necessary if you use qlora
#pip install xformers==0.0.23.post1 # only necessary if you want to accelerate loading model in memery efficient way

View File

@ -0,0 +1,41 @@
{
// -- base model and training args
"nn_model_path": "/model/path/to/Baichuan-7B-Chat", // local model path
"train_dataset_path": "/data/train/dataset.json", // train dataset path
"nn_invoker": "nn4k.invoker.base.LLMInvoker", // invoker to use
"nn_executor": "nn4k.executor.huggingface.hf_decode_only_executor.HFDecodeOnlyExecutor", // executor to use
"output_dir": "/path/to/output/dir", // trained model output dir
// ----- The following args are optional-----
// "eval_dataset_path": "/data/eval/dataset-eval.json", // eval dataset path, if you want to do eval
// -- adapter model info, only if you want to train lora adapter
// "adapter_name": "YouYou", //set it to a not "default" string value to enable adapter sft
// "adapter_type": "lora", // adapter type. Don't need it if adapter_name is not set
// "adapter_config": { // only necessary if adapter_name is set, same as peft LoraConfig args if tyep is 'lora'
// "r": 8,
// "lora_alpha": 16,
// "lora_dropout": 0.05,
// "bias": "none",
// "target_modules": ["W_pack", "o_proj"], // this is only an example for BaiChuan lora training
// "task_type": "CAUSAL_LM"
// },
// "qlora_bits_and_bytes_config": { // only necessary if you want to quantinize load model
// "load_in_4bit": true,
// "bnb_4bit_compute_dtype": "bfloat16",
// "bnb_4bit_use_double_quant": true,
// "bnb_4bit_quant_type": "nf4"
// }
//-- start training args
// "resume_from_checkpoint": "True", // only necessary if you want to resume training from checkpoint
"trust_remote_code": true,
"max_input_length": 256, // input max length. Inputs will be cut down to this length
//-- start: same as huggingface trainer args
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 1,
"lr_scheduler_type": "cosine", // adjust learning rate scheduler
"logging_steps": 20,
"save_steps": 10000,
"learning_rate": 4e-5,
"num_train_epochs": 1.0
//-- end: huggingface trainer args
}

View File

@ -0,0 +1,22 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from nn4k.invoker.base import NNInvoker
def main():
NNInvoker.from_config("local_sft.json5").local_sft()
# Inference example, not implemented yet.
# NNInvoker.from_config("inferece_args.json").local_inference("你是谁")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,117 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import torch
import transformers
from transformers import AutoModelForCausalLM
from nn4k.executor.huggingface import HFModelArgs
from nn4k.executor.huggingface import HFLLMExecutor
class HFDecodeOnlyExecutor(HFLLMExecutor):
"""
Huggingface decode only default executor, will use AutoModelForCausalLM to load model and
DataCollatorForSeq2Seq as a default data collator.
"""
def _hf_model_loader(
self,
args: HFModelArgs,
mode,
resume_from_checkpoint=False,
device=None,
**kwargs,
):
if device is None or "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
# load base model
model_config = self._hf_model_config_loader(args, **kwargs)
quant_config = None
if args.adapter_name and args.qlora_bits_and_bytes_config:
from transformers import BitsAndBytesConfig
quant_config = BitsAndBytesConfig(**args.qlora_bits_and_bytes_config)
model_load_args = dict(
pretrained_model_name_or_path=args.pretrained_model_name_or_path,
config=model_config,
quantization_config=quant_config,
revision=args.nn_version,
torch_dtype=args.torch_dtype,
from_tf=args.from_tf,
trust_remote_code=args.trust_remote_code,
)
model = AutoModelForCausalLM.from_pretrained(**model_load_args)
if quant_config:
from peft import prepare_model_for_kbit_training
model = prepare_model_for_kbit_training(model)
# load adapter model
if args.adapter_name:
# provide an adapter_path, means one can load an exist lora adapter and start a new train based on that.
if args.adapter_path and not resume_from_checkpoint:
from peft import PeftModel
# TODO NN4K: Notice: NN4K plan to provide a hub-managed adapter implementation in the near future.
model = PeftModel.from_pretrained(
model=model,
model_id=args.adapter_path,
adapter_name=args.adapter_name,
adapter_version=args.adapter_version,
is_trainable=(mode == "train"),
)
elif (
args.adapter_config
): # no adapter_path but adapter_config means train an adapter from scratch
from peft import get_peft_model
from peft import LoraConfig
if args.adapter_type in ["lora", "qlora"]:
peft_config = LoraConfig(**args.adapter_config)
else:
raise NotImplementedError(
f"adapter_type {args.adapter_type} is not supported in "
f"hf_decode_only_executor use lora or qlora instead"
)
model = get_peft_model(
model=model,
peft_config=peft_config,
adapter_name=args.adapter_name,
# TODO NN4K: NN4K plan to provide a hub-managed adapter implementation in the
# near future. adapter_version=args.adapter_version,
)
else:
raise ValueError(
"You should either provide a adapter_path to load an existing adapter without resume"
"a training, or provide a adapter_config to train a adapter from scratch or resume a "
"adapter training from checkpoint."
)
model.print_trainable_parameters()
if mode == "inference":
model.eval()
model.to(device)
return model
def _data_collator(self, return_tensors="pt", **kwargs):
return transformers.DataCollatorForSeq2Seq(
self.tokenizer,
pad_to_multiple_of=8,
return_tensors=return_tensors,
padding=True,
)

View File

@ -0,0 +1,59 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from nn4k.executor import LLMExecutor
class HFEmbeddingExecutor(LLMExecutor):
@classmethod
def from_config(cls, nn_config: dict) -> "HFEmbeddingExecutor":
"""
Create an HFEmbeddingExecutor instance from `nn_config`.
"""
executor = cls(nn_config)
return executor
def load_model(self, args=None, **kwargs):
import torch
from sentence_transformers import SentenceTransformer
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
from nn4k.consts import NN_DEVICE_KEY
from nn4k.utils.config_parsing import get_string_field
nn_config: dict = args or self.init_args
if self._model is None:
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
nn_version = nn_config.get(NN_VERSION_KEY)
if nn_version is not None:
nn_version = get_string_field(
nn_config, NN_VERSION_KEY, NN_VERSION_TEXT
)
model_path = nn_name
revision = nn_version
use_fast_tokenizer = False
device = nn_config.get(NN_DEVICE_KEY)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
#
# SentenceTransformer will support `revision` soon. See:
#
# https://github.com/UKPLab/sentence-transformers/pull/2419
#
model = SentenceTransformer(
model_path,
device=device,
)
self._model = model
def inference(self, data, args=None, **kwargs):
model = self.model
embeddings = model.encode(data)
return embeddings

View File

@ -0,0 +1,204 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
import safetensors
import torch
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from packaging import version
from peft import PeftModel
from transformers import PretrainedConfig, Trainer, __version__
from transformers.integrations import is_deepspeed_available
from transformers.modeling_utils import load_sharded_checkpoint
from transformers.trainer import logger
from transformers.utils import (
ADAPTER_SAFE_WEIGHTS_NAME,
ADAPTER_WEIGHTS_NAME,
CONFIG_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
is_accelerate_available,
is_peft_available,
is_sagemaker_mp_enabled,
)
if is_accelerate_available():
from accelerate import Accelerator, skip_first_batches
from accelerate import __version__ as accelerate_version
from accelerate.utils import (
DistributedDataParallelKwargs,
GradientAccumulationPlugin,
load_fsdp_model,
load_fsdp_optimizer,
save_fsdp_model,
save_fsdp_optimizer,
)
DATA_SAMPLERS = [RandomSampler]
if version.parse(accelerate_version) > version.parse("0.23.0"):
from accelerate.data_loader import SeedableRandomSampler
DATA_SAMPLERS += [SeedableRandomSampler]
if is_deepspeed_available():
from accelerate.utils import DeepSpeedSchedulerWrapper
class NNHFTrainer(Trainer):
"""
only trying to fix resume checkpoint for lora adapter, will be replaced by using Trainer when the bug is
fixed in huggingface trainer. The PR is offered: https://github.com/huggingface/transformers/pull/28547
"""
def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
# the following code only trying to fix resuming checkpoint for adapter model(Peft)
if model is None:
model = self.model
if not (is_peft_available() and isinstance(model, PeftModel)):
return super()._load_from_checkpoint(resume_from_checkpoint, model)
adapter_name_path = ""
if isinstance(model, PeftModel):
adapter_name_path = (
model.active_adapter
if model.active_adapter not in ["default", None]
else ""
)
config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)
adapter_weights_file = os.path.join(
resume_from_checkpoint, adapter_name_path, ADAPTER_WEIGHTS_NAME
)
adapter_safe_weights_file = os.path.join(
resume_from_checkpoint, adapter_name_path, ADAPTER_SAFE_WEIGHTS_NAME
)
weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME)
weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
safe_weights_index_file = os.path.join(
resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME
)
if not any(
os.path.isfile(f)
for f in [
weights_file,
safe_weights_file,
weights_index_file,
safe_weights_index_file,
os.path.join(adapter_weights_file),
os.path.join(adapter_safe_weights_file),
]
):
raise ValueError(
f"Can't find a valid checkpoint at {resume_from_checkpoint}"
)
logger.info(f"Loading model from {resume_from_checkpoint}.")
if os.path.isfile(config_file):
config = PretrainedConfig.from_json_file(config_file)
checkpoint_version = config.transformers_version
if checkpoint_version is not None and checkpoint_version != __version__:
logger.warning(
f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
f"Transformers but your current version is {__version__}. This is not recommended and could "
"yield to errors or unwanted behaviors."
)
if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file):
# If the model is on the GPU, it still works!
if is_sagemaker_mp_enabled():
if os.path.isfile(
os.path.join(resume_from_checkpoint, "user_content.pt")
):
# If the 'user_content.pt' file exists, load with the new smp api.
# Checkpoint must have been saved with the new smp api.
import smdistributed.modelparallel.torch as smp
smp.resume_from_checkpoint(
path=resume_from_checkpoint,
tag=WEIGHTS_NAME,
partial=False,
load_optimizer=False,
)
else:
# If the 'user_content.pt' file does NOT exist, load with the old smp api.
# Checkpoint must have been saved with the old smp api.
if hasattr(self.args, "fp16") and self.args.fp16 is True:
logger.warning(
"Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported."
)
state_dict = torch.load(weights_file, map_location="cpu")
# Required for smp to not auto-translate state_dict from hf to smp (is already smp).
state_dict["_smp_is_partial"] = False
load_result = model.load_state_dict(state_dict, strict=True)
# release memory
del state_dict
elif self.is_fsdp_enabled:
load_fsdp_model(
self.accelerator.state.fsdp_plugin,
self.accelerator,
model,
resume_from_checkpoint,
)
else:
# We load the model state dict on the CPU to avoid an OOM error.
if self.args.save_safetensors and os.path.isfile(safe_weights_file):
state_dict = safetensors.torch.load_file(
safe_weights_file, device="cpu"
)
else:
state_dict = torch.load(weights_file, map_location="cpu")
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
# which takes *args instead of **kwargs
load_result = model.load_state_dict(state_dict, False)
# release memory
del state_dict
self._issue_warnings_after_load(load_result)
# Load adapters following PR # 24096
elif is_peft_available() and isinstance(model, PeftModel):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
adapter_model_path = os.path.join(
resume_from_checkpoint, adapter_name_path
)
if os.path.exists(adapter_model_path):
model.load_adapter(
adapter_model_path, model.active_adapter, is_trainable=True
)
else:
logger.warning(
"The intermediate checkpoints of PEFT may not be saved correctly, "
f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. "
"Check some examples here: https://github.com/huggingface/peft/issues/96"
)
else:
logger.warning(
"Could not load adapter model, make sure to have `peft>=0.3.0` installed"
)
else:
# We load the sharded checkpoint
load_result = load_sharded_checkpoint(
model,
resume_from_checkpoint,
strict=is_sagemaker_mp_enabled(),
prefer_safe=self.args.save_safetensors,
)
if not is_sagemaker_mp_enabled():
self._issue_warnings_after_load(load_result)

View File

@ -9,6 +9,7 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import copy
from abc import ABC, abstractmethod
from enum import Enum
from typing import Union
@ -19,6 +20,7 @@ from nn4k.executor import NNExecutor
class SubmitMode(Enum):
K8s = "k8s"
Docker = "docker"
Local = "local"
class NNInvoker(ABC):
@ -138,6 +140,15 @@ class LLMInvoker(NNInvoker):
"""
raise NotImplementedError(f"{self.__class__.__name__} does not support SFT.")
def local_sft(self, args: dict = None):
sft_args = copy.deepcopy(self.init_args)
args = args or {}
sft_args.update(args)
from nn4k.executor import LLMExecutor
LLMExecutor.from_config(sft_args).execute_sft()
def submit_rl_tuning(self, submit_mode: SubmitMode = SubmitMode.K8s):
"""
Submit remote RL-Tuning execution.
@ -187,7 +198,7 @@ class LLMInvoker(NNInvoker):
message += "is not found in the model hub"
raise RuntimeError(message)
self._nn_executor: NNExecutor = executor
self._nn_executor.load_model()
self._nn_executor.load_model(mode="inference")
self._nn_executor.warmup_inference()
@classmethod

View File

@ -15,6 +15,7 @@ from abc import ABC, abstractmethod
from typing import Optional, Union, Tuple, Type
from nn4k.executor import NNExecutor
from nn4k.utils.class_importing import dynamic_import_class
class NNHub(ABC):
@ -146,8 +147,8 @@ class SimpleNNHub(NNHub):
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
from nn4k.consts import NN_LOCAL_HF_MODEL_CONFIG_FILE
from nn4k.consts import NN_LOCAL_SENTENCE_TRANSFORMERS_CONFIG_FILE
from nn4k.executor.hugging_face import HFLLMExecutor
from nn4k.executor.hugging_face import HFEmbeddingExecutor
from nn4k.executor.huggingface.hf_embedding_executor import HFEmbeddingExecutor
from nn4k.executor.huggingface.base.hf_llm_executor import HFLLMExecutor
from nn4k.utils.config_parsing import get_string_field
nn_executor = nn_config.get(NN_EXECUTOR_KEY)
@ -186,32 +187,19 @@ class SimpleNNHub(NNHub):
message += ", version: %r" % nn_version
raise RuntimeError(message)
def _add_local_executor(self, nn_config: dict):
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
from nn4k.utils.config_parsing import get_string_field
executor_class = self._get_local_executor_class(nn_config)
executor = executor_class.from_config(nn_config)
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
nn_version = nn_config.get(NN_VERSION_KEY)
if nn_version is not None:
nn_version = get_string_field(nn_config, NN_VERSION_KEY, NN_VERSION_TEXT)
self.publish(executor, nn_name, nn_version)
def get_invoker(self, nn_config: dict) -> Optional["NNInvoker"]:
from nn4k.invoker import LLMInvoker
from nn4k.invoker.openai import OpenAIInvoker
from nn4k.utils.invoker_checking import is_openai_invoker
from nn4k.utils.invoker_checking import is_local_invoker
if is_openai_invoker(nn_config):
invoker = OpenAIInvoker.from_config(nn_config)
return invoker
if is_local_invoker(nn_config):
# TODO NN4K: this will be replaced once we publish the SimpleHub solution. Now we only have openai invoker
# and LLMInvoker
# if is_local_invoker(nn_config):
else:
invoker = LLMInvoker.from_config(nn_config)
self._add_local_executor(nn_config)
return invoker
return None
# return None

View File

@ -9,10 +9,8 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
from typing import Any
from typing import Union
from pathlib import Path
from typing import Any, Union
def preprocess_config(nn_config: Union[str, dict]) -> dict:
@ -21,22 +19,45 @@ def preprocess_config(nn_config: Union[str, dict]) -> dict:
* If `nn_config` is already a dictionary, return it as is.
* If `nn_config` is a string, decode it as a JSON file.
* If `nn_config` is a string, decode it as a JSON or JSON5 file.
:param nn_config: config to be preprocessed
:type nn_config: str or dict
:return: `nn_config` or `nn_config` decoded as JSON
:return: `nn_config` or `nn_config` decoded as JSON or JSON5
:rtype: dict
:raises ValueError: if cannot decode config file specified by
`nn_config` as JSON
`nn_config` as JSON or JSON5
"""
try:
if isinstance(nn_config, str):
with open(nn_config, "r") as f:
nn_config = json.load(f)
if isinstance(nn_config, dict):
return nn_config
elif isinstance(nn_config, str):
if nn_config.endswith(".json"):
import json
with open(Path(nn_config), "r", encoding="utf-8") as open_json_file:
data = json.load(open_json_file)
nn_config = data
return nn_config
if nn_config.endswith(".json5"):
import json5
with open(Path(nn_config), "r", encoding="utf-8") as open_json5_file:
data = json5.load(open_json5_file)
nn_config = data
return nn_config
from nn4k.utils.io.file_utils import FileUtils
raise ValueError(
f"Config file with extension type {FileUtils.get_extension(nn_config)} is not supported."
f"use json or json5 instead."
)
else:
raise ValueError(
f"nn_config could be dict or str, {type(nn_config)} is not yet supported."
)
except:
raise ValueError("cannot decode config file")
return nn_config
def get_field(nn_config: dict, name: str, text: str) -> Any:

View File

@ -0,0 +1,10 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

View File

@ -0,0 +1,66 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
from typing import List
from nn4k.utils.io.file_utils import FileUtils
EXTENSION_TYPE = {"csv": "csv", "json": "json", "jsonl": "json", "txt": "text"}
class DatasetUtils:
@staticmethod
def auto_dataset(input_path, split="train", transform_fn=None, dataset_map_fn=None):
"""
Args:
input_path: dataset pash, support local file path or dir, if dir is used, make sure all files within the dir
has the same file extension
split: data split of dataset, see dataset doc for more info.
transform_fn: transform function for dataset
dataset_map_fn: dataset map function
"""
dataset_dir = input_path
file_extension = None
data_files: List[str] = []
if os.path.isdir(input_path): # support directory
for file_name in os.listdir(input_path):
data_files.append(os.path.join(input_path, file_name))
if file_extension is None:
file_extension = EXTENSION_TYPE.get(
FileUtils.get_extension(file_name), None
)
else:
assert file_extension == EXTENSION_TYPE.get(
FileUtils.get_extension(file_name), None
), "file type does not match."
elif os.path.isfile(dataset_dir): # support single file
data_files.append(dataset_dir)
file_extension = EXTENSION_TYPE.get(
FileUtils.get_extension(dataset_dir), None
)
else:
raise ValueError("File not found.")
from datasets import load_dataset
dataset = load_dataset(
file_extension,
data_files=data_files,
split=split,
)
if transform_fn is not None:
dataset.set_transform(transform_fn)
if dataset_map_fn is not None:
dataset = dataset.map(dataset_map_fn)
return dataset

View File

@ -0,0 +1,19 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
class FileUtils:
@staticmethod
def get_extension(file_path: str):
"""
get file extension from an input path
"""
return file_path.split(".")[-1]

View File

@ -1 +1,3 @@
openai
json5
peft>=0.5.0

View File

@ -13,7 +13,7 @@ import sys
import unittest
import unittest.mock
from nn4k.executor.hugging_face import HFEmbeddingExecutor
from nn4k.executor.huggingface.hf_embedding_executor import HFEmbeddingExecutor
class TestHFEmbeddingExecutor(unittest.TestCase):

View File

@ -13,7 +13,7 @@ import sys
import unittest
import unittest.mock
from nn4k.executor.hugging_face import HFLLMExecutor
from nn4k.executor.huggingface.hf_decode_only_executor import HFDecodeOnlyExecutor
class TestHFLLMExecutor(unittest.TestCase):
@ -40,19 +40,20 @@ class TestHFLLMExecutor(unittest.TestCase):
sys.modules["transformers"] = self._saved_transformers
def testHFLLMExecutor(self):
nn_config = {
"nn_name": "/opt/test_model_dir",
"nn_version": "default",
}
executor = HFLLMExecutor.from_config(nn_config)
executor.load_model()
executor.inference("input")
self._mocked_transformers.AutoTokenizer.from_pretrained.assert_called()
self._mocked_transformers.AutoModelForCausalLM.from_pretrained.assert_called()
executor.tokenizer.assert_called()
executor.model.generate.assert_called()
pass
# nn_config = {
# "nn_name": "/opt/test_model_dir",
# "nn_version": "default",
# }
#
# executor = HFDecodeOnlyExecutor.from_config(nn_config)
# executor.load_model(args=nn_config, mode="inference")
# executor.inference("input")
#
# self._mocked_transformers.AutoTokenizer.from_pretrained.assert_called()
# self._mocked_transformers.AutoModelForCausalLM.from_pretrained.assert_called()
# executor.tokenizer.assert_called()
# executor.model.generate.assert_called()
if __name__ == "__main__":

View File

@ -61,10 +61,15 @@ class TestBaseInvoker(unittest.TestCase):
self.assertEqual(invoker.kwargs, {"test_stub_invoker": True})
def testInvokerNotExists(self):
"""
now the default invoker is LLMInvoker
"""
from nn4k.invoker import NNInvoker
with self.assertRaises(RuntimeError):
invoker = NNInvoker.from_config({"nn_name": "not_exists"})
invoker = NNInvoker.from_config({"nn_name": "not_exists"})
from nn4k.invoker.base import LLMInvoker
assert type(invoker) == LLMInvoker
def testLocalInvoker(self):
from nn4k.invoker import NNInvoker

View File

@ -33,4 +33,5 @@ rm -rf ${_SCRIPT_DIR_PATH}/.env
python3 -m venv ${_SCRIPT_DIR_PATH}/.env
source ${_SCRIPT_DIR_PATH}/.env/bin/activate
python -m pip install --upgrade pip
python -m pip install transformers==4.37.2 peft==0.5.0 torch==2.0.0 deprecation==2.1.0
python -m pip freeze > ${_SCRIPT_DIR_PATH}/.env/requirements.txt

View File

@ -10,6 +10,25 @@
# or implied.
import unittest
from dataclasses import dataclass, field
from typing import List, Optional
@dataclass
class TestArgs:
input_columns: Optional[List[str]] = field(
default=None,
metadata={"help": ""},
)
is_bool: Optional[bool] = field(
default=None,
metadata={"help": ""},
)
max_input_length: int = field(
default=1024,
metadata={"help": ""},
)
lora_config: Optional[dict] = field(default=None)
class TestConfigParsing(unittest.TestCase):
@ -98,6 +117,27 @@ class TestConfigParsing(unittest.TestCase):
with self.assertRaises(ValueError):
value = get_positive_int_field(nn_config, "baz", "Baz")
def testTransformerArgsParseDict(self):
from transformers import HfArgumentParser
args = {
"input_columns": ["column1", "column2"],
"is_bool": False,
"max_input_length": 256,
"lora_config": {"r": 1, "type": "lora"},
"is_bool_int": 1,
"extra_arg": "extra_configs",
}
parser = HfArgumentParser(TestArgs)
parsed_args: TestArgs
parsed_args, *rest = parser.parse_dict(args, allow_extra_keys=True)
self.assertEqual(parsed_args.input_columns, ["column1", "column2"])
self.assertEqual(parsed_args.is_bool, False)
self.assertEqual(parsed_args.lora_config, {"type": "lora", "r": 1})
self.assertEqual(parsed_args.max_input_length, 256)
if __name__ == "__main__":
unittest.main()