diff --git a/python/nn4k/executor/hugging_face.py b/python/nn4k/executor/hugging_face.py index 504a5932..5957f119 100644 --- a/python/nn4k/executor/hugging_face.py +++ b/python/nn4k/executor/hugging_face.py @@ -17,12 +17,8 @@ class HfLLMExecutor(NNExecutor): def _parse_config(cls, nn_config: dict) -> dict: from nn4k.utils.config_parsing import get_string_field - nn_name = get_string_field( - nn_config, "nn_name", "NN model name" - ) - nn_version= get_string_field( - nn_config, "nn_version", "NN model version" - ) + nn_name = get_string_field(nn_config, "nn_name", "NN model name") + nn_version = get_string_field(nn_config, "nn_version", "NN model version") config = dict( nn_name=nn_name, nn_version=nn_version, @@ -52,13 +48,11 @@ class HfLLMExecutor(NNExecutor): model_path = self._nn_name revision = self._nn_version use_fast_tokenizer = False - device = self._nn_config.get('nn_device') + device = self._nn_config.get("nn_device") if device is None: - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained( - model_path, - use_fast=use_fast_tokenizer, - revision=revision + model_path, use_fast=use_fast_tokenizer, revision=revision ) model = AutoModelForCausalLM.from_pretrained( model_path, @@ -84,23 +78,31 @@ class HfLLMExecutor(NNExecutor): def inference(self, data, **kwargs): nn_tokenizer = self._get_tokenizer() nn_model = self._get_model() - input_ids = nn_tokenizer(data, - padding=True, - return_token_type_ids=False, - return_tensors="pt", - truncation=True, - max_length=64).to(self._nn_device) - output_ids = nn_model.generate(**input_ids, - max_new_tokens=1024, - do_sample=False, - eos_token_id=nn_tokenizer.eos_token_id, - pad_token_id=nn_tokenizer.pad_token_id) - outputs = [nn_tokenizer.decode(output_id[len(input_ids["input_ids"][idx]):], - skip_special_tokens=True) - for idx, output_id in enumerate(output_ids)] + input_ids = nn_tokenizer( + data, + padding=True, + return_token_type_ids=False, + return_tensors="pt", + truncation=True, + max_length=64, + ).to(self._nn_device) + output_ids = nn_model.generate( + **input_ids, + max_new_tokens=1024, + do_sample=False, + eos_token_id=nn_tokenizer.eos_token_id, + pad_token_id=nn_tokenizer.pad_token_id + ) + outputs = [ + nn_tokenizer.decode( + output_id[len(input_ids["input_ids"][idx]) :], skip_special_tokens=True + ) + for idx, output_id in enumerate(output_ids) + ] - outputs = [nn_tokenizer.decode(output_id[:], - skip_special_tokens=True) - for idx, output_id in enumerate(output_ids)] + outputs = [ + nn_tokenizer.decode(output_id[:], skip_special_tokens=True) + for idx, output_id in enumerate(output_ids) + ] return outputs diff --git a/python/nn4k/invoker/openai_invoker.py b/python/nn4k/invoker/openai_invoker.py index 034da765..170e181a 100644 --- a/python/nn4k/invoker/openai_invoker.py +++ b/python/nn4k/invoker/openai_invoker.py @@ -20,9 +20,7 @@ class OpenAIInvoker(NNInvoker): from nn4k.utils.config_parsing import get_string_field from nn4k.utils.config_parsing import get_positive_int_field - openai_api_key = get_string_field( - nn_config, "openai_api_key", "openai api key" - ) + openai_api_key = get_string_field(nn_config, "openai_api_key", "openai api key") openai_api_base = get_string_field( nn_config, "openai_api_base", "openai api base" ) diff --git a/python/nn4k/utils/config_parsing.py b/python/nn4k/utils/config_parsing.py index abc53849..98b23a69 100644 --- a/python/nn4k/utils/config_parsing.py +++ b/python/nn4k/utils/config_parsing.py @@ -24,6 +24,7 @@ def preprocess_config(nn_config: Union[str, dict]) -> dict: raise ValueError("cannot decode config file") return nn_config + def get_field(nn_config: dict, name: str, text: str) -> Any: value = nn_config.get(name) if value is None: @@ -31,6 +32,7 @@ def get_field(nn_config: dict, name: str, text: str) -> Any: raise ValueError(message) return value + def get_string_field(nn_config: dict, name: str, text: str) -> str: value = get_field(nn_config, name, text) if not isinstance(value, str): @@ -39,6 +41,7 @@ def get_string_field(nn_config: dict, name: str, text: str) -> str: raise TypeError(message) return value + def get_int_field(nn_config: dict, name: str, text: str) -> int: value = get_field(nn_config, name, text) if not isinstance(value, int): @@ -47,6 +50,7 @@ def get_int_field(nn_config: dict, name: str, text: str) -> int: raise TypeError(message) return value + def get_positive_int_field(nn_config: dict, name: str, text: str) -> int: value = get_int_field(nn_config, name, text) if value <= 0: