mirror of
https://github.com/OpenSPG/openspg.git
synced 2025-06-27 03:20:10 +00:00
feat(nn4k): not load model to cuda when using deepspeed (#207)
This commit is contained in:
parent
5ca3f17718
commit
ff948d5403
@ -78,7 +78,7 @@ class NNExecutor(ABC):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, args=None, mode=None, **kwargs):
|
||||
def load_model(self, args=None, mode=None, model_to_cuda=True, **kwargs):
|
||||
"""
|
||||
Implement model loading logic in derived executor classes.
|
||||
|
||||
|
@ -44,8 +44,6 @@ class HFLLMExecutor(LLMExecutor):
|
||||
def execute_sft(self, args: dict = None, callbacks=None, **kwargs):
|
||||
args = args or self.init_args
|
||||
|
||||
self.load_model(args=args, mode="train")
|
||||
|
||||
# parse args into HFSftArgs dataclass for more convenient features
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
@ -53,6 +51,11 @@ class HFLLMExecutor(LLMExecutor):
|
||||
hf_sft_args: HFSftArgs
|
||||
hf_sft_args, *_ = parser.parse_dict(args, allow_extra_keys=True)
|
||||
|
||||
model_to_cuda = (
|
||||
False if hf_sft_args.deepspeed else True
|
||||
) # if using deepspeed, don't load model to cuda in nn4k
|
||||
self.load_model(args=args, mode="train", model_to_cuda=model_to_cuda)
|
||||
|
||||
# load checkpoint path if necessary.
|
||||
resume_from_checkpoint_path = self._get_last_checkpoint(hf_sft_args)
|
||||
|
||||
@ -177,7 +180,7 @@ class HFLLMExecutor(LLMExecutor):
|
||||
|
||||
return DatasetUtils.auto_dataset(data_path, split)
|
||||
|
||||
def load_model(self, args: dict = None, mode=None, **kwargs):
|
||||
def load_model(self, args: dict = None, mode=None, model_to_cuda=True, **kwargs):
|
||||
"""
|
||||
load model and tokenizer. If the model with the same mode is already loaded, will not load again.
|
||||
"""
|
||||
@ -201,7 +204,10 @@ class HFLLMExecutor(LLMExecutor):
|
||||
self.model_mode = mode
|
||||
self._tokenizer = self._hf_tokenizer_loader(hf_model_args)
|
||||
self._model = self._hf_model_loader(
|
||||
args=hf_model_args, mode=mode, device=hf_model_args.nn_device
|
||||
args=hf_model_args,
|
||||
mode=mode,
|
||||
device=hf_model_args.nn_device,
|
||||
model_to_cuda=model_to_cuda,
|
||||
)
|
||||
|
||||
if self.tokenizer.eos_token_id is None:
|
||||
@ -276,6 +282,7 @@ class HFLLMExecutor(LLMExecutor):
|
||||
mode,
|
||||
resume_from_checkpoint=False,
|
||||
device=None,
|
||||
model_to_cuda=True,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
@ -29,6 +29,7 @@ class HFDecodeOnlyExecutor(HFLLMExecutor):
|
||||
mode,
|
||||
resume_from_checkpoint=False,
|
||||
device=None,
|
||||
model_to_cuda=True,
|
||||
**kwargs,
|
||||
):
|
||||
if device is None or "auto":
|
||||
@ -104,7 +105,9 @@ class HFDecodeOnlyExecutor(HFLLMExecutor):
|
||||
|
||||
if mode == "inference":
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
if model_to_cuda:
|
||||
model.to(device)
|
||||
|
||||
return model
|
||||
|
||||
|
@ -20,7 +20,7 @@ class HFEmbeddingExecutor(LLMExecutor):
|
||||
executor = cls(nn_config)
|
||||
return executor
|
||||
|
||||
def load_model(self, args=None, **kwargs):
|
||||
def load_model(self, args=None, mode=None, model_to_cuda=True, **kwargs):
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||
|
@ -16,7 +16,7 @@ from nn4k.nnhub import SimpleNNHub
|
||||
|
||||
|
||||
class StubExecutor(LLMExecutor):
|
||||
def load_model(self, args=None, mode=None, **kwargs):
|
||||
def load_model(self, args=None, mode=None, model_to_cuda=True, **kwargs):
|
||||
pass
|
||||
|
||||
def warmup_inference(self, args=None, **kwargs):
|
||||
|
@ -31,7 +31,7 @@ class NotInvoker:
|
||||
|
||||
|
||||
class StubExecutor(NNExecutor):
|
||||
def load_model(self, args=None, mode=None, **kwargs):
|
||||
def load_model(self, args=None, mode=None, model_to_cuda=True, **kwargs):
|
||||
self.load_model_called = True
|
||||
|
||||
def warmup_inference(self, args=None, **kwargs):
|
||||
|
Loading…
x
Reference in New Issue
Block a user