diff --git a/python/nn4k/nn4k/executor/base.py b/python/nn4k/nn4k/executor/base.py index 2abf643c..b61d2372 100644 --- a/python/nn4k/nn4k/executor/base.py +++ b/python/nn4k/nn4k/executor/base.py @@ -78,7 +78,7 @@ class NNExecutor(ABC): raise NotImplementedError() @abstractmethod - def load_model(self, args=None, mode=None, **kwargs): + def load_model(self, args=None, mode=None, model_to_cuda=True, **kwargs): """ Implement model loading logic in derived executor classes. diff --git a/python/nn4k/nn4k/executor/huggingface/base/hf_llm_executor.py b/python/nn4k/nn4k/executor/huggingface/base/hf_llm_executor.py index 5be7ec4b..724f0409 100644 --- a/python/nn4k/nn4k/executor/huggingface/base/hf_llm_executor.py +++ b/python/nn4k/nn4k/executor/huggingface/base/hf_llm_executor.py @@ -44,8 +44,6 @@ class HFLLMExecutor(LLMExecutor): def execute_sft(self, args: dict = None, callbacks=None, **kwargs): args = args or self.init_args - self.load_model(args=args, mode="train") - # parse args into HFSftArgs dataclass for more convenient features from transformers import HfArgumentParser @@ -53,6 +51,11 @@ class HFLLMExecutor(LLMExecutor): hf_sft_args: HFSftArgs hf_sft_args, *_ = parser.parse_dict(args, allow_extra_keys=True) + model_to_cuda = ( + False if hf_sft_args.deepspeed else True + ) # if using deepspeed, don't load model to cuda in nn4k + self.load_model(args=args, mode="train", model_to_cuda=model_to_cuda) + # load checkpoint path if necessary. resume_from_checkpoint_path = self._get_last_checkpoint(hf_sft_args) @@ -177,7 +180,7 @@ class HFLLMExecutor(LLMExecutor): return DatasetUtils.auto_dataset(data_path, split) - def load_model(self, args: dict = None, mode=None, **kwargs): + def load_model(self, args: dict = None, mode=None, model_to_cuda=True, **kwargs): """ load model and tokenizer. If the model with the same mode is already loaded, will not load again. """ @@ -201,7 +204,10 @@ class HFLLMExecutor(LLMExecutor): self.model_mode = mode self._tokenizer = self._hf_tokenizer_loader(hf_model_args) self._model = self._hf_model_loader( - args=hf_model_args, mode=mode, device=hf_model_args.nn_device + args=hf_model_args, + mode=mode, + device=hf_model_args.nn_device, + model_to_cuda=model_to_cuda, ) if self.tokenizer.eos_token_id is None: @@ -276,6 +282,7 @@ class HFLLMExecutor(LLMExecutor): mode, resume_from_checkpoint=False, device=None, + model_to_cuda=True, **kwargs, ): """ diff --git a/python/nn4k/nn4k/executor/huggingface/hf_decode_only_executor.py b/python/nn4k/nn4k/executor/huggingface/hf_decode_only_executor.py index 4368aa2b..7ff68e24 100644 --- a/python/nn4k/nn4k/executor/huggingface/hf_decode_only_executor.py +++ b/python/nn4k/nn4k/executor/huggingface/hf_decode_only_executor.py @@ -29,6 +29,7 @@ class HFDecodeOnlyExecutor(HFLLMExecutor): mode, resume_from_checkpoint=False, device=None, + model_to_cuda=True, **kwargs, ): if device is None or "auto": @@ -104,7 +105,9 @@ class HFDecodeOnlyExecutor(HFLLMExecutor): if mode == "inference": model.eval() - model.to(device) + + if model_to_cuda: + model.to(device) return model diff --git a/python/nn4k/nn4k/executor/huggingface/hf_embedding_executor.py b/python/nn4k/nn4k/executor/huggingface/hf_embedding_executor.py index 96d19097..f619d573 100644 --- a/python/nn4k/nn4k/executor/huggingface/hf_embedding_executor.py +++ b/python/nn4k/nn4k/executor/huggingface/hf_embedding_executor.py @@ -20,7 +20,7 @@ class HFEmbeddingExecutor(LLMExecutor): executor = cls(nn_config) return executor - def load_model(self, args=None, **kwargs): + def load_model(self, args=None, mode=None, model_to_cuda=True, **kwargs): import torch from sentence_transformers import SentenceTransformer from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT diff --git a/python/nn4k/tests/executor/executor_test_stub.py b/python/nn4k/tests/executor/executor_test_stub.py index e1820b4c..36f32633 100644 --- a/python/nn4k/tests/executor/executor_test_stub.py +++ b/python/nn4k/tests/executor/executor_test_stub.py @@ -16,7 +16,7 @@ from nn4k.nnhub import SimpleNNHub class StubExecutor(LLMExecutor): - def load_model(self, args=None, mode=None, **kwargs): + def load_model(self, args=None, mode=None, model_to_cuda=True, **kwargs): pass def warmup_inference(self, args=None, **kwargs): diff --git a/python/nn4k/tests/invoker/invoker_test_stub.py b/python/nn4k/tests/invoker/invoker_test_stub.py index b96475fd..ae7ea5aa 100644 --- a/python/nn4k/tests/invoker/invoker_test_stub.py +++ b/python/nn4k/tests/invoker/invoker_test_stub.py @@ -31,7 +31,7 @@ class NotInvoker: class StubExecutor(NNExecutor): - def load_model(self, args=None, mode=None, **kwargs): + def load_model(self, args=None, mode=None, model_to_cuda=True, **kwargs): self.load_model_called = True def warmup_inference(self, args=None, **kwargs):