mirror of
https://github.com/OpenSPG/openspg.git
synced 2025-09-15 11:37:22 +00:00
feat(nn4k): implement HfLLMExecutor
This commit is contained in:
parent
f58d82557f
commit
c1367f284f
@ -36,8 +36,14 @@ class LLMExecutor(NNExecutor):
|
||||
nn_config
|
||||
"""
|
||||
|
||||
# TODO
|
||||
pass
|
||||
if "nn_name" in nn_config:
|
||||
from nn4k.executor.hugging_face import HfLLMExecutor
|
||||
|
||||
return HfLLMExecutor.from_config(nn_config)
|
||||
else:
|
||||
o = cls.__new__(cls)
|
||||
o._nn_config = nn_config
|
||||
return o
|
||||
|
||||
@abstractmethod
|
||||
def execute_sft(self, args=None, callbacks=None, **kwargs):
|
||||
|
@ -13,5 +13,94 @@ from nn4k.executor import NNExecutor
|
||||
|
||||
|
||||
class HfLLMExecutor(NNExecutor):
|
||||
@classmethod
|
||||
def _parse_config(cls, nn_config: dict) -> dict:
|
||||
from nn4k.utils.config_parsing import get_string_field
|
||||
|
||||
pass
|
||||
nn_name = get_string_field(
|
||||
nn_config, "nn_name", "NN model name"
|
||||
)
|
||||
nn_version= get_string_field(
|
||||
nn_config, "nn_version", "NN model version"
|
||||
)
|
||||
config = dict(
|
||||
nn_name=nn_name,
|
||||
nn_version=nn_version,
|
||||
)
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, nn_config: dict):
|
||||
config = cls._parse_config(nn_config)
|
||||
|
||||
o = cls.__new__(cls)
|
||||
o._nn_config = nn_config
|
||||
o._nn_name = config["nn_name"]
|
||||
o._nn_version = config["nn_version"]
|
||||
o._nn_device = None
|
||||
o._nn_tokenizer = None
|
||||
o._nn_model = None
|
||||
|
||||
return o
|
||||
|
||||
def _load_model(self):
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
if self._nn_model is None:
|
||||
model_path = self._nn_name
|
||||
revision = self._nn_version
|
||||
use_fast_tokenizer = False
|
||||
device = self._nn_config.get('nn_device')
|
||||
if device is None:
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path,
|
||||
use_fast=use_fast_tokenizer,
|
||||
revision=revision
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
low_cpu_mem_usage=True,
|
||||
torch_dtype=torch.float16,
|
||||
revision=revision,
|
||||
)
|
||||
model.to(device)
|
||||
self._nn_device = device
|
||||
self._nn_tokenizer = tokenizer
|
||||
self._nn_model = model
|
||||
|
||||
def _get_tokenizer(self):
|
||||
if self._nn_model is None:
|
||||
self._load_model()
|
||||
return self._nn_tokenizer
|
||||
|
||||
def _get_model(self):
|
||||
if self._nn_model is None:
|
||||
self._load_model()
|
||||
return self._nn_model
|
||||
|
||||
def inference(self, data, **kwargs):
|
||||
nn_tokenizer = self._get_tokenizer()
|
||||
nn_model = self._get_model()
|
||||
input_ids = nn_tokenizer(data,
|
||||
padding=True,
|
||||
return_token_type_ids=False,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
max_length=64).to(self._nn_device)
|
||||
output_ids = nn_model.generate(**input_ids,
|
||||
max_new_tokens=1024,
|
||||
do_sample=False,
|
||||
eos_token_id=nn_tokenizer.eos_token_id,
|
||||
pad_token_id=nn_tokenizer.pad_token_id)
|
||||
outputs = [nn_tokenizer.decode(output_id[len(input_ids["input_ids"][idx]):],
|
||||
skip_special_tokens=True)
|
||||
for idx, output_id in enumerate(output_ids)]
|
||||
|
||||
outputs = [nn_tokenizer.decode(output_id[:],
|
||||
skip_special_tokens=True)
|
||||
for idx, output_id in enumerate(output_ids)]
|
||||
|
||||
return outputs
|
||||
|
Loading…
x
Reference in New Issue
Block a user