mirror of
https://github.com/OpenSPG/openspg.git
synced 2025-11-06 13:05:28 +00:00
feat(nn4k): implement openai invoker and local hf executor (#57)
Co-authored-by: 基尔 <qy266141@antgroup.com> Co-authored-by: didicout <julin.jl@antgroup.com>
This commit is contained in:
parent
22ea3ee395
commit
6c3f8584ec
@ -42,4 +42,4 @@ header:
|
|||||||
# If you don't want to check dependencies' license compatibility, remove the following part
|
# If you don't want to check dependencies' license compatibility, remove the following part
|
||||||
dependency:
|
dependency:
|
||||||
files:
|
files:
|
||||||
- pom.xml
|
- pom.xml
|
||||||
|
|||||||
3
python/nn4k/.gitignore
vendored
Normal file
3
python/nn4k/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
/*.whl
|
||||||
|
/*.egg-info/
|
||||||
|
/build/
|
||||||
10
python/nn4k/LICENSE
Normal file
10
python/nn4k/LICENSE
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
in compliance with the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
or implied.
|
||||||
2
python/nn4k/MANIFEST.in
Normal file
2
python/nn4k/MANIFEST.in
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
recursive-include nn4k *
|
||||||
|
recursive-exclude nn4k/examples *
|
||||||
1
python/nn4k/NN4K_VERSION
Normal file
1
python/nn4k/NN4K_VERSION
Normal file
@ -0,0 +1 @@
|
|||||||
|
0.0.2-beta1
|
||||||
14
python/nn4k/nn4k/__init__.py
Normal file
14
python/nn4k/nn4k/__init__.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
|
||||||
|
__package_name__ = "openspg-nn4k"
|
||||||
|
__version__ = "0.0.2-beta1"
|
||||||
43
python/nn4k/nn4k/consts/__init__.py
Normal file
43
python/nn4k/nn4k/consts/__init__.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
NN_NAME_KEY = "nn_name"
|
||||||
|
NN_NAME_TEXT = "NN model name"
|
||||||
|
|
||||||
|
NN_VERSION_KEY = "nn_version"
|
||||||
|
NN_VERSION_TEXT = "NN model version"
|
||||||
|
NN_VERSION_DEFAULT = "default"
|
||||||
|
|
||||||
|
NN_INVOKER_KEY = "nn_invoker"
|
||||||
|
NN_INVOKER_TEXT = "NN invoker"
|
||||||
|
|
||||||
|
NN_EXECUTOR_KEY = "nn_executor"
|
||||||
|
NN_EXECUTOR_TEXT = "NN executor"
|
||||||
|
|
||||||
|
NN_DEVICE_KEY = "device"
|
||||||
|
NN_TRUST_REMOTE_CODE_KEY = "trust_remote_code"
|
||||||
|
|
||||||
|
NN_OPENAI_MODEL_NAME_KEY = NN_NAME_KEY
|
||||||
|
NN_OPENAI_MODEL_NAME_TEXT = "openai model name"
|
||||||
|
|
||||||
|
NN_OPENAI_API_KEY_KEY = "openai_api_key"
|
||||||
|
NN_OPENAI_API_KEY_TEXT = "openai api key"
|
||||||
|
|
||||||
|
NN_OPENAI_API_BASE_KEY = "openai_api_base"
|
||||||
|
NN_OPENAI_API_BASE_TEXT = "openai api base"
|
||||||
|
|
||||||
|
NN_OPENAI_MAX_TOKENS_KEY = "openai_max_tokens"
|
||||||
|
NN_OPENAI_MAX_TOKENS_TEXT = "openai max tokens"
|
||||||
|
|
||||||
|
NN_OPENAI_GPT4_PREFIX = "gpt-4"
|
||||||
|
NN_OPENAI_GPT35_PREFIX = "gpt-3.5"
|
||||||
|
|
||||||
|
NN_LOCAL_HF_MODEL_CONFIG_FILE = "config.json"
|
||||||
12
python/nn4k/nn4k/executor/__init__.py
Normal file
12
python/nn4k/nn4k/executor/__init__.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
from nn4k.executor.base import NNExecutor, LLMExecutor
|
||||||
161
python/nn4k/nn4k/executor/base.py
Normal file
161
python/nn4k/nn4k/executor/base.py
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
|
||||||
|
class NNExecutor(ABC):
|
||||||
|
"""
|
||||||
|
Entry point of model execution in a certain pod.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, init_args: dict, **kwargs):
|
||||||
|
self._init_args = init_args
|
||||||
|
self._kwargs = kwargs
|
||||||
|
self._model = None
|
||||||
|
self._tokenizer = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def init_args(self):
|
||||||
|
"""
|
||||||
|
Return the `init_args` passed to the executor constructor.
|
||||||
|
"""
|
||||||
|
return self._init_args
|
||||||
|
|
||||||
|
@property
|
||||||
|
def kwargs(self):
|
||||||
|
"""
|
||||||
|
Return the `kwargs` passed to the executor constructor.
|
||||||
|
"""
|
||||||
|
return self._kwargs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self):
|
||||||
|
"""
|
||||||
|
Return the model object managed by this executor.
|
||||||
|
|
||||||
|
:raises RuntimeError: if the model is not initialized yet
|
||||||
|
"""
|
||||||
|
if self._model is None:
|
||||||
|
message = "model is not initialized yet"
|
||||||
|
raise RuntimeError(message)
|
||||||
|
return self._model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tokenizer(self):
|
||||||
|
"""
|
||||||
|
Return the tokenizer object managed by this executor.
|
||||||
|
|
||||||
|
:raises RuntimeError: if the tokenizer is not initialized yet
|
||||||
|
"""
|
||||||
|
if self._tokenizer is None:
|
||||||
|
message = "tokenizer is not initialized yet"
|
||||||
|
raise RuntimeError(message)
|
||||||
|
return self._tokenizer
|
||||||
|
|
||||||
|
def execute_inference(self, args=None, **kwargs):
|
||||||
|
"""
|
||||||
|
The entry point of batch inference in a certain pod.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.__class__.__name__} does not support batch inference."
|
||||||
|
)
|
||||||
|
|
||||||
|
def inference(self, data, args=None, **kwargs):
|
||||||
|
"""
|
||||||
|
The entry point of inference. Usually for local invokers or model services.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_model(self, args=None, mode=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Implement model loading logic in derived executor classes.
|
||||||
|
|
||||||
|
Implementer should initialize `self._model` and `self._tokenizer`.
|
||||||
|
|
||||||
|
This method will be called by several entry methods in executors and invokers.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def warmup_inference(self, args=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Implement model warming up logic for inference in derived executor classes.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def from_config(cls, nn_config: Union[str, dict]) -> "NNExecutor":
|
||||||
|
"""
|
||||||
|
Create an NN executor instance from `nn_config`.
|
||||||
|
|
||||||
|
This method is abstract, derived class must override it by either
|
||||||
|
creating executor instances or implementating dispatch logic.
|
||||||
|
|
||||||
|
:param nn_config: config to use, can be dictionary or path to a JSON file
|
||||||
|
:type nn_config: str or dict
|
||||||
|
:rtype: NNExecutor
|
||||||
|
"""
|
||||||
|
from nn4k.nnhub import NNHub
|
||||||
|
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||||
|
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
|
||||||
|
from nn4k.consts import NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT
|
||||||
|
from nn4k.utils.config_parsing import preprocess_config
|
||||||
|
from nn4k.utils.config_parsing import get_string_field
|
||||||
|
from nn4k.utils.class_importing import dynamic_import_class
|
||||||
|
|
||||||
|
nn_config = preprocess_config(nn_config)
|
||||||
|
nn_executor = nn_config.get(NN_EXECUTOR_KEY)
|
||||||
|
if nn_executor is not None:
|
||||||
|
nn_executor = get_string_field(nn_config, NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT)
|
||||||
|
executor_class = dynamic_import_class(nn_executor, NN_EXECUTOR_TEXT)
|
||||||
|
if not issubclass(executor_class, NNExecutor):
|
||||||
|
message = "%r is not an %s class" % (nn_executor, NN_EXECUTOR_TEXT)
|
||||||
|
raise RuntimeError(message)
|
||||||
|
executor = executor_class.from_config(nn_config)
|
||||||
|
return executor
|
||||||
|
|
||||||
|
nn_name = nn_config.get(NN_NAME_KEY)
|
||||||
|
if nn_name is not None:
|
||||||
|
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
|
||||||
|
nn_version = nn_config.get(NN_VERSION_KEY)
|
||||||
|
if nn_version is not None:
|
||||||
|
nn_version = get_string_field(nn_config, NN_VERSION_KEY, NN_VERSION_TEXT)
|
||||||
|
if nn_name is not None:
|
||||||
|
hub = NNHub.get_instance()
|
||||||
|
executor = hub.get_model_executor(nn_name, nn_version)
|
||||||
|
if executor is not None:
|
||||||
|
return executor
|
||||||
|
|
||||||
|
message = "can not create executor for NN config"
|
||||||
|
if nn_name is not None:
|
||||||
|
message += "; model: %r" % nn_name
|
||||||
|
if nn_version is not None:
|
||||||
|
message += ", version: %r" % nn_version
|
||||||
|
raise RuntimeError(message)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMExecutor(NNExecutor):
|
||||||
|
def execute_sft(self, args=None, callbacks=None, **kwargs):
|
||||||
|
"""
|
||||||
|
The entry point of SFT execution in a certain pod.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(f"{self.__class__.__name__} does not support SFT.")
|
||||||
|
|
||||||
|
def execute_rl_tuning(self, args=None, callbacks=None, **kwargs):
|
||||||
|
"""
|
||||||
|
The entry point of SFT execution in a certain pod.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.__class__.__name__} does not support RL-Tuning."
|
||||||
|
)
|
||||||
17
python/nn4k/nn4k/executor/deepke.py
Normal file
17
python/nn4k/nn4k/executor/deepke.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
from nn4k.executor import NNExecutor
|
||||||
|
|
||||||
|
|
||||||
|
class DeepKeExecutor(NNExecutor):
|
||||||
|
|
||||||
|
pass
|
||||||
104
python/nn4k/nn4k/executor/hugging_face.py
Normal file
104
python/nn4k/nn4k/executor/hugging_face.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
from nn4k.executor import LLMExecutor
|
||||||
|
|
||||||
|
|
||||||
|
class HfLLMExecutor(LLMExecutor):
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, nn_config: dict) -> "HfLLMExecutor":
|
||||||
|
"""
|
||||||
|
Create an HfLLMExecutor instance from `nn_config`.
|
||||||
|
"""
|
||||||
|
executor = cls(nn_config)
|
||||||
|
return executor
|
||||||
|
|
||||||
|
def execute_sft(self, args=None, callbacks=None, **kwargs):
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.__class__.__name__} will support SFT in the next version."
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_model(self, args=None, **kwargs):
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||||
|
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
|
||||||
|
from nn4k.consts import NN_DEVICE_KEY, NN_TRUST_REMOTE_CODE_KEY
|
||||||
|
from nn4k.utils.config_parsing import get_string_field
|
||||||
|
|
||||||
|
nn_config: dict = args or self.init_args
|
||||||
|
if self._model is None:
|
||||||
|
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
|
||||||
|
nn_version = nn_config.get(NN_VERSION_KEY)
|
||||||
|
if nn_version is not None:
|
||||||
|
nn_version = get_string_field(
|
||||||
|
nn_config, NN_VERSION_KEY, NN_VERSION_TEXT
|
||||||
|
)
|
||||||
|
model_path = nn_name
|
||||||
|
revision = nn_version
|
||||||
|
use_fast_tokenizer = False
|
||||||
|
device = nn_config.get(NN_DEVICE_KEY)
|
||||||
|
trust_remote_code = nn_config.get(NN_TRUST_REMOTE_CODE_KEY, False)
|
||||||
|
if device is None:
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
use_fast=use_fast_tokenizer,
|
||||||
|
revision=revision,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
revision=revision,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
self._tokenizer = tokenizer
|
||||||
|
self._model = model
|
||||||
|
|
||||||
|
def inference(
|
||||||
|
self,
|
||||||
|
data,
|
||||||
|
max_input_length: int = 1024,
|
||||||
|
max_output_length: int = 1024,
|
||||||
|
do_sample: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
model = self.model
|
||||||
|
tokenizer = self.tokenizer
|
||||||
|
input_ids = tokenizer(
|
||||||
|
data,
|
||||||
|
padding=True,
|
||||||
|
return_token_type_ids=False,
|
||||||
|
return_tensors="pt",
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_input_length,
|
||||||
|
).to(model.device)
|
||||||
|
output_ids = model.generate(
|
||||||
|
**input_ids,
|
||||||
|
max_new_tokens=max_output_length,
|
||||||
|
do_sample=do_sample,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = [
|
||||||
|
tokenizer.decode(
|
||||||
|
output_id[len(input_ids["input_ids"][idx]) :], skip_special_tokens=True
|
||||||
|
)
|
||||||
|
for idx, output_id in enumerate(output_ids)
|
||||||
|
]
|
||||||
|
return outputs
|
||||||
12
python/nn4k/nn4k/invoker/__init__.py
Normal file
12
python/nn4k/nn4k/invoker/__init__.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
from nn4k.invoker.base import NNInvoker, LLMInvoker
|
||||||
186
python/nn4k/nn4k/invoker/base.py
Normal file
186
python/nn4k/nn4k/invoker/base.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from nn4k.executor import LLMExecutor
|
||||||
|
|
||||||
|
|
||||||
|
class SubmitMode(Enum):
|
||||||
|
K8s = "k8s"
|
||||||
|
Docker = "docker"
|
||||||
|
|
||||||
|
|
||||||
|
class NNInvoker(ABC):
|
||||||
|
"""
|
||||||
|
Invoking Entry Interfaces for NN Models.
|
||||||
|
One NNInvoker object is for one NN Model.
|
||||||
|
- Interfaces starting with "submit_" means submitting a batch task to a remote execution engine.
|
||||||
|
- Interfaces starting with "remote_" means querying a remote service for some results.
|
||||||
|
- Interfaces starting with "local_" means running something locally.
|
||||||
|
Must call `warmup_local_model` before calling any local_xxx interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, init_args: dict, **kwargs):
|
||||||
|
self._init_args = init_args
|
||||||
|
self._kwargs = kwargs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def init_args(self):
|
||||||
|
"""
|
||||||
|
Return the `init_args` passed to the invoker constructor.
|
||||||
|
"""
|
||||||
|
return self._init_args
|
||||||
|
|
||||||
|
@property
|
||||||
|
def kwargs(self):
|
||||||
|
"""
|
||||||
|
Return the `kwargs` passed to the invoker constructor.
|
||||||
|
"""
|
||||||
|
return self._kwargs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def from_config(cls, nn_config: Union[str, dict]) -> "NNInvoker":
|
||||||
|
"""
|
||||||
|
Create an NN invoker instance from `nn_config`.
|
||||||
|
|
||||||
|
This method is abstract, derived class must override it by either
|
||||||
|
creating invoker instances or implementating dispatch logic.
|
||||||
|
|
||||||
|
:param nn_config: config to use, can be dictionary or path to a JSON file
|
||||||
|
:type nn_config: str or dict
|
||||||
|
:rtype: NNInvoker
|
||||||
|
:raises RuntimeError: if the NN config is not recognized
|
||||||
|
"""
|
||||||
|
from nn4k.nnhub import NNHub
|
||||||
|
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||||
|
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
|
||||||
|
from nn4k.consts import NN_INVOKER_KEY, NN_INVOKER_TEXT
|
||||||
|
from nn4k.utils.config_parsing import preprocess_config
|
||||||
|
from nn4k.utils.config_parsing import get_string_field
|
||||||
|
from nn4k.utils.class_importing import dynamic_import_class
|
||||||
|
|
||||||
|
nn_config = preprocess_config(nn_config)
|
||||||
|
nn_invoker = nn_config.get(NN_INVOKER_KEY)
|
||||||
|
if nn_invoker is not None:
|
||||||
|
nn_invoker = get_string_field(nn_config, NN_INVOKER_KEY, NN_INVOKER_TEXT)
|
||||||
|
invoker_class = dynamic_import_class(nn_invoker, NN_INVOKER_TEXT)
|
||||||
|
if not issubclass(invoker_class, NNInvoker):
|
||||||
|
message = "%r is not an %s class" % (nn_invoker, NN_INVOKER_TEXT)
|
||||||
|
raise RuntimeError(message)
|
||||||
|
invoker = invoker_class.from_config(nn_config)
|
||||||
|
return invoker
|
||||||
|
|
||||||
|
hub = NNHub.get_instance()
|
||||||
|
invoker = hub.get_invoker(nn_config)
|
||||||
|
if invoker is not None:
|
||||||
|
return invoker
|
||||||
|
|
||||||
|
nn_name = nn_config.get(NN_NAME_KEY)
|
||||||
|
if nn_name is not None:
|
||||||
|
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
|
||||||
|
nn_version = nn_config.get(NN_VERSION_KEY)
|
||||||
|
if nn_version is not None:
|
||||||
|
nn_version = get_string_field(nn_config, NN_VERSION_KEY, NN_VERSION_TEXT)
|
||||||
|
message = "can not create invoker for NN config"
|
||||||
|
if nn_name is not None:
|
||||||
|
message += "; model: %r" % nn_name
|
||||||
|
if nn_version is not None:
|
||||||
|
message += ", version: %r" % nn_version
|
||||||
|
raise RuntimeError(message)
|
||||||
|
|
||||||
|
def submit_inference(self, submit_mode: SubmitMode = SubmitMode.K8s):
|
||||||
|
"""
|
||||||
|
Submit remote batch inference execution.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.__class__.__name__} does not support batch inference."
|
||||||
|
)
|
||||||
|
|
||||||
|
def remote_inference(self, input, **kwargs):
|
||||||
|
"""
|
||||||
|
Inference via existing remote services.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.__class__.__name__} does not support remote inference."
|
||||||
|
)
|
||||||
|
|
||||||
|
def local_inference(self, data, **kwargs):
|
||||||
|
"""
|
||||||
|
Implement local inference in derived invoker classes.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.__class__.__name__} does not support local inference."
|
||||||
|
)
|
||||||
|
|
||||||
|
def warmup_local_model(self):
|
||||||
|
"""
|
||||||
|
Implement local model warming up logic in derived invoker classes.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LLMInvoker(NNInvoker):
|
||||||
|
def submit_sft(self, submit_mode: SubmitMode = SubmitMode.K8s):
|
||||||
|
"""
|
||||||
|
Submit remote SFT execution.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(f"{self.__class__.__name__} does not support SFT.")
|
||||||
|
|
||||||
|
def submit_rl_tuning(self, submit_mode: SubmitMode = SubmitMode.K8s):
|
||||||
|
"""
|
||||||
|
Submit remote RL-Tuning execution.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.__class__.__name__} does not support RL-Tuning."
|
||||||
|
)
|
||||||
|
|
||||||
|
def local_inference(self, data, **kwargs):
|
||||||
|
"""
|
||||||
|
Implement local inference for local invoker.
|
||||||
|
"""
|
||||||
|
return self._nn_executor.inference(data, **kwargs)
|
||||||
|
|
||||||
|
def warmup_local_model(self):
|
||||||
|
"""
|
||||||
|
Implement local model warming up logic for local invoker.
|
||||||
|
"""
|
||||||
|
from nn4k.nnhub import NNHub
|
||||||
|
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||||
|
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
|
||||||
|
from nn4k.utils.config_parsing import get_string_field
|
||||||
|
|
||||||
|
nn_name = get_string_field(self.init_args, NN_NAME_KEY, NN_NAME_TEXT)
|
||||||
|
nn_version = self.init_args.get(NN_VERSION_KEY)
|
||||||
|
if nn_version is not None:
|
||||||
|
nn_version = get_string_field(
|
||||||
|
self.init_args, NN_VERSION_KEY, NN_VERSION_TEXT
|
||||||
|
)
|
||||||
|
hub = NNHub.get_instance()
|
||||||
|
executor = hub.get_model_executor(nn_name, nn_version)
|
||||||
|
if executor is None:
|
||||||
|
message = "model %r version %r " % (nn_name, nn_version)
|
||||||
|
message += "is not found in the model hub"
|
||||||
|
raise RuntimeError(message)
|
||||||
|
self._nn_executor: LLMExecutor = executor
|
||||||
|
self._nn_executor.load_model()
|
||||||
|
self._nn_executor.warmup_inference()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, nn_config: dict) -> "LLMInvoker":
|
||||||
|
"""
|
||||||
|
Create an LLMInvoker instance from `nn_config`.
|
||||||
|
"""
|
||||||
|
invoker = cls(nn_config)
|
||||||
|
return invoker
|
||||||
75
python/nn4k/nn4k/invoker/openai.py
Normal file
75
python/nn4k/nn4k/invoker/openai.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from nn4k.invoker import NNInvoker
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIInvoker(NNInvoker):
|
||||||
|
def __init__(self, nn_config: dict):
|
||||||
|
super().__init__(nn_config)
|
||||||
|
|
||||||
|
import openai
|
||||||
|
from nn4k.consts import NN_OPENAI_MODEL_NAME_KEY, NN_OPENAI_MODEL_NAME_TEXT
|
||||||
|
from nn4k.consts import NN_OPENAI_API_KEY_KEY, NN_OPENAI_API_KEY_TEXT
|
||||||
|
from nn4k.consts import NN_OPENAI_API_BASE_KEY, NN_OPENAI_API_BASE_TEXT
|
||||||
|
from nn4k.consts import NN_OPENAI_MAX_TOKENS_KEY, NN_OPENAI_MAX_TOKENS_TEXT
|
||||||
|
from nn4k.utils.config_parsing import get_string_field
|
||||||
|
from nn4k.utils.config_parsing import get_positive_int_field
|
||||||
|
|
||||||
|
self.openai_model_name = get_string_field(
|
||||||
|
self.init_args, NN_OPENAI_MODEL_NAME_KEY, NN_OPENAI_MODEL_NAME_TEXT
|
||||||
|
)
|
||||||
|
self.openai_api_key = get_string_field(
|
||||||
|
self.init_args, NN_OPENAI_API_KEY_KEY, NN_OPENAI_API_KEY_TEXT
|
||||||
|
)
|
||||||
|
self.openai_api_base = get_string_field(
|
||||||
|
self.init_args, NN_OPENAI_API_BASE_KEY, NN_OPENAI_API_BASE_TEXT
|
||||||
|
)
|
||||||
|
self.openai_max_tokens = get_positive_int_field(
|
||||||
|
self.init_args, NN_OPENAI_MAX_TOKENS_KEY, NN_OPENAI_MAX_TOKENS_TEXT
|
||||||
|
)
|
||||||
|
|
||||||
|
openai.api_key = self.openai_api_key
|
||||||
|
openai.api_base = self.openai_api_base
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, nn_config: dict) -> "OpenAIInvoker":
|
||||||
|
invoker = cls(nn_config)
|
||||||
|
return invoker
|
||||||
|
|
||||||
|
def _create_prompt(self, input, **kwargs):
|
||||||
|
if isinstance(input, list):
|
||||||
|
prompt = input
|
||||||
|
else:
|
||||||
|
prompt = [input]
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def _create_output(self, input, prompt, completion, **kwargs):
|
||||||
|
output = [choice.text for choice in completion.choices]
|
||||||
|
return output
|
||||||
|
|
||||||
|
def remote_inference(
|
||||||
|
self, input, max_output_length: Optional[int] = None, **kwargs
|
||||||
|
):
|
||||||
|
import openai
|
||||||
|
|
||||||
|
if max_output_length is None:
|
||||||
|
max_output_length = self.openai_max_tokens
|
||||||
|
prompt = self._create_prompt(input, **kwargs)
|
||||||
|
completion = openai.Completion.create(
|
||||||
|
model=self.openai_model_name,
|
||||||
|
prompt=prompt,
|
||||||
|
max_tokens=max_output_length,
|
||||||
|
)
|
||||||
|
output = self._create_output(input, prompt, completion, **kwargs)
|
||||||
|
return output
|
||||||
165
python/nn4k/nn4k/nnhub/__init__.py
Normal file
165
python/nn4k/nn4k/nnhub/__init__.py
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional, Union, Tuple, Type
|
||||||
|
|
||||||
|
from nn4k.executor import NNExecutor
|
||||||
|
|
||||||
|
|
||||||
|
class NNHub(ABC):
|
||||||
|
_hub_instance = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_instance() -> "NNHub":
|
||||||
|
"""
|
||||||
|
Get the NNHub instance. If the instance is not initialized, create a stub `SimpleNNHub`.
|
||||||
|
"""
|
||||||
|
if NNHub._hub_instance is None:
|
||||||
|
NNHub._hub_instance = SimpleNNHub()
|
||||||
|
return NNHub._hub_instance
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def publish(
|
||||||
|
self,
|
||||||
|
model_executor: Union[NNExecutor, Tuple[Type[NNExecutor], tuple, dict, tuple]],
|
||||||
|
name: str,
|
||||||
|
version: str = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Publish a model(executor) to hub.
|
||||||
|
|
||||||
|
:param model_executor: An NNExecutor object, which is pickleable.
|
||||||
|
Or a tuple of (class, init_args, kwargs, weight_ids) for creating an NNExecutor,
|
||||||
|
while all these 4 augments are pickleable.
|
||||||
|
|
||||||
|
:param str name: The name of a model, like `llama2`. We do not have a `namespace`.
|
||||||
|
Use a joined name like `alibaba/qwen` to support such features.
|
||||||
|
|
||||||
|
:param str version: Optional. Auto generate a version if this param is not given.
|
||||||
|
|
||||||
|
:return: The published model version.
|
||||||
|
:rtype: str
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_model_executor(
|
||||||
|
self, name: str, version: str = None
|
||||||
|
) -> Optional[NNExecutor]:
|
||||||
|
"""
|
||||||
|
Get an NNExecutor instance from Hub.
|
||||||
|
|
||||||
|
:param str name: The name of a model.
|
||||||
|
:param str version: The version of a model. Get default version of a model if this param is not given.
|
||||||
|
:return: The ModelExecutor Instance. None for NotFound.
|
||||||
|
:rtype: Optional[NNExecutor]
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_invoker(self, nn_config: dict) -> Optional["NNInvoker"]:
|
||||||
|
"""
|
||||||
|
Get an NNExecutor instance from Hub.
|
||||||
|
|
||||||
|
:param dict nn_config: The config dictionary.
|
||||||
|
:return: The NNExecutor Instance. None for NotFound.
|
||||||
|
:rtype: Optional[NNInvoker]
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def start_service(self, name: str, version: str, service_id: str = None, **kwargs):
|
||||||
|
raise NotImplementedError("This Hub does not support starting model service.")
|
||||||
|
|
||||||
|
def stop_service(self, name: str, version: str, service_id: str = None, **kwargs):
|
||||||
|
raise NotImplementedError("This Hub does not support stopping model service.")
|
||||||
|
|
||||||
|
def get_service(self, name: str, version: str, service_id: str = None):
|
||||||
|
raise NotImplementedError("This Hub does not support model services.")
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleNNHub(NNHub):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._model_executors = {}
|
||||||
|
|
||||||
|
def _add_executor(
|
||||||
|
self,
|
||||||
|
executor: Union[NNExecutor, Tuple[Type[NNExecutor], tuple, dict, tuple]],
|
||||||
|
name: str,
|
||||||
|
version: str = None,
|
||||||
|
):
|
||||||
|
from nn4k.consts import NN_VERSION_DEFAULT
|
||||||
|
|
||||||
|
if version is None:
|
||||||
|
version = NN_VERSION_DEFAULT
|
||||||
|
if self._model_executors.get(name) is None:
|
||||||
|
self._model_executors[name] = {version: executor}
|
||||||
|
else:
|
||||||
|
self._model_executors[name][version] = executor
|
||||||
|
|
||||||
|
def publish(
|
||||||
|
self, model_executor: NNExecutor, name: str, version: str = None
|
||||||
|
) -> str:
|
||||||
|
from nn4k.consts import NN_VERSION_DEFAULT
|
||||||
|
|
||||||
|
print(
|
||||||
|
"WARNING: You are using SimpleNNHub which can only maintain models in memory without data persistence!"
|
||||||
|
)
|
||||||
|
if version is None:
|
||||||
|
version = NN_VERSION_DEFAULT
|
||||||
|
self._add_executor(model_executor, name, version)
|
||||||
|
return version
|
||||||
|
|
||||||
|
def _create_model_executor(self, cls, init_args, kwargs, weights):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_model_executor(
|
||||||
|
self, name: str, version: str = None
|
||||||
|
) -> Optional[NNExecutor]:
|
||||||
|
if self._model_executors.get(name) is None:
|
||||||
|
return None
|
||||||
|
executor = self._model_executors.get(name).get(version)
|
||||||
|
if isinstance(executor, NNExecutor):
|
||||||
|
return executor
|
||||||
|
cls, init_args, kwargs, weights = executor
|
||||||
|
executor = self._create_model_executor(cls, init_args, kwargs, weights)
|
||||||
|
return executor
|
||||||
|
|
||||||
|
def _add_local_executor(self, nn_config):
|
||||||
|
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||||
|
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
|
||||||
|
from nn4k.executor.hugging_face import HfLLMExecutor
|
||||||
|
from nn4k.utils.config_parsing import get_string_field
|
||||||
|
|
||||||
|
executor = HfLLMExecutor.from_config(nn_config)
|
||||||
|
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
|
||||||
|
nn_version = nn_config.get(NN_VERSION_KEY)
|
||||||
|
if nn_version is not None:
|
||||||
|
nn_version = get_string_field(nn_config, NN_VERSION_KEY, NN_VERSION_TEXT)
|
||||||
|
self.publish(executor, nn_name, nn_version)
|
||||||
|
|
||||||
|
def get_invoker(self, nn_config: dict) -> Optional["NNInvoker"]:
|
||||||
|
from nn4k.invoker import LLMInvoker
|
||||||
|
from nn4k.invoker.openai import OpenAIInvoker
|
||||||
|
from nn4k.utils.invoker_checking import is_openai_invoker
|
||||||
|
from nn4k.utils.invoker_checking import is_local_invoker
|
||||||
|
|
||||||
|
if is_openai_invoker(nn_config):
|
||||||
|
invoker = OpenAIInvoker.from_config(nn_config)
|
||||||
|
return invoker
|
||||||
|
|
||||||
|
if is_local_invoker(nn_config):
|
||||||
|
invoker = LLMInvoker.from_config(nn_config)
|
||||||
|
self._add_local_executor(nn_config)
|
||||||
|
return invoker
|
||||||
|
|
||||||
|
return None
|
||||||
10
python/nn4k/nn4k/utils/__init__.py
Normal file
10
python/nn4k/nn4k/utils/__init__.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
57
python/nn4k/nn4k/utils/class_importing.py
Normal file
57
python/nn4k/nn4k/utils/class_importing.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
|
||||||
|
def split_module_class_name(name: str, text: str) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Split `name` as module name and class name pair.
|
||||||
|
|
||||||
|
:param name: fully qualified class name, e.g. ``foo.bar.MyClass``
|
||||||
|
:type name: str
|
||||||
|
:param text: describe the kind of the class, used in the exception message
|
||||||
|
:type text: str
|
||||||
|
:rtype: Tuple[str, str]
|
||||||
|
:raises RuntimeError: if `name` is not a fully qualified class name
|
||||||
|
"""
|
||||||
|
i = name.rfind(".")
|
||||||
|
if i == -1:
|
||||||
|
message = "invalid %s class name: %s" % (text, name)
|
||||||
|
raise RuntimeError(message)
|
||||||
|
module_name = name[:i]
|
||||||
|
class_name = name[i + 1 :]
|
||||||
|
return module_name, class_name
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_import_class(name: str, text: str):
|
||||||
|
"""
|
||||||
|
Import the class specified by `name` dyanmically.
|
||||||
|
|
||||||
|
:param name: fully qualified class name, e.g. ``foo.bar.MyClass``
|
||||||
|
:type name: str
|
||||||
|
:param text: describe the kind of the class, use in the exception message
|
||||||
|
:type text: str
|
||||||
|
:raises RuntimeError: if `name` is not a fully qualified class name, or
|
||||||
|
the class is not in the module specified by `name`
|
||||||
|
:raises ModuleNotFoundError: the module specified by `name` is not found
|
||||||
|
"""
|
||||||
|
module_name, class_name = split_module_class_name(name, text)
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
class_ = getattr(module, class_name, None)
|
||||||
|
if class_ is None:
|
||||||
|
message = "class %r not found in module %r" % (class_name, module_name)
|
||||||
|
raise RuntimeError(message)
|
||||||
|
if not isinstance(class_, type):
|
||||||
|
message = "%r is not a class" % (name,)
|
||||||
|
raise RuntimeError(message)
|
||||||
|
return class_
|
||||||
118
python/nn4k/nn4k/utils/config_parsing.py
Normal file
118
python/nn4k/nn4k/utils/config_parsing.py
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_config(nn_config: Union[str, dict]) -> dict:
|
||||||
|
"""
|
||||||
|
Preprocess config `nn_config` into a dictionary.
|
||||||
|
|
||||||
|
* If `nn_config` is already a dictionary, return it as is.
|
||||||
|
|
||||||
|
* If `nn_config` is a string, decode it as a JSON file.
|
||||||
|
|
||||||
|
:param nn_config: config to be preprocessed
|
||||||
|
:type nn_config: str or dict
|
||||||
|
:return: `nn_config` or `nn_config` decoded as JSON
|
||||||
|
:rtype: dict
|
||||||
|
:raises ValueError: if cannot decode config file specified by
|
||||||
|
`nn_config` as JSON
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if isinstance(nn_config, str):
|
||||||
|
with open(nn_config, "r") as f:
|
||||||
|
nn_config = json.load(f)
|
||||||
|
except:
|
||||||
|
raise ValueError("cannot decode config file")
|
||||||
|
return nn_config
|
||||||
|
|
||||||
|
|
||||||
|
def get_field(nn_config: dict, name: str, text: str) -> Any:
|
||||||
|
"""
|
||||||
|
Get the value of the field specified by `name` from the configuration
|
||||||
|
dictionary `nn_config`.
|
||||||
|
|
||||||
|
:param str name: name of the field
|
||||||
|
:param str name: descriptive text of the name of the field
|
||||||
|
:return: value of the field
|
||||||
|
:rtype: Any
|
||||||
|
:raises ValueError: if the field is not specified in `nn_config`
|
||||||
|
"""
|
||||||
|
value = nn_config.get(name)
|
||||||
|
if value is None:
|
||||||
|
message = "%s %r not found" % (text, name)
|
||||||
|
raise ValueError(message)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def get_string_field(nn_config: dict, name: str, text: str) -> str:
|
||||||
|
"""
|
||||||
|
Get the value of the string field specified by `name` from the
|
||||||
|
configuration dictionary `nn_config`.
|
||||||
|
|
||||||
|
:param str name: name of the field
|
||||||
|
:param str name: descriptive text of the name of the field
|
||||||
|
:return: value of the field
|
||||||
|
:rtype: str
|
||||||
|
:raises ValueError: if the field is not specified in `nn_config`
|
||||||
|
:raises TypeError: if the value of the field is not a string
|
||||||
|
"""
|
||||||
|
value = get_field(nn_config, name, text)
|
||||||
|
if not isinstance(value, str):
|
||||||
|
message = "%s %r must be string; " % (text, name)
|
||||||
|
message += "%r is invalid" % (value,)
|
||||||
|
raise TypeError(message)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def get_int_field(nn_config: dict, name: str, text: str) -> int:
|
||||||
|
"""
|
||||||
|
Get the value of the integer field specified by `name` from the
|
||||||
|
configuration dictionary `nn_config`.
|
||||||
|
|
||||||
|
:param str name: name of the field
|
||||||
|
:param str name: descriptive text of the name of the field
|
||||||
|
:return: value of the field
|
||||||
|
:rtype: int
|
||||||
|
:raises ValueError: if the field is not specified in `nn_config`
|
||||||
|
:raises TypeError: if the value of the field is not an integer
|
||||||
|
"""
|
||||||
|
value = get_field(nn_config, name, text)
|
||||||
|
if not isinstance(value, int):
|
||||||
|
message = "%s %r must be integer; " % (text, name)
|
||||||
|
message += "%r is invalid" % (value,)
|
||||||
|
raise TypeError(message)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def get_positive_int_field(nn_config: dict, name: str, text: str) -> int:
|
||||||
|
"""
|
||||||
|
Get the value of the positive integer field specified by `name`
|
||||||
|
from the configuration dictionary `nn_config`.
|
||||||
|
|
||||||
|
:param str name: name of the field
|
||||||
|
:param str name: descriptive text of the name of the field
|
||||||
|
:return: value of the field
|
||||||
|
:rtype: int
|
||||||
|
:raises ValueError: if the field is not specified in `nn_config`, or the
|
||||||
|
value of the field is not a positive integer
|
||||||
|
:raises TypeError: if the value of the field is not an integer
|
||||||
|
"""
|
||||||
|
value = get_int_field(nn_config, name, text)
|
||||||
|
if value <= 0:
|
||||||
|
message = "%s %r must be positive integer; " % (text, name)
|
||||||
|
message += "%r is invalid" % (value,)
|
||||||
|
raise ValueError(message)
|
||||||
|
return value
|
||||||
62
python/nn4k/nn4k/utils/invoker_checking.py
Normal file
62
python/nn4k/nn4k/utils/invoker_checking.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def is_openai_invoker(nn_config: dict) -> bool:
|
||||||
|
"""
|
||||||
|
Check whether `nn_config` specifies OpenAI invoker.
|
||||||
|
|
||||||
|
:type nn_config: dict
|
||||||
|
:rtype: bool
|
||||||
|
"""
|
||||||
|
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||||
|
from nn4k.consts import NN_OPENAI_API_KEY_KEY
|
||||||
|
from nn4k.consts import NN_OPENAI_API_BASE_KEY
|
||||||
|
from nn4k.consts import NN_OPENAI_MAX_TOKENS_KEY
|
||||||
|
from nn4k.consts import NN_OPENAI_GPT4_PREFIX
|
||||||
|
from nn4k.consts import NN_OPENAI_GPT35_PREFIX
|
||||||
|
from nn4k.utils.config_parsing import get_string_field
|
||||||
|
|
||||||
|
nn_name = nn_config.get(NN_NAME_KEY)
|
||||||
|
if nn_name is not None:
|
||||||
|
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
|
||||||
|
if nn_name.startswith(NN_OPENAI_GPT4_PREFIX) or nn_name.startswith(
|
||||||
|
NN_OPENAI_GPT35_PREFIX
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
keys = (NN_OPENAI_API_KEY_KEY, NN_OPENAI_API_BASE_KEY, NN_OPENAI_MAX_TOKENS_KEY)
|
||||||
|
for key in keys:
|
||||||
|
if key in nn_config:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_local_invoker(nn_config: dict) -> bool:
|
||||||
|
"""
|
||||||
|
Check whether `nn_config` specifies local invoker.
|
||||||
|
|
||||||
|
:type nn_config: dict
|
||||||
|
:rtype: bool
|
||||||
|
"""
|
||||||
|
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||||
|
from nn4k.consts import NN_LOCAL_HF_MODEL_CONFIG_FILE
|
||||||
|
from nn4k.utils.config_parsing import get_string_field
|
||||||
|
|
||||||
|
nn_name = nn_config.get(NN_NAME_KEY)
|
||||||
|
if nn_name is not None:
|
||||||
|
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
|
||||||
|
if os.path.isdir(nn_name):
|
||||||
|
file_path = os.path.join(nn_name, NN_LOCAL_HF_MODEL_CONFIG_FILE)
|
||||||
|
if os.path.isfile(file_path):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
1
python/nn4k/requirements.txt
Normal file
1
python/nn4k/requirements.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
openai<1
|
||||||
78
python/nn4k/setup.py
Normal file
78
python/nn4k/setup.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
package_name = "openspg-nn4k"
|
||||||
|
|
||||||
|
# version
|
||||||
|
cwd = os.path.abspath(os.path.dirname(__file__))
|
||||||
|
with open(os.path.join(cwd, "NN4K_VERSION"), "r") as rf:
|
||||||
|
version = rf.readline().strip("\n").strip()
|
||||||
|
|
||||||
|
# license
|
||||||
|
license = ""
|
||||||
|
with open(os.path.join(cwd, "LICENSE"), "r") as rf:
|
||||||
|
line = rf.readline()
|
||||||
|
while line:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
license += "# " + line + "\n"
|
||||||
|
else:
|
||||||
|
license += "#\n"
|
||||||
|
line = rf.readline()
|
||||||
|
|
||||||
|
# Generate nn4k.__init__.py
|
||||||
|
with open(os.path.join(cwd, "nn4k/__init__.py"), "w") as wf:
|
||||||
|
content = f"""{license}
|
||||||
|
|
||||||
|
__package_name__ = "{package_name}"
|
||||||
|
__version__ = "{version}"
|
||||||
|
"""
|
||||||
|
wf.write(content)
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name=package_name,
|
||||||
|
version=version,
|
||||||
|
description="nn4k",
|
||||||
|
url="https://github.com/OpenSPG/openspg",
|
||||||
|
packages=find_packages(
|
||||||
|
where=".",
|
||||||
|
exclude=[
|
||||||
|
".*test.py",
|
||||||
|
"*_test.py",
|
||||||
|
"*_debug.py",
|
||||||
|
"*.txt",
|
||||||
|
"tests",
|
||||||
|
"tests.*",
|
||||||
|
"configs",
|
||||||
|
"configs.*",
|
||||||
|
"test",
|
||||||
|
"test.*",
|
||||||
|
"*.tests",
|
||||||
|
"*.tests.*",
|
||||||
|
"*.pyc",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
python_requires=">=3.8",
|
||||||
|
install_requires=[
|
||||||
|
r.strip()
|
||||||
|
for r in open("requirements.txt", "r")
|
||||||
|
if not r.strip().startswith("#")
|
||||||
|
],
|
||||||
|
include_package_data=True,
|
||||||
|
package_data={
|
||||||
|
"bin": ["*"],
|
||||||
|
},
|
||||||
|
)
|
||||||
10
python/nn4k/tests/__init__.py
Normal file
10
python/nn4k/tests/__init__.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
10
python/nn4k/tests/executor/__init__.py
Normal file
10
python/nn4k/tests/executor/__init__.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
69
python/nn4k/tests/executor/test_base_executor.py
Normal file
69
python/nn4k/tests/executor/test_base_executor.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseExecutor(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
NNExecutor and LLMExecutor unittest
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
# for importing test_stub.py
|
||||||
|
dir_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
sys.path.insert(0, dir_path)
|
||||||
|
|
||||||
|
from nn4k.nnhub import NNHub
|
||||||
|
from test_stub import StubHub
|
||||||
|
|
||||||
|
NNHub._hub_instance = StubHub()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
from nn4k.nnhub import NNHub
|
||||||
|
|
||||||
|
sys.path.pop(0)
|
||||||
|
NNHub._hub_instance = None
|
||||||
|
|
||||||
|
def testCustomNNExecutor(self):
|
||||||
|
from nn4k.executor import NNExecutor
|
||||||
|
from test_stub import StubExecutor
|
||||||
|
|
||||||
|
nn_config = {"nn_executor": "test_stub.StubExecutor"}
|
||||||
|
executor = NNExecutor.from_config(nn_config)
|
||||||
|
self.assertTrue(isinstance(executor, StubExecutor))
|
||||||
|
self.assertEqual(executor.init_args, nn_config)
|
||||||
|
self.assertEqual(executor.kwargs, {})
|
||||||
|
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
executor = NNExecutor.from_config({"nn_executor": "test_stub.NotExecutor"})
|
||||||
|
|
||||||
|
def testHubExecutor(self):
|
||||||
|
from nn4k.executor import NNExecutor
|
||||||
|
from test_stub import StubExecutor
|
||||||
|
|
||||||
|
nn_config = {"nn_name": "test_stub", "nn_version": "default"}
|
||||||
|
executor = NNExecutor.from_config(nn_config)
|
||||||
|
self.assertTrue(isinstance(executor, StubExecutor))
|
||||||
|
self.assertEqual(executor.init_args, nn_config)
|
||||||
|
self.assertEqual(executor.kwargs, {"test_stub_executor": True})
|
||||||
|
|
||||||
|
def testExecutorNotExists(self):
|
||||||
|
from nn4k.executor import NNExecutor
|
||||||
|
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
executor = NNExecutor.from_config({"nn_name": "not_exists"})
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
57
python/nn4k/tests/executor/test_hf_llm_executor.py
Normal file
57
python/nn4k/tests/executor/test_hf_llm_executor.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
import unittest.mock
|
||||||
|
|
||||||
|
from nn4k.executor.hugging_face import HfLLMExecutor
|
||||||
|
|
||||||
|
|
||||||
|
class TestHfLLMExecutor(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
HfLLMExecutor unittest
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self._saved_torch = sys.modules.get("torch")
|
||||||
|
self._mocked_torch = unittest.mock.MagicMock()
|
||||||
|
sys.modules["torch"] = self._mocked_torch
|
||||||
|
|
||||||
|
self._saved_transformers = sys.modules.get("transformers")
|
||||||
|
self._mocked_transformers = unittest.mock.MagicMock()
|
||||||
|
sys.modules["transformers"] = self._mocked_transformers
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
del sys.modules["torch"]
|
||||||
|
if self._saved_torch is not None:
|
||||||
|
sys.modules["torch"] = self._saved_torch
|
||||||
|
|
||||||
|
del sys.modules["transformers"]
|
||||||
|
if self._saved_transformers is not None:
|
||||||
|
sys.modules["transformers"] = self._saved_transformers
|
||||||
|
|
||||||
|
def testHfLLMExecutor(self):
|
||||||
|
nn_config = {
|
||||||
|
"nn_name": "/opt/test_model_dir",
|
||||||
|
"nn_version": "default",
|
||||||
|
}
|
||||||
|
|
||||||
|
executor = HfLLMExecutor.from_config(nn_config)
|
||||||
|
executor.load_model()
|
||||||
|
executor.inference("input")
|
||||||
|
|
||||||
|
self._mocked_transformers.AutoTokenizer.from_pretrained.assert_called()
|
||||||
|
self._mocked_transformers.AutoModelForCausalLM.from_pretrained.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
52
python/nn4k/tests/executor/test_stub.py
Normal file
52
python/nn4k/tests/executor/test_stub.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from nn4k.executor import NNExecutor, LLMExecutor
|
||||||
|
from nn4k.nnhub import SimpleNNHub
|
||||||
|
|
||||||
|
|
||||||
|
class StubExecutor(LLMExecutor):
|
||||||
|
def load_model(self, args=None, mode=None, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def warmup_inference(self, args=None, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def inference(self, data, args=None, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, nn_config: dict) -> "StubExecutor":
|
||||||
|
"""
|
||||||
|
Create a StubExecutor instance from `nn_config`.
|
||||||
|
"""
|
||||||
|
executor = cls(nn_config)
|
||||||
|
return executor
|
||||||
|
|
||||||
|
|
||||||
|
class NotExecutor:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class StubHub(SimpleNNHub):
|
||||||
|
def get_model_executor(
|
||||||
|
self, name: str, version: str = None
|
||||||
|
) -> Optional[NNExecutor]:
|
||||||
|
if name == "test_stub":
|
||||||
|
if version is None:
|
||||||
|
version = "default"
|
||||||
|
executor = StubExecutor(
|
||||||
|
{"nn_name": name, "nn_version": version}, test_stub_executor=True
|
||||||
|
)
|
||||||
|
return executor
|
||||||
|
return super().get_model_executor(name, version)
|
||||||
10
python/nn4k/tests/invoker/__init__.py
Normal file
10
python/nn4k/tests/invoker/__init__.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
84
python/nn4k/tests/invoker/test_base_invoker.py
Normal file
84
python/nn4k/tests/invoker/test_base_invoker.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseInvoker(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
NNInvoker and LLMInvoker unittest
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
# for importing test_stub.py
|
||||||
|
dir_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
sys.path.insert(0, dir_path)
|
||||||
|
|
||||||
|
from nn4k.nnhub import NNHub
|
||||||
|
from test_stub import StubHub
|
||||||
|
|
||||||
|
NNHub._hub_instance = StubHub()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
from nn4k.nnhub import NNHub
|
||||||
|
|
||||||
|
sys.path.pop(0)
|
||||||
|
NNHub._hub_instance = None
|
||||||
|
|
||||||
|
def testCustomNNInvoker(self):
|
||||||
|
from nn4k.invoker import NNInvoker
|
||||||
|
from test_stub import StubInvoker
|
||||||
|
|
||||||
|
nn_config = {"nn_invoker": "test_stub.StubInvoker"}
|
||||||
|
invoker = NNInvoker.from_config(nn_config)
|
||||||
|
self.assertTrue(isinstance(invoker, StubInvoker))
|
||||||
|
self.assertEqual(invoker.init_args, nn_config)
|
||||||
|
self.assertEqual(invoker.kwargs, {})
|
||||||
|
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
invoker = NNInvoker.from_config({"nn_invoker": "test_stub.NotInvoker"})
|
||||||
|
|
||||||
|
def testHubInvoker(self):
|
||||||
|
from nn4k.invoker import NNInvoker
|
||||||
|
from test_stub import StubInvoker
|
||||||
|
|
||||||
|
nn_config = {"nn_name": "test_stub"}
|
||||||
|
invoker = NNInvoker.from_config(nn_config)
|
||||||
|
self.assertTrue(isinstance(invoker, StubInvoker))
|
||||||
|
self.assertEqual(invoker.init_args, nn_config)
|
||||||
|
self.assertEqual(invoker.kwargs, {"test_stub_invoker": True})
|
||||||
|
|
||||||
|
def testInvokerNotExists(self):
|
||||||
|
from nn4k.invoker import NNInvoker
|
||||||
|
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
invoker = NNInvoker.from_config({"nn_name": "not_exists"})
|
||||||
|
|
||||||
|
def testLocalInvoker(self):
|
||||||
|
from nn4k.invoker import NNInvoker
|
||||||
|
from test_stub import StubInvoker
|
||||||
|
|
||||||
|
nn_config = {"nn_name": "test_stub"}
|
||||||
|
invoker = NNInvoker.from_config(nn_config)
|
||||||
|
self.assertTrue(isinstance(invoker, StubInvoker))
|
||||||
|
self.assertEqual(invoker.init_args, nn_config)
|
||||||
|
self.assertEqual(invoker.kwargs, {"test_stub_invoker": True})
|
||||||
|
|
||||||
|
invoker.warmup_local_model()
|
||||||
|
invoker._nn_executor.inference_result = "inference result"
|
||||||
|
result = invoker.local_inference("input")
|
||||||
|
self.assertEqual(result, invoker._nn_executor.inference_result)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
70
python/nn4k/tests/invoker/test_openai_invoker.py
Normal file
70
python/nn4k/tests/invoker/test_openai_invoker.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
import unittest.mock
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from nn4k.invoker import NNInvoker
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockCompletion:
|
||||||
|
choices: list
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockChoice:
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIInvoker(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
OpenAIInvoker unittest
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self._saved_openai = sys.modules.get("openai")
|
||||||
|
self._mocked_openai = unittest.mock.MagicMock()
|
||||||
|
sys.modules["openai"] = self._mocked_openai
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
del sys.modules["openai"]
|
||||||
|
if self._saved_openai is not None:
|
||||||
|
sys.modules["openai"] = self._saved_openai
|
||||||
|
|
||||||
|
def testOpenAIInvoker(self):
|
||||||
|
nn_config = {
|
||||||
|
"nn_name": "gpt-3.5-turbo",
|
||||||
|
"openai_api_key": "EMPTY",
|
||||||
|
"openai_api_base": "http://localhost:38080/v1",
|
||||||
|
"openai_max_tokens": 2000,
|
||||||
|
}
|
||||||
|
invoker = NNInvoker.from_config(nn_config)
|
||||||
|
self.assertEqual(invoker.init_args, nn_config)
|
||||||
|
self.assertEqual(self._mocked_openai.api_key, nn_config["openai_api_key"])
|
||||||
|
self.assertEqual(self._mocked_openai.api_base, nn_config["openai_api_base"])
|
||||||
|
|
||||||
|
mock_completion = MockCompletion(choices=[MockChoice("a dog named Bolt ...")])
|
||||||
|
self._mocked_openai.Completion.create.return_value = mock_completion
|
||||||
|
|
||||||
|
result = invoker.remote_inference("Long long ago, ")
|
||||||
|
self._mocked_openai.Completion.create.assert_called_with(
|
||||||
|
prompt=["Long long ago, "],
|
||||||
|
model=nn_config["nn_name"],
|
||||||
|
max_tokens=nn_config["openai_max_tokens"],
|
||||||
|
)
|
||||||
|
self.assertEqual(result, [mock_completion.choices[0].text])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
70
python/nn4k/tests/invoker/test_stub.py
Normal file
70
python/nn4k/tests/invoker/test_stub.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from nn4k.invoker import NNInvoker, LLMInvoker
|
||||||
|
from nn4k.executor import NNExecutor
|
||||||
|
from nn4k.nnhub import SimpleNNHub
|
||||||
|
|
||||||
|
|
||||||
|
class StubInvoker(LLMInvoker):
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, nn_config: dict) -> "StubInvoker":
|
||||||
|
"""
|
||||||
|
Create a StubInvoker instance from `nn_config`.
|
||||||
|
"""
|
||||||
|
invoker = cls(nn_config)
|
||||||
|
return invoker
|
||||||
|
|
||||||
|
|
||||||
|
class NotInvoker:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class StubExecutor(NNExecutor):
|
||||||
|
def load_model(self, args=None, mode=None, **kwargs):
|
||||||
|
self.load_model_called = True
|
||||||
|
|
||||||
|
def warmup_inference(self, args=None, **kwargs):
|
||||||
|
self.warmup_inference_called = True
|
||||||
|
|
||||||
|
def inference(self, data, args=None, **kwargs):
|
||||||
|
return self.inference_result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, nn_config: dict) -> "StubExecutor":
|
||||||
|
"""
|
||||||
|
Create a StubExecutor instance from `nn_config`.
|
||||||
|
"""
|
||||||
|
executor = cls(nn_config)
|
||||||
|
return executor
|
||||||
|
|
||||||
|
|
||||||
|
class StubHub(SimpleNNHub):
|
||||||
|
def get_invoker(self, nn_config: dict) -> Optional[NNInvoker]:
|
||||||
|
nn_name = nn_config.get("nn_name")
|
||||||
|
if nn_name is not None and nn_name == "test_stub":
|
||||||
|
invoker = StubInvoker(nn_config, test_stub_invoker=True)
|
||||||
|
return invoker
|
||||||
|
return super().get_invoker(nn_config)
|
||||||
|
|
||||||
|
def get_model_executor(
|
||||||
|
self, name: str, version: str = None
|
||||||
|
) -> Optional[NNExecutor]:
|
||||||
|
if name == "test_stub":
|
||||||
|
if version is None:
|
||||||
|
version = "default"
|
||||||
|
executor = StubExecutor(
|
||||||
|
{"nn_name": name, "nn_version": version}, test_stub_executor=True
|
||||||
|
)
|
||||||
|
return executor
|
||||||
|
return super().get_model_executor(name, version)
|
||||||
10
python/nn4k/tests/nnhub/__init__.py
Normal file
10
python/nn4k/tests/nnhub/__init__.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
34
python/nn4k/tests/nnhub/test_base_nnhub.py
Normal file
34
python/nn4k/tests/nnhub/test_base_nnhub.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseNNHub(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
NNHub and SimpleNNHub unittest
|
||||||
|
|
||||||
|
The interface and implementation of NNHub may be revised later,
|
||||||
|
then unittests will be added.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def testBaseNNHub(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
10
python/nn4k/tests/utils/__init__.py
Normal file
10
python/nn4k/tests/utils/__init__.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
58
python/nn4k/tests/utils/test_class_importing.py
Normal file
58
python/nn4k/tests/utils/test_class_importing.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class TestClassImporting(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
module nn4k.utils.class_importing unittest
|
||||||
|
"""
|
||||||
|
|
||||||
|
def testSplitModuleClassName(self):
|
||||||
|
from nn4k.utils.class_importing import split_module_class_name
|
||||||
|
|
||||||
|
pair = split_module_class_name("foo.bar.Baz", "test")
|
||||||
|
self.assertEqual(pair, ("foo.bar", "Baz"))
|
||||||
|
|
||||||
|
def testSplitModuleClassNameInvalid(self):
|
||||||
|
from nn4k.utils.class_importing import split_module_class_name
|
||||||
|
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
pair = split_module_class_name("foo", "test")
|
||||||
|
|
||||||
|
def testDynamicImportClass(self):
|
||||||
|
from nn4k.utils.class_importing import dynamic_import_class
|
||||||
|
|
||||||
|
class_ = dynamic_import_class("unittest.TestCase", "test")
|
||||||
|
self.assertEqual(class_, unittest.TestCase)
|
||||||
|
|
||||||
|
def testDynamicImportClassModuleNotFound(self):
|
||||||
|
from nn4k.utils.class_importing import dynamic_import_class
|
||||||
|
|
||||||
|
with self.assertRaises(ModuleNotFoundError):
|
||||||
|
class_ = dynamic_import_class("not_exists.ClassName", "test")
|
||||||
|
|
||||||
|
def testDynamicImportClassClassNotFound(self):
|
||||||
|
from nn4k.utils.class_importing import dynamic_import_class
|
||||||
|
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
class_ = dynamic_import_class("unittest.NotExists", "test")
|
||||||
|
|
||||||
|
def testDynamicImportClassNotClass(self):
|
||||||
|
from nn4k.utils.class_importing import dynamic_import_class
|
||||||
|
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
class_ = dynamic_import_class("unittest.mock", "test")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
3
python/nn4k/tests/utils/test_config.json
Normal file
3
python/nn4k/tests/utils/test_config.json
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
{
|
||||||
|
"foo": "bar"
|
||||||
|
}
|
||||||
103
python/nn4k/tests/utils/test_config_parsing.py
Normal file
103
python/nn4k/tests/utils/test_config_parsing.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfigParsing(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
module nn4k.utils.config_parsing unittest
|
||||||
|
"""
|
||||||
|
|
||||||
|
def testPreprocessConfigFile(self):
|
||||||
|
import os
|
||||||
|
from nn4k.utils.config_parsing import preprocess_config
|
||||||
|
|
||||||
|
dir_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
file_path = os.path.join(dir_path, "test_config.json")
|
||||||
|
nn_config = preprocess_config(file_path)
|
||||||
|
self.assertEqual(nn_config, {"foo": "bar"})
|
||||||
|
|
||||||
|
def testPreprocessConfigFileNotExists(self):
|
||||||
|
import os
|
||||||
|
from nn4k.utils.config_parsing import preprocess_config
|
||||||
|
|
||||||
|
dir_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
file_path = os.path.join(dir_path, "not_exists.json")
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
nn_config = preprocess_config(file_path)
|
||||||
|
|
||||||
|
def testPreprocessConfigDict(self):
|
||||||
|
from nn4k.utils.config_parsing import preprocess_config
|
||||||
|
|
||||||
|
conf = {"foo": "bar"}
|
||||||
|
nn_config = preprocess_config(conf)
|
||||||
|
self.assertEqual(nn_config, conf)
|
||||||
|
|
||||||
|
def testGetField(self):
|
||||||
|
from nn4k.utils.config_parsing import get_field
|
||||||
|
|
||||||
|
nn_config = {"foo": "bar"}
|
||||||
|
value = get_field(nn_config, "foo", "Foo")
|
||||||
|
self.assertEqual(value, "bar")
|
||||||
|
|
||||||
|
def testGetFieldNotExists(self):
|
||||||
|
from nn4k.utils.config_parsing import get_field
|
||||||
|
|
||||||
|
nn_config = {"foo": "bar"}
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
value = get_field(nn_config, "not_exists", "not exists")
|
||||||
|
|
||||||
|
def testGetStringField(self):
|
||||||
|
from nn4k.utils.config_parsing import get_string_field
|
||||||
|
|
||||||
|
nn_config = {"foo": "bar"}
|
||||||
|
value = get_string_field(nn_config, "foo", "Foo")
|
||||||
|
self.assertEqual(value, "bar")
|
||||||
|
|
||||||
|
def testGetStringFieldNotString(self):
|
||||||
|
from nn4k.utils.config_parsing import get_string_field
|
||||||
|
|
||||||
|
nn_config = {"foo": "bar", "baz": True}
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
value = get_string_field(nn_config, "baz", "Baz")
|
||||||
|
|
||||||
|
def testGetIntField(self):
|
||||||
|
from nn4k.utils.config_parsing import get_int_field
|
||||||
|
|
||||||
|
nn_config = {"foo": "bar", "baz": 1000}
|
||||||
|
value = get_int_field(nn_config, "baz", "Baz")
|
||||||
|
self.assertEqual(value, 1000)
|
||||||
|
|
||||||
|
def testGetIntFieldNotInteger(self):
|
||||||
|
from nn4k.utils.config_parsing import get_int_field
|
||||||
|
|
||||||
|
nn_config = {"foo": "bar", "baz": "quux"}
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
value = get_int_field(nn_config, "baz", "Baz")
|
||||||
|
|
||||||
|
def testGetPositiveIntField(self):
|
||||||
|
from nn4k.utils.config_parsing import get_positive_int_field
|
||||||
|
|
||||||
|
nn_config = {"foo": "bar", "baz": 1000}
|
||||||
|
value = get_positive_int_field(nn_config, "baz", "Baz")
|
||||||
|
self.assertEqual(value, 1000)
|
||||||
|
|
||||||
|
def testGetPositiveIntFieldNotPositive(self):
|
||||||
|
from nn4k.utils.config_parsing import get_positive_int_field
|
||||||
|
|
||||||
|
nn_config = {"foo": "bar", "baz": 0}
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
value = get_positive_int_field(nn_config, "baz", "Baz")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
49
python/nn4k/tests/utils/test_invoker_checking.py
Normal file
49
python/nn4k/tests/utils/test_invoker_checking.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
# Copyright 2023 Ant Group CO., Ltd.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class TestInvokerChecking(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
module nn4k.utils.invoker_checking unittest
|
||||||
|
"""
|
||||||
|
|
||||||
|
def testIsOpenAIInvoker(self):
|
||||||
|
from nn4k.utils.invoker_checking import is_openai_invoker
|
||||||
|
|
||||||
|
self.assertTrue(is_openai_invoker({"nn_name": "gpt-3.5-turbo"}))
|
||||||
|
self.assertTrue(is_openai_invoker({"nn_name": "gpt-4"}))
|
||||||
|
self.assertFalse(is_openai_invoker({"nn_name": "dummy"}))
|
||||||
|
|
||||||
|
self.assertTrue(is_openai_invoker({"openai_api_key": "EMPTY"}))
|
||||||
|
self.assertTrue(
|
||||||
|
is_openai_invoker({"openai_api_base": "http://localhost:38000/v1"})
|
||||||
|
)
|
||||||
|
self.assertTrue(is_openai_invoker({"openai_max_tokens": 1000}))
|
||||||
|
self.assertFalse(is_openai_invoker({"foo": "bar"}))
|
||||||
|
|
||||||
|
def testIsLocalInvoker(self):
|
||||||
|
import os
|
||||||
|
from nn4k.utils.invoker_checking import is_local_invoker
|
||||||
|
|
||||||
|
dir_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
self.assertFalse(is_local_invoker({"nn_name": dir_path}))
|
||||||
|
|
||||||
|
model_dir_path = os.path.join(dir_path, "test_model_dir")
|
||||||
|
self.assertTrue(is_local_invoker({"nn_name": model_dir_path}))
|
||||||
|
|
||||||
|
self.assertFalse(is_local_invoker({"nn_name": "/not_exists"}))
|
||||||
|
self.assertFalse(is_local_invoker({"foo": "bar"}))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
1
python/nn4k/tests/utils/test_model_dir/config.json
Normal file
1
python/nn4k/tests/utils/test_model_dir/config.json
Normal file
@ -0,0 +1 @@
|
|||||||
|
{}
|
||||||
Loading…
x
Reference in New Issue
Block a user