mirror of
https://github.com/OpenSPG/openspg.git
synced 2025-07-18 14:36:19 +00:00
feat(nn4k): not load model to cuda when using deepspeed (#207)
This commit is contained in:
parent
5ca3f17718
commit
ff948d5403
@ -78,7 +78,7 @@ class NNExecutor(ABC):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_model(self, args=None, mode=None, **kwargs):
|
def load_model(self, args=None, mode=None, model_to_cuda=True, **kwargs):
|
||||||
"""
|
"""
|
||||||
Implement model loading logic in derived executor classes.
|
Implement model loading logic in derived executor classes.
|
||||||
|
|
||||||
|
@ -44,8 +44,6 @@ class HFLLMExecutor(LLMExecutor):
|
|||||||
def execute_sft(self, args: dict = None, callbacks=None, **kwargs):
|
def execute_sft(self, args: dict = None, callbacks=None, **kwargs):
|
||||||
args = args or self.init_args
|
args = args or self.init_args
|
||||||
|
|
||||||
self.load_model(args=args, mode="train")
|
|
||||||
|
|
||||||
# parse args into HFSftArgs dataclass for more convenient features
|
# parse args into HFSftArgs dataclass for more convenient features
|
||||||
from transformers import HfArgumentParser
|
from transformers import HfArgumentParser
|
||||||
|
|
||||||
@ -53,6 +51,11 @@ class HFLLMExecutor(LLMExecutor):
|
|||||||
hf_sft_args: HFSftArgs
|
hf_sft_args: HFSftArgs
|
||||||
hf_sft_args, *_ = parser.parse_dict(args, allow_extra_keys=True)
|
hf_sft_args, *_ = parser.parse_dict(args, allow_extra_keys=True)
|
||||||
|
|
||||||
|
model_to_cuda = (
|
||||||
|
False if hf_sft_args.deepspeed else True
|
||||||
|
) # if using deepspeed, don't load model to cuda in nn4k
|
||||||
|
self.load_model(args=args, mode="train", model_to_cuda=model_to_cuda)
|
||||||
|
|
||||||
# load checkpoint path if necessary.
|
# load checkpoint path if necessary.
|
||||||
resume_from_checkpoint_path = self._get_last_checkpoint(hf_sft_args)
|
resume_from_checkpoint_path = self._get_last_checkpoint(hf_sft_args)
|
||||||
|
|
||||||
@ -177,7 +180,7 @@ class HFLLMExecutor(LLMExecutor):
|
|||||||
|
|
||||||
return DatasetUtils.auto_dataset(data_path, split)
|
return DatasetUtils.auto_dataset(data_path, split)
|
||||||
|
|
||||||
def load_model(self, args: dict = None, mode=None, **kwargs):
|
def load_model(self, args: dict = None, mode=None, model_to_cuda=True, **kwargs):
|
||||||
"""
|
"""
|
||||||
load model and tokenizer. If the model with the same mode is already loaded, will not load again.
|
load model and tokenizer. If the model with the same mode is already loaded, will not load again.
|
||||||
"""
|
"""
|
||||||
@ -201,7 +204,10 @@ class HFLLMExecutor(LLMExecutor):
|
|||||||
self.model_mode = mode
|
self.model_mode = mode
|
||||||
self._tokenizer = self._hf_tokenizer_loader(hf_model_args)
|
self._tokenizer = self._hf_tokenizer_loader(hf_model_args)
|
||||||
self._model = self._hf_model_loader(
|
self._model = self._hf_model_loader(
|
||||||
args=hf_model_args, mode=mode, device=hf_model_args.nn_device
|
args=hf_model_args,
|
||||||
|
mode=mode,
|
||||||
|
device=hf_model_args.nn_device,
|
||||||
|
model_to_cuda=model_to_cuda,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.tokenizer.eos_token_id is None:
|
if self.tokenizer.eos_token_id is None:
|
||||||
@ -276,6 +282,7 @@ class HFLLMExecutor(LLMExecutor):
|
|||||||
mode,
|
mode,
|
||||||
resume_from_checkpoint=False,
|
resume_from_checkpoint=False,
|
||||||
device=None,
|
device=None,
|
||||||
|
model_to_cuda=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -29,6 +29,7 @@ class HFDecodeOnlyExecutor(HFLLMExecutor):
|
|||||||
mode,
|
mode,
|
||||||
resume_from_checkpoint=False,
|
resume_from_checkpoint=False,
|
||||||
device=None,
|
device=None,
|
||||||
|
model_to_cuda=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if device is None or "auto":
|
if device is None or "auto":
|
||||||
@ -104,6 +105,8 @@ class HFDecodeOnlyExecutor(HFLLMExecutor):
|
|||||||
|
|
||||||
if mode == "inference":
|
if mode == "inference":
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
if model_to_cuda:
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
@ -20,7 +20,7 @@ class HFEmbeddingExecutor(LLMExecutor):
|
|||||||
executor = cls(nn_config)
|
executor = cls(nn_config)
|
||||||
return executor
|
return executor
|
||||||
|
|
||||||
def load_model(self, args=None, **kwargs):
|
def load_model(self, args=None, mode=None, model_to_cuda=True, **kwargs):
|
||||||
import torch
|
import torch
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||||
|
@ -16,7 +16,7 @@ from nn4k.nnhub import SimpleNNHub
|
|||||||
|
|
||||||
|
|
||||||
class StubExecutor(LLMExecutor):
|
class StubExecutor(LLMExecutor):
|
||||||
def load_model(self, args=None, mode=None, **kwargs):
|
def load_model(self, args=None, mode=None, model_to_cuda=True, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def warmup_inference(self, args=None, **kwargs):
|
def warmup_inference(self, args=None, **kwargs):
|
||||||
|
@ -31,7 +31,7 @@ class NotInvoker:
|
|||||||
|
|
||||||
|
|
||||||
class StubExecutor(NNExecutor):
|
class StubExecutor(NNExecutor):
|
||||||
def load_model(self, args=None, mode=None, **kwargs):
|
def load_model(self, args=None, mode=None, model_to_cuda=True, **kwargs):
|
||||||
self.load_model_called = True
|
self.load_model_called = True
|
||||||
|
|
||||||
def warmup_inference(self, args=None, **kwargs):
|
def warmup_inference(self, args=None, **kwargs):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user