feat(nn4k): implement openai invoker and local hf executor (#57)

Co-authored-by: 基尔 <qy266141@antgroup.com>
Co-authored-by: didicout <julin.jl@antgroup.com>
This commit is contained in:
xionghuaidong 2024-01-06 12:12:12 +08:00 committed by GitHub
parent 22ea3ee395
commit 6c3f8584ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 1832 additions and 1 deletions

View File

@ -42,4 +42,4 @@ header:
# If you don't want to check dependencies' license compatibility, remove the following part # If you don't want to check dependencies' license compatibility, remove the following part
dependency: dependency:
files: files:
- pom.xml - pom.xml

3
python/nn4k/.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
/*.whl
/*.egg-info/
/build/

10
python/nn4k/LICENSE Normal file
View File

@ -0,0 +1,10 @@
Copyright 2023 Ant Group CO., Ltd.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
or implied.

2
python/nn4k/MANIFEST.in Normal file
View File

@ -0,0 +1,2 @@
recursive-include nn4k *
recursive-exclude nn4k/examples *

1
python/nn4k/NN4K_VERSION Normal file
View File

@ -0,0 +1 @@
0.0.2-beta1

View File

@ -0,0 +1,14 @@
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
__package_name__ = "openspg-nn4k"
__version__ = "0.0.2-beta1"

View File

@ -0,0 +1,43 @@
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
NN_NAME_KEY = "nn_name"
NN_NAME_TEXT = "NN model name"
NN_VERSION_KEY = "nn_version"
NN_VERSION_TEXT = "NN model version"
NN_VERSION_DEFAULT = "default"
NN_INVOKER_KEY = "nn_invoker"
NN_INVOKER_TEXT = "NN invoker"
NN_EXECUTOR_KEY = "nn_executor"
NN_EXECUTOR_TEXT = "NN executor"
NN_DEVICE_KEY = "device"
NN_TRUST_REMOTE_CODE_KEY = "trust_remote_code"
NN_OPENAI_MODEL_NAME_KEY = NN_NAME_KEY
NN_OPENAI_MODEL_NAME_TEXT = "openai model name"
NN_OPENAI_API_KEY_KEY = "openai_api_key"
NN_OPENAI_API_KEY_TEXT = "openai api key"
NN_OPENAI_API_BASE_KEY = "openai_api_base"
NN_OPENAI_API_BASE_TEXT = "openai api base"
NN_OPENAI_MAX_TOKENS_KEY = "openai_max_tokens"
NN_OPENAI_MAX_TOKENS_TEXT = "openai max tokens"
NN_OPENAI_GPT4_PREFIX = "gpt-4"
NN_OPENAI_GPT35_PREFIX = "gpt-3.5"
NN_LOCAL_HF_MODEL_CONFIG_FILE = "config.json"

View File

@ -0,0 +1,12 @@
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from nn4k.executor.base import NNExecutor, LLMExecutor

View File

@ -0,0 +1,161 @@
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from abc import ABC, abstractmethod
from typing import Union
class NNExecutor(ABC):
"""
Entry point of model execution in a certain pod.
"""
def __init__(self, init_args: dict, **kwargs):
self._init_args = init_args
self._kwargs = kwargs
self._model = None
self._tokenizer = None
@property
def init_args(self):
"""
Return the `init_args` passed to the executor constructor.
"""
return self._init_args
@property
def kwargs(self):
"""
Return the `kwargs` passed to the executor constructor.
"""
return self._kwargs
@property
def model(self):
"""
Return the model object managed by this executor.
:raises RuntimeError: if the model is not initialized yet
"""
if self._model is None:
message = "model is not initialized yet"
raise RuntimeError(message)
return self._model
@property
def tokenizer(self):
"""
Return the tokenizer object managed by this executor.
:raises RuntimeError: if the tokenizer is not initialized yet
"""
if self._tokenizer is None:
message = "tokenizer is not initialized yet"
raise RuntimeError(message)
return self._tokenizer
def execute_inference(self, args=None, **kwargs):
"""
The entry point of batch inference in a certain pod.
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not support batch inference."
)
def inference(self, data, args=None, **kwargs):
"""
The entry point of inference. Usually for local invokers or model services.
"""
raise NotImplementedError()
@abstractmethod
def load_model(self, args=None, mode=None, **kwargs):
"""
Implement model loading logic in derived executor classes.
Implementer should initialize `self._model` and `self._tokenizer`.
This method will be called by several entry methods in executors and invokers.
"""
raise NotImplementedError()
def warmup_inference(self, args=None, **kwargs):
"""
Implement model warming up logic for inference in derived executor classes.
"""
pass
@classmethod
@abstractmethod
def from_config(cls, nn_config: Union[str, dict]) -> "NNExecutor":
"""
Create an NN executor instance from `nn_config`.
This method is abstract, derived class must override it by either
creating executor instances or implementating dispatch logic.
:param nn_config: config to use, can be dictionary or path to a JSON file
:type nn_config: str or dict
:rtype: NNExecutor
"""
from nn4k.nnhub import NNHub
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
from nn4k.consts import NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT
from nn4k.utils.config_parsing import preprocess_config
from nn4k.utils.config_parsing import get_string_field
from nn4k.utils.class_importing import dynamic_import_class
nn_config = preprocess_config(nn_config)
nn_executor = nn_config.get(NN_EXECUTOR_KEY)
if nn_executor is not None:
nn_executor = get_string_field(nn_config, NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT)
executor_class = dynamic_import_class(nn_executor, NN_EXECUTOR_TEXT)
if not issubclass(executor_class, NNExecutor):
message = "%r is not an %s class" % (nn_executor, NN_EXECUTOR_TEXT)
raise RuntimeError(message)
executor = executor_class.from_config(nn_config)
return executor
nn_name = nn_config.get(NN_NAME_KEY)
if nn_name is not None:
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
nn_version = nn_config.get(NN_VERSION_KEY)
if nn_version is not None:
nn_version = get_string_field(nn_config, NN_VERSION_KEY, NN_VERSION_TEXT)
if nn_name is not None:
hub = NNHub.get_instance()
executor = hub.get_model_executor(nn_name, nn_version)
if executor is not None:
return executor
message = "can not create executor for NN config"
if nn_name is not None:
message += "; model: %r" % nn_name
if nn_version is not None:
message += ", version: %r" % nn_version
raise RuntimeError(message)
class LLMExecutor(NNExecutor):
def execute_sft(self, args=None, callbacks=None, **kwargs):
"""
The entry point of SFT execution in a certain pod.
"""
raise NotImplementedError(f"{self.__class__.__name__} does not support SFT.")
def execute_rl_tuning(self, args=None, callbacks=None, **kwargs):
"""
The entry point of SFT execution in a certain pod.
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not support RL-Tuning."
)

View File

@ -0,0 +1,17 @@
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from nn4k.executor import NNExecutor
class DeepKeExecutor(NNExecutor):
pass

View File

@ -0,0 +1,104 @@
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from typing import Union
from nn4k.executor import LLMExecutor
class HfLLMExecutor(LLMExecutor):
@classmethod
def from_config(cls, nn_config: dict) -> "HfLLMExecutor":
"""
Create an HfLLMExecutor instance from `nn_config`.
"""
executor = cls(nn_config)
return executor
def execute_sft(self, args=None, callbacks=None, **kwargs):
raise NotImplementedError(
f"{self.__class__.__name__} will support SFT in the next version."
)
def load_model(self, args=None, **kwargs):
import torch
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
from nn4k.consts import NN_DEVICE_KEY, NN_TRUST_REMOTE_CODE_KEY
from nn4k.utils.config_parsing import get_string_field
nn_config: dict = args or self.init_args
if self._model is None:
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
nn_version = nn_config.get(NN_VERSION_KEY)
if nn_version is not None:
nn_version = get_string_field(
nn_config, NN_VERSION_KEY, NN_VERSION_TEXT
)
model_path = nn_name
revision = nn_version
use_fast_tokenizer = False
device = nn_config.get(NN_DEVICE_KEY)
trust_remote_code = nn_config.get(NN_TRUST_REMOTE_CODE_KEY, False)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(
model_path,
use_fast=use_fast_tokenizer,
revision=revision,
trust_remote_code=trust_remote_code,
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
revision=revision,
trust_remote_code=trust_remote_code,
)
model.to(device)
self._tokenizer = tokenizer
self._model = model
def inference(
self,
data,
max_input_length: int = 1024,
max_output_length: int = 1024,
do_sample: bool = False,
**kwargs,
):
model = self.model
tokenizer = self.tokenizer
input_ids = tokenizer(
data,
padding=True,
return_token_type_ids=False,
return_tensors="pt",
truncation=True,
max_length=max_input_length,
).to(model.device)
output_ids = model.generate(
**input_ids,
max_new_tokens=max_output_length,
do_sample=do_sample,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
**kwargs,
)
outputs = [
tokenizer.decode(
output_id[len(input_ids["input_ids"][idx]) :], skip_special_tokens=True
)
for idx, output_id in enumerate(output_ids)
]
return outputs

View File

@ -0,0 +1,12 @@
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from nn4k.invoker.base import NNInvoker, LLMInvoker

View File

@ -0,0 +1,186 @@
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from abc import ABC, abstractmethod
from enum import Enum
from typing import Union
from nn4k.executor import LLMExecutor
class SubmitMode(Enum):
K8s = "k8s"
Docker = "docker"
class NNInvoker(ABC):
"""
Invoking Entry Interfaces for NN Models.
One NNInvoker object is for one NN Model.
- Interfaces starting with "submit_" means submitting a batch task to a remote execution engine.
- Interfaces starting with "remote_" means querying a remote service for some results.
- Interfaces starting with "local_" means running something locally.
Must call `warmup_local_model` before calling any local_xxx interface.
"""
def __init__(self, init_args: dict, **kwargs):
self._init_args = init_args
self._kwargs = kwargs
@property
def init_args(self):
"""
Return the `init_args` passed to the invoker constructor.
"""
return self._init_args
@property
def kwargs(self):
"""
Return the `kwargs` passed to the invoker constructor.
"""
return self._kwargs
@classmethod
@abstractmethod
def from_config(cls, nn_config: Union[str, dict]) -> "NNInvoker":
"""
Create an NN invoker instance from `nn_config`.
This method is abstract, derived class must override it by either
creating invoker instances or implementating dispatch logic.
:param nn_config: config to use, can be dictionary or path to a JSON file
:type nn_config: str or dict
:rtype: NNInvoker
:raises RuntimeError: if the NN config is not recognized
"""
from nn4k.nnhub import NNHub
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
from nn4k.consts import NN_INVOKER_KEY, NN_INVOKER_TEXT
from nn4k.utils.config_parsing import preprocess_config
from nn4k.utils.config_parsing import get_string_field
from nn4k.utils.class_importing import dynamic_import_class
nn_config = preprocess_config(nn_config)
nn_invoker = nn_config.get(NN_INVOKER_KEY)
if nn_invoker is not None:
nn_invoker = get_string_field(nn_config, NN_INVOKER_KEY, NN_INVOKER_TEXT)
invoker_class = dynamic_import_class(nn_invoker, NN_INVOKER_TEXT)
if not issubclass(invoker_class, NNInvoker):
message = "%r is not an %s class" % (nn_invoker, NN_INVOKER_TEXT)
raise RuntimeError(message)
invoker = invoker_class.from_config(nn_config)
return invoker
hub = NNHub.get_instance()
invoker = hub.get_invoker(nn_config)
if invoker is not None:
return invoker
nn_name = nn_config.get(NN_NAME_KEY)
if nn_name is not None:
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
nn_version = nn_config.get(NN_VERSION_KEY)
if nn_version is not None:
nn_version = get_string_field(nn_config, NN_VERSION_KEY, NN_VERSION_TEXT)
message = "can not create invoker for NN config"
if nn_name is not None:
message += "; model: %r" % nn_name
if nn_version is not None:
message += ", version: %r" % nn_version
raise RuntimeError(message)
def submit_inference(self, submit_mode: SubmitMode = SubmitMode.K8s):
"""
Submit remote batch inference execution.
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not support batch inference."
)
def remote_inference(self, input, **kwargs):
"""
Inference via existing remote services.
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not support remote inference."
)
def local_inference(self, data, **kwargs):
"""
Implement local inference in derived invoker classes.
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not support local inference."
)
def warmup_local_model(self):
"""
Implement local model warming up logic in derived invoker classes.
"""
pass
class LLMInvoker(NNInvoker):
def submit_sft(self, submit_mode: SubmitMode = SubmitMode.K8s):
"""
Submit remote SFT execution.
"""
raise NotImplementedError(f"{self.__class__.__name__} does not support SFT.")
def submit_rl_tuning(self, submit_mode: SubmitMode = SubmitMode.K8s):
"""
Submit remote RL-Tuning execution.
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not support RL-Tuning."
)
def local_inference(self, data, **kwargs):
"""
Implement local inference for local invoker.
"""
return self._nn_executor.inference(data, **kwargs)
def warmup_local_model(self):
"""
Implement local model warming up logic for local invoker.
"""
from nn4k.nnhub import NNHub
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
from nn4k.utils.config_parsing import get_string_field
nn_name = get_string_field(self.init_args, NN_NAME_KEY, NN_NAME_TEXT)
nn_version = self.init_args.get(NN_VERSION_KEY)
if nn_version is not None:
nn_version = get_string_field(
self.init_args, NN_VERSION_KEY, NN_VERSION_TEXT
)
hub = NNHub.get_instance()
executor = hub.get_model_executor(nn_name, nn_version)
if executor is None:
message = "model %r version %r " % (nn_name, nn_version)
message += "is not found in the model hub"
raise RuntimeError(message)
self._nn_executor: LLMExecutor = executor
self._nn_executor.load_model()
self._nn_executor.warmup_inference()
@classmethod
def from_config(cls, nn_config: dict) -> "LLMInvoker":
"""
Create an LLMInvoker instance from `nn_config`.
"""
invoker = cls(nn_config)
return invoker

View File

@ -0,0 +1,75 @@
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from typing import Optional
from nn4k.invoker import NNInvoker
class OpenAIInvoker(NNInvoker):
def __init__(self, nn_config: dict):
super().__init__(nn_config)
import openai
from nn4k.consts import NN_OPENAI_MODEL_NAME_KEY, NN_OPENAI_MODEL_NAME_TEXT
from nn4k.consts import NN_OPENAI_API_KEY_KEY, NN_OPENAI_API_KEY_TEXT
from nn4k.consts import NN_OPENAI_API_BASE_KEY, NN_OPENAI_API_BASE_TEXT
from nn4k.consts import NN_OPENAI_MAX_TOKENS_KEY, NN_OPENAI_MAX_TOKENS_TEXT
from nn4k.utils.config_parsing import get_string_field
from nn4k.utils.config_parsing import get_positive_int_field
self.openai_model_name = get_string_field(
self.init_args, NN_OPENAI_MODEL_NAME_KEY, NN_OPENAI_MODEL_NAME_TEXT
)
self.openai_api_key = get_string_field(
self.init_args, NN_OPENAI_API_KEY_KEY, NN_OPENAI_API_KEY_TEXT
)
self.openai_api_base = get_string_field(
self.init_args, NN_OPENAI_API_BASE_KEY, NN_OPENAI_API_BASE_TEXT
)
self.openai_max_tokens = get_positive_int_field(
self.init_args, NN_OPENAI_MAX_TOKENS_KEY, NN_OPENAI_MAX_TOKENS_TEXT
)
openai.api_key = self.openai_api_key
openai.api_base = self.openai_api_base
@classmethod
def from_config(cls, nn_config: dict) -> "OpenAIInvoker":
invoker = cls(nn_config)
return invoker
def _create_prompt(self, input, **kwargs):
if isinstance(input, list):
prompt = input
else:
prompt = [input]
return prompt
def _create_output(self, input, prompt, completion, **kwargs):
output = [choice.text for choice in completion.choices]
return output
def remote_inference(
self, input, max_output_length: Optional[int] = None, **kwargs
):
import openai
if max_output_length is None:
max_output_length = self.openai_max_tokens
prompt = self._create_prompt(input, **kwargs)
completion = openai.Completion.create(
model=self.openai_model_name,
prompt=prompt,
max_tokens=max_output_length,
)
output = self._create_output(input, prompt, completion, **kwargs)
return output

View File

@ -0,0 +1,165 @@
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from abc import ABC, abstractmethod
from typing import Optional, Union, Tuple, Type
from nn4k.executor import NNExecutor
class NNHub(ABC):
_hub_instance = None
@staticmethod
def get_instance() -> "NNHub":
"""
Get the NNHub instance. If the instance is not initialized, create a stub `SimpleNNHub`.
"""
if NNHub._hub_instance is None:
NNHub._hub_instance = SimpleNNHub()
return NNHub._hub_instance
@abstractmethod
def publish(
self,
model_executor: Union[NNExecutor, Tuple[Type[NNExecutor], tuple, dict, tuple]],
name: str,
version: str = None,
) -> str:
"""
Publish a model(executor) to hub.
:param model_executor: An NNExecutor object, which is pickleable.
Or a tuple of (class, init_args, kwargs, weight_ids) for creating an NNExecutor,
while all these 4 augments are pickleable.
:param str name: The name of a model, like `llama2`. We do not have a `namespace`.
Use a joined name like `alibaba/qwen` to support such features.
:param str version: Optional. Auto generate a version if this param is not given.
:return: The published model version.
:rtype: str
"""
pass
@abstractmethod
def get_model_executor(
self, name: str, version: str = None
) -> Optional[NNExecutor]:
"""
Get an NNExecutor instance from Hub.
:param str name: The name of a model.
:param str version: The version of a model. Get default version of a model if this param is not given.
:return: The ModelExecutor Instance. None for NotFound.
:rtype: Optional[NNExecutor]
"""
pass
@abstractmethod
def get_invoker(self, nn_config: dict) -> Optional["NNInvoker"]:
"""
Get an NNExecutor instance from Hub.
:param dict nn_config: The config dictionary.
:return: The NNExecutor Instance. None for NotFound.
:rtype: Optional[NNInvoker]
"""
pass
def start_service(self, name: str, version: str, service_id: str = None, **kwargs):
raise NotImplementedError("This Hub does not support starting model service.")
def stop_service(self, name: str, version: str, service_id: str = None, **kwargs):
raise NotImplementedError("This Hub does not support stopping model service.")
def get_service(self, name: str, version: str, service_id: str = None):
raise NotImplementedError("This Hub does not support model services.")
class SimpleNNHub(NNHub):
def __init__(self) -> None:
super().__init__()
self._model_executors = {}
def _add_executor(
self,
executor: Union[NNExecutor, Tuple[Type[NNExecutor], tuple, dict, tuple]],
name: str,
version: str = None,
):
from nn4k.consts import NN_VERSION_DEFAULT
if version is None:
version = NN_VERSION_DEFAULT
if self._model_executors.get(name) is None:
self._model_executors[name] = {version: executor}
else:
self._model_executors[name][version] = executor
def publish(
self, model_executor: NNExecutor, name: str, version: str = None
) -> str:
from nn4k.consts import NN_VERSION_DEFAULT
print(
"WARNING: You are using SimpleNNHub which can only maintain models in memory without data persistence!"
)
if version is None:
version = NN_VERSION_DEFAULT
self._add_executor(model_executor, name, version)
return version
def _create_model_executor(self, cls, init_args, kwargs, weights):
raise NotImplementedError()
def get_model_executor(
self, name: str, version: str = None
) -> Optional[NNExecutor]:
if self._model_executors.get(name) is None:
return None
executor = self._model_executors.get(name).get(version)
if isinstance(executor, NNExecutor):
return executor
cls, init_args, kwargs, weights = executor
executor = self._create_model_executor(cls, init_args, kwargs, weights)
return executor
def _add_local_executor(self, nn_config):
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
from nn4k.executor.hugging_face import HfLLMExecutor
from nn4k.utils.config_parsing import get_string_field
executor = HfLLMExecutor.from_config(nn_config)
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
nn_version = nn_config.get(NN_VERSION_KEY)
if nn_version is not None:
nn_version = get_string_field(nn_config, NN_VERSION_KEY, NN_VERSION_TEXT)
self.publish(executor, nn_name, nn_version)
def get_invoker(self, nn_config: dict) -> Optional["NNInvoker"]:
from nn4k.invoker import LLMInvoker
from nn4k.invoker.openai import OpenAIInvoker
from nn4k.utils.invoker_checking import is_openai_invoker
from nn4k.utils.invoker_checking import is_local_invoker
if is_openai_invoker(nn_config):
invoker = OpenAIInvoker.from_config(nn_config)
return invoker
if is_local_invoker(nn_config):
invoker = LLMInvoker.from_config(nn_config)
self._add_local_executor(nn_config)
return invoker
return None

View File

@ -0,0 +1,10 @@
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

View File

@ -0,0 +1,57 @@
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import importlib
from typing import Tuple
def split_module_class_name(name: str, text: str) -> Tuple[str, str]:
"""
Split `name` as module name and class name pair.
:param name: fully qualified class name, e.g. ``foo.bar.MyClass``
:type name: str
:param text: describe the kind of the class, used in the exception message
:type text: str
:rtype: Tuple[str, str]
:raises RuntimeError: if `name` is not a fully qualified class name
"""
i = name.rfind(".")
if i == -1:
message = "invalid %s class name: %s" % (text, name)
raise RuntimeError(message)
module_name = name[:i]
class_name = name[i + 1 :]
return module_name, class_name
def dynamic_import_class(name: str, text: str):
"""
Import the class specified by `name` dyanmically.
:param name: fully qualified class name, e.g. ``foo.bar.MyClass``
:type name: str
:param text: describe the kind of the class, use in the exception message
:type text: str
:raises RuntimeError: if `name` is not a fully qualified class name, or
the class is not in the module specified by `name`
:raises ModuleNotFoundError: the module specified by `name` is not found
"""
module_name, class_name = split_module_class_name(name, text)
module = importlib.import_module(module_name)
class_ = getattr(module, class_name, None)
if class_ is None:
message = "class %r not found in module %r" % (class_name, module_name)
raise RuntimeError(message)
if not isinstance(class_, type):
message = "%r is not a class" % (name,)
raise RuntimeError(message)
return class_

View File

@ -0,0 +1,118 @@
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
from typing import Any
from typing import Union
def preprocess_config(nn_config: Union[str, dict]) -> dict:
"""
Preprocess config `nn_config` into a dictionary.
* If `nn_config` is already a dictionary, return it as is.
* If `nn_config` is a string, decode it as a JSON file.
:param nn_config: config to be preprocessed
:type nn_config: str or dict
:return: `nn_config` or `nn_config` decoded as JSON
:rtype: dict
:raises ValueError: if cannot decode config file specified by
`nn_config` as JSON
"""
try:
if isinstance(nn_config, str):
with open(nn_config, "r") as f:
nn_config = json.load(f)
except:
raise ValueError("cannot decode config file")
return nn_config
def get_field(nn_config: dict, name: str, text: str) -> Any:
"""
Get the value of the field specified by `name` from the configuration
dictionary `nn_config`.
:param str name: name of the field
:param str name: descriptive text of the name of the field
:return: value of the field
:rtype: Any
:raises ValueError: if the field is not specified in `nn_config`
"""
value = nn_config.get(name)
if value is None:
message = "%s %r not found" % (text, name)
raise ValueError(message)
return value
def get_string_field(nn_config: dict, name: str, text: str) -> str:
"""
Get the value of the string field specified by `name` from the
configuration dictionary `nn_config`.
:param str name: name of the field
:param str name: descriptive text of the name of the field
:return: value of the field
:rtype: str
:raises ValueError: if the field is not specified in `nn_config`
:raises TypeError: if the value of the field is not a string
"""
value = get_field(nn_config, name, text)
if not isinstance(value, str):
message = "%s %r must be string; " % (text, name)
message += "%r is invalid" % (value,)
raise TypeError(message)
return value
def get_int_field(nn_config: dict, name: str, text: str) -> int:
"""
Get the value of the integer field specified by `name` from the
configuration dictionary `nn_config`.
:param str name: name of the field
:param str name: descriptive text of the name of the field
:return: value of the field
:rtype: int
:raises ValueError: if the field is not specified in `nn_config`
:raises TypeError: if the value of the field is not an integer
"""
value = get_field(nn_config, name, text)
if not isinstance(value, int):
message = "%s %r must be integer; " % (text, name)
message += "%r is invalid" % (value,)
raise TypeError(message)
return value
def get_positive_int_field(nn_config: dict, name: str, text: str) -> int:
"""
Get the value of the positive integer field specified by `name`
from the configuration dictionary `nn_config`.
:param str name: name of the field
:param str name: descriptive text of the name of the field
:return: value of the field
:rtype: int
:raises ValueError: if the field is not specified in `nn_config`, or the
value of the field is not a positive integer
:raises TypeError: if the value of the field is not an integer
"""
value = get_int_field(nn_config, name, text)
if value <= 0:
message = "%s %r must be positive integer; " % (text, name)
message += "%r is invalid" % (value,)
raise ValueError(message)
return value

View File

@ -0,0 +1,62 @@
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
def is_openai_invoker(nn_config: dict) -> bool:
"""
Check whether `nn_config` specifies OpenAI invoker.
:type nn_config: dict
:rtype: bool
"""
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
from nn4k.consts import NN_OPENAI_API_KEY_KEY
from nn4k.consts import NN_OPENAI_API_BASE_KEY
from nn4k.consts import NN_OPENAI_MAX_TOKENS_KEY
from nn4k.consts import NN_OPENAI_GPT4_PREFIX
from nn4k.consts import NN_OPENAI_GPT35_PREFIX
from nn4k.utils.config_parsing import get_string_field
nn_name = nn_config.get(NN_NAME_KEY)
if nn_name is not None:
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
if nn_name.startswith(NN_OPENAI_GPT4_PREFIX) or nn_name.startswith(
NN_OPENAI_GPT35_PREFIX
):
return True
keys = (NN_OPENAI_API_KEY_KEY, NN_OPENAI_API_BASE_KEY, NN_OPENAI_MAX_TOKENS_KEY)
for key in keys:
if key in nn_config:
return True
return False
def is_local_invoker(nn_config: dict) -> bool:
"""
Check whether `nn_config` specifies local invoker.
:type nn_config: dict
:rtype: bool
"""
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
from nn4k.consts import NN_LOCAL_HF_MODEL_CONFIG_FILE
from nn4k.utils.config_parsing import get_string_field
nn_name = nn_config.get(NN_NAME_KEY)
if nn_name is not None:
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
if os.path.isdir(nn_name):
file_path = os.path.join(nn_name, NN_LOCAL_HF_MODEL_CONFIG_FILE)
if os.path.isfile(file_path):
return True
return False

View File

@ -0,0 +1 @@
openai<1

78
python/nn4k/setup.py Normal file
View File

@ -0,0 +1,78 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
from setuptools import setup, find_packages
package_name = "openspg-nn4k"
# version
cwd = os.path.abspath(os.path.dirname(__file__))
with open(os.path.join(cwd, "NN4K_VERSION"), "r") as rf:
version = rf.readline().strip("\n").strip()
# license
license = ""
with open(os.path.join(cwd, "LICENSE"), "r") as rf:
line = rf.readline()
while line:
line = line.strip()
if line:
license += "# " + line + "\n"
else:
license += "#\n"
line = rf.readline()
# Generate nn4k.__init__.py
with open(os.path.join(cwd, "nn4k/__init__.py"), "w") as wf:
content = f"""{license}
__package_name__ = "{package_name}"
__version__ = "{version}"
"""
wf.write(content)
setup(
name=package_name,
version=version,
description="nn4k",
url="https://github.com/OpenSPG/openspg",
packages=find_packages(
where=".",
exclude=[
".*test.py",
"*_test.py",
"*_debug.py",
"*.txt",
"tests",
"tests.*",
"configs",
"configs.*",
"test",
"test.*",
"*.tests",
"*.tests.*",
"*.pyc",
],
),
python_requires=">=3.8",
install_requires=[
r.strip()
for r in open("requirements.txt", "r")
if not r.strip().startswith("#")
],
include_package_data=True,
package_data={
"bin": ["*"],
},
)

View File

@ -0,0 +1,10 @@
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

View File

@ -0,0 +1,10 @@
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

View File

@ -0,0 +1,69 @@
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
import sys
import unittest
class TestBaseExecutor(unittest.TestCase):
"""
NNExecutor and LLMExecutor unittest
"""
def setUp(self):
# for importing test_stub.py
dir_path = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, dir_path)
from nn4k.nnhub import NNHub
from test_stub import StubHub
NNHub._hub_instance = StubHub()
def tearDown(self):
from nn4k.nnhub import NNHub
sys.path.pop(0)
NNHub._hub_instance = None
def testCustomNNExecutor(self):
from nn4k.executor import NNExecutor
from test_stub import StubExecutor
nn_config = {"nn_executor": "test_stub.StubExecutor"}
executor = NNExecutor.from_config(nn_config)
self.assertTrue(isinstance(executor, StubExecutor))
self.assertEqual(executor.init_args, nn_config)
self.assertEqual(executor.kwargs, {})
with self.assertRaises(RuntimeError):
executor = NNExecutor.from_config({"nn_executor": "test_stub.NotExecutor"})
def testHubExecutor(self):
from nn4k.executor import NNExecutor
from test_stub import StubExecutor
nn_config = {"nn_name": "test_stub", "nn_version": "default"}
executor = NNExecutor.from_config(nn_config)
self.assertTrue(isinstance(executor, StubExecutor))
self.assertEqual(executor.init_args, nn_config)
self.assertEqual(executor.kwargs, {"test_stub_executor": True})
def testExecutorNotExists(self):
from nn4k.executor import NNExecutor
with self.assertRaises(RuntimeError):
executor = NNExecutor.from_config({"nn_name": "not_exists"})
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,57 @@
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import sys
import unittest
import unittest.mock
from nn4k.executor.hugging_face import HfLLMExecutor
class TestHfLLMExecutor(unittest.TestCase):
"""
HfLLMExecutor unittest
"""
def setUp(self):
self._saved_torch = sys.modules.get("torch")
self._mocked_torch = unittest.mock.MagicMock()
sys.modules["torch"] = self._mocked_torch
self._saved_transformers = sys.modules.get("transformers")
self._mocked_transformers = unittest.mock.MagicMock()
sys.modules["transformers"] = self._mocked_transformers
def tearDown(self):
del sys.modules["torch"]
if self._saved_torch is not None:
sys.modules["torch"] = self._saved_torch
del sys.modules["transformers"]
if self._saved_transformers is not None:
sys.modules["transformers"] = self._saved_transformers
def testHfLLMExecutor(self):
nn_config = {
"nn_name": "/opt/test_model_dir",
"nn_version": "default",
}
executor = HfLLMExecutor.from_config(nn_config)
executor.load_model()
executor.inference("input")
self._mocked_transformers.AutoTokenizer.from_pretrained.assert_called()
self._mocked_transformers.AutoModelForCausalLM.from_pretrained.assert_called()
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,52 @@
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from typing import Optional
from nn4k.executor import NNExecutor, LLMExecutor
from nn4k.nnhub import SimpleNNHub
class StubExecutor(LLMExecutor):
def load_model(self, args=None, mode=None, **kwargs):
pass
def warmup_inference(self, args=None, **kwargs):
pass
def inference(self, data, args=None, **kwargs):
pass
@classmethod
def from_config(cls, nn_config: dict) -> "StubExecutor":
"""
Create a StubExecutor instance from `nn_config`.
"""
executor = cls(nn_config)
return executor
class NotExecutor:
pass
class StubHub(SimpleNNHub):
def get_model_executor(
self, name: str, version: str = None
) -> Optional[NNExecutor]:
if name == "test_stub":
if version is None:
version = "default"
executor = StubExecutor(
{"nn_name": name, "nn_version": version}, test_stub_executor=True
)
return executor
return super().get_model_executor(name, version)

View File

@ -0,0 +1,10 @@
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

View File

@ -0,0 +1,84 @@
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
import sys
import unittest
class TestBaseInvoker(unittest.TestCase):
"""
NNInvoker and LLMInvoker unittest
"""
def setUp(self):
# for importing test_stub.py
dir_path = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, dir_path)
from nn4k.nnhub import NNHub
from test_stub import StubHub
NNHub._hub_instance = StubHub()
def tearDown(self):
from nn4k.nnhub import NNHub
sys.path.pop(0)
NNHub._hub_instance = None
def testCustomNNInvoker(self):
from nn4k.invoker import NNInvoker
from test_stub import StubInvoker
nn_config = {"nn_invoker": "test_stub.StubInvoker"}
invoker = NNInvoker.from_config(nn_config)
self.assertTrue(isinstance(invoker, StubInvoker))
self.assertEqual(invoker.init_args, nn_config)
self.assertEqual(invoker.kwargs, {})
with self.assertRaises(RuntimeError):
invoker = NNInvoker.from_config({"nn_invoker": "test_stub.NotInvoker"})
def testHubInvoker(self):
from nn4k.invoker import NNInvoker
from test_stub import StubInvoker
nn_config = {"nn_name": "test_stub"}
invoker = NNInvoker.from_config(nn_config)
self.assertTrue(isinstance(invoker, StubInvoker))
self.assertEqual(invoker.init_args, nn_config)
self.assertEqual(invoker.kwargs, {"test_stub_invoker": True})
def testInvokerNotExists(self):
from nn4k.invoker import NNInvoker
with self.assertRaises(RuntimeError):
invoker = NNInvoker.from_config({"nn_name": "not_exists"})
def testLocalInvoker(self):
from nn4k.invoker import NNInvoker
from test_stub import StubInvoker
nn_config = {"nn_name": "test_stub"}
invoker = NNInvoker.from_config(nn_config)
self.assertTrue(isinstance(invoker, StubInvoker))
self.assertEqual(invoker.init_args, nn_config)
self.assertEqual(invoker.kwargs, {"test_stub_invoker": True})
invoker.warmup_local_model()
invoker._nn_executor.inference_result = "inference result"
result = invoker.local_inference("input")
self.assertEqual(result, invoker._nn_executor.inference_result)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,70 @@
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import sys
import unittest
import unittest.mock
from dataclasses import dataclass
from nn4k.invoker import NNInvoker
@dataclass
class MockCompletion:
choices: list
@dataclass
class MockChoice:
text: str
class TestOpenAIInvoker(unittest.TestCase):
"""
OpenAIInvoker unittest
"""
def setUp(self):
self._saved_openai = sys.modules.get("openai")
self._mocked_openai = unittest.mock.MagicMock()
sys.modules["openai"] = self._mocked_openai
def tearDown(self):
del sys.modules["openai"]
if self._saved_openai is not None:
sys.modules["openai"] = self._saved_openai
def testOpenAIInvoker(self):
nn_config = {
"nn_name": "gpt-3.5-turbo",
"openai_api_key": "EMPTY",
"openai_api_base": "http://localhost:38080/v1",
"openai_max_tokens": 2000,
}
invoker = NNInvoker.from_config(nn_config)
self.assertEqual(invoker.init_args, nn_config)
self.assertEqual(self._mocked_openai.api_key, nn_config["openai_api_key"])
self.assertEqual(self._mocked_openai.api_base, nn_config["openai_api_base"])
mock_completion = MockCompletion(choices=[MockChoice("a dog named Bolt ...")])
self._mocked_openai.Completion.create.return_value = mock_completion
result = invoker.remote_inference("Long long ago, ")
self._mocked_openai.Completion.create.assert_called_with(
prompt=["Long long ago, "],
model=nn_config["nn_name"],
max_tokens=nn_config["openai_max_tokens"],
)
self.assertEqual(result, [mock_completion.choices[0].text])
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,70 @@
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from typing import Optional
from nn4k.invoker import NNInvoker, LLMInvoker
from nn4k.executor import NNExecutor
from nn4k.nnhub import SimpleNNHub
class StubInvoker(LLMInvoker):
@classmethod
def from_config(cls, nn_config: dict) -> "StubInvoker":
"""
Create a StubInvoker instance from `nn_config`.
"""
invoker = cls(nn_config)
return invoker
class NotInvoker:
pass
class StubExecutor(NNExecutor):
def load_model(self, args=None, mode=None, **kwargs):
self.load_model_called = True
def warmup_inference(self, args=None, **kwargs):
self.warmup_inference_called = True
def inference(self, data, args=None, **kwargs):
return self.inference_result
@classmethod
def from_config(cls, nn_config: dict) -> "StubExecutor":
"""
Create a StubExecutor instance from `nn_config`.
"""
executor = cls(nn_config)
return executor
class StubHub(SimpleNNHub):
def get_invoker(self, nn_config: dict) -> Optional[NNInvoker]:
nn_name = nn_config.get("nn_name")
if nn_name is not None and nn_name == "test_stub":
invoker = StubInvoker(nn_config, test_stub_invoker=True)
return invoker
return super().get_invoker(nn_config)
def get_model_executor(
self, name: str, version: str = None
) -> Optional[NNExecutor]:
if name == "test_stub":
if version is None:
version = "default"
executor = StubExecutor(
{"nn_name": name, "nn_version": version}, test_stub_executor=True
)
return executor
return super().get_model_executor(name, version)

View File

@ -0,0 +1,10 @@
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

View File

@ -0,0 +1,34 @@
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import unittest
class TestBaseNNHub(unittest.TestCase):
"""
NNHub and SimpleNNHub unittest
The interface and implementation of NNHub may be revised later,
then unittests will be added.
"""
def setUp(self):
pass
def tearDown(self):
pass
def testBaseNNHub(self):
pass
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,10 @@
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

View File

@ -0,0 +1,58 @@
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import unittest
class TestClassImporting(unittest.TestCase):
"""
module nn4k.utils.class_importing unittest
"""
def testSplitModuleClassName(self):
from nn4k.utils.class_importing import split_module_class_name
pair = split_module_class_name("foo.bar.Baz", "test")
self.assertEqual(pair, ("foo.bar", "Baz"))
def testSplitModuleClassNameInvalid(self):
from nn4k.utils.class_importing import split_module_class_name
with self.assertRaises(RuntimeError):
pair = split_module_class_name("foo", "test")
def testDynamicImportClass(self):
from nn4k.utils.class_importing import dynamic_import_class
class_ = dynamic_import_class("unittest.TestCase", "test")
self.assertEqual(class_, unittest.TestCase)
def testDynamicImportClassModuleNotFound(self):
from nn4k.utils.class_importing import dynamic_import_class
with self.assertRaises(ModuleNotFoundError):
class_ = dynamic_import_class("not_exists.ClassName", "test")
def testDynamicImportClassClassNotFound(self):
from nn4k.utils.class_importing import dynamic_import_class
with self.assertRaises(RuntimeError):
class_ = dynamic_import_class("unittest.NotExists", "test")
def testDynamicImportClassNotClass(self):
from nn4k.utils.class_importing import dynamic_import_class
with self.assertRaises(RuntimeError):
class_ = dynamic_import_class("unittest.mock", "test")
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,3 @@
{
"foo": "bar"
}

View File

@ -0,0 +1,103 @@
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import unittest
class TestConfigParsing(unittest.TestCase):
"""
module nn4k.utils.config_parsing unittest
"""
def testPreprocessConfigFile(self):
import os
from nn4k.utils.config_parsing import preprocess_config
dir_path = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(dir_path, "test_config.json")
nn_config = preprocess_config(file_path)
self.assertEqual(nn_config, {"foo": "bar"})
def testPreprocessConfigFileNotExists(self):
import os
from nn4k.utils.config_parsing import preprocess_config
dir_path = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(dir_path, "not_exists.json")
with self.assertRaises(ValueError):
nn_config = preprocess_config(file_path)
def testPreprocessConfigDict(self):
from nn4k.utils.config_parsing import preprocess_config
conf = {"foo": "bar"}
nn_config = preprocess_config(conf)
self.assertEqual(nn_config, conf)
def testGetField(self):
from nn4k.utils.config_parsing import get_field
nn_config = {"foo": "bar"}
value = get_field(nn_config, "foo", "Foo")
self.assertEqual(value, "bar")
def testGetFieldNotExists(self):
from nn4k.utils.config_parsing import get_field
nn_config = {"foo": "bar"}
with self.assertRaises(ValueError):
value = get_field(nn_config, "not_exists", "not exists")
def testGetStringField(self):
from nn4k.utils.config_parsing import get_string_field
nn_config = {"foo": "bar"}
value = get_string_field(nn_config, "foo", "Foo")
self.assertEqual(value, "bar")
def testGetStringFieldNotString(self):
from nn4k.utils.config_parsing import get_string_field
nn_config = {"foo": "bar", "baz": True}
with self.assertRaises(TypeError):
value = get_string_field(nn_config, "baz", "Baz")
def testGetIntField(self):
from nn4k.utils.config_parsing import get_int_field
nn_config = {"foo": "bar", "baz": 1000}
value = get_int_field(nn_config, "baz", "Baz")
self.assertEqual(value, 1000)
def testGetIntFieldNotInteger(self):
from nn4k.utils.config_parsing import get_int_field
nn_config = {"foo": "bar", "baz": "quux"}
with self.assertRaises(TypeError):
value = get_int_field(nn_config, "baz", "Baz")
def testGetPositiveIntField(self):
from nn4k.utils.config_parsing import get_positive_int_field
nn_config = {"foo": "bar", "baz": 1000}
value = get_positive_int_field(nn_config, "baz", "Baz")
self.assertEqual(value, 1000)
def testGetPositiveIntFieldNotPositive(self):
from nn4k.utils.config_parsing import get_positive_int_field
nn_config = {"foo": "bar", "baz": 0}
with self.assertRaises(ValueError):
value = get_positive_int_field(nn_config, "baz", "Baz")
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,49 @@
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import unittest
class TestInvokerChecking(unittest.TestCase):
"""
module nn4k.utils.invoker_checking unittest
"""
def testIsOpenAIInvoker(self):
from nn4k.utils.invoker_checking import is_openai_invoker
self.assertTrue(is_openai_invoker({"nn_name": "gpt-3.5-turbo"}))
self.assertTrue(is_openai_invoker({"nn_name": "gpt-4"}))
self.assertFalse(is_openai_invoker({"nn_name": "dummy"}))
self.assertTrue(is_openai_invoker({"openai_api_key": "EMPTY"}))
self.assertTrue(
is_openai_invoker({"openai_api_base": "http://localhost:38000/v1"})
)
self.assertTrue(is_openai_invoker({"openai_max_tokens": 1000}))
self.assertFalse(is_openai_invoker({"foo": "bar"}))
def testIsLocalInvoker(self):
import os
from nn4k.utils.invoker_checking import is_local_invoker
dir_path = os.path.dirname(os.path.abspath(__file__))
self.assertFalse(is_local_invoker({"nn_name": dir_path}))
model_dir_path = os.path.join(dir_path, "test_model_dir")
self.assertTrue(is_local_invoker({"nn_name": model_dir_path}))
self.assertFalse(is_local_invoker({"nn_name": "/not_exists"}))
self.assertFalse(is_local_invoker({"foo": "bar"}))
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1 @@
{}