style(nn4k): format nn4k .py files

This commit is contained in:
xionghuaidong 2023-12-19 20:31:58 +08:00
parent 8bbbd352c7
commit cfa714f952
3 changed files with 35 additions and 31 deletions

View File

@ -17,12 +17,8 @@ class HfLLMExecutor(NNExecutor):
def _parse_config(cls, nn_config: dict) -> dict: def _parse_config(cls, nn_config: dict) -> dict:
from nn4k.utils.config_parsing import get_string_field from nn4k.utils.config_parsing import get_string_field
nn_name = get_string_field( nn_name = get_string_field(nn_config, "nn_name", "NN model name")
nn_config, "nn_name", "NN model name" nn_version = get_string_field(nn_config, "nn_version", "NN model version")
)
nn_version= get_string_field(
nn_config, "nn_version", "NN model version"
)
config = dict( config = dict(
nn_name=nn_name, nn_name=nn_name,
nn_version=nn_version, nn_version=nn_version,
@ -52,13 +48,11 @@ class HfLLMExecutor(NNExecutor):
model_path = self._nn_name model_path = self._nn_name
revision = self._nn_version revision = self._nn_version
use_fast_tokenizer = False use_fast_tokenizer = False
device = self._nn_config.get('nn_device') device = self._nn_config.get("nn_device")
if device is None: if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_path, model_path, use_fast=use_fast_tokenizer, revision=revision
use_fast=use_fast_tokenizer,
revision=revision
) )
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_path, model_path,
@ -84,23 +78,31 @@ class HfLLMExecutor(NNExecutor):
def inference(self, data, **kwargs): def inference(self, data, **kwargs):
nn_tokenizer = self._get_tokenizer() nn_tokenizer = self._get_tokenizer()
nn_model = self._get_model() nn_model = self._get_model()
input_ids = nn_tokenizer(data, input_ids = nn_tokenizer(
data,
padding=True, padding=True,
return_token_type_ids=False, return_token_type_ids=False,
return_tensors="pt", return_tensors="pt",
truncation=True, truncation=True,
max_length=64).to(self._nn_device) max_length=64,
output_ids = nn_model.generate(**input_ids, ).to(self._nn_device)
output_ids = nn_model.generate(
**input_ids,
max_new_tokens=1024, max_new_tokens=1024,
do_sample=False, do_sample=False,
eos_token_id=nn_tokenizer.eos_token_id, eos_token_id=nn_tokenizer.eos_token_id,
pad_token_id=nn_tokenizer.pad_token_id) pad_token_id=nn_tokenizer.pad_token_id
outputs = [nn_tokenizer.decode(output_id[len(input_ids["input_ids"][idx]):], )
skip_special_tokens=True) outputs = [
for idx, output_id in enumerate(output_ids)] nn_tokenizer.decode(
output_id[len(input_ids["input_ids"][idx]) :], skip_special_tokens=True
)
for idx, output_id in enumerate(output_ids)
]
outputs = [nn_tokenizer.decode(output_id[:], outputs = [
skip_special_tokens=True) nn_tokenizer.decode(output_id[:], skip_special_tokens=True)
for idx, output_id in enumerate(output_ids)] for idx, output_id in enumerate(output_ids)
]
return outputs return outputs

View File

@ -20,9 +20,7 @@ class OpenAIInvoker(NNInvoker):
from nn4k.utils.config_parsing import get_string_field from nn4k.utils.config_parsing import get_string_field
from nn4k.utils.config_parsing import get_positive_int_field from nn4k.utils.config_parsing import get_positive_int_field
openai_api_key = get_string_field( openai_api_key = get_string_field(nn_config, "openai_api_key", "openai api key")
nn_config, "openai_api_key", "openai api key"
)
openai_api_base = get_string_field( openai_api_base = get_string_field(
nn_config, "openai_api_base", "openai api base" nn_config, "openai_api_base", "openai api base"
) )

View File

@ -24,6 +24,7 @@ def preprocess_config(nn_config: Union[str, dict]) -> dict:
raise ValueError("cannot decode config file") raise ValueError("cannot decode config file")
return nn_config return nn_config
def get_field(nn_config: dict, name: str, text: str) -> Any: def get_field(nn_config: dict, name: str, text: str) -> Any:
value = nn_config.get(name) value = nn_config.get(name)
if value is None: if value is None:
@ -31,6 +32,7 @@ def get_field(nn_config: dict, name: str, text: str) -> Any:
raise ValueError(message) raise ValueError(message)
return value return value
def get_string_field(nn_config: dict, name: str, text: str) -> str: def get_string_field(nn_config: dict, name: str, text: str) -> str:
value = get_field(nn_config, name, text) value = get_field(nn_config, name, text)
if not isinstance(value, str): if not isinstance(value, str):
@ -39,6 +41,7 @@ def get_string_field(nn_config: dict, name: str, text: str) -> str:
raise TypeError(message) raise TypeError(message)
return value return value
def get_int_field(nn_config: dict, name: str, text: str) -> int: def get_int_field(nn_config: dict, name: str, text: str) -> int:
value = get_field(nn_config, name, text) value = get_field(nn_config, name, text)
if not isinstance(value, int): if not isinstance(value, int):
@ -47,6 +50,7 @@ def get_int_field(nn_config: dict, name: str, text: str) -> int:
raise TypeError(message) raise TypeError(message)
return value return value
def get_positive_int_field(nn_config: dict, name: str, text: str) -> int: def get_positive_int_field(nn_config: dict, name: str, text: str) -> int:
value = get_int_field(nn_config, name, text) value = get_int_field(nn_config, name, text)
if value <= 0: if value <= 0: