style(nn4k): format nn4k .py files

This commit is contained in:
xionghuaidong 2023-12-19 20:31:58 +08:00
parent 8bbbd352c7
commit cfa714f952
3 changed files with 35 additions and 31 deletions

View File

@ -17,12 +17,8 @@ class HfLLMExecutor(NNExecutor):
def _parse_config(cls, nn_config: dict) -> dict:
from nn4k.utils.config_parsing import get_string_field
nn_name = get_string_field(
nn_config, "nn_name", "NN model name"
)
nn_version= get_string_field(
nn_config, "nn_version", "NN model version"
)
nn_name = get_string_field(nn_config, "nn_name", "NN model name")
nn_version = get_string_field(nn_config, "nn_version", "NN model version")
config = dict(
nn_name=nn_name,
nn_version=nn_version,
@ -52,13 +48,11 @@ class HfLLMExecutor(NNExecutor):
model_path = self._nn_name
revision = self._nn_version
use_fast_tokenizer = False
device = self._nn_config.get('nn_device')
device = self._nn_config.get("nn_device")
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(
model_path,
use_fast=use_fast_tokenizer,
revision=revision
model_path, use_fast=use_fast_tokenizer, revision=revision
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
@ -84,23 +78,31 @@ class HfLLMExecutor(NNExecutor):
def inference(self, data, **kwargs):
nn_tokenizer = self._get_tokenizer()
nn_model = self._get_model()
input_ids = nn_tokenizer(data,
padding=True,
return_token_type_ids=False,
return_tensors="pt",
truncation=True,
max_length=64).to(self._nn_device)
output_ids = nn_model.generate(**input_ids,
max_new_tokens=1024,
do_sample=False,
eos_token_id=nn_tokenizer.eos_token_id,
pad_token_id=nn_tokenizer.pad_token_id)
outputs = [nn_tokenizer.decode(output_id[len(input_ids["input_ids"][idx]):],
skip_special_tokens=True)
for idx, output_id in enumerate(output_ids)]
input_ids = nn_tokenizer(
data,
padding=True,
return_token_type_ids=False,
return_tensors="pt",
truncation=True,
max_length=64,
).to(self._nn_device)
output_ids = nn_model.generate(
**input_ids,
max_new_tokens=1024,
do_sample=False,
eos_token_id=nn_tokenizer.eos_token_id,
pad_token_id=nn_tokenizer.pad_token_id
)
outputs = [
nn_tokenizer.decode(
output_id[len(input_ids["input_ids"][idx]) :], skip_special_tokens=True
)
for idx, output_id in enumerate(output_ids)
]
outputs = [nn_tokenizer.decode(output_id[:],
skip_special_tokens=True)
for idx, output_id in enumerate(output_ids)]
outputs = [
nn_tokenizer.decode(output_id[:], skip_special_tokens=True)
for idx, output_id in enumerate(output_ids)
]
return outputs

View File

@ -20,9 +20,7 @@ class OpenAIInvoker(NNInvoker):
from nn4k.utils.config_parsing import get_string_field
from nn4k.utils.config_parsing import get_positive_int_field
openai_api_key = get_string_field(
nn_config, "openai_api_key", "openai api key"
)
openai_api_key = get_string_field(nn_config, "openai_api_key", "openai api key")
openai_api_base = get_string_field(
nn_config, "openai_api_base", "openai api base"
)

View File

@ -24,6 +24,7 @@ def preprocess_config(nn_config: Union[str, dict]) -> dict:
raise ValueError("cannot decode config file")
return nn_config
def get_field(nn_config: dict, name: str, text: str) -> Any:
value = nn_config.get(name)
if value is None:
@ -31,6 +32,7 @@ def get_field(nn_config: dict, name: str, text: str) -> Any:
raise ValueError(message)
return value
def get_string_field(nn_config: dict, name: str, text: str) -> str:
value = get_field(nn_config, name, text)
if not isinstance(value, str):
@ -39,6 +41,7 @@ def get_string_field(nn_config: dict, name: str, text: str) -> str:
raise TypeError(message)
return value
def get_int_field(nn_config: dict, name: str, text: str) -> int:
value = get_field(nn_config, name, text)
if not isinstance(value, int):
@ -47,6 +50,7 @@ def get_int_field(nn_config: dict, name: str, text: str) -> int:
raise TypeError(message)
return value
def get_positive_int_field(nn_config: dict, name: str, text: str) -> int:
value = get_int_field(nn_config, name, text)
if value <= 0: