diff --git a/python/nn4k/executor/base.py b/python/nn4k/executor/base.py index 3d29e95d..799c5266 100644 --- a/python/nn4k/executor/base.py +++ b/python/nn4k/executor/base.py @@ -36,8 +36,14 @@ class LLMExecutor(NNExecutor): nn_config """ - # TODO - pass + if "nn_name" in nn_config: + from nn4k.executor.hugging_face import HfLLMExecutor + + return HfLLMExecutor.from_config(nn_config) + else: + o = cls.__new__(cls) + o._nn_config = nn_config + return o @abstractmethod def execute_sft(self, args=None, callbacks=None, **kwargs): diff --git a/python/nn4k/executor/hugging_face.py b/python/nn4k/executor/hugging_face.py index 1e4dbdb7..504a5932 100644 --- a/python/nn4k/executor/hugging_face.py +++ b/python/nn4k/executor/hugging_face.py @@ -13,5 +13,94 @@ from nn4k.executor import NNExecutor class HfLLMExecutor(NNExecutor): + @classmethod + def _parse_config(cls, nn_config: dict) -> dict: + from nn4k.utils.config_parsing import get_string_field - pass + nn_name = get_string_field( + nn_config, "nn_name", "NN model name" + ) + nn_version= get_string_field( + nn_config, "nn_version", "NN model version" + ) + config = dict( + nn_name=nn_name, + nn_version=nn_version, + ) + return config + + @classmethod + def from_config(cls, nn_config: dict): + config = cls._parse_config(nn_config) + + o = cls.__new__(cls) + o._nn_config = nn_config + o._nn_name = config["nn_name"] + o._nn_version = config["nn_version"] + o._nn_device = None + o._nn_tokenizer = None + o._nn_model = None + + return o + + def _load_model(self): + import torch + from transformers import AutoTokenizer + from transformers import AutoModelForCausalLM + + if self._nn_model is None: + model_path = self._nn_name + revision = self._nn_version + use_fast_tokenizer = False + device = self._nn_config.get('nn_device') + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + tokenizer = AutoTokenizer.from_pretrained( + model_path, + use_fast=use_fast_tokenizer, + revision=revision + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + torch_dtype=torch.float16, + revision=revision, + ) + model.to(device) + self._nn_device = device + self._nn_tokenizer = tokenizer + self._nn_model = model + + def _get_tokenizer(self): + if self._nn_model is None: + self._load_model() + return self._nn_tokenizer + + def _get_model(self): + if self._nn_model is None: + self._load_model() + return self._nn_model + + def inference(self, data, **kwargs): + nn_tokenizer = self._get_tokenizer() + nn_model = self._get_model() + input_ids = nn_tokenizer(data, + padding=True, + return_token_type_ids=False, + return_tensors="pt", + truncation=True, + max_length=64).to(self._nn_device) + output_ids = nn_model.generate(**input_ids, + max_new_tokens=1024, + do_sample=False, + eos_token_id=nn_tokenizer.eos_token_id, + pad_token_id=nn_tokenizer.pad_token_id) + outputs = [nn_tokenizer.decode(output_id[len(input_ids["input_ids"][idx]):], + skip_special_tokens=True) + for idx, output_id in enumerate(output_ids)] + + outputs = [nn_tokenizer.decode(output_id[:], + skip_special_tokens=True) + for idx, output_id in enumerate(output_ids)] + + return outputs