mirror of
https://github.com/OpenSPG/openspg.git
synced 2025-11-06 04:52:49 +00:00
feat(nn4k): implement openai invoker and local hf executor (#57)
Co-authored-by: 基尔 <qy266141@antgroup.com> Co-authored-by: didicout <julin.jl@antgroup.com>
This commit is contained in:
parent
22ea3ee395
commit
6c3f8584ec
3
python/nn4k/.gitignore
vendored
Normal file
3
python/nn4k/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
/*.whl
|
||||
/*.egg-info/
|
||||
/build/
|
||||
10
python/nn4k/LICENSE
Normal file
10
python/nn4k/LICENSE
Normal file
@ -0,0 +1,10 @@
|
||||
Copyright 2023 Ant Group CO., Ltd.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
in compliance with the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
or implied.
|
||||
2
python/nn4k/MANIFEST.in
Normal file
2
python/nn4k/MANIFEST.in
Normal file
@ -0,0 +1,2 @@
|
||||
recursive-include nn4k *
|
||||
recursive-exclude nn4k/examples *
|
||||
1
python/nn4k/NN4K_VERSION
Normal file
1
python/nn4k/NN4K_VERSION
Normal file
@ -0,0 +1 @@
|
||||
0.0.2-beta1
|
||||
14
python/nn4k/nn4k/__init__.py
Normal file
14
python/nn4k/nn4k/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
|
||||
__package_name__ = "openspg-nn4k"
|
||||
__version__ = "0.0.2-beta1"
|
||||
43
python/nn4k/nn4k/consts/__init__.py
Normal file
43
python/nn4k/nn4k/consts/__init__.py
Normal file
@ -0,0 +1,43 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
NN_NAME_KEY = "nn_name"
|
||||
NN_NAME_TEXT = "NN model name"
|
||||
|
||||
NN_VERSION_KEY = "nn_version"
|
||||
NN_VERSION_TEXT = "NN model version"
|
||||
NN_VERSION_DEFAULT = "default"
|
||||
|
||||
NN_INVOKER_KEY = "nn_invoker"
|
||||
NN_INVOKER_TEXT = "NN invoker"
|
||||
|
||||
NN_EXECUTOR_KEY = "nn_executor"
|
||||
NN_EXECUTOR_TEXT = "NN executor"
|
||||
|
||||
NN_DEVICE_KEY = "device"
|
||||
NN_TRUST_REMOTE_CODE_KEY = "trust_remote_code"
|
||||
|
||||
NN_OPENAI_MODEL_NAME_KEY = NN_NAME_KEY
|
||||
NN_OPENAI_MODEL_NAME_TEXT = "openai model name"
|
||||
|
||||
NN_OPENAI_API_KEY_KEY = "openai_api_key"
|
||||
NN_OPENAI_API_KEY_TEXT = "openai api key"
|
||||
|
||||
NN_OPENAI_API_BASE_KEY = "openai_api_base"
|
||||
NN_OPENAI_API_BASE_TEXT = "openai api base"
|
||||
|
||||
NN_OPENAI_MAX_TOKENS_KEY = "openai_max_tokens"
|
||||
NN_OPENAI_MAX_TOKENS_TEXT = "openai max tokens"
|
||||
|
||||
NN_OPENAI_GPT4_PREFIX = "gpt-4"
|
||||
NN_OPENAI_GPT35_PREFIX = "gpt-3.5"
|
||||
|
||||
NN_LOCAL_HF_MODEL_CONFIG_FILE = "config.json"
|
||||
12
python/nn4k/nn4k/executor/__init__.py
Normal file
12
python/nn4k/nn4k/executor/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
from nn4k.executor.base import NNExecutor, LLMExecutor
|
||||
161
python/nn4k/nn4k/executor/base.py
Normal file
161
python/nn4k/nn4k/executor/base.py
Normal file
@ -0,0 +1,161 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Union
|
||||
|
||||
|
||||
class NNExecutor(ABC):
|
||||
"""
|
||||
Entry point of model execution in a certain pod.
|
||||
"""
|
||||
|
||||
def __init__(self, init_args: dict, **kwargs):
|
||||
self._init_args = init_args
|
||||
self._kwargs = kwargs
|
||||
self._model = None
|
||||
self._tokenizer = None
|
||||
|
||||
@property
|
||||
def init_args(self):
|
||||
"""
|
||||
Return the `init_args` passed to the executor constructor.
|
||||
"""
|
||||
return self._init_args
|
||||
|
||||
@property
|
||||
def kwargs(self):
|
||||
"""
|
||||
Return the `kwargs` passed to the executor constructor.
|
||||
"""
|
||||
return self._kwargs
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""
|
||||
Return the model object managed by this executor.
|
||||
|
||||
:raises RuntimeError: if the model is not initialized yet
|
||||
"""
|
||||
if self._model is None:
|
||||
message = "model is not initialized yet"
|
||||
raise RuntimeError(message)
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def tokenizer(self):
|
||||
"""
|
||||
Return the tokenizer object managed by this executor.
|
||||
|
||||
:raises RuntimeError: if the tokenizer is not initialized yet
|
||||
"""
|
||||
if self._tokenizer is None:
|
||||
message = "tokenizer is not initialized yet"
|
||||
raise RuntimeError(message)
|
||||
return self._tokenizer
|
||||
|
||||
def execute_inference(self, args=None, **kwargs):
|
||||
"""
|
||||
The entry point of batch inference in a certain pod.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not support batch inference."
|
||||
)
|
||||
|
||||
def inference(self, data, args=None, **kwargs):
|
||||
"""
|
||||
The entry point of inference. Usually for local invokers or model services.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, args=None, mode=None, **kwargs):
|
||||
"""
|
||||
Implement model loading logic in derived executor classes.
|
||||
|
||||
Implementer should initialize `self._model` and `self._tokenizer`.
|
||||
|
||||
This method will be called by several entry methods in executors and invokers.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def warmup_inference(self, args=None, **kwargs):
|
||||
"""
|
||||
Implement model warming up logic for inference in derived executor classes.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, nn_config: Union[str, dict]) -> "NNExecutor":
|
||||
"""
|
||||
Create an NN executor instance from `nn_config`.
|
||||
|
||||
This method is abstract, derived class must override it by either
|
||||
creating executor instances or implementating dispatch logic.
|
||||
|
||||
:param nn_config: config to use, can be dictionary or path to a JSON file
|
||||
:type nn_config: str or dict
|
||||
:rtype: NNExecutor
|
||||
"""
|
||||
from nn4k.nnhub import NNHub
|
||||
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
|
||||
from nn4k.consts import NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT
|
||||
from nn4k.utils.config_parsing import preprocess_config
|
||||
from nn4k.utils.config_parsing import get_string_field
|
||||
from nn4k.utils.class_importing import dynamic_import_class
|
||||
|
||||
nn_config = preprocess_config(nn_config)
|
||||
nn_executor = nn_config.get(NN_EXECUTOR_KEY)
|
||||
if nn_executor is not None:
|
||||
nn_executor = get_string_field(nn_config, NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT)
|
||||
executor_class = dynamic_import_class(nn_executor, NN_EXECUTOR_TEXT)
|
||||
if not issubclass(executor_class, NNExecutor):
|
||||
message = "%r is not an %s class" % (nn_executor, NN_EXECUTOR_TEXT)
|
||||
raise RuntimeError(message)
|
||||
executor = executor_class.from_config(nn_config)
|
||||
return executor
|
||||
|
||||
nn_name = nn_config.get(NN_NAME_KEY)
|
||||
if nn_name is not None:
|
||||
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
|
||||
nn_version = nn_config.get(NN_VERSION_KEY)
|
||||
if nn_version is not None:
|
||||
nn_version = get_string_field(nn_config, NN_VERSION_KEY, NN_VERSION_TEXT)
|
||||
if nn_name is not None:
|
||||
hub = NNHub.get_instance()
|
||||
executor = hub.get_model_executor(nn_name, nn_version)
|
||||
if executor is not None:
|
||||
return executor
|
||||
|
||||
message = "can not create executor for NN config"
|
||||
if nn_name is not None:
|
||||
message += "; model: %r" % nn_name
|
||||
if nn_version is not None:
|
||||
message += ", version: %r" % nn_version
|
||||
raise RuntimeError(message)
|
||||
|
||||
|
||||
class LLMExecutor(NNExecutor):
|
||||
def execute_sft(self, args=None, callbacks=None, **kwargs):
|
||||
"""
|
||||
The entry point of SFT execution in a certain pod.
|
||||
"""
|
||||
raise NotImplementedError(f"{self.__class__.__name__} does not support SFT.")
|
||||
|
||||
def execute_rl_tuning(self, args=None, callbacks=None, **kwargs):
|
||||
"""
|
||||
The entry point of SFT execution in a certain pod.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not support RL-Tuning."
|
||||
)
|
||||
17
python/nn4k/nn4k/executor/deepke.py
Normal file
17
python/nn4k/nn4k/executor/deepke.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
from nn4k.executor import NNExecutor
|
||||
|
||||
|
||||
class DeepKeExecutor(NNExecutor):
|
||||
|
||||
pass
|
||||
104
python/nn4k/nn4k/executor/hugging_face.py
Normal file
104
python/nn4k/nn4k/executor/hugging_face.py
Normal file
@ -0,0 +1,104 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
from typing import Union
|
||||
from nn4k.executor import LLMExecutor
|
||||
|
||||
|
||||
class HfLLMExecutor(LLMExecutor):
|
||||
@classmethod
|
||||
def from_config(cls, nn_config: dict) -> "HfLLMExecutor":
|
||||
"""
|
||||
Create an HfLLMExecutor instance from `nn_config`.
|
||||
"""
|
||||
executor = cls(nn_config)
|
||||
return executor
|
||||
|
||||
def execute_sft(self, args=None, callbacks=None, **kwargs):
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} will support SFT in the next version."
|
||||
)
|
||||
|
||||
def load_model(self, args=None, **kwargs):
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModelForCausalLM
|
||||
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
|
||||
from nn4k.consts import NN_DEVICE_KEY, NN_TRUST_REMOTE_CODE_KEY
|
||||
from nn4k.utils.config_parsing import get_string_field
|
||||
|
||||
nn_config: dict = args or self.init_args
|
||||
if self._model is None:
|
||||
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
|
||||
nn_version = nn_config.get(NN_VERSION_KEY)
|
||||
if nn_version is not None:
|
||||
nn_version = get_string_field(
|
||||
nn_config, NN_VERSION_KEY, NN_VERSION_TEXT
|
||||
)
|
||||
model_path = nn_name
|
||||
revision = nn_version
|
||||
use_fast_tokenizer = False
|
||||
device = nn_config.get(NN_DEVICE_KEY)
|
||||
trust_remote_code = nn_config.get(NN_TRUST_REMOTE_CODE_KEY, False)
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path,
|
||||
use_fast=use_fast_tokenizer,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
low_cpu_mem_usage=True,
|
||||
torch_dtype=torch.float16,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
model.to(device)
|
||||
self._tokenizer = tokenizer
|
||||
self._model = model
|
||||
|
||||
def inference(
|
||||
self,
|
||||
data,
|
||||
max_input_length: int = 1024,
|
||||
max_output_length: int = 1024,
|
||||
do_sample: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
model = self.model
|
||||
tokenizer = self.tokenizer
|
||||
input_ids = tokenizer(
|
||||
data,
|
||||
padding=True,
|
||||
return_token_type_ids=False,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
max_length=max_input_length,
|
||||
).to(model.device)
|
||||
output_ids = model.generate(
|
||||
**input_ids,
|
||||
max_new_tokens=max_output_length,
|
||||
do_sample=do_sample,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
outputs = [
|
||||
tokenizer.decode(
|
||||
output_id[len(input_ids["input_ids"][idx]) :], skip_special_tokens=True
|
||||
)
|
||||
for idx, output_id in enumerate(output_ids)
|
||||
]
|
||||
return outputs
|
||||
12
python/nn4k/nn4k/invoker/__init__.py
Normal file
12
python/nn4k/nn4k/invoker/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
from nn4k.invoker.base import NNInvoker, LLMInvoker
|
||||
186
python/nn4k/nn4k/invoker/base.py
Normal file
186
python/nn4k/nn4k/invoker/base.py
Normal file
@ -0,0 +1,186 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Union
|
||||
|
||||
from nn4k.executor import LLMExecutor
|
||||
|
||||
|
||||
class SubmitMode(Enum):
|
||||
K8s = "k8s"
|
||||
Docker = "docker"
|
||||
|
||||
|
||||
class NNInvoker(ABC):
|
||||
"""
|
||||
Invoking Entry Interfaces for NN Models.
|
||||
One NNInvoker object is for one NN Model.
|
||||
- Interfaces starting with "submit_" means submitting a batch task to a remote execution engine.
|
||||
- Interfaces starting with "remote_" means querying a remote service for some results.
|
||||
- Interfaces starting with "local_" means running something locally.
|
||||
Must call `warmup_local_model` before calling any local_xxx interface.
|
||||
"""
|
||||
|
||||
def __init__(self, init_args: dict, **kwargs):
|
||||
self._init_args = init_args
|
||||
self._kwargs = kwargs
|
||||
|
||||
@property
|
||||
def init_args(self):
|
||||
"""
|
||||
Return the `init_args` passed to the invoker constructor.
|
||||
"""
|
||||
return self._init_args
|
||||
|
||||
@property
|
||||
def kwargs(self):
|
||||
"""
|
||||
Return the `kwargs` passed to the invoker constructor.
|
||||
"""
|
||||
return self._kwargs
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, nn_config: Union[str, dict]) -> "NNInvoker":
|
||||
"""
|
||||
Create an NN invoker instance from `nn_config`.
|
||||
|
||||
This method is abstract, derived class must override it by either
|
||||
creating invoker instances or implementating dispatch logic.
|
||||
|
||||
:param nn_config: config to use, can be dictionary or path to a JSON file
|
||||
:type nn_config: str or dict
|
||||
:rtype: NNInvoker
|
||||
:raises RuntimeError: if the NN config is not recognized
|
||||
"""
|
||||
from nn4k.nnhub import NNHub
|
||||
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
|
||||
from nn4k.consts import NN_INVOKER_KEY, NN_INVOKER_TEXT
|
||||
from nn4k.utils.config_parsing import preprocess_config
|
||||
from nn4k.utils.config_parsing import get_string_field
|
||||
from nn4k.utils.class_importing import dynamic_import_class
|
||||
|
||||
nn_config = preprocess_config(nn_config)
|
||||
nn_invoker = nn_config.get(NN_INVOKER_KEY)
|
||||
if nn_invoker is not None:
|
||||
nn_invoker = get_string_field(nn_config, NN_INVOKER_KEY, NN_INVOKER_TEXT)
|
||||
invoker_class = dynamic_import_class(nn_invoker, NN_INVOKER_TEXT)
|
||||
if not issubclass(invoker_class, NNInvoker):
|
||||
message = "%r is not an %s class" % (nn_invoker, NN_INVOKER_TEXT)
|
||||
raise RuntimeError(message)
|
||||
invoker = invoker_class.from_config(nn_config)
|
||||
return invoker
|
||||
|
||||
hub = NNHub.get_instance()
|
||||
invoker = hub.get_invoker(nn_config)
|
||||
if invoker is not None:
|
||||
return invoker
|
||||
|
||||
nn_name = nn_config.get(NN_NAME_KEY)
|
||||
if nn_name is not None:
|
||||
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
|
||||
nn_version = nn_config.get(NN_VERSION_KEY)
|
||||
if nn_version is not None:
|
||||
nn_version = get_string_field(nn_config, NN_VERSION_KEY, NN_VERSION_TEXT)
|
||||
message = "can not create invoker for NN config"
|
||||
if nn_name is not None:
|
||||
message += "; model: %r" % nn_name
|
||||
if nn_version is not None:
|
||||
message += ", version: %r" % nn_version
|
||||
raise RuntimeError(message)
|
||||
|
||||
def submit_inference(self, submit_mode: SubmitMode = SubmitMode.K8s):
|
||||
"""
|
||||
Submit remote batch inference execution.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not support batch inference."
|
||||
)
|
||||
|
||||
def remote_inference(self, input, **kwargs):
|
||||
"""
|
||||
Inference via existing remote services.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not support remote inference."
|
||||
)
|
||||
|
||||
def local_inference(self, data, **kwargs):
|
||||
"""
|
||||
Implement local inference in derived invoker classes.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not support local inference."
|
||||
)
|
||||
|
||||
def warmup_local_model(self):
|
||||
"""
|
||||
Implement local model warming up logic in derived invoker classes.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class LLMInvoker(NNInvoker):
|
||||
def submit_sft(self, submit_mode: SubmitMode = SubmitMode.K8s):
|
||||
"""
|
||||
Submit remote SFT execution.
|
||||
"""
|
||||
raise NotImplementedError(f"{self.__class__.__name__} does not support SFT.")
|
||||
|
||||
def submit_rl_tuning(self, submit_mode: SubmitMode = SubmitMode.K8s):
|
||||
"""
|
||||
Submit remote RL-Tuning execution.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not support RL-Tuning."
|
||||
)
|
||||
|
||||
def local_inference(self, data, **kwargs):
|
||||
"""
|
||||
Implement local inference for local invoker.
|
||||
"""
|
||||
return self._nn_executor.inference(data, **kwargs)
|
||||
|
||||
def warmup_local_model(self):
|
||||
"""
|
||||
Implement local model warming up logic for local invoker.
|
||||
"""
|
||||
from nn4k.nnhub import NNHub
|
||||
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
|
||||
from nn4k.utils.config_parsing import get_string_field
|
||||
|
||||
nn_name = get_string_field(self.init_args, NN_NAME_KEY, NN_NAME_TEXT)
|
||||
nn_version = self.init_args.get(NN_VERSION_KEY)
|
||||
if nn_version is not None:
|
||||
nn_version = get_string_field(
|
||||
self.init_args, NN_VERSION_KEY, NN_VERSION_TEXT
|
||||
)
|
||||
hub = NNHub.get_instance()
|
||||
executor = hub.get_model_executor(nn_name, nn_version)
|
||||
if executor is None:
|
||||
message = "model %r version %r " % (nn_name, nn_version)
|
||||
message += "is not found in the model hub"
|
||||
raise RuntimeError(message)
|
||||
self._nn_executor: LLMExecutor = executor
|
||||
self._nn_executor.load_model()
|
||||
self._nn_executor.warmup_inference()
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, nn_config: dict) -> "LLMInvoker":
|
||||
"""
|
||||
Create an LLMInvoker instance from `nn_config`.
|
||||
"""
|
||||
invoker = cls(nn_config)
|
||||
return invoker
|
||||
75
python/nn4k/nn4k/invoker/openai.py
Normal file
75
python/nn4k/nn4k/invoker/openai.py
Normal file
@ -0,0 +1,75 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from nn4k.invoker import NNInvoker
|
||||
|
||||
|
||||
class OpenAIInvoker(NNInvoker):
|
||||
def __init__(self, nn_config: dict):
|
||||
super().__init__(nn_config)
|
||||
|
||||
import openai
|
||||
from nn4k.consts import NN_OPENAI_MODEL_NAME_KEY, NN_OPENAI_MODEL_NAME_TEXT
|
||||
from nn4k.consts import NN_OPENAI_API_KEY_KEY, NN_OPENAI_API_KEY_TEXT
|
||||
from nn4k.consts import NN_OPENAI_API_BASE_KEY, NN_OPENAI_API_BASE_TEXT
|
||||
from nn4k.consts import NN_OPENAI_MAX_TOKENS_KEY, NN_OPENAI_MAX_TOKENS_TEXT
|
||||
from nn4k.utils.config_parsing import get_string_field
|
||||
from nn4k.utils.config_parsing import get_positive_int_field
|
||||
|
||||
self.openai_model_name = get_string_field(
|
||||
self.init_args, NN_OPENAI_MODEL_NAME_KEY, NN_OPENAI_MODEL_NAME_TEXT
|
||||
)
|
||||
self.openai_api_key = get_string_field(
|
||||
self.init_args, NN_OPENAI_API_KEY_KEY, NN_OPENAI_API_KEY_TEXT
|
||||
)
|
||||
self.openai_api_base = get_string_field(
|
||||
self.init_args, NN_OPENAI_API_BASE_KEY, NN_OPENAI_API_BASE_TEXT
|
||||
)
|
||||
self.openai_max_tokens = get_positive_int_field(
|
||||
self.init_args, NN_OPENAI_MAX_TOKENS_KEY, NN_OPENAI_MAX_TOKENS_TEXT
|
||||
)
|
||||
|
||||
openai.api_key = self.openai_api_key
|
||||
openai.api_base = self.openai_api_base
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, nn_config: dict) -> "OpenAIInvoker":
|
||||
invoker = cls(nn_config)
|
||||
return invoker
|
||||
|
||||
def _create_prompt(self, input, **kwargs):
|
||||
if isinstance(input, list):
|
||||
prompt = input
|
||||
else:
|
||||
prompt = [input]
|
||||
return prompt
|
||||
|
||||
def _create_output(self, input, prompt, completion, **kwargs):
|
||||
output = [choice.text for choice in completion.choices]
|
||||
return output
|
||||
|
||||
def remote_inference(
|
||||
self, input, max_output_length: Optional[int] = None, **kwargs
|
||||
):
|
||||
import openai
|
||||
|
||||
if max_output_length is None:
|
||||
max_output_length = self.openai_max_tokens
|
||||
prompt = self._create_prompt(input, **kwargs)
|
||||
completion = openai.Completion.create(
|
||||
model=self.openai_model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=max_output_length,
|
||||
)
|
||||
output = self._create_output(input, prompt, completion, **kwargs)
|
||||
return output
|
||||
165
python/nn4k/nn4k/nnhub/__init__.py
Normal file
165
python/nn4k/nn4k/nnhub/__init__.py
Normal file
@ -0,0 +1,165 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union, Tuple, Type
|
||||
|
||||
from nn4k.executor import NNExecutor
|
||||
|
||||
|
||||
class NNHub(ABC):
|
||||
_hub_instance = None
|
||||
|
||||
@staticmethod
|
||||
def get_instance() -> "NNHub":
|
||||
"""
|
||||
Get the NNHub instance. If the instance is not initialized, create a stub `SimpleNNHub`.
|
||||
"""
|
||||
if NNHub._hub_instance is None:
|
||||
NNHub._hub_instance = SimpleNNHub()
|
||||
return NNHub._hub_instance
|
||||
|
||||
@abstractmethod
|
||||
def publish(
|
||||
self,
|
||||
model_executor: Union[NNExecutor, Tuple[Type[NNExecutor], tuple, dict, tuple]],
|
||||
name: str,
|
||||
version: str = None,
|
||||
) -> str:
|
||||
"""
|
||||
Publish a model(executor) to hub.
|
||||
|
||||
:param model_executor: An NNExecutor object, which is pickleable.
|
||||
Or a tuple of (class, init_args, kwargs, weight_ids) for creating an NNExecutor,
|
||||
while all these 4 augments are pickleable.
|
||||
|
||||
:param str name: The name of a model, like `llama2`. We do not have a `namespace`.
|
||||
Use a joined name like `alibaba/qwen` to support such features.
|
||||
|
||||
:param str version: Optional. Auto generate a version if this param is not given.
|
||||
|
||||
:return: The published model version.
|
||||
:rtype: str
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model_executor(
|
||||
self, name: str, version: str = None
|
||||
) -> Optional[NNExecutor]:
|
||||
"""
|
||||
Get an NNExecutor instance from Hub.
|
||||
|
||||
:param str name: The name of a model.
|
||||
:param str version: The version of a model. Get default version of a model if this param is not given.
|
||||
:return: The ModelExecutor Instance. None for NotFound.
|
||||
:rtype: Optional[NNExecutor]
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_invoker(self, nn_config: dict) -> Optional["NNInvoker"]:
|
||||
"""
|
||||
Get an NNExecutor instance from Hub.
|
||||
|
||||
:param dict nn_config: The config dictionary.
|
||||
:return: The NNExecutor Instance. None for NotFound.
|
||||
:rtype: Optional[NNInvoker]
|
||||
"""
|
||||
pass
|
||||
|
||||
def start_service(self, name: str, version: str, service_id: str = None, **kwargs):
|
||||
raise NotImplementedError("This Hub does not support starting model service.")
|
||||
|
||||
def stop_service(self, name: str, version: str, service_id: str = None, **kwargs):
|
||||
raise NotImplementedError("This Hub does not support stopping model service.")
|
||||
|
||||
def get_service(self, name: str, version: str, service_id: str = None):
|
||||
raise NotImplementedError("This Hub does not support model services.")
|
||||
|
||||
|
||||
class SimpleNNHub(NNHub):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._model_executors = {}
|
||||
|
||||
def _add_executor(
|
||||
self,
|
||||
executor: Union[NNExecutor, Tuple[Type[NNExecutor], tuple, dict, tuple]],
|
||||
name: str,
|
||||
version: str = None,
|
||||
):
|
||||
from nn4k.consts import NN_VERSION_DEFAULT
|
||||
|
||||
if version is None:
|
||||
version = NN_VERSION_DEFAULT
|
||||
if self._model_executors.get(name) is None:
|
||||
self._model_executors[name] = {version: executor}
|
||||
else:
|
||||
self._model_executors[name][version] = executor
|
||||
|
||||
def publish(
|
||||
self, model_executor: NNExecutor, name: str, version: str = None
|
||||
) -> str:
|
||||
from nn4k.consts import NN_VERSION_DEFAULT
|
||||
|
||||
print(
|
||||
"WARNING: You are using SimpleNNHub which can only maintain models in memory without data persistence!"
|
||||
)
|
||||
if version is None:
|
||||
version = NN_VERSION_DEFAULT
|
||||
self._add_executor(model_executor, name, version)
|
||||
return version
|
||||
|
||||
def _create_model_executor(self, cls, init_args, kwargs, weights):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_model_executor(
|
||||
self, name: str, version: str = None
|
||||
) -> Optional[NNExecutor]:
|
||||
if self._model_executors.get(name) is None:
|
||||
return None
|
||||
executor = self._model_executors.get(name).get(version)
|
||||
if isinstance(executor, NNExecutor):
|
||||
return executor
|
||||
cls, init_args, kwargs, weights = executor
|
||||
executor = self._create_model_executor(cls, init_args, kwargs, weights)
|
||||
return executor
|
||||
|
||||
def _add_local_executor(self, nn_config):
|
||||
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
|
||||
from nn4k.executor.hugging_face import HfLLMExecutor
|
||||
from nn4k.utils.config_parsing import get_string_field
|
||||
|
||||
executor = HfLLMExecutor.from_config(nn_config)
|
||||
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
|
||||
nn_version = nn_config.get(NN_VERSION_KEY)
|
||||
if nn_version is not None:
|
||||
nn_version = get_string_field(nn_config, NN_VERSION_KEY, NN_VERSION_TEXT)
|
||||
self.publish(executor, nn_name, nn_version)
|
||||
|
||||
def get_invoker(self, nn_config: dict) -> Optional["NNInvoker"]:
|
||||
from nn4k.invoker import LLMInvoker
|
||||
from nn4k.invoker.openai import OpenAIInvoker
|
||||
from nn4k.utils.invoker_checking import is_openai_invoker
|
||||
from nn4k.utils.invoker_checking import is_local_invoker
|
||||
|
||||
if is_openai_invoker(nn_config):
|
||||
invoker = OpenAIInvoker.from_config(nn_config)
|
||||
return invoker
|
||||
|
||||
if is_local_invoker(nn_config):
|
||||
invoker = LLMInvoker.from_config(nn_config)
|
||||
self._add_local_executor(nn_config)
|
||||
return invoker
|
||||
|
||||
return None
|
||||
10
python/nn4k/nn4k/utils/__init__.py
Normal file
10
python/nn4k/nn4k/utils/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
57
python/nn4k/nn4k/utils/class_importing.py
Normal file
57
python/nn4k/nn4k/utils/class_importing.py
Normal file
@ -0,0 +1,57 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
import importlib
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def split_module_class_name(name: str, text: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Split `name` as module name and class name pair.
|
||||
|
||||
:param name: fully qualified class name, e.g. ``foo.bar.MyClass``
|
||||
:type name: str
|
||||
:param text: describe the kind of the class, used in the exception message
|
||||
:type text: str
|
||||
:rtype: Tuple[str, str]
|
||||
:raises RuntimeError: if `name` is not a fully qualified class name
|
||||
"""
|
||||
i = name.rfind(".")
|
||||
if i == -1:
|
||||
message = "invalid %s class name: %s" % (text, name)
|
||||
raise RuntimeError(message)
|
||||
module_name = name[:i]
|
||||
class_name = name[i + 1 :]
|
||||
return module_name, class_name
|
||||
|
||||
|
||||
def dynamic_import_class(name: str, text: str):
|
||||
"""
|
||||
Import the class specified by `name` dyanmically.
|
||||
|
||||
:param name: fully qualified class name, e.g. ``foo.bar.MyClass``
|
||||
:type name: str
|
||||
:param text: describe the kind of the class, use in the exception message
|
||||
:type text: str
|
||||
:raises RuntimeError: if `name` is not a fully qualified class name, or
|
||||
the class is not in the module specified by `name`
|
||||
:raises ModuleNotFoundError: the module specified by `name` is not found
|
||||
"""
|
||||
module_name, class_name = split_module_class_name(name, text)
|
||||
module = importlib.import_module(module_name)
|
||||
class_ = getattr(module, class_name, None)
|
||||
if class_ is None:
|
||||
message = "class %r not found in module %r" % (class_name, module_name)
|
||||
raise RuntimeError(message)
|
||||
if not isinstance(class_, type):
|
||||
message = "%r is not a class" % (name,)
|
||||
raise RuntimeError(message)
|
||||
return class_
|
||||
118
python/nn4k/nn4k/utils/config_parsing.py
Normal file
118
python/nn4k/nn4k/utils/config_parsing.py
Normal file
@ -0,0 +1,118 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
import json
|
||||
|
||||
from typing import Any
|
||||
from typing import Union
|
||||
|
||||
|
||||
def preprocess_config(nn_config: Union[str, dict]) -> dict:
|
||||
"""
|
||||
Preprocess config `nn_config` into a dictionary.
|
||||
|
||||
* If `nn_config` is already a dictionary, return it as is.
|
||||
|
||||
* If `nn_config` is a string, decode it as a JSON file.
|
||||
|
||||
:param nn_config: config to be preprocessed
|
||||
:type nn_config: str or dict
|
||||
:return: `nn_config` or `nn_config` decoded as JSON
|
||||
:rtype: dict
|
||||
:raises ValueError: if cannot decode config file specified by
|
||||
`nn_config` as JSON
|
||||
"""
|
||||
try:
|
||||
if isinstance(nn_config, str):
|
||||
with open(nn_config, "r") as f:
|
||||
nn_config = json.load(f)
|
||||
except:
|
||||
raise ValueError("cannot decode config file")
|
||||
return nn_config
|
||||
|
||||
|
||||
def get_field(nn_config: dict, name: str, text: str) -> Any:
|
||||
"""
|
||||
Get the value of the field specified by `name` from the configuration
|
||||
dictionary `nn_config`.
|
||||
|
||||
:param str name: name of the field
|
||||
:param str name: descriptive text of the name of the field
|
||||
:return: value of the field
|
||||
:rtype: Any
|
||||
:raises ValueError: if the field is not specified in `nn_config`
|
||||
"""
|
||||
value = nn_config.get(name)
|
||||
if value is None:
|
||||
message = "%s %r not found" % (text, name)
|
||||
raise ValueError(message)
|
||||
return value
|
||||
|
||||
|
||||
def get_string_field(nn_config: dict, name: str, text: str) -> str:
|
||||
"""
|
||||
Get the value of the string field specified by `name` from the
|
||||
configuration dictionary `nn_config`.
|
||||
|
||||
:param str name: name of the field
|
||||
:param str name: descriptive text of the name of the field
|
||||
:return: value of the field
|
||||
:rtype: str
|
||||
:raises ValueError: if the field is not specified in `nn_config`
|
||||
:raises TypeError: if the value of the field is not a string
|
||||
"""
|
||||
value = get_field(nn_config, name, text)
|
||||
if not isinstance(value, str):
|
||||
message = "%s %r must be string; " % (text, name)
|
||||
message += "%r is invalid" % (value,)
|
||||
raise TypeError(message)
|
||||
return value
|
||||
|
||||
|
||||
def get_int_field(nn_config: dict, name: str, text: str) -> int:
|
||||
"""
|
||||
Get the value of the integer field specified by `name` from the
|
||||
configuration dictionary `nn_config`.
|
||||
|
||||
:param str name: name of the field
|
||||
:param str name: descriptive text of the name of the field
|
||||
:return: value of the field
|
||||
:rtype: int
|
||||
:raises ValueError: if the field is not specified in `nn_config`
|
||||
:raises TypeError: if the value of the field is not an integer
|
||||
"""
|
||||
value = get_field(nn_config, name, text)
|
||||
if not isinstance(value, int):
|
||||
message = "%s %r must be integer; " % (text, name)
|
||||
message += "%r is invalid" % (value,)
|
||||
raise TypeError(message)
|
||||
return value
|
||||
|
||||
|
||||
def get_positive_int_field(nn_config: dict, name: str, text: str) -> int:
|
||||
"""
|
||||
Get the value of the positive integer field specified by `name`
|
||||
from the configuration dictionary `nn_config`.
|
||||
|
||||
:param str name: name of the field
|
||||
:param str name: descriptive text of the name of the field
|
||||
:return: value of the field
|
||||
:rtype: int
|
||||
:raises ValueError: if the field is not specified in `nn_config`, or the
|
||||
value of the field is not a positive integer
|
||||
:raises TypeError: if the value of the field is not an integer
|
||||
"""
|
||||
value = get_int_field(nn_config, name, text)
|
||||
if value <= 0:
|
||||
message = "%s %r must be positive integer; " % (text, name)
|
||||
message += "%r is invalid" % (value,)
|
||||
raise ValueError(message)
|
||||
return value
|
||||
62
python/nn4k/nn4k/utils/invoker_checking.py
Normal file
62
python/nn4k/nn4k/utils/invoker_checking.py
Normal file
@ -0,0 +1,62 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def is_openai_invoker(nn_config: dict) -> bool:
|
||||
"""
|
||||
Check whether `nn_config` specifies OpenAI invoker.
|
||||
|
||||
:type nn_config: dict
|
||||
:rtype: bool
|
||||
"""
|
||||
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||
from nn4k.consts import NN_OPENAI_API_KEY_KEY
|
||||
from nn4k.consts import NN_OPENAI_API_BASE_KEY
|
||||
from nn4k.consts import NN_OPENAI_MAX_TOKENS_KEY
|
||||
from nn4k.consts import NN_OPENAI_GPT4_PREFIX
|
||||
from nn4k.consts import NN_OPENAI_GPT35_PREFIX
|
||||
from nn4k.utils.config_parsing import get_string_field
|
||||
|
||||
nn_name = nn_config.get(NN_NAME_KEY)
|
||||
if nn_name is not None:
|
||||
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
|
||||
if nn_name.startswith(NN_OPENAI_GPT4_PREFIX) or nn_name.startswith(
|
||||
NN_OPENAI_GPT35_PREFIX
|
||||
):
|
||||
return True
|
||||
keys = (NN_OPENAI_API_KEY_KEY, NN_OPENAI_API_BASE_KEY, NN_OPENAI_MAX_TOKENS_KEY)
|
||||
for key in keys:
|
||||
if key in nn_config:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_local_invoker(nn_config: dict) -> bool:
|
||||
"""
|
||||
Check whether `nn_config` specifies local invoker.
|
||||
|
||||
:type nn_config: dict
|
||||
:rtype: bool
|
||||
"""
|
||||
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||
from nn4k.consts import NN_LOCAL_HF_MODEL_CONFIG_FILE
|
||||
from nn4k.utils.config_parsing import get_string_field
|
||||
|
||||
nn_name = nn_config.get(NN_NAME_KEY)
|
||||
if nn_name is not None:
|
||||
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
|
||||
if os.path.isdir(nn_name):
|
||||
file_path = os.path.join(nn_name, NN_LOCAL_HF_MODEL_CONFIG_FILE)
|
||||
if os.path.isfile(file_path):
|
||||
return True
|
||||
return False
|
||||
1
python/nn4k/requirements.txt
Normal file
1
python/nn4k/requirements.txt
Normal file
@ -0,0 +1 @@
|
||||
openai<1
|
||||
78
python/nn4k/setup.py
Normal file
78
python/nn4k/setup.py
Normal file
@ -0,0 +1,78 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
import os
|
||||
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
package_name = "openspg-nn4k"
|
||||
|
||||
# version
|
||||
cwd = os.path.abspath(os.path.dirname(__file__))
|
||||
with open(os.path.join(cwd, "NN4K_VERSION"), "r") as rf:
|
||||
version = rf.readline().strip("\n").strip()
|
||||
|
||||
# license
|
||||
license = ""
|
||||
with open(os.path.join(cwd, "LICENSE"), "r") as rf:
|
||||
line = rf.readline()
|
||||
while line:
|
||||
line = line.strip()
|
||||
if line:
|
||||
license += "# " + line + "\n"
|
||||
else:
|
||||
license += "#\n"
|
||||
line = rf.readline()
|
||||
|
||||
# Generate nn4k.__init__.py
|
||||
with open(os.path.join(cwd, "nn4k/__init__.py"), "w") as wf:
|
||||
content = f"""{license}
|
||||
|
||||
__package_name__ = "{package_name}"
|
||||
__version__ = "{version}"
|
||||
"""
|
||||
wf.write(content)
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
version=version,
|
||||
description="nn4k",
|
||||
url="https://github.com/OpenSPG/openspg",
|
||||
packages=find_packages(
|
||||
where=".",
|
||||
exclude=[
|
||||
".*test.py",
|
||||
"*_test.py",
|
||||
"*_debug.py",
|
||||
"*.txt",
|
||||
"tests",
|
||||
"tests.*",
|
||||
"configs",
|
||||
"configs.*",
|
||||
"test",
|
||||
"test.*",
|
||||
"*.tests",
|
||||
"*.tests.*",
|
||||
"*.pyc",
|
||||
],
|
||||
),
|
||||
python_requires=">=3.8",
|
||||
install_requires=[
|
||||
r.strip()
|
||||
for r in open("requirements.txt", "r")
|
||||
if not r.strip().startswith("#")
|
||||
],
|
||||
include_package_data=True,
|
||||
package_data={
|
||||
"bin": ["*"],
|
||||
},
|
||||
)
|
||||
10
python/nn4k/tests/__init__.py
Normal file
10
python/nn4k/tests/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
10
python/nn4k/tests/executor/__init__.py
Normal file
10
python/nn4k/tests/executor/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
69
python/nn4k/tests/executor/test_base_executor.py
Normal file
69
python/nn4k/tests/executor/test_base_executor.py
Normal file
@ -0,0 +1,69 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
|
||||
class TestBaseExecutor(unittest.TestCase):
|
||||
"""
|
||||
NNExecutor and LLMExecutor unittest
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
# for importing test_stub.py
|
||||
dir_path = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, dir_path)
|
||||
|
||||
from nn4k.nnhub import NNHub
|
||||
from test_stub import StubHub
|
||||
|
||||
NNHub._hub_instance = StubHub()
|
||||
|
||||
def tearDown(self):
|
||||
from nn4k.nnhub import NNHub
|
||||
|
||||
sys.path.pop(0)
|
||||
NNHub._hub_instance = None
|
||||
|
||||
def testCustomNNExecutor(self):
|
||||
from nn4k.executor import NNExecutor
|
||||
from test_stub import StubExecutor
|
||||
|
||||
nn_config = {"nn_executor": "test_stub.StubExecutor"}
|
||||
executor = NNExecutor.from_config(nn_config)
|
||||
self.assertTrue(isinstance(executor, StubExecutor))
|
||||
self.assertEqual(executor.init_args, nn_config)
|
||||
self.assertEqual(executor.kwargs, {})
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
executor = NNExecutor.from_config({"nn_executor": "test_stub.NotExecutor"})
|
||||
|
||||
def testHubExecutor(self):
|
||||
from nn4k.executor import NNExecutor
|
||||
from test_stub import StubExecutor
|
||||
|
||||
nn_config = {"nn_name": "test_stub", "nn_version": "default"}
|
||||
executor = NNExecutor.from_config(nn_config)
|
||||
self.assertTrue(isinstance(executor, StubExecutor))
|
||||
self.assertEqual(executor.init_args, nn_config)
|
||||
self.assertEqual(executor.kwargs, {"test_stub_executor": True})
|
||||
|
||||
def testExecutorNotExists(self):
|
||||
from nn4k.executor import NNExecutor
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
executor = NNExecutor.from_config({"nn_name": "not_exists"})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
57
python/nn4k/tests/executor/test_hf_llm_executor.py
Normal file
57
python/nn4k/tests/executor/test_hf_llm_executor.py
Normal file
@ -0,0 +1,57 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
import unittest.mock
|
||||
|
||||
from nn4k.executor.hugging_face import HfLLMExecutor
|
||||
|
||||
|
||||
class TestHfLLMExecutor(unittest.TestCase):
|
||||
"""
|
||||
HfLLMExecutor unittest
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self._saved_torch = sys.modules.get("torch")
|
||||
self._mocked_torch = unittest.mock.MagicMock()
|
||||
sys.modules["torch"] = self._mocked_torch
|
||||
|
||||
self._saved_transformers = sys.modules.get("transformers")
|
||||
self._mocked_transformers = unittest.mock.MagicMock()
|
||||
sys.modules["transformers"] = self._mocked_transformers
|
||||
|
||||
def tearDown(self):
|
||||
del sys.modules["torch"]
|
||||
if self._saved_torch is not None:
|
||||
sys.modules["torch"] = self._saved_torch
|
||||
|
||||
del sys.modules["transformers"]
|
||||
if self._saved_transformers is not None:
|
||||
sys.modules["transformers"] = self._saved_transformers
|
||||
|
||||
def testHfLLMExecutor(self):
|
||||
nn_config = {
|
||||
"nn_name": "/opt/test_model_dir",
|
||||
"nn_version": "default",
|
||||
}
|
||||
|
||||
executor = HfLLMExecutor.from_config(nn_config)
|
||||
executor.load_model()
|
||||
executor.inference("input")
|
||||
|
||||
self._mocked_transformers.AutoTokenizer.from_pretrained.assert_called()
|
||||
self._mocked_transformers.AutoModelForCausalLM.from_pretrained.assert_called()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
52
python/nn4k/tests/executor/test_stub.py
Normal file
52
python/nn4k/tests/executor/test_stub.py
Normal file
@ -0,0 +1,52 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from nn4k.executor import NNExecutor, LLMExecutor
|
||||
from nn4k.nnhub import SimpleNNHub
|
||||
|
||||
|
||||
class StubExecutor(LLMExecutor):
|
||||
def load_model(self, args=None, mode=None, **kwargs):
|
||||
pass
|
||||
|
||||
def warmup_inference(self, args=None, **kwargs):
|
||||
pass
|
||||
|
||||
def inference(self, data, args=None, **kwargs):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, nn_config: dict) -> "StubExecutor":
|
||||
"""
|
||||
Create a StubExecutor instance from `nn_config`.
|
||||
"""
|
||||
executor = cls(nn_config)
|
||||
return executor
|
||||
|
||||
|
||||
class NotExecutor:
|
||||
pass
|
||||
|
||||
|
||||
class StubHub(SimpleNNHub):
|
||||
def get_model_executor(
|
||||
self, name: str, version: str = None
|
||||
) -> Optional[NNExecutor]:
|
||||
if name == "test_stub":
|
||||
if version is None:
|
||||
version = "default"
|
||||
executor = StubExecutor(
|
||||
{"nn_name": name, "nn_version": version}, test_stub_executor=True
|
||||
)
|
||||
return executor
|
||||
return super().get_model_executor(name, version)
|
||||
10
python/nn4k/tests/invoker/__init__.py
Normal file
10
python/nn4k/tests/invoker/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
84
python/nn4k/tests/invoker/test_base_invoker.py
Normal file
84
python/nn4k/tests/invoker/test_base_invoker.py
Normal file
@ -0,0 +1,84 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
|
||||
class TestBaseInvoker(unittest.TestCase):
|
||||
"""
|
||||
NNInvoker and LLMInvoker unittest
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
# for importing test_stub.py
|
||||
dir_path = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, dir_path)
|
||||
|
||||
from nn4k.nnhub import NNHub
|
||||
from test_stub import StubHub
|
||||
|
||||
NNHub._hub_instance = StubHub()
|
||||
|
||||
def tearDown(self):
|
||||
from nn4k.nnhub import NNHub
|
||||
|
||||
sys.path.pop(0)
|
||||
NNHub._hub_instance = None
|
||||
|
||||
def testCustomNNInvoker(self):
|
||||
from nn4k.invoker import NNInvoker
|
||||
from test_stub import StubInvoker
|
||||
|
||||
nn_config = {"nn_invoker": "test_stub.StubInvoker"}
|
||||
invoker = NNInvoker.from_config(nn_config)
|
||||
self.assertTrue(isinstance(invoker, StubInvoker))
|
||||
self.assertEqual(invoker.init_args, nn_config)
|
||||
self.assertEqual(invoker.kwargs, {})
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
invoker = NNInvoker.from_config({"nn_invoker": "test_stub.NotInvoker"})
|
||||
|
||||
def testHubInvoker(self):
|
||||
from nn4k.invoker import NNInvoker
|
||||
from test_stub import StubInvoker
|
||||
|
||||
nn_config = {"nn_name": "test_stub"}
|
||||
invoker = NNInvoker.from_config(nn_config)
|
||||
self.assertTrue(isinstance(invoker, StubInvoker))
|
||||
self.assertEqual(invoker.init_args, nn_config)
|
||||
self.assertEqual(invoker.kwargs, {"test_stub_invoker": True})
|
||||
|
||||
def testInvokerNotExists(self):
|
||||
from nn4k.invoker import NNInvoker
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
invoker = NNInvoker.from_config({"nn_name": "not_exists"})
|
||||
|
||||
def testLocalInvoker(self):
|
||||
from nn4k.invoker import NNInvoker
|
||||
from test_stub import StubInvoker
|
||||
|
||||
nn_config = {"nn_name": "test_stub"}
|
||||
invoker = NNInvoker.from_config(nn_config)
|
||||
self.assertTrue(isinstance(invoker, StubInvoker))
|
||||
self.assertEqual(invoker.init_args, nn_config)
|
||||
self.assertEqual(invoker.kwargs, {"test_stub_invoker": True})
|
||||
|
||||
invoker.warmup_local_model()
|
||||
invoker._nn_executor.inference_result = "inference result"
|
||||
result = invoker.local_inference("input")
|
||||
self.assertEqual(result, invoker._nn_executor.inference_result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
70
python/nn4k/tests/invoker/test_openai_invoker.py
Normal file
70
python/nn4k/tests/invoker/test_openai_invoker.py
Normal file
@ -0,0 +1,70 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
import unittest.mock
|
||||
from dataclasses import dataclass
|
||||
|
||||
from nn4k.invoker import NNInvoker
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockCompletion:
|
||||
choices: list
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockChoice:
|
||||
text: str
|
||||
|
||||
|
||||
class TestOpenAIInvoker(unittest.TestCase):
|
||||
"""
|
||||
OpenAIInvoker unittest
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self._saved_openai = sys.modules.get("openai")
|
||||
self._mocked_openai = unittest.mock.MagicMock()
|
||||
sys.modules["openai"] = self._mocked_openai
|
||||
|
||||
def tearDown(self):
|
||||
del sys.modules["openai"]
|
||||
if self._saved_openai is not None:
|
||||
sys.modules["openai"] = self._saved_openai
|
||||
|
||||
def testOpenAIInvoker(self):
|
||||
nn_config = {
|
||||
"nn_name": "gpt-3.5-turbo",
|
||||
"openai_api_key": "EMPTY",
|
||||
"openai_api_base": "http://localhost:38080/v1",
|
||||
"openai_max_tokens": 2000,
|
||||
}
|
||||
invoker = NNInvoker.from_config(nn_config)
|
||||
self.assertEqual(invoker.init_args, nn_config)
|
||||
self.assertEqual(self._mocked_openai.api_key, nn_config["openai_api_key"])
|
||||
self.assertEqual(self._mocked_openai.api_base, nn_config["openai_api_base"])
|
||||
|
||||
mock_completion = MockCompletion(choices=[MockChoice("a dog named Bolt ...")])
|
||||
self._mocked_openai.Completion.create.return_value = mock_completion
|
||||
|
||||
result = invoker.remote_inference("Long long ago, ")
|
||||
self._mocked_openai.Completion.create.assert_called_with(
|
||||
prompt=["Long long ago, "],
|
||||
model=nn_config["nn_name"],
|
||||
max_tokens=nn_config["openai_max_tokens"],
|
||||
)
|
||||
self.assertEqual(result, [mock_completion.choices[0].text])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
70
python/nn4k/tests/invoker/test_stub.py
Normal file
70
python/nn4k/tests/invoker/test_stub.py
Normal file
@ -0,0 +1,70 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from nn4k.invoker import NNInvoker, LLMInvoker
|
||||
from nn4k.executor import NNExecutor
|
||||
from nn4k.nnhub import SimpleNNHub
|
||||
|
||||
|
||||
class StubInvoker(LLMInvoker):
|
||||
@classmethod
|
||||
def from_config(cls, nn_config: dict) -> "StubInvoker":
|
||||
"""
|
||||
Create a StubInvoker instance from `nn_config`.
|
||||
"""
|
||||
invoker = cls(nn_config)
|
||||
return invoker
|
||||
|
||||
|
||||
class NotInvoker:
|
||||
pass
|
||||
|
||||
|
||||
class StubExecutor(NNExecutor):
|
||||
def load_model(self, args=None, mode=None, **kwargs):
|
||||
self.load_model_called = True
|
||||
|
||||
def warmup_inference(self, args=None, **kwargs):
|
||||
self.warmup_inference_called = True
|
||||
|
||||
def inference(self, data, args=None, **kwargs):
|
||||
return self.inference_result
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, nn_config: dict) -> "StubExecutor":
|
||||
"""
|
||||
Create a StubExecutor instance from `nn_config`.
|
||||
"""
|
||||
executor = cls(nn_config)
|
||||
return executor
|
||||
|
||||
|
||||
class StubHub(SimpleNNHub):
|
||||
def get_invoker(self, nn_config: dict) -> Optional[NNInvoker]:
|
||||
nn_name = nn_config.get("nn_name")
|
||||
if nn_name is not None and nn_name == "test_stub":
|
||||
invoker = StubInvoker(nn_config, test_stub_invoker=True)
|
||||
return invoker
|
||||
return super().get_invoker(nn_config)
|
||||
|
||||
def get_model_executor(
|
||||
self, name: str, version: str = None
|
||||
) -> Optional[NNExecutor]:
|
||||
if name == "test_stub":
|
||||
if version is None:
|
||||
version = "default"
|
||||
executor = StubExecutor(
|
||||
{"nn_name": name, "nn_version": version}, test_stub_executor=True
|
||||
)
|
||||
return executor
|
||||
return super().get_model_executor(name, version)
|
||||
10
python/nn4k/tests/nnhub/__init__.py
Normal file
10
python/nn4k/tests/nnhub/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
34
python/nn4k/tests/nnhub/test_base_nnhub.py
Normal file
34
python/nn4k/tests/nnhub/test_base_nnhub.py
Normal file
@ -0,0 +1,34 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestBaseNNHub(unittest.TestCase):
|
||||
"""
|
||||
NNHub and SimpleNNHub unittest
|
||||
|
||||
The interface and implementation of NNHub may be revised later,
|
||||
then unittests will be added.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def testBaseNNHub(self):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
10
python/nn4k/tests/utils/__init__.py
Normal file
10
python/nn4k/tests/utils/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
58
python/nn4k/tests/utils/test_class_importing.py
Normal file
58
python/nn4k/tests/utils/test_class_importing.py
Normal file
@ -0,0 +1,58 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestClassImporting(unittest.TestCase):
|
||||
"""
|
||||
module nn4k.utils.class_importing unittest
|
||||
"""
|
||||
|
||||
def testSplitModuleClassName(self):
|
||||
from nn4k.utils.class_importing import split_module_class_name
|
||||
|
||||
pair = split_module_class_name("foo.bar.Baz", "test")
|
||||
self.assertEqual(pair, ("foo.bar", "Baz"))
|
||||
|
||||
def testSplitModuleClassNameInvalid(self):
|
||||
from nn4k.utils.class_importing import split_module_class_name
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
pair = split_module_class_name("foo", "test")
|
||||
|
||||
def testDynamicImportClass(self):
|
||||
from nn4k.utils.class_importing import dynamic_import_class
|
||||
|
||||
class_ = dynamic_import_class("unittest.TestCase", "test")
|
||||
self.assertEqual(class_, unittest.TestCase)
|
||||
|
||||
def testDynamicImportClassModuleNotFound(self):
|
||||
from nn4k.utils.class_importing import dynamic_import_class
|
||||
|
||||
with self.assertRaises(ModuleNotFoundError):
|
||||
class_ = dynamic_import_class("not_exists.ClassName", "test")
|
||||
|
||||
def testDynamicImportClassClassNotFound(self):
|
||||
from nn4k.utils.class_importing import dynamic_import_class
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
class_ = dynamic_import_class("unittest.NotExists", "test")
|
||||
|
||||
def testDynamicImportClassNotClass(self):
|
||||
from nn4k.utils.class_importing import dynamic_import_class
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
class_ = dynamic_import_class("unittest.mock", "test")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
3
python/nn4k/tests/utils/test_config.json
Normal file
3
python/nn4k/tests/utils/test_config.json
Normal file
@ -0,0 +1,3 @@
|
||||
{
|
||||
"foo": "bar"
|
||||
}
|
||||
103
python/nn4k/tests/utils/test_config_parsing.py
Normal file
103
python/nn4k/tests/utils/test_config_parsing.py
Normal file
@ -0,0 +1,103 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestConfigParsing(unittest.TestCase):
|
||||
"""
|
||||
module nn4k.utils.config_parsing unittest
|
||||
"""
|
||||
|
||||
def testPreprocessConfigFile(self):
|
||||
import os
|
||||
from nn4k.utils.config_parsing import preprocess_config
|
||||
|
||||
dir_path = os.path.dirname(os.path.abspath(__file__))
|
||||
file_path = os.path.join(dir_path, "test_config.json")
|
||||
nn_config = preprocess_config(file_path)
|
||||
self.assertEqual(nn_config, {"foo": "bar"})
|
||||
|
||||
def testPreprocessConfigFileNotExists(self):
|
||||
import os
|
||||
from nn4k.utils.config_parsing import preprocess_config
|
||||
|
||||
dir_path = os.path.dirname(os.path.abspath(__file__))
|
||||
file_path = os.path.join(dir_path, "not_exists.json")
|
||||
with self.assertRaises(ValueError):
|
||||
nn_config = preprocess_config(file_path)
|
||||
|
||||
def testPreprocessConfigDict(self):
|
||||
from nn4k.utils.config_parsing import preprocess_config
|
||||
|
||||
conf = {"foo": "bar"}
|
||||
nn_config = preprocess_config(conf)
|
||||
self.assertEqual(nn_config, conf)
|
||||
|
||||
def testGetField(self):
|
||||
from nn4k.utils.config_parsing import get_field
|
||||
|
||||
nn_config = {"foo": "bar"}
|
||||
value = get_field(nn_config, "foo", "Foo")
|
||||
self.assertEqual(value, "bar")
|
||||
|
||||
def testGetFieldNotExists(self):
|
||||
from nn4k.utils.config_parsing import get_field
|
||||
|
||||
nn_config = {"foo": "bar"}
|
||||
with self.assertRaises(ValueError):
|
||||
value = get_field(nn_config, "not_exists", "not exists")
|
||||
|
||||
def testGetStringField(self):
|
||||
from nn4k.utils.config_parsing import get_string_field
|
||||
|
||||
nn_config = {"foo": "bar"}
|
||||
value = get_string_field(nn_config, "foo", "Foo")
|
||||
self.assertEqual(value, "bar")
|
||||
|
||||
def testGetStringFieldNotString(self):
|
||||
from nn4k.utils.config_parsing import get_string_field
|
||||
|
||||
nn_config = {"foo": "bar", "baz": True}
|
||||
with self.assertRaises(TypeError):
|
||||
value = get_string_field(nn_config, "baz", "Baz")
|
||||
|
||||
def testGetIntField(self):
|
||||
from nn4k.utils.config_parsing import get_int_field
|
||||
|
||||
nn_config = {"foo": "bar", "baz": 1000}
|
||||
value = get_int_field(nn_config, "baz", "Baz")
|
||||
self.assertEqual(value, 1000)
|
||||
|
||||
def testGetIntFieldNotInteger(self):
|
||||
from nn4k.utils.config_parsing import get_int_field
|
||||
|
||||
nn_config = {"foo": "bar", "baz": "quux"}
|
||||
with self.assertRaises(TypeError):
|
||||
value = get_int_field(nn_config, "baz", "Baz")
|
||||
|
||||
def testGetPositiveIntField(self):
|
||||
from nn4k.utils.config_parsing import get_positive_int_field
|
||||
|
||||
nn_config = {"foo": "bar", "baz": 1000}
|
||||
value = get_positive_int_field(nn_config, "baz", "Baz")
|
||||
self.assertEqual(value, 1000)
|
||||
|
||||
def testGetPositiveIntFieldNotPositive(self):
|
||||
from nn4k.utils.config_parsing import get_positive_int_field
|
||||
|
||||
nn_config = {"foo": "bar", "baz": 0}
|
||||
with self.assertRaises(ValueError):
|
||||
value = get_positive_int_field(nn_config, "baz", "Baz")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
49
python/nn4k/tests/utils/test_invoker_checking.py
Normal file
49
python/nn4k/tests/utils/test_invoker_checking.py
Normal file
@ -0,0 +1,49 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestInvokerChecking(unittest.TestCase):
|
||||
"""
|
||||
module nn4k.utils.invoker_checking unittest
|
||||
"""
|
||||
|
||||
def testIsOpenAIInvoker(self):
|
||||
from nn4k.utils.invoker_checking import is_openai_invoker
|
||||
|
||||
self.assertTrue(is_openai_invoker({"nn_name": "gpt-3.5-turbo"}))
|
||||
self.assertTrue(is_openai_invoker({"nn_name": "gpt-4"}))
|
||||
self.assertFalse(is_openai_invoker({"nn_name": "dummy"}))
|
||||
|
||||
self.assertTrue(is_openai_invoker({"openai_api_key": "EMPTY"}))
|
||||
self.assertTrue(
|
||||
is_openai_invoker({"openai_api_base": "http://localhost:38000/v1"})
|
||||
)
|
||||
self.assertTrue(is_openai_invoker({"openai_max_tokens": 1000}))
|
||||
self.assertFalse(is_openai_invoker({"foo": "bar"}))
|
||||
|
||||
def testIsLocalInvoker(self):
|
||||
import os
|
||||
from nn4k.utils.invoker_checking import is_local_invoker
|
||||
|
||||
dir_path = os.path.dirname(os.path.abspath(__file__))
|
||||
self.assertFalse(is_local_invoker({"nn_name": dir_path}))
|
||||
|
||||
model_dir_path = os.path.join(dir_path, "test_model_dir")
|
||||
self.assertTrue(is_local_invoker({"nn_name": model_dir_path}))
|
||||
|
||||
self.assertFalse(is_local_invoker({"nn_name": "/not_exists"}))
|
||||
self.assertFalse(is_local_invoker({"foo": "bar"}))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
1
python/nn4k/tests/utils/test_model_dir/config.json
Normal file
1
python/nn4k/tests/utils/test_model_dir/config.json
Normal file
@ -0,0 +1 @@
|
||||
{}
|
||||
Loading…
x
Reference in New Issue
Block a user