feat(nn4k): not load model to cuda when using deepspeed (#207)

This commit is contained in:
chenbin11200 2024-04-19 11:20:25 +08:00 committed by GitHub
parent 5ca3f17718
commit ff948d5403
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 19 additions and 9 deletions

View File

@ -78,7 +78,7 @@ class NNExecutor(ABC):
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def load_model(self, args=None, mode=None, **kwargs): def load_model(self, args=None, mode=None, model_to_cuda=True, **kwargs):
""" """
Implement model loading logic in derived executor classes. Implement model loading logic in derived executor classes.

View File

@ -44,8 +44,6 @@ class HFLLMExecutor(LLMExecutor):
def execute_sft(self, args: dict = None, callbacks=None, **kwargs): def execute_sft(self, args: dict = None, callbacks=None, **kwargs):
args = args or self.init_args args = args or self.init_args
self.load_model(args=args, mode="train")
# parse args into HFSftArgs dataclass for more convenient features # parse args into HFSftArgs dataclass for more convenient features
from transformers import HfArgumentParser from transformers import HfArgumentParser
@ -53,6 +51,11 @@ class HFLLMExecutor(LLMExecutor):
hf_sft_args: HFSftArgs hf_sft_args: HFSftArgs
hf_sft_args, *_ = parser.parse_dict(args, allow_extra_keys=True) hf_sft_args, *_ = parser.parse_dict(args, allow_extra_keys=True)
model_to_cuda = (
False if hf_sft_args.deepspeed else True
) # if using deepspeed, don't load model to cuda in nn4k
self.load_model(args=args, mode="train", model_to_cuda=model_to_cuda)
# load checkpoint path if necessary. # load checkpoint path if necessary.
resume_from_checkpoint_path = self._get_last_checkpoint(hf_sft_args) resume_from_checkpoint_path = self._get_last_checkpoint(hf_sft_args)
@ -177,7 +180,7 @@ class HFLLMExecutor(LLMExecutor):
return DatasetUtils.auto_dataset(data_path, split) return DatasetUtils.auto_dataset(data_path, split)
def load_model(self, args: dict = None, mode=None, **kwargs): def load_model(self, args: dict = None, mode=None, model_to_cuda=True, **kwargs):
""" """
load model and tokenizer. If the model with the same mode is already loaded, will not load again. load model and tokenizer. If the model with the same mode is already loaded, will not load again.
""" """
@ -201,7 +204,10 @@ class HFLLMExecutor(LLMExecutor):
self.model_mode = mode self.model_mode = mode
self._tokenizer = self._hf_tokenizer_loader(hf_model_args) self._tokenizer = self._hf_tokenizer_loader(hf_model_args)
self._model = self._hf_model_loader( self._model = self._hf_model_loader(
args=hf_model_args, mode=mode, device=hf_model_args.nn_device args=hf_model_args,
mode=mode,
device=hf_model_args.nn_device,
model_to_cuda=model_to_cuda,
) )
if self.tokenizer.eos_token_id is None: if self.tokenizer.eos_token_id is None:
@ -276,6 +282,7 @@ class HFLLMExecutor(LLMExecutor):
mode, mode,
resume_from_checkpoint=False, resume_from_checkpoint=False,
device=None, device=None,
model_to_cuda=True,
**kwargs, **kwargs,
): ):
""" """

View File

@ -29,6 +29,7 @@ class HFDecodeOnlyExecutor(HFLLMExecutor):
mode, mode,
resume_from_checkpoint=False, resume_from_checkpoint=False,
device=None, device=None,
model_to_cuda=True,
**kwargs, **kwargs,
): ):
if device is None or "auto": if device is None or "auto":
@ -104,6 +105,8 @@ class HFDecodeOnlyExecutor(HFLLMExecutor):
if mode == "inference": if mode == "inference":
model.eval() model.eval()
if model_to_cuda:
model.to(device) model.to(device)
return model return model

View File

@ -20,7 +20,7 @@ class HFEmbeddingExecutor(LLMExecutor):
executor = cls(nn_config) executor = cls(nn_config)
return executor return executor
def load_model(self, args=None, **kwargs): def load_model(self, args=None, mode=None, model_to_cuda=True, **kwargs):
import torch import torch
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT

View File

@ -16,7 +16,7 @@ from nn4k.nnhub import SimpleNNHub
class StubExecutor(LLMExecutor): class StubExecutor(LLMExecutor):
def load_model(self, args=None, mode=None, **kwargs): def load_model(self, args=None, mode=None, model_to_cuda=True, **kwargs):
pass pass
def warmup_inference(self, args=None, **kwargs): def warmup_inference(self, args=None, **kwargs):

View File

@ -31,7 +31,7 @@ class NotInvoker:
class StubExecutor(NNExecutor): class StubExecutor(NNExecutor):
def load_model(self, args=None, mode=None, **kwargs): def load_model(self, args=None, mode=None, model_to_cuda=True, **kwargs):
self.load_model_called = True self.load_model_called = True
def warmup_inference(self, args=None, **kwargs): def warmup_inference(self, args=None, **kwargs):