mirror of
https://github.com/OpenSPG/openspg.git
synced 2025-07-30 20:41:22 +00:00
style(nn4k): format nn4k .py files
This commit is contained in:
parent
8bbbd352c7
commit
cfa714f952
@ -17,12 +17,8 @@ class HfLLMExecutor(NNExecutor):
|
||||
def _parse_config(cls, nn_config: dict) -> dict:
|
||||
from nn4k.utils.config_parsing import get_string_field
|
||||
|
||||
nn_name = get_string_field(
|
||||
nn_config, "nn_name", "NN model name"
|
||||
)
|
||||
nn_version= get_string_field(
|
||||
nn_config, "nn_version", "NN model version"
|
||||
)
|
||||
nn_name = get_string_field(nn_config, "nn_name", "NN model name")
|
||||
nn_version = get_string_field(nn_config, "nn_version", "NN model version")
|
||||
config = dict(
|
||||
nn_name=nn_name,
|
||||
nn_version=nn_version,
|
||||
@ -52,13 +48,11 @@ class HfLLMExecutor(NNExecutor):
|
||||
model_path = self._nn_name
|
||||
revision = self._nn_version
|
||||
use_fast_tokenizer = False
|
||||
device = self._nn_config.get('nn_device')
|
||||
device = self._nn_config.get("nn_device")
|
||||
if device is None:
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path,
|
||||
use_fast=use_fast_tokenizer,
|
||||
revision=revision
|
||||
model_path, use_fast=use_fast_tokenizer, revision=revision
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
@ -84,23 +78,31 @@ class HfLLMExecutor(NNExecutor):
|
||||
def inference(self, data, **kwargs):
|
||||
nn_tokenizer = self._get_tokenizer()
|
||||
nn_model = self._get_model()
|
||||
input_ids = nn_tokenizer(data,
|
||||
padding=True,
|
||||
return_token_type_ids=False,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
max_length=64).to(self._nn_device)
|
||||
output_ids = nn_model.generate(**input_ids,
|
||||
max_new_tokens=1024,
|
||||
do_sample=False,
|
||||
eos_token_id=nn_tokenizer.eos_token_id,
|
||||
pad_token_id=nn_tokenizer.pad_token_id)
|
||||
outputs = [nn_tokenizer.decode(output_id[len(input_ids["input_ids"][idx]):],
|
||||
skip_special_tokens=True)
|
||||
for idx, output_id in enumerate(output_ids)]
|
||||
input_ids = nn_tokenizer(
|
||||
data,
|
||||
padding=True,
|
||||
return_token_type_ids=False,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
max_length=64,
|
||||
).to(self._nn_device)
|
||||
output_ids = nn_model.generate(
|
||||
**input_ids,
|
||||
max_new_tokens=1024,
|
||||
do_sample=False,
|
||||
eos_token_id=nn_tokenizer.eos_token_id,
|
||||
pad_token_id=nn_tokenizer.pad_token_id
|
||||
)
|
||||
outputs = [
|
||||
nn_tokenizer.decode(
|
||||
output_id[len(input_ids["input_ids"][idx]) :], skip_special_tokens=True
|
||||
)
|
||||
for idx, output_id in enumerate(output_ids)
|
||||
]
|
||||
|
||||
outputs = [nn_tokenizer.decode(output_id[:],
|
||||
skip_special_tokens=True)
|
||||
for idx, output_id in enumerate(output_ids)]
|
||||
outputs = [
|
||||
nn_tokenizer.decode(output_id[:], skip_special_tokens=True)
|
||||
for idx, output_id in enumerate(output_ids)
|
||||
]
|
||||
|
||||
return outputs
|
||||
|
@ -20,9 +20,7 @@ class OpenAIInvoker(NNInvoker):
|
||||
from nn4k.utils.config_parsing import get_string_field
|
||||
from nn4k.utils.config_parsing import get_positive_int_field
|
||||
|
||||
openai_api_key = get_string_field(
|
||||
nn_config, "openai_api_key", "openai api key"
|
||||
)
|
||||
openai_api_key = get_string_field(nn_config, "openai_api_key", "openai api key")
|
||||
openai_api_base = get_string_field(
|
||||
nn_config, "openai_api_base", "openai api base"
|
||||
)
|
||||
|
@ -24,6 +24,7 @@ def preprocess_config(nn_config: Union[str, dict]) -> dict:
|
||||
raise ValueError("cannot decode config file")
|
||||
return nn_config
|
||||
|
||||
|
||||
def get_field(nn_config: dict, name: str, text: str) -> Any:
|
||||
value = nn_config.get(name)
|
||||
if value is None:
|
||||
@ -31,6 +32,7 @@ def get_field(nn_config: dict, name: str, text: str) -> Any:
|
||||
raise ValueError(message)
|
||||
return value
|
||||
|
||||
|
||||
def get_string_field(nn_config: dict, name: str, text: str) -> str:
|
||||
value = get_field(nn_config, name, text)
|
||||
if not isinstance(value, str):
|
||||
@ -39,6 +41,7 @@ def get_string_field(nn_config: dict, name: str, text: str) -> str:
|
||||
raise TypeError(message)
|
||||
return value
|
||||
|
||||
|
||||
def get_int_field(nn_config: dict, name: str, text: str) -> int:
|
||||
value = get_field(nn_config, name, text)
|
||||
if not isinstance(value, int):
|
||||
@ -47,6 +50,7 @@ def get_int_field(nn_config: dict, name: str, text: str) -> int:
|
||||
raise TypeError(message)
|
||||
return value
|
||||
|
||||
|
||||
def get_positive_int_field(nn_config: dict, name: str, text: str) -> int:
|
||||
value = get_int_field(nn_config, name, text)
|
||||
if value <= 0:
|
||||
|
Loading…
x
Reference in New Issue
Block a user