feat(nn4k): not load model to cuda when using deepspeed (#207)

This commit is contained in:
chenbin11200 2024-04-19 11:20:25 +08:00 committed by GitHub
parent 5ca3f17718
commit ff948d5403
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 19 additions and 9 deletions

View File

@ -78,7 +78,7 @@ class NNExecutor(ABC):
raise NotImplementedError()
@abstractmethod
def load_model(self, args=None, mode=None, **kwargs):
def load_model(self, args=None, mode=None, model_to_cuda=True, **kwargs):
"""
Implement model loading logic in derived executor classes.

View File

@ -44,8 +44,6 @@ class HFLLMExecutor(LLMExecutor):
def execute_sft(self, args: dict = None, callbacks=None, **kwargs):
args = args or self.init_args
self.load_model(args=args, mode="train")
# parse args into HFSftArgs dataclass for more convenient features
from transformers import HfArgumentParser
@ -53,6 +51,11 @@ class HFLLMExecutor(LLMExecutor):
hf_sft_args: HFSftArgs
hf_sft_args, *_ = parser.parse_dict(args, allow_extra_keys=True)
model_to_cuda = (
False if hf_sft_args.deepspeed else True
) # if using deepspeed, don't load model to cuda in nn4k
self.load_model(args=args, mode="train", model_to_cuda=model_to_cuda)
# load checkpoint path if necessary.
resume_from_checkpoint_path = self._get_last_checkpoint(hf_sft_args)
@ -177,7 +180,7 @@ class HFLLMExecutor(LLMExecutor):
return DatasetUtils.auto_dataset(data_path, split)
def load_model(self, args: dict = None, mode=None, **kwargs):
def load_model(self, args: dict = None, mode=None, model_to_cuda=True, **kwargs):
"""
load model and tokenizer. If the model with the same mode is already loaded, will not load again.
"""
@ -201,7 +204,10 @@ class HFLLMExecutor(LLMExecutor):
self.model_mode = mode
self._tokenizer = self._hf_tokenizer_loader(hf_model_args)
self._model = self._hf_model_loader(
args=hf_model_args, mode=mode, device=hf_model_args.nn_device
args=hf_model_args,
mode=mode,
device=hf_model_args.nn_device,
model_to_cuda=model_to_cuda,
)
if self.tokenizer.eos_token_id is None:
@ -276,6 +282,7 @@ class HFLLMExecutor(LLMExecutor):
mode,
resume_from_checkpoint=False,
device=None,
model_to_cuda=True,
**kwargs,
):
"""

View File

@ -29,6 +29,7 @@ class HFDecodeOnlyExecutor(HFLLMExecutor):
mode,
resume_from_checkpoint=False,
device=None,
model_to_cuda=True,
**kwargs,
):
if device is None or "auto":
@ -104,7 +105,9 @@ class HFDecodeOnlyExecutor(HFLLMExecutor):
if mode == "inference":
model.eval()
model.to(device)
if model_to_cuda:
model.to(device)
return model

View File

@ -20,7 +20,7 @@ class HFEmbeddingExecutor(LLMExecutor):
executor = cls(nn_config)
return executor
def load_model(self, args=None, **kwargs):
def load_model(self, args=None, mode=None, model_to_cuda=True, **kwargs):
import torch
from sentence_transformers import SentenceTransformer
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT

View File

@ -16,7 +16,7 @@ from nn4k.nnhub import SimpleNNHub
class StubExecutor(LLMExecutor):
def load_model(self, args=None, mode=None, **kwargs):
def load_model(self, args=None, mode=None, model_to_cuda=True, **kwargs):
pass
def warmup_inference(self, args=None, **kwargs):

View File

@ -31,7 +31,7 @@ class NotInvoker:
class StubExecutor(NNExecutor):
def load_model(self, args=None, mode=None, **kwargs):
def load_model(self, args=None, mode=None, model_to_cuda=True, **kwargs):
self.load_model_called = True
def warmup_inference(self, args=None, **kwargs):