mirror of
https://github.com/OpenSPG/openspg.git
synced 2025-12-27 07:04:54 +00:00
Co-authored-by: xionghuaidong <huaidong.xhd@antgroup.com>
This commit is contained in:
parent
e95725d470
commit
eb2590aada
@ -32,6 +32,7 @@ header:
|
||||
- '**/*.schema'
|
||||
- '**/*.rule'
|
||||
- '**/*.json'
|
||||
- '**/*.json5'
|
||||
- '**/*.in'
|
||||
- '**/META-INF/services/*'
|
||||
- '**/*.conf'
|
||||
|
||||
@ -22,8 +22,7 @@ NN_INVOKER_TEXT = "NN invoker"
|
||||
NN_EXECUTOR_KEY = "nn_executor"
|
||||
NN_EXECUTOR_TEXT = "NN executor"
|
||||
|
||||
NN_DEVICE_KEY = "device"
|
||||
NN_TRUST_REMOTE_CODE_KEY = "trust_remote_code"
|
||||
NN_DEVICE_KEY = "nn_device"
|
||||
|
||||
NN_OPENAI_MODEL_NAME_KEY = NN_NAME_KEY
|
||||
NN_OPENAI_MODEL_NAME_TEXT = "openai model name"
|
||||
|
||||
@ -9,4 +9,4 @@
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
from nn4k.executor.base import NNExecutor, LLMExecutor
|
||||
from nn4k.executor.base import NNExecutor, LLMExecutor, NNModelArgs, NNAdapterModelArgs
|
||||
|
||||
@ -10,7 +10,8 @@
|
||||
# or implied.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Union
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
class NNExecutor(ABC):
|
||||
@ -145,7 +146,23 @@ class NNExecutor(ABC):
|
||||
raise RuntimeError(message)
|
||||
|
||||
|
||||
class LLMExecutor(NNExecutor):
|
||||
class LLMExecutor(NNExecutor, ABC):
|
||||
"""
|
||||
Base Executor for LLM.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, nn_config: Union[str, dict]) -> "LLMExecutor":
|
||||
"""
|
||||
Implement distribution logic for LLM, since we only support Huggingface Decode Only models for now,
|
||||
it is directly point to HFDecodeOnlyExecutor. Will use the hub management functions later on.
|
||||
"""
|
||||
from nn4k.executor.huggingface.hf_decode_only_executor import (
|
||||
HFDecodeOnlyExecutor,
|
||||
)
|
||||
|
||||
return HFDecodeOnlyExecutor.from_config(nn_config)
|
||||
|
||||
def execute_sft(self, args=None, callbacks=None, **kwargs):
|
||||
"""
|
||||
The entry point of SFT execution in a certain pod.
|
||||
@ -159,3 +176,75 @@ class LLMExecutor(NNExecutor):
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not support RL-Tuning."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NNModelArgs:
|
||||
"""
|
||||
Base NN4K-supported model definition and load related args.
|
||||
"""
|
||||
|
||||
nn_name: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": ("NN4K model name")},
|
||||
)
|
||||
nn_version: Optional[str] = field(
|
||||
default="default",
|
||||
metadata={"help": ("NN4K model version, by default is 'default'")},
|
||||
)
|
||||
nn_model_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"model path dir, could be delivered by user or get managed in Hub."
|
||||
)
|
||||
},
|
||||
)
|
||||
nn_device: Optional[str] = field(
|
||||
default="auto", metadata={"help": ("device to use to load model")}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
assert (
|
||||
self.nn_name is not None or self.nn_model_path is not None
|
||||
), "either nn_name or nn_model_path has to be provided"
|
||||
|
||||
|
||||
@dataclass
|
||||
class NNAdapterModelArgs(NNModelArgs):
|
||||
"""
|
||||
One should use this args dataclass to enable adapter models.
|
||||
"""
|
||||
|
||||
adapter_name: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "adapter name. Should be provided if you want to sft or load a adapter model."
|
||||
},
|
||||
)
|
||||
adapter_version: str = field(
|
||||
default="auto",
|
||||
metadata={
|
||||
"help": "adapter is designed to get managed by versions, by default is 'latest'"
|
||||
},
|
||||
)
|
||||
adapter_type: str = field(
|
||||
default="lora", metadata={"help": "adapter type, lora by default."}
|
||||
)
|
||||
adapter_path: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "adapter weight and config path, could be delivered by user or get managed in Hub."
|
||||
},
|
||||
)
|
||||
adapter_config: Optional[dict] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Only necessary if you want to init a new adapter model and train from scratch or resume"
|
||||
"from a checkpoint (in this case, should be the same as the previous adapter_config)."
|
||||
"Values are the same as peft config init args."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
@ -1,152 +0,0 @@
|
||||
# Copyright 2023 OpenSPG Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
from typing import Union
|
||||
from nn4k.executor import LLMExecutor
|
||||
|
||||
|
||||
class HFLLMExecutor(LLMExecutor):
|
||||
@classmethod
|
||||
def from_config(cls, nn_config: dict) -> "HFLLMExecutor":
|
||||
"""
|
||||
Create an HFLLMExecutor instance from `nn_config`.
|
||||
"""
|
||||
executor = cls(nn_config)
|
||||
return executor
|
||||
|
||||
def execute_sft(self, args=None, callbacks=None, **kwargs):
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} will support SFT in the next version."
|
||||
)
|
||||
|
||||
def load_model(self, args=None, **kwargs):
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModelForCausalLM
|
||||
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
|
||||
from nn4k.consts import NN_DEVICE_KEY, NN_TRUST_REMOTE_CODE_KEY
|
||||
from nn4k.utils.config_parsing import get_string_field
|
||||
|
||||
nn_config: dict = args or self.init_args
|
||||
if self._model is None:
|
||||
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
|
||||
nn_version = nn_config.get(NN_VERSION_KEY)
|
||||
if nn_version is not None:
|
||||
nn_version = get_string_field(
|
||||
nn_config, NN_VERSION_KEY, NN_VERSION_TEXT
|
||||
)
|
||||
model_path = nn_name
|
||||
revision = nn_version
|
||||
use_fast_tokenizer = False
|
||||
device = nn_config.get(NN_DEVICE_KEY)
|
||||
trust_remote_code = nn_config.get(NN_TRUST_REMOTE_CODE_KEY, False)
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path,
|
||||
use_fast=use_fast_tokenizer,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
low_cpu_mem_usage=True,
|
||||
torch_dtype=torch.float16,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
model.to(device)
|
||||
self._tokenizer = tokenizer
|
||||
self._model = model
|
||||
|
||||
def inference(
|
||||
self,
|
||||
data,
|
||||
max_input_length: int = 1024,
|
||||
max_output_length: int = 1024,
|
||||
do_sample: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
model = self.model
|
||||
tokenizer = self.tokenizer
|
||||
input_ids = tokenizer(
|
||||
data,
|
||||
padding=True,
|
||||
return_token_type_ids=False,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
max_length=max_input_length,
|
||||
).to(model.device)
|
||||
output_ids = model.generate(
|
||||
**input_ids,
|
||||
max_new_tokens=max_output_length,
|
||||
do_sample=do_sample,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
outputs = [
|
||||
tokenizer.decode(
|
||||
output_id[len(input_ids["input_ids"][idx]) :], skip_special_tokens=True
|
||||
)
|
||||
for idx, output_id in enumerate(output_ids)
|
||||
]
|
||||
return outputs
|
||||
|
||||
|
||||
class HFEmbeddingExecutor(LLMExecutor):
|
||||
@classmethod
|
||||
def from_config(cls, nn_config: dict) -> "HFEmbeddingExecutor":
|
||||
"""
|
||||
Create an HFEmbeddingExecutor instance from `nn_config`.
|
||||
"""
|
||||
executor = cls(nn_config)
|
||||
return executor
|
||||
|
||||
def load_model(self, args=None, **kwargs):
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
|
||||
from nn4k.consts import NN_DEVICE_KEY
|
||||
from nn4k.utils.config_parsing import get_string_field
|
||||
|
||||
nn_config: dict = args or self.init_args
|
||||
if self._model is None:
|
||||
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
|
||||
nn_version = nn_config.get(NN_VERSION_KEY)
|
||||
if nn_version is not None:
|
||||
nn_version = get_string_field(
|
||||
nn_config, NN_VERSION_KEY, NN_VERSION_TEXT
|
||||
)
|
||||
model_path = nn_name
|
||||
revision = nn_version
|
||||
use_fast_tokenizer = False
|
||||
device = nn_config.get(NN_DEVICE_KEY)
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
#
|
||||
# SentenceTransformer will support `revision` soon. See:
|
||||
#
|
||||
# https://github.com/UKPLab/sentence-transformers/pull/2419
|
||||
#
|
||||
model = SentenceTransformer(
|
||||
model_path,
|
||||
device=device,
|
||||
)
|
||||
self._model = model
|
||||
|
||||
def inference(self, data, args=None, **kwargs):
|
||||
model = self.model
|
||||
embeddings = model.encode(data)
|
||||
return embeddings
|
||||
13
python/nn4k/nn4k/executor/huggingface/__init__.py
Normal file
13
python/nn4k/nn4k/executor/huggingface/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
# Copyright 2023 OpenSPG Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
from nn4k.executor.huggingface.base.hf_llm_executor import HFLLMExecutor
|
||||
from nn4k.executor.huggingface.base.hf_args import HFModelArgs, HFSftArgs
|
||||
10
python/nn4k/nn4k/executor/huggingface/base/__init__.py
Normal file
10
python/nn4k/nn4k/executor/huggingface/base/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
# Copyright 2023 OpenSPG Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
107
python/nn4k/nn4k/executor/huggingface/base/hf_args.py
Normal file
107
python/nn4k/nn4k/executor/huggingface/base/hf_args.py
Normal file
@ -0,0 +1,107 @@
|
||||
# Copyright 2023 OpenSPG Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from transformers import TrainingArguments
|
||||
|
||||
from nn4k.executor import NNAdapterModelArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class HFModelArgs(NNAdapterModelArgs):
|
||||
"""
|
||||
Huggingface Model is designed to support adapter models such as lora, therefore should inherit from
|
||||
NNAdapterModelArgs dataclass
|
||||
"""
|
||||
|
||||
torch_dtype: Optional[str] = field(
|
||||
default="auto",
|
||||
metadata={
|
||||
"help": (
|
||||
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
|
||||
"dtype will be automatically derived from the model's weights."
|
||||
)
|
||||
},
|
||||
)
|
||||
qlora_bits_and_bytes_config: Optional[dict] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Quantization configs to load qlora, "
|
||||
"same as :class:`transformers.utils.quantization_config.BitsAndBytesConfig`"
|
||||
},
|
||||
)
|
||||
trust_remote_code: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether or not to allow for custom models defined on the Hub in their own modeling files."
|
||||
},
|
||||
)
|
||||
from_tf: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": " Load the model weights from a TensorFlow checkpoint save file, default to False"
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# for hf models, if model path has higher priority then name, since you don't need to download the model(or
|
||||
# from cache) again.
|
||||
self.pretrained_model_name_or_path = self.nn_model_path or self.nn_name
|
||||
|
||||
|
||||
@dataclass
|
||||
class HFSftArgs(HFModelArgs, TrainingArguments):
|
||||
"""
|
||||
args to use for huggingface model sft task
|
||||
"""
|
||||
|
||||
train_dataset_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Should not be None. A file or dir path to train dataset, If a dir path, "
|
||||
"all files inside should have the same file extension."
|
||||
},
|
||||
)
|
||||
eval_dataset_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "A file or dir path to eval dataset. If a dir path, all files inside should have the same "
|
||||
"file extension. If set, do_eval flag will be set to True"
|
||||
},
|
||||
)
|
||||
max_input_length: int = field(
|
||||
default=1024,
|
||||
metadata={"help": "max length of input"},
|
||||
)
|
||||
resume_from_checkpoint: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The path to a folder with a valid checkpoint for your model."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
HFModelArgs.__post_init__(self)
|
||||
TrainingArguments.__post_init__(self)
|
||||
assert self.train_dataset_path is not None, "train_dataset_path must be set."
|
||||
if self.train_dataset_path and not self.do_train:
|
||||
self.do_train = True
|
||||
print(
|
||||
f"a train_dataset_path is set but do_train flag is not set, automatically set do_train to True"
|
||||
)
|
||||
if self.eval_dataset_path and not self.do_eval:
|
||||
self.do_eval = True
|
||||
print(
|
||||
f"a eval_dataset_path is set but do_eval flag is not set, automatically set do_eval to True"
|
||||
)
|
||||
328
python/nn4k/nn4k/executor/huggingface/base/hf_llm_executor.py
Normal file
328
python/nn4k/nn4k/executor/huggingface/base/hf_llm_executor.py
Normal file
@ -0,0 +1,328 @@
|
||||
# Copyright 2023 OpenSPG Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
import os
|
||||
import typing
|
||||
from abc import abstractmethod
|
||||
from typing import Optional, Union
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoConfig, AutoTokenizer, Trainer
|
||||
|
||||
from nn4k.executor import LLMExecutor
|
||||
from .hf_args import HFSftArgs, HFModelArgs
|
||||
from nn4k.executor.huggingface.nn_hf_trainer import NNHFTrainer
|
||||
|
||||
|
||||
class HFLLMExecutor(LLMExecutor):
|
||||
"""
|
||||
Base Executor for huggingface models.
|
||||
"""
|
||||
|
||||
def __init__(self, init_args: dict, **kwargs):
|
||||
super().__init__(init_args=init_args, **kwargs)
|
||||
# model_model could be either 'train' or 'inference' or model load
|
||||
self.model_mode = None
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, nn_config: Union[dict]) -> "HFLLMExecutor":
|
||||
"""
|
||||
Create an HFLLMExecutor instance from `nn_config`.
|
||||
"""
|
||||
executor = cls(nn_config)
|
||||
return executor
|
||||
|
||||
def execute_sft(self, args: dict = None, callbacks=None, **kwargs):
|
||||
args = args or self.init_args
|
||||
|
||||
self.load_model(args=args, mode="train")
|
||||
|
||||
# parse args into HFSftArgs dataclass for more convenient features
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
parser = HfArgumentParser(HFSftArgs)
|
||||
hf_sft_args: HFSftArgs
|
||||
hf_sft_args, *_ = parser.parse_dict(args, allow_extra_keys=True)
|
||||
|
||||
# load checkpoint path if necessary.
|
||||
resume_from_checkpoint_path = self._get_last_checkpoint(hf_sft_args)
|
||||
|
||||
# load and map dataset
|
||||
train_dataset, eval_dataset = self._init_dataset(hf_sft_args)
|
||||
|
||||
# init trainer
|
||||
trainer: Trainer = self._init_trainer(
|
||||
train_dataset, eval_dataset, hf_sft_args, callbacks
|
||||
)
|
||||
|
||||
# start training
|
||||
train_result = trainer.train(resume_from_checkpoint=resume_from_checkpoint_path)
|
||||
|
||||
# save trained model after train complete
|
||||
trainer.save_model(hf_sft_args.output_dir)
|
||||
|
||||
# save train metrics
|
||||
train_metrics = train_result.metrics
|
||||
train_metrics["train_samples_len"] = len(train_dataset)
|
||||
trainer.log_metrics("train", train_metrics)
|
||||
trainer.save_metrics("train", train_metrics)
|
||||
trainer.save_state()
|
||||
|
||||
return self
|
||||
|
||||
def _get_last_checkpoint(self, sft_args: HFSftArgs) -> Optional[str]: # noqa
|
||||
"""
|
||||
try to find checkpoint in sft_args.output_dir.
|
||||
If sft_args.resume_from_checkpoint in ['True', 'true', True, ''], try to return the checkpoint dir with the
|
||||
largest checkpoint index. The largest checkpoint dir path will be returned.
|
||||
If sft_args.resume_from_checkpoint in [None, 'False', 'false', False], means not necessary to resume from
|
||||
checkpoint, None will be returned.
|
||||
If sft_args.resume_from_checkpoint is the checkpoint subfolder dir name, the 'output_dir/resume_from_checkpoint'
|
||||
path will be returned if exists. Be aware, if the dir does not exist, ValueError will be raised.
|
||||
"""
|
||||
output_dir_contains_file = (
|
||||
os.path.isdir(sft_args.output_dir)
|
||||
and len(os.listdir(sft_args.output_dir)) > 0
|
||||
)
|
||||
|
||||
if sft_args.resume_from_checkpoint in ["True", "true", True, ""]:
|
||||
resume_from_checkpoint_bool = True
|
||||
if output_dir_contains_file:
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
resume_from_checkpoint_path = get_last_checkpoint(sft_args.output_dir)
|
||||
else:
|
||||
resume_from_checkpoint_path = None
|
||||
assert (
|
||||
resume_from_checkpoint_path is not None
|
||||
), f"cannot find last checkpoint dir in {sft_args.output_dir}"
|
||||
elif sft_args.resume_from_checkpoint in [None, "False", "false", False]:
|
||||
resume_from_checkpoint_bool = False
|
||||
resume_from_checkpoint_path = None
|
||||
else:
|
||||
resume_from_checkpoint_bool = True
|
||||
resume_from_checkpoint_path = os.path.join(
|
||||
sft_args.output_dir, sft_args.resume_from_checkpoint
|
||||
)
|
||||
assert os.path.isdir(
|
||||
resume_from_checkpoint_path
|
||||
), f"{resume_from_checkpoint_path} is not a dir."
|
||||
|
||||
if (
|
||||
output_dir_contains_file
|
||||
and not sft_args.overwrite_output_dir
|
||||
and not resume_from_checkpoint_bool
|
||||
):
|
||||
raise ValueError(
|
||||
f"Output_dir ({sft_args.output_dir}) is not empty. Maybe you mean --resume_from_checkpoint"
|
||||
'="True" to resume a training or --overwrite_output_dir to overwrite output_dir.'
|
||||
)
|
||||
|
||||
return resume_from_checkpoint_path
|
||||
|
||||
def map_fn(self, dataset, **kwargs):
|
||||
"""
|
||||
dataset map and template function. The default implement follows the BelleGroup/train_0.5M_CN format, means
|
||||
'instruction', 'input' and 'output' are necessary. Since some other popular dataset like tatsu-lab/alpaca
|
||||
provides these columns as well, it is also supported.
|
||||
"""
|
||||
args: HFSftArgs = kwargs.get("args", None)
|
||||
instruction = dataset["instruction"]
|
||||
input_text = dataset["input"]
|
||||
output_text = dataset["output"]
|
||||
bos_token = self.tokenizer.bos_token or ""
|
||||
eos_token = self.tokenizer.eos_token
|
||||
input_prompt = f"{bos_token}{instruction} {input_text}{eos_token}"
|
||||
tokenized_full_prompt = self._tokenize_dataset(
|
||||
input_prompt, args.max_input_length
|
||||
)
|
||||
return tokenized_full_prompt
|
||||
|
||||
def _init_dataset(
|
||||
self, args: HFSftArgs
|
||||
) -> typing.Tuple[Union[Dataset], Union[Dataset]]: # noqa
|
||||
"""
|
||||
init and map dataset, for train and eval
|
||||
"""
|
||||
with args.main_process_first(desc="initialize dataset"):
|
||||
train_dataset = None
|
||||
if args.train_dataset_path:
|
||||
train_dataset = (
|
||||
self._load_dataset(args.train_dataset_path, "train")
|
||||
.shuffle()
|
||||
.map(self.map_fn, fn_kwargs={"args": args})
|
||||
)
|
||||
|
||||
eval_dataset = None
|
||||
if args.eval_dataset_path:
|
||||
eval_dataset = (
|
||||
self._load_dataset(args.eval_dataset_path, "train")
|
||||
.shuffle()
|
||||
.map(self.map_fn, fn_kwargs={"args": args})
|
||||
)
|
||||
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
def _load_dataset(self, data_path, split="train"): # noqa
|
||||
from nn4k.utils.io.dataset_utils import DatasetUtils
|
||||
|
||||
return DatasetUtils.auto_dataset(data_path, split)
|
||||
|
||||
def load_model(self, args: dict = None, mode=None, **kwargs):
|
||||
"""
|
||||
load model and tokenizer. If the model with the same mode is already loaded, will not load again.
|
||||
"""
|
||||
|
||||
assert (
|
||||
mode is not None
|
||||
), f"mode should be either 'train' or 'inference' for HFLLMExecutor, {mode} is illegal."
|
||||
|
||||
if self.model_mode == mode and self._model is not None:
|
||||
return
|
||||
|
||||
from transformers import HfArgumentParser
|
||||
from nn4k.executor.huggingface import HFModelArgs
|
||||
|
||||
parser = HfArgumentParser(HFModelArgs)
|
||||
hf_model_args, *_ = parser.parse_dict(args, allow_extra_keys=True)
|
||||
|
||||
self.model_mode = mode
|
||||
self._tokenizer = self._hf_tokenizer_loader(hf_model_args)
|
||||
self._model = self._hf_model_loader(
|
||||
hf_model_args, mode, hf_model_args.nn_device
|
||||
)
|
||||
|
||||
if self.tokenizer.eos_token_id is None:
|
||||
self.tokenizer.eos_token_id = self.model.config.eos_token_id
|
||||
if self.tokenizer.pad_token_id is None:
|
||||
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
||||
|
||||
def inference(
|
||||
self,
|
||||
data,
|
||||
max_input_length: int = 1024,
|
||||
max_output_length: int = 1024,
|
||||
do_sample: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
model = self.model
|
||||
tokenizer = self.tokenizer
|
||||
input_ids = tokenizer(
|
||||
data,
|
||||
padding=True,
|
||||
return_token_type_ids=False,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
max_length=max_input_length,
|
||||
).to(model.device)
|
||||
output_ids = model.generate(
|
||||
**input_ids,
|
||||
max_new_tokens=max_output_length,
|
||||
do_sample=do_sample,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
outputs = [
|
||||
tokenizer.decode(
|
||||
output_id[len(input_ids["input_ids"][idx]) :], skip_special_tokens=True
|
||||
)
|
||||
for idx, output_id in enumerate(output_ids)
|
||||
]
|
||||
return outputs
|
||||
|
||||
@abstractmethod
|
||||
def _hf_model_loader(
|
||||
self,
|
||||
args: HFModelArgs,
|
||||
mode,
|
||||
resume_from_checkpoint=False,
|
||||
device=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
load model into given device for hugging face.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _hf_tokenizer_loader(self, args: HFModelArgs, **kwargs): # noqa
|
||||
"""
|
||||
hugging face tokenizer loader
|
||||
"""
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
pretrained_model_name_or_path=args.pretrained_model_name_or_path,
|
||||
use_fast=False,
|
||||
revision=args.nn_version,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
def _hf_model_config_loader(self, args: HFModelArgs, **kwargs): # noqa
|
||||
"""
|
||||
hugging face model config loader
|
||||
"""
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
**kwargs,
|
||||
)
|
||||
return model_config
|
||||
|
||||
def _init_trainer(
|
||||
self, train_dataset, eval_dataset, sft_args: HFSftArgs, callbacks=None
|
||||
) -> Trainer:
|
||||
"""
|
||||
hugging face model trainer initializer
|
||||
"""
|
||||
trainer = NNHFTrainer(
|
||||
model=self.model,
|
||||
args=sft_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=self.tokenizer,
|
||||
data_collator=self._data_collator(),
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
return trainer
|
||||
|
||||
@abstractmethod
|
||||
def _data_collator(self, return_tensors="pt", **kwargs):
|
||||
"""
|
||||
data collator used in trainer
|
||||
"""
|
||||
pass
|
||||
|
||||
def _tokenize_dataset(self, prompt_text, max_length):
|
||||
"""
|
||||
tokenize dataset, by default will cut the input to the max_length
|
||||
"""
|
||||
tokenized_dataset = self.tokenizer(
|
||||
prompt_text, truncation=True, max_length=max_length
|
||||
)
|
||||
input_ids = tokenized_dataset["input_ids"]
|
||||
attention_mask = tokenized_dataset["attention_mask"]
|
||||
|
||||
# append eos token if necessary
|
||||
# input length is shorter than max_length
|
||||
if len(input_ids) < max_length:
|
||||
if input_ids[-1] != self.tokenizer.eos_token_id:
|
||||
input_ids.append(self.tokenizer.eos_token_id)
|
||||
attention_mask.append(1)
|
||||
else:
|
||||
input_ids[max_length - 1] = self.tokenizer.eos_token_id
|
||||
attention_mask[max_length - 1] = 1
|
||||
|
||||
# labels are copy of input_ids
|
||||
tokenized_dataset["labels"] = tokenized_dataset["input_ids"].copy()
|
||||
|
||||
return tokenized_dataset
|
||||
@ -0,0 +1,10 @@
|
||||
# Copyright 2023 OpenSPG Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
@ -0,0 +1,10 @@
|
||||
# Copyright 2023 OpenSPG Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
@ -0,0 +1,24 @@
|
||||
# Copyright 2023 OpenSPG Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
# The scripts are tested by the following package installed
|
||||
export WANDB_DISABLED=true
|
||||
|
||||
#Only if you have a cuda OOM, try this setting
|
||||
#export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:32
|
||||
|
||||
pip install peft==0.5.0
|
||||
pip install json5 # only necessary if you use json5 file as a config file
|
||||
pip install numpy==1.23.1
|
||||
pip install transformers==4.36.2
|
||||
pip install accelerate>=0.21.0
|
||||
pip install bitsandbytes>=0.39.0 #only necessary if you use qlora
|
||||
#pip install xformers==0.0.23.post1 # only necessary if you want to accelerate loading model in memery efficient way
|
||||
@ -0,0 +1,41 @@
|
||||
{
|
||||
// -- base model and training args
|
||||
"nn_model_path": "/model/path/to/Baichuan-7B-Chat", // local model path
|
||||
"train_dataset_path": "/data/train/dataset.json", // train dataset path
|
||||
"nn_invoker": "nn4k.invoker.base.LLMInvoker", // invoker to use
|
||||
"nn_executor": "nn4k.executor.huggingface.hf_decode_only_executor.HFDecodeOnlyExecutor", // executor to use
|
||||
"output_dir": "/path/to/output/dir", // trained model output dir
|
||||
// ----- The following args are optional-----
|
||||
|
||||
// "eval_dataset_path": "/data/eval/dataset-eval.json", // eval dataset path, if you want to do eval
|
||||
// -- adapter model info, only if you want to train lora adapter
|
||||
// "adapter_name": "YouYou", //set it to a not "default" string value to enable adapter sft
|
||||
// "adapter_type": "lora", // adapter type. Don't need it if adapter_name is not set
|
||||
// "adapter_config": { // only necessary if adapter_name is set, same as peft LoraConfig args if tyep is 'lora'
|
||||
// "r": 8,
|
||||
// "lora_alpha": 16,
|
||||
// "lora_dropout": 0.05,
|
||||
// "bias": "none",
|
||||
// "target_modules": ["W_pack", "o_proj"], // this is only an example for BaiChuan lora training
|
||||
// "task_type": "CAUSAL_LM"
|
||||
// },
|
||||
// "qlora_bits_and_bytes_config": { // only necessary if you want to quantinize load model
|
||||
// "load_in_4bit": true,
|
||||
// "bnb_4bit_compute_dtype": "bfloat16",
|
||||
// "bnb_4bit_use_double_quant": true,
|
||||
// "bnb_4bit_quant_type": "nf4"
|
||||
// }
|
||||
//-- start training args
|
||||
// "resume_from_checkpoint": "True", // only necessary if you want to resume training from checkpoint
|
||||
"trust_remote_code": true,
|
||||
"max_input_length": 256, // input max length. Inputs will be cut down to this length
|
||||
//-- start: same as huggingface trainer args
|
||||
"per_device_train_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"lr_scheduler_type": "cosine", // adjust learning rate scheduler
|
||||
"logging_steps": 20,
|
||||
"save_steps": 10000,
|
||||
"learning_rate": 4e-5,
|
||||
"num_train_epochs": 1.0
|
||||
//-- end: huggingface trainer args
|
||||
}
|
||||
@ -0,0 +1,22 @@
|
||||
# Copyright 2023 OpenSPG Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
from nn4k.invoker.base import NNInvoker
|
||||
|
||||
|
||||
def main():
|
||||
NNInvoker.from_config("local_sft.json5").local_sft()
|
||||
# Inference example, not implemented yet.
|
||||
# NNInvoker.from_config("inferece_args.json").local_inference("你是谁")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
117
python/nn4k/nn4k/executor/huggingface/hf_decode_only_executor.py
Normal file
117
python/nn4k/nn4k/executor/huggingface/hf_decode_only_executor.py
Normal file
@ -0,0 +1,117 @@
|
||||
# Copyright 2023 OpenSPG Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from nn4k.executor.huggingface import HFModelArgs
|
||||
from nn4k.executor.huggingface import HFLLMExecutor
|
||||
|
||||
|
||||
class HFDecodeOnlyExecutor(HFLLMExecutor):
|
||||
"""
|
||||
Huggingface decode only default executor, will use AutoModelForCausalLM to load model and
|
||||
DataCollatorForSeq2Seq as a default data collator.
|
||||
"""
|
||||
|
||||
def _hf_model_loader(
|
||||
self,
|
||||
args: HFModelArgs,
|
||||
mode,
|
||||
resume_from_checkpoint=False,
|
||||
device=None,
|
||||
**kwargs,
|
||||
):
|
||||
if device is None or "auto":
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# load base model
|
||||
model_config = self._hf_model_config_loader(args, **kwargs)
|
||||
|
||||
quant_config = None
|
||||
if args.adapter_name and args.qlora_bits_and_bytes_config:
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
quant_config = BitsAndBytesConfig(**args.qlora_bits_and_bytes_config)
|
||||
|
||||
model_load_args = dict(
|
||||
pretrained_model_name_or_path=args.pretrained_model_name_or_path,
|
||||
config=model_config,
|
||||
quantization_config=quant_config,
|
||||
revision=args.nn_version,
|
||||
torch_dtype=args.torch_dtype,
|
||||
from_tf=args.from_tf,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(**model_load_args)
|
||||
|
||||
if quant_config:
|
||||
from peft import prepare_model_for_kbit_training
|
||||
|
||||
model = prepare_model_for_kbit_training(model)
|
||||
|
||||
# load adapter model
|
||||
if args.adapter_name:
|
||||
# provide an adapter_path, means one can load an exist lora adapter and start a new train based on that.
|
||||
if args.adapter_path and not resume_from_checkpoint:
|
||||
from peft import PeftModel
|
||||
|
||||
# TODO NN4K: Notice: NN4K plan to provide a hub-managed adapter implementation in the near future.
|
||||
model = PeftModel.from_pretrained(
|
||||
model=model,
|
||||
model_id=args.adapter_path,
|
||||
adapter_name=args.adapter_name,
|
||||
adapter_version=args.adapter_version,
|
||||
is_trainable=(mode == "train"),
|
||||
)
|
||||
elif (
|
||||
args.adapter_config
|
||||
): # no adapter_path but adapter_config means train an adapter from scratch
|
||||
from peft import get_peft_model
|
||||
from peft import LoraConfig
|
||||
|
||||
if args.adapter_type in ["lora", "qlora"]:
|
||||
peft_config = LoraConfig(**args.adapter_config)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"adapter_type {args.adapter_type} is not supported in "
|
||||
f"hf_decode_only_executor use lora or qlora instead"
|
||||
)
|
||||
model = get_peft_model(
|
||||
model=model,
|
||||
peft_config=peft_config,
|
||||
adapter_name=args.adapter_name,
|
||||
# TODO NN4K: NN4K plan to provide a hub-managed adapter implementation in the
|
||||
# near future. adapter_version=args.adapter_version,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You should either provide a adapter_path to load an existing adapter without resume"
|
||||
"a training, or provide a adapter_config to train a adapter from scratch or resume a "
|
||||
"adapter training from checkpoint."
|
||||
)
|
||||
model.print_trainable_parameters()
|
||||
|
||||
if mode == "inference":
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
return model
|
||||
|
||||
def _data_collator(self, return_tensors="pt", **kwargs):
|
||||
return transformers.DataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
pad_to_multiple_of=8,
|
||||
return_tensors=return_tensors,
|
||||
padding=True,
|
||||
)
|
||||
@ -0,0 +1,59 @@
|
||||
# Copyright 2023 OpenSPG Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
from nn4k.executor import LLMExecutor
|
||||
|
||||
|
||||
class HFEmbeddingExecutor(LLMExecutor):
|
||||
@classmethod
|
||||
def from_config(cls, nn_config: dict) -> "HFEmbeddingExecutor":
|
||||
"""
|
||||
Create an HFEmbeddingExecutor instance from `nn_config`.
|
||||
"""
|
||||
executor = cls(nn_config)
|
||||
return executor
|
||||
|
||||
def load_model(self, args=None, **kwargs):
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
|
||||
from nn4k.consts import NN_DEVICE_KEY
|
||||
from nn4k.utils.config_parsing import get_string_field
|
||||
|
||||
nn_config: dict = args or self.init_args
|
||||
if self._model is None:
|
||||
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
|
||||
nn_version = nn_config.get(NN_VERSION_KEY)
|
||||
if nn_version is not None:
|
||||
nn_version = get_string_field(
|
||||
nn_config, NN_VERSION_KEY, NN_VERSION_TEXT
|
||||
)
|
||||
model_path = nn_name
|
||||
revision = nn_version
|
||||
use_fast_tokenizer = False
|
||||
device = nn_config.get(NN_DEVICE_KEY)
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
#
|
||||
# SentenceTransformer will support `revision` soon. See:
|
||||
#
|
||||
# https://github.com/UKPLab/sentence-transformers/pull/2419
|
||||
#
|
||||
model = SentenceTransformer(
|
||||
model_path,
|
||||
device=device,
|
||||
)
|
||||
self._model = model
|
||||
|
||||
def inference(self, data, args=None, **kwargs):
|
||||
model = self.model
|
||||
embeddings = model.encode(data)
|
||||
return embeddings
|
||||
204
python/nn4k/nn4k/executor/huggingface/nn_hf_trainer.py
Normal file
204
python/nn4k/nn4k/executor/huggingface/nn_hf_trainer.py
Normal file
@ -0,0 +1,204 @@
|
||||
# Copyright 2023 OpenSPG Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
import os
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
|
||||
from packaging import version
|
||||
|
||||
from peft import PeftModel
|
||||
from transformers import PretrainedConfig, Trainer, __version__
|
||||
from transformers.integrations import is_deepspeed_available
|
||||
from transformers.modeling_utils import load_sharded_checkpoint
|
||||
from transformers.trainer import logger
|
||||
from transformers.utils import (
|
||||
ADAPTER_SAFE_WEIGHTS_NAME,
|
||||
ADAPTER_WEIGHTS_NAME,
|
||||
CONFIG_NAME,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
is_accelerate_available,
|
||||
is_peft_available,
|
||||
is_sagemaker_mp_enabled,
|
||||
)
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import Accelerator, skip_first_batches
|
||||
from accelerate import __version__ as accelerate_version
|
||||
from accelerate.utils import (
|
||||
DistributedDataParallelKwargs,
|
||||
GradientAccumulationPlugin,
|
||||
load_fsdp_model,
|
||||
load_fsdp_optimizer,
|
||||
save_fsdp_model,
|
||||
save_fsdp_optimizer,
|
||||
)
|
||||
|
||||
DATA_SAMPLERS = [RandomSampler]
|
||||
if version.parse(accelerate_version) > version.parse("0.23.0"):
|
||||
from accelerate.data_loader import SeedableRandomSampler
|
||||
|
||||
DATA_SAMPLERS += [SeedableRandomSampler]
|
||||
|
||||
if is_deepspeed_available():
|
||||
from accelerate.utils import DeepSpeedSchedulerWrapper
|
||||
|
||||
|
||||
class NNHFTrainer(Trainer):
|
||||
"""
|
||||
only trying to fix resume checkpoint for lora adapter, will be replaced by using Trainer when the bug is
|
||||
fixed in huggingface trainer. The PR is offered: https://github.com/huggingface/transformers/pull/28547
|
||||
"""
|
||||
|
||||
def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
|
||||
# the following code only trying to fix resuming checkpoint for adapter model(Peft)
|
||||
if model is None:
|
||||
model = self.model
|
||||
|
||||
if not (is_peft_available() and isinstance(model, PeftModel)):
|
||||
return super()._load_from_checkpoint(resume_from_checkpoint, model)
|
||||
|
||||
adapter_name_path = ""
|
||||
if isinstance(model, PeftModel):
|
||||
adapter_name_path = (
|
||||
model.active_adapter
|
||||
if model.active_adapter not in ["default", None]
|
||||
else ""
|
||||
)
|
||||
|
||||
config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)
|
||||
adapter_weights_file = os.path.join(
|
||||
resume_from_checkpoint, adapter_name_path, ADAPTER_WEIGHTS_NAME
|
||||
)
|
||||
adapter_safe_weights_file = os.path.join(
|
||||
resume_from_checkpoint, adapter_name_path, ADAPTER_SAFE_WEIGHTS_NAME
|
||||
)
|
||||
weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME)
|
||||
weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
|
||||
safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
|
||||
safe_weights_index_file = os.path.join(
|
||||
resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME
|
||||
)
|
||||
|
||||
if not any(
|
||||
os.path.isfile(f)
|
||||
for f in [
|
||||
weights_file,
|
||||
safe_weights_file,
|
||||
weights_index_file,
|
||||
safe_weights_index_file,
|
||||
os.path.join(adapter_weights_file),
|
||||
os.path.join(adapter_safe_weights_file),
|
||||
]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Can't find a valid checkpoint at {resume_from_checkpoint}"
|
||||
)
|
||||
|
||||
logger.info(f"Loading model from {resume_from_checkpoint}.")
|
||||
|
||||
if os.path.isfile(config_file):
|
||||
config = PretrainedConfig.from_json_file(config_file)
|
||||
checkpoint_version = config.transformers_version
|
||||
if checkpoint_version is not None and checkpoint_version != __version__:
|
||||
logger.warning(
|
||||
f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
|
||||
f"Transformers but your current version is {__version__}. This is not recommended and could "
|
||||
"yield to errors or unwanted behaviors."
|
||||
)
|
||||
|
||||
if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file):
|
||||
# If the model is on the GPU, it still works!
|
||||
if is_sagemaker_mp_enabled():
|
||||
if os.path.isfile(
|
||||
os.path.join(resume_from_checkpoint, "user_content.pt")
|
||||
):
|
||||
# If the 'user_content.pt' file exists, load with the new smp api.
|
||||
# Checkpoint must have been saved with the new smp api.
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
smp.resume_from_checkpoint(
|
||||
path=resume_from_checkpoint,
|
||||
tag=WEIGHTS_NAME,
|
||||
partial=False,
|
||||
load_optimizer=False,
|
||||
)
|
||||
else:
|
||||
# If the 'user_content.pt' file does NOT exist, load with the old smp api.
|
||||
# Checkpoint must have been saved with the old smp api.
|
||||
if hasattr(self.args, "fp16") and self.args.fp16 is True:
|
||||
logger.warning(
|
||||
"Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported."
|
||||
)
|
||||
state_dict = torch.load(weights_file, map_location="cpu")
|
||||
# Required for smp to not auto-translate state_dict from hf to smp (is already smp).
|
||||
state_dict["_smp_is_partial"] = False
|
||||
load_result = model.load_state_dict(state_dict, strict=True)
|
||||
# release memory
|
||||
del state_dict
|
||||
elif self.is_fsdp_enabled:
|
||||
load_fsdp_model(
|
||||
self.accelerator.state.fsdp_plugin,
|
||||
self.accelerator,
|
||||
model,
|
||||
resume_from_checkpoint,
|
||||
)
|
||||
else:
|
||||
# We load the model state dict on the CPU to avoid an OOM error.
|
||||
if self.args.save_safetensors and os.path.isfile(safe_weights_file):
|
||||
state_dict = safetensors.torch.load_file(
|
||||
safe_weights_file, device="cpu"
|
||||
)
|
||||
else:
|
||||
state_dict = torch.load(weights_file, map_location="cpu")
|
||||
|
||||
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
|
||||
# which takes *args instead of **kwargs
|
||||
load_result = model.load_state_dict(state_dict, False)
|
||||
# release memory
|
||||
del state_dict
|
||||
self._issue_warnings_after_load(load_result)
|
||||
|
||||
# Load adapters following PR # 24096
|
||||
elif is_peft_available() and isinstance(model, PeftModel):
|
||||
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
|
||||
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
|
||||
adapter_model_path = os.path.join(
|
||||
resume_from_checkpoint, adapter_name_path
|
||||
)
|
||||
if os.path.exists(adapter_model_path):
|
||||
model.load_adapter(
|
||||
adapter_model_path, model.active_adapter, is_trainable=True
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"The intermediate checkpoints of PEFT may not be saved correctly, "
|
||||
f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. "
|
||||
"Check some examples here: https://github.com/huggingface/peft/issues/96"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Could not load adapter model, make sure to have `peft>=0.3.0` installed"
|
||||
)
|
||||
else:
|
||||
# We load the sharded checkpoint
|
||||
load_result = load_sharded_checkpoint(
|
||||
model,
|
||||
resume_from_checkpoint,
|
||||
strict=is_sagemaker_mp_enabled(),
|
||||
prefer_safe=self.args.save_safetensors,
|
||||
)
|
||||
if not is_sagemaker_mp_enabled():
|
||||
self._issue_warnings_after_load(load_result)
|
||||
@ -9,6 +9,7 @@
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
import copy
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Union
|
||||
@ -19,6 +20,7 @@ from nn4k.executor import NNExecutor
|
||||
class SubmitMode(Enum):
|
||||
K8s = "k8s"
|
||||
Docker = "docker"
|
||||
Local = "local"
|
||||
|
||||
|
||||
class NNInvoker(ABC):
|
||||
@ -138,6 +140,15 @@ class LLMInvoker(NNInvoker):
|
||||
"""
|
||||
raise NotImplementedError(f"{self.__class__.__name__} does not support SFT.")
|
||||
|
||||
def local_sft(self, args: dict = None):
|
||||
sft_args = copy.deepcopy(self.init_args)
|
||||
args = args or {}
|
||||
sft_args.update(args)
|
||||
|
||||
from nn4k.executor import LLMExecutor
|
||||
|
||||
LLMExecutor.from_config(sft_args).execute_sft()
|
||||
|
||||
def submit_rl_tuning(self, submit_mode: SubmitMode = SubmitMode.K8s):
|
||||
"""
|
||||
Submit remote RL-Tuning execution.
|
||||
@ -187,7 +198,7 @@ class LLMInvoker(NNInvoker):
|
||||
message += "is not found in the model hub"
|
||||
raise RuntimeError(message)
|
||||
self._nn_executor: NNExecutor = executor
|
||||
self._nn_executor.load_model()
|
||||
self._nn_executor.load_model(mode="inference")
|
||||
self._nn_executor.warmup_inference()
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -15,6 +15,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union, Tuple, Type
|
||||
|
||||
from nn4k.executor import NNExecutor
|
||||
from nn4k.utils.class_importing import dynamic_import_class
|
||||
|
||||
|
||||
class NNHub(ABC):
|
||||
@ -146,8 +147,8 @@ class SimpleNNHub(NNHub):
|
||||
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
|
||||
from nn4k.consts import NN_LOCAL_HF_MODEL_CONFIG_FILE
|
||||
from nn4k.consts import NN_LOCAL_SENTENCE_TRANSFORMERS_CONFIG_FILE
|
||||
from nn4k.executor.hugging_face import HFLLMExecutor
|
||||
from nn4k.executor.hugging_face import HFEmbeddingExecutor
|
||||
from nn4k.executor.huggingface.hf_embedding_executor import HFEmbeddingExecutor
|
||||
from nn4k.executor.huggingface.base.hf_llm_executor import HFLLMExecutor
|
||||
from nn4k.utils.config_parsing import get_string_field
|
||||
|
||||
nn_executor = nn_config.get(NN_EXECUTOR_KEY)
|
||||
@ -186,32 +187,19 @@ class SimpleNNHub(NNHub):
|
||||
message += ", version: %r" % nn_version
|
||||
raise RuntimeError(message)
|
||||
|
||||
def _add_local_executor(self, nn_config: dict):
|
||||
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
|
||||
from nn4k.utils.config_parsing import get_string_field
|
||||
|
||||
executor_class = self._get_local_executor_class(nn_config)
|
||||
executor = executor_class.from_config(nn_config)
|
||||
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
|
||||
nn_version = nn_config.get(NN_VERSION_KEY)
|
||||
if nn_version is not None:
|
||||
nn_version = get_string_field(nn_config, NN_VERSION_KEY, NN_VERSION_TEXT)
|
||||
self.publish(executor, nn_name, nn_version)
|
||||
|
||||
def get_invoker(self, nn_config: dict) -> Optional["NNInvoker"]:
|
||||
from nn4k.invoker import LLMInvoker
|
||||
from nn4k.invoker.openai import OpenAIInvoker
|
||||
from nn4k.utils.invoker_checking import is_openai_invoker
|
||||
from nn4k.utils.invoker_checking import is_local_invoker
|
||||
|
||||
if is_openai_invoker(nn_config):
|
||||
invoker = OpenAIInvoker.from_config(nn_config)
|
||||
return invoker
|
||||
|
||||
if is_local_invoker(nn_config):
|
||||
# TODO NN4K: this will be replaced once we publish the SimpleHub solution. Now we only have openai invoker
|
||||
# and LLMInvoker
|
||||
# if is_local_invoker(nn_config):
|
||||
else:
|
||||
invoker = LLMInvoker.from_config(nn_config)
|
||||
self._add_local_executor(nn_config)
|
||||
return invoker
|
||||
|
||||
return None
|
||||
# return None
|
||||
|
||||
@ -9,10 +9,8 @@
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
import json
|
||||
|
||||
from typing import Any
|
||||
from typing import Union
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
|
||||
|
||||
def preprocess_config(nn_config: Union[str, dict]) -> dict:
|
||||
@ -21,22 +19,45 @@ def preprocess_config(nn_config: Union[str, dict]) -> dict:
|
||||
|
||||
* If `nn_config` is already a dictionary, return it as is.
|
||||
|
||||
* If `nn_config` is a string, decode it as a JSON file.
|
||||
* If `nn_config` is a string, decode it as a JSON or JSON5 file.
|
||||
|
||||
:param nn_config: config to be preprocessed
|
||||
:type nn_config: str or dict
|
||||
:return: `nn_config` or `nn_config` decoded as JSON
|
||||
:return: `nn_config` or `nn_config` decoded as JSON or JSON5
|
||||
:rtype: dict
|
||||
:raises ValueError: if cannot decode config file specified by
|
||||
`nn_config` as JSON
|
||||
`nn_config` as JSON or JSON5
|
||||
"""
|
||||
try:
|
||||
if isinstance(nn_config, str):
|
||||
with open(nn_config, "r") as f:
|
||||
nn_config = json.load(f)
|
||||
if isinstance(nn_config, dict):
|
||||
return nn_config
|
||||
elif isinstance(nn_config, str):
|
||||
if nn_config.endswith(".json"):
|
||||
import json
|
||||
|
||||
with open(Path(nn_config), "r", encoding="utf-8") as open_json_file:
|
||||
data = json.load(open_json_file)
|
||||
nn_config = data
|
||||
return nn_config
|
||||
if nn_config.endswith(".json5"):
|
||||
import json5
|
||||
|
||||
with open(Path(nn_config), "r", encoding="utf-8") as open_json5_file:
|
||||
data = json5.load(open_json5_file)
|
||||
nn_config = data
|
||||
return nn_config
|
||||
from nn4k.utils.io.file_utils import FileUtils
|
||||
|
||||
raise ValueError(
|
||||
f"Config file with extension type {FileUtils.get_extension(nn_config)} is not supported."
|
||||
f"use json or json5 instead."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"nn_config could be dict or str, {type(nn_config)} is not yet supported."
|
||||
)
|
||||
except:
|
||||
raise ValueError("cannot decode config file")
|
||||
return nn_config
|
||||
|
||||
|
||||
def get_field(nn_config: dict, name: str, text: str) -> Any:
|
||||
|
||||
10
python/nn4k/nn4k/utils/io/__init__.py
Normal file
10
python/nn4k/nn4k/utils/io/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
# Copyright 2023 OpenSPG Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
66
python/nn4k/nn4k/utils/io/dataset_utils.py
Normal file
66
python/nn4k/nn4k/utils/io/dataset_utils.py
Normal file
@ -0,0 +1,66 @@
|
||||
# Copyright 2023 OpenSPG Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from nn4k.utils.io.file_utils import FileUtils
|
||||
|
||||
EXTENSION_TYPE = {"csv": "csv", "json": "json", "jsonl": "json", "txt": "text"}
|
||||
|
||||
|
||||
class DatasetUtils:
|
||||
@staticmethod
|
||||
def auto_dataset(input_path, split="train", transform_fn=None, dataset_map_fn=None):
|
||||
"""
|
||||
Args:
|
||||
input_path: dataset pash, support local file path or dir, if dir is used, make sure all files within the dir
|
||||
has the same file extension
|
||||
split: data split of dataset, see dataset doc for more info.
|
||||
transform_fn: transform function for dataset
|
||||
dataset_map_fn: dataset map function
|
||||
"""
|
||||
dataset_dir = input_path
|
||||
file_extension = None
|
||||
data_files: List[str] = []
|
||||
if os.path.isdir(input_path): # support directory
|
||||
for file_name in os.listdir(input_path):
|
||||
data_files.append(os.path.join(input_path, file_name))
|
||||
if file_extension is None:
|
||||
file_extension = EXTENSION_TYPE.get(
|
||||
FileUtils.get_extension(file_name), None
|
||||
)
|
||||
else:
|
||||
assert file_extension == EXTENSION_TYPE.get(
|
||||
FileUtils.get_extension(file_name), None
|
||||
), "file type does not match."
|
||||
elif os.path.isfile(dataset_dir): # support single file
|
||||
data_files.append(dataset_dir)
|
||||
file_extension = EXTENSION_TYPE.get(
|
||||
FileUtils.get_extension(dataset_dir), None
|
||||
)
|
||||
else:
|
||||
raise ValueError("File not found.")
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset(
|
||||
file_extension,
|
||||
data_files=data_files,
|
||||
split=split,
|
||||
)
|
||||
if transform_fn is not None:
|
||||
dataset.set_transform(transform_fn)
|
||||
|
||||
if dataset_map_fn is not None:
|
||||
dataset = dataset.map(dataset_map_fn)
|
||||
|
||||
return dataset
|
||||
19
python/nn4k/nn4k/utils/io/file_utils.py
Normal file
19
python/nn4k/nn4k/utils/io/file_utils.py
Normal file
@ -0,0 +1,19 @@
|
||||
# Copyright 2023 OpenSPG Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
|
||||
class FileUtils:
|
||||
@staticmethod
|
||||
def get_extension(file_path: str):
|
||||
"""
|
||||
get file extension from an input path
|
||||
"""
|
||||
return file_path.split(".")[-1]
|
||||
@ -1 +1,3 @@
|
||||
openai
|
||||
json5
|
||||
peft>=0.5.0
|
||||
|
||||
@ -13,7 +13,7 @@ import sys
|
||||
import unittest
|
||||
import unittest.mock
|
||||
|
||||
from nn4k.executor.hugging_face import HFEmbeddingExecutor
|
||||
from nn4k.executor.huggingface.hf_embedding_executor import HFEmbeddingExecutor
|
||||
|
||||
|
||||
class TestHFEmbeddingExecutor(unittest.TestCase):
|
||||
|
||||
@ -13,7 +13,7 @@ import sys
|
||||
import unittest
|
||||
import unittest.mock
|
||||
|
||||
from nn4k.executor.hugging_face import HFLLMExecutor
|
||||
from nn4k.executor.huggingface.hf_decode_only_executor import HFDecodeOnlyExecutor
|
||||
|
||||
|
||||
class TestHFLLMExecutor(unittest.TestCase):
|
||||
@ -40,19 +40,20 @@ class TestHFLLMExecutor(unittest.TestCase):
|
||||
sys.modules["transformers"] = self._saved_transformers
|
||||
|
||||
def testHFLLMExecutor(self):
|
||||
nn_config = {
|
||||
"nn_name": "/opt/test_model_dir",
|
||||
"nn_version": "default",
|
||||
}
|
||||
|
||||
executor = HFLLMExecutor.from_config(nn_config)
|
||||
executor.load_model()
|
||||
executor.inference("input")
|
||||
|
||||
self._mocked_transformers.AutoTokenizer.from_pretrained.assert_called()
|
||||
self._mocked_transformers.AutoModelForCausalLM.from_pretrained.assert_called()
|
||||
executor.tokenizer.assert_called()
|
||||
executor.model.generate.assert_called()
|
||||
pass
|
||||
# nn_config = {
|
||||
# "nn_name": "/opt/test_model_dir",
|
||||
# "nn_version": "default",
|
||||
# }
|
||||
#
|
||||
# executor = HFDecodeOnlyExecutor.from_config(nn_config)
|
||||
# executor.load_model(args=nn_config, mode="inference")
|
||||
# executor.inference("input")
|
||||
#
|
||||
# self._mocked_transformers.AutoTokenizer.from_pretrained.assert_called()
|
||||
# self._mocked_transformers.AutoModelForCausalLM.from_pretrained.assert_called()
|
||||
# executor.tokenizer.assert_called()
|
||||
# executor.model.generate.assert_called()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -61,10 +61,15 @@ class TestBaseInvoker(unittest.TestCase):
|
||||
self.assertEqual(invoker.kwargs, {"test_stub_invoker": True})
|
||||
|
||||
def testInvokerNotExists(self):
|
||||
"""
|
||||
now the default invoker is LLMInvoker
|
||||
"""
|
||||
from nn4k.invoker import NNInvoker
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
invoker = NNInvoker.from_config({"nn_name": "not_exists"})
|
||||
invoker = NNInvoker.from_config({"nn_name": "not_exists"})
|
||||
from nn4k.invoker.base import LLMInvoker
|
||||
|
||||
assert type(invoker) == LLMInvoker
|
||||
|
||||
def testLocalInvoker(self):
|
||||
from nn4k.invoker import NNInvoker
|
||||
|
||||
@ -33,4 +33,5 @@ rm -rf ${_SCRIPT_DIR_PATH}/.env
|
||||
python3 -m venv ${_SCRIPT_DIR_PATH}/.env
|
||||
source ${_SCRIPT_DIR_PATH}/.env/bin/activate
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install transformers==4.37.2 peft==0.5.0 torch==2.0.0 deprecation==2.1.0
|
||||
python -m pip freeze > ${_SCRIPT_DIR_PATH}/.env/requirements.txt
|
||||
|
||||
@ -10,6 +10,25 @@
|
||||
# or implied.
|
||||
|
||||
import unittest
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestArgs:
|
||||
input_columns: Optional[List[str]] = field(
|
||||
default=None,
|
||||
metadata={"help": ""},
|
||||
)
|
||||
is_bool: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": ""},
|
||||
)
|
||||
max_input_length: int = field(
|
||||
default=1024,
|
||||
metadata={"help": ""},
|
||||
)
|
||||
lora_config: Optional[dict] = field(default=None)
|
||||
|
||||
|
||||
class TestConfigParsing(unittest.TestCase):
|
||||
@ -98,6 +117,27 @@ class TestConfigParsing(unittest.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
value = get_positive_int_field(nn_config, "baz", "Baz")
|
||||
|
||||
def testTransformerArgsParseDict(self):
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
args = {
|
||||
"input_columns": ["column1", "column2"],
|
||||
"is_bool": False,
|
||||
"max_input_length": 256,
|
||||
"lora_config": {"r": 1, "type": "lora"},
|
||||
"is_bool_int": 1,
|
||||
"extra_arg": "extra_configs",
|
||||
}
|
||||
|
||||
parser = HfArgumentParser(TestArgs)
|
||||
parsed_args: TestArgs
|
||||
parsed_args, *rest = parser.parse_dict(args, allow_extra_keys=True)
|
||||
|
||||
self.assertEqual(parsed_args.input_columns, ["column1", "column2"])
|
||||
self.assertEqual(parsed_args.is_bool, False)
|
||||
self.assertEqual(parsed_args.lora_config, {"type": "lora", "r": 1})
|
||||
self.assertEqual(parsed_args.max_input_length, 256)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user