mirror of
https://github.com/OpenSPG/openspg.git
synced 2025-08-01 13:32:15 +00:00
style(nn4k): format nn4k .py files
This commit is contained in:
parent
8bbbd352c7
commit
cfa714f952
@ -17,12 +17,8 @@ class HfLLMExecutor(NNExecutor):
|
|||||||
def _parse_config(cls, nn_config: dict) -> dict:
|
def _parse_config(cls, nn_config: dict) -> dict:
|
||||||
from nn4k.utils.config_parsing import get_string_field
|
from nn4k.utils.config_parsing import get_string_field
|
||||||
|
|
||||||
nn_name = get_string_field(
|
nn_name = get_string_field(nn_config, "nn_name", "NN model name")
|
||||||
nn_config, "nn_name", "NN model name"
|
nn_version = get_string_field(nn_config, "nn_version", "NN model version")
|
||||||
)
|
|
||||||
nn_version= get_string_field(
|
|
||||||
nn_config, "nn_version", "NN model version"
|
|
||||||
)
|
|
||||||
config = dict(
|
config = dict(
|
||||||
nn_name=nn_name,
|
nn_name=nn_name,
|
||||||
nn_version=nn_version,
|
nn_version=nn_version,
|
||||||
@ -52,13 +48,11 @@ class HfLLMExecutor(NNExecutor):
|
|||||||
model_path = self._nn_name
|
model_path = self._nn_name
|
||||||
revision = self._nn_version
|
revision = self._nn_version
|
||||||
use_fast_tokenizer = False
|
use_fast_tokenizer = False
|
||||||
device = self._nn_config.get('nn_device')
|
device = self._nn_config.get("nn_device")
|
||||||
if device is None:
|
if device is None:
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_path,
|
model_path, use_fast=use_fast_tokenizer, revision=revision
|
||||||
use_fast=use_fast_tokenizer,
|
|
||||||
revision=revision
|
|
||||||
)
|
)
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_path,
|
model_path,
|
||||||
@ -84,23 +78,31 @@ class HfLLMExecutor(NNExecutor):
|
|||||||
def inference(self, data, **kwargs):
|
def inference(self, data, **kwargs):
|
||||||
nn_tokenizer = self._get_tokenizer()
|
nn_tokenizer = self._get_tokenizer()
|
||||||
nn_model = self._get_model()
|
nn_model = self._get_model()
|
||||||
input_ids = nn_tokenizer(data,
|
input_ids = nn_tokenizer(
|
||||||
padding=True,
|
data,
|
||||||
return_token_type_ids=False,
|
padding=True,
|
||||||
return_tensors="pt",
|
return_token_type_ids=False,
|
||||||
truncation=True,
|
return_tensors="pt",
|
||||||
max_length=64).to(self._nn_device)
|
truncation=True,
|
||||||
output_ids = nn_model.generate(**input_ids,
|
max_length=64,
|
||||||
max_new_tokens=1024,
|
).to(self._nn_device)
|
||||||
do_sample=False,
|
output_ids = nn_model.generate(
|
||||||
eos_token_id=nn_tokenizer.eos_token_id,
|
**input_ids,
|
||||||
pad_token_id=nn_tokenizer.pad_token_id)
|
max_new_tokens=1024,
|
||||||
outputs = [nn_tokenizer.decode(output_id[len(input_ids["input_ids"][idx]):],
|
do_sample=False,
|
||||||
skip_special_tokens=True)
|
eos_token_id=nn_tokenizer.eos_token_id,
|
||||||
for idx, output_id in enumerate(output_ids)]
|
pad_token_id=nn_tokenizer.pad_token_id
|
||||||
|
)
|
||||||
|
outputs = [
|
||||||
|
nn_tokenizer.decode(
|
||||||
|
output_id[len(input_ids["input_ids"][idx]) :], skip_special_tokens=True
|
||||||
|
)
|
||||||
|
for idx, output_id in enumerate(output_ids)
|
||||||
|
]
|
||||||
|
|
||||||
outputs = [nn_tokenizer.decode(output_id[:],
|
outputs = [
|
||||||
skip_special_tokens=True)
|
nn_tokenizer.decode(output_id[:], skip_special_tokens=True)
|
||||||
for idx, output_id in enumerate(output_ids)]
|
for idx, output_id in enumerate(output_ids)
|
||||||
|
]
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
@ -20,9 +20,7 @@ class OpenAIInvoker(NNInvoker):
|
|||||||
from nn4k.utils.config_parsing import get_string_field
|
from nn4k.utils.config_parsing import get_string_field
|
||||||
from nn4k.utils.config_parsing import get_positive_int_field
|
from nn4k.utils.config_parsing import get_positive_int_field
|
||||||
|
|
||||||
openai_api_key = get_string_field(
|
openai_api_key = get_string_field(nn_config, "openai_api_key", "openai api key")
|
||||||
nn_config, "openai_api_key", "openai api key"
|
|
||||||
)
|
|
||||||
openai_api_base = get_string_field(
|
openai_api_base = get_string_field(
|
||||||
nn_config, "openai_api_base", "openai api base"
|
nn_config, "openai_api_base", "openai api base"
|
||||||
)
|
)
|
||||||
|
@ -24,6 +24,7 @@ def preprocess_config(nn_config: Union[str, dict]) -> dict:
|
|||||||
raise ValueError("cannot decode config file")
|
raise ValueError("cannot decode config file")
|
||||||
return nn_config
|
return nn_config
|
||||||
|
|
||||||
|
|
||||||
def get_field(nn_config: dict, name: str, text: str) -> Any:
|
def get_field(nn_config: dict, name: str, text: str) -> Any:
|
||||||
value = nn_config.get(name)
|
value = nn_config.get(name)
|
||||||
if value is None:
|
if value is None:
|
||||||
@ -31,6 +32,7 @@ def get_field(nn_config: dict, name: str, text: str) -> Any:
|
|||||||
raise ValueError(message)
|
raise ValueError(message)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
def get_string_field(nn_config: dict, name: str, text: str) -> str:
|
def get_string_field(nn_config: dict, name: str, text: str) -> str:
|
||||||
value = get_field(nn_config, name, text)
|
value = get_field(nn_config, name, text)
|
||||||
if not isinstance(value, str):
|
if not isinstance(value, str):
|
||||||
@ -39,6 +41,7 @@ def get_string_field(nn_config: dict, name: str, text: str) -> str:
|
|||||||
raise TypeError(message)
|
raise TypeError(message)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
def get_int_field(nn_config: dict, name: str, text: str) -> int:
|
def get_int_field(nn_config: dict, name: str, text: str) -> int:
|
||||||
value = get_field(nn_config, name, text)
|
value = get_field(nn_config, name, text)
|
||||||
if not isinstance(value, int):
|
if not isinstance(value, int):
|
||||||
@ -47,6 +50,7 @@ def get_int_field(nn_config: dict, name: str, text: str) -> int:
|
|||||||
raise TypeError(message)
|
raise TypeError(message)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
def get_positive_int_field(nn_config: dict, name: str, text: str) -> int:
|
def get_positive_int_field(nn_config: dict, name: str, text: str) -> int:
|
||||||
value = get_int_field(nn_config, name, text)
|
value = get_int_field(nn_config, name, text)
|
||||||
if value <= 0:
|
if value <= 0:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user