mirror of
				https://github.com/OpenSPG/openspg.git
				synced 2025-11-03 19:45:23 +00:00 
			
		
		
		
	feat(nn4k): implement openai invoker and local hf executor (#57)
Co-authored-by: 基尔 <qy266141@antgroup.com> Co-authored-by: didicout <julin.jl@antgroup.com>
This commit is contained in:
		
							parent
							
								
									22ea3ee395
								
							
						
					
					
						commit
						6c3f8584ec
					
				@ -42,4 +42,4 @@ header:
 | 
			
		||||
# If you don't want to check dependencies' license compatibility, remove the following part
 | 
			
		||||
dependency:
 | 
			
		||||
  files:
 | 
			
		||||
    - pom.xml
 | 
			
		||||
    - pom.xml
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										3
									
								
								python/nn4k/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								python/nn4k/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,3 @@
 | 
			
		||||
/*.whl
 | 
			
		||||
/*.egg-info/
 | 
			
		||||
/build/
 | 
			
		||||
							
								
								
									
										10
									
								
								python/nn4k/LICENSE
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								python/nn4k/LICENSE
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,10 @@
 | 
			
		||||
Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
or implied.
 | 
			
		||||
							
								
								
									
										2
									
								
								python/nn4k/MANIFEST.in
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								python/nn4k/MANIFEST.in
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,2 @@
 | 
			
		||||
recursive-include nn4k *
 | 
			
		||||
recursive-exclude nn4k/examples *
 | 
			
		||||
							
								
								
									
										1
									
								
								python/nn4k/NN4K_VERSION
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								python/nn4k/NN4K_VERSION
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1 @@
 | 
			
		||||
0.0.2-beta1
 | 
			
		||||
							
								
								
									
										14
									
								
								python/nn4k/nn4k/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								python/nn4k/nn4k/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,14 @@
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__package_name__ = "openspg-nn4k"
 | 
			
		||||
__version__ = "0.0.2-beta1"
 | 
			
		||||
							
								
								
									
										43
									
								
								python/nn4k/nn4k/consts/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								python/nn4k/nn4k/consts/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,43 @@
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
 | 
			
		||||
NN_NAME_KEY = "nn_name"
 | 
			
		||||
NN_NAME_TEXT = "NN model name"
 | 
			
		||||
 | 
			
		||||
NN_VERSION_KEY = "nn_version"
 | 
			
		||||
NN_VERSION_TEXT = "NN model version"
 | 
			
		||||
NN_VERSION_DEFAULT = "default"
 | 
			
		||||
 | 
			
		||||
NN_INVOKER_KEY = "nn_invoker"
 | 
			
		||||
NN_INVOKER_TEXT = "NN invoker"
 | 
			
		||||
 | 
			
		||||
NN_EXECUTOR_KEY = "nn_executor"
 | 
			
		||||
NN_EXECUTOR_TEXT = "NN executor"
 | 
			
		||||
 | 
			
		||||
NN_DEVICE_KEY = "device"
 | 
			
		||||
NN_TRUST_REMOTE_CODE_KEY = "trust_remote_code"
 | 
			
		||||
 | 
			
		||||
NN_OPENAI_MODEL_NAME_KEY = NN_NAME_KEY
 | 
			
		||||
NN_OPENAI_MODEL_NAME_TEXT = "openai model name"
 | 
			
		||||
 | 
			
		||||
NN_OPENAI_API_KEY_KEY = "openai_api_key"
 | 
			
		||||
NN_OPENAI_API_KEY_TEXT = "openai api key"
 | 
			
		||||
 | 
			
		||||
NN_OPENAI_API_BASE_KEY = "openai_api_base"
 | 
			
		||||
NN_OPENAI_API_BASE_TEXT = "openai api base"
 | 
			
		||||
 | 
			
		||||
NN_OPENAI_MAX_TOKENS_KEY = "openai_max_tokens"
 | 
			
		||||
NN_OPENAI_MAX_TOKENS_TEXT = "openai max tokens"
 | 
			
		||||
 | 
			
		||||
NN_OPENAI_GPT4_PREFIX = "gpt-4"
 | 
			
		||||
NN_OPENAI_GPT35_PREFIX = "gpt-3.5"
 | 
			
		||||
 | 
			
		||||
NN_LOCAL_HF_MODEL_CONFIG_FILE = "config.json"
 | 
			
		||||
							
								
								
									
										12
									
								
								python/nn4k/nn4k/executor/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								python/nn4k/nn4k/executor/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,12 @@
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
 | 
			
		||||
from nn4k.executor.base import NNExecutor, LLMExecutor
 | 
			
		||||
							
								
								
									
										161
									
								
								python/nn4k/nn4k/executor/base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										161
									
								
								python/nn4k/nn4k/executor/base.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,161 @@
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
 | 
			
		||||
from abc import ABC, abstractmethod
 | 
			
		||||
from typing import Union
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NNExecutor(ABC):
 | 
			
		||||
    """
 | 
			
		||||
    Entry point of model execution in a certain pod.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, init_args: dict, **kwargs):
 | 
			
		||||
        self._init_args = init_args
 | 
			
		||||
        self._kwargs = kwargs
 | 
			
		||||
        self._model = None
 | 
			
		||||
        self._tokenizer = None
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def init_args(self):
 | 
			
		||||
        """
 | 
			
		||||
        Return the `init_args` passed to the executor constructor.
 | 
			
		||||
        """
 | 
			
		||||
        return self._init_args
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def kwargs(self):
 | 
			
		||||
        """
 | 
			
		||||
        Return the `kwargs` passed to the executor constructor.
 | 
			
		||||
        """
 | 
			
		||||
        return self._kwargs
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def model(self):
 | 
			
		||||
        """
 | 
			
		||||
        Return the model object managed by this executor.
 | 
			
		||||
 | 
			
		||||
        :raises RuntimeError: if the model is not initialized yet
 | 
			
		||||
        """
 | 
			
		||||
        if self._model is None:
 | 
			
		||||
            message = "model is not initialized yet"
 | 
			
		||||
            raise RuntimeError(message)
 | 
			
		||||
        return self._model
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def tokenizer(self):
 | 
			
		||||
        """
 | 
			
		||||
        Return the tokenizer object managed by this executor.
 | 
			
		||||
 | 
			
		||||
        :raises RuntimeError: if the tokenizer is not initialized yet
 | 
			
		||||
        """
 | 
			
		||||
        if self._tokenizer is None:
 | 
			
		||||
            message = "tokenizer is not initialized yet"
 | 
			
		||||
            raise RuntimeError(message)
 | 
			
		||||
        return self._tokenizer
 | 
			
		||||
 | 
			
		||||
    def execute_inference(self, args=None, **kwargs):
 | 
			
		||||
        """
 | 
			
		||||
        The entry point of batch inference in a certain pod.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError(
 | 
			
		||||
            f"{self.__class__.__name__} does not support batch inference."
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def inference(self, data, args=None, **kwargs):
 | 
			
		||||
        """
 | 
			
		||||
        The entry point of inference. Usually for local invokers or model services.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def load_model(self, args=None, mode=None, **kwargs):
 | 
			
		||||
        """
 | 
			
		||||
        Implement model loading logic in derived executor classes.
 | 
			
		||||
 | 
			
		||||
        Implementer should initialize `self._model` and `self._tokenizer`.
 | 
			
		||||
 | 
			
		||||
        This method will be called by several entry methods in executors and invokers.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def warmup_inference(self, args=None, **kwargs):
 | 
			
		||||
        """
 | 
			
		||||
        Implement model warming up logic for inference in derived executor classes.
 | 
			
		||||
        """
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def from_config(cls, nn_config: Union[str, dict]) -> "NNExecutor":
 | 
			
		||||
        """
 | 
			
		||||
        Create an NN executor instance from `nn_config`.
 | 
			
		||||
 | 
			
		||||
        This method is abstract, derived class must override it by either
 | 
			
		||||
        creating executor instances or implementating dispatch logic.
 | 
			
		||||
 | 
			
		||||
        :param nn_config: config to use, can be dictionary or path to a JSON file
 | 
			
		||||
        :type nn_config: str or dict
 | 
			
		||||
        :rtype: NNExecutor
 | 
			
		||||
        """
 | 
			
		||||
        from nn4k.nnhub import NNHub
 | 
			
		||||
        from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
 | 
			
		||||
        from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
 | 
			
		||||
        from nn4k.consts import NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT
 | 
			
		||||
        from nn4k.utils.config_parsing import preprocess_config
 | 
			
		||||
        from nn4k.utils.config_parsing import get_string_field
 | 
			
		||||
        from nn4k.utils.class_importing import dynamic_import_class
 | 
			
		||||
 | 
			
		||||
        nn_config = preprocess_config(nn_config)
 | 
			
		||||
        nn_executor = nn_config.get(NN_EXECUTOR_KEY)
 | 
			
		||||
        if nn_executor is not None:
 | 
			
		||||
            nn_executor = get_string_field(nn_config, NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT)
 | 
			
		||||
            executor_class = dynamic_import_class(nn_executor, NN_EXECUTOR_TEXT)
 | 
			
		||||
            if not issubclass(executor_class, NNExecutor):
 | 
			
		||||
                message = "%r is not an %s class" % (nn_executor, NN_EXECUTOR_TEXT)
 | 
			
		||||
                raise RuntimeError(message)
 | 
			
		||||
            executor = executor_class.from_config(nn_config)
 | 
			
		||||
            return executor
 | 
			
		||||
 | 
			
		||||
        nn_name = nn_config.get(NN_NAME_KEY)
 | 
			
		||||
        if nn_name is not None:
 | 
			
		||||
            nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
 | 
			
		||||
        nn_version = nn_config.get(NN_VERSION_KEY)
 | 
			
		||||
        if nn_version is not None:
 | 
			
		||||
            nn_version = get_string_field(nn_config, NN_VERSION_KEY, NN_VERSION_TEXT)
 | 
			
		||||
        if nn_name is not None:
 | 
			
		||||
            hub = NNHub.get_instance()
 | 
			
		||||
            executor = hub.get_model_executor(nn_name, nn_version)
 | 
			
		||||
            if executor is not None:
 | 
			
		||||
                return executor
 | 
			
		||||
 | 
			
		||||
        message = "can not create executor for NN config"
 | 
			
		||||
        if nn_name is not None:
 | 
			
		||||
            message += "; model: %r" % nn_name
 | 
			
		||||
            if nn_version is not None:
 | 
			
		||||
                message += ", version: %r" % nn_version
 | 
			
		||||
        raise RuntimeError(message)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LLMExecutor(NNExecutor):
 | 
			
		||||
    def execute_sft(self, args=None, callbacks=None, **kwargs):
 | 
			
		||||
        """
 | 
			
		||||
        The entry point of SFT execution in a certain pod.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError(f"{self.__class__.__name__} does not support SFT.")
 | 
			
		||||
 | 
			
		||||
    def execute_rl_tuning(self, args=None, callbacks=None, **kwargs):
 | 
			
		||||
        """
 | 
			
		||||
        The entry point of SFT execution in a certain pod.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError(
 | 
			
		||||
            f"{self.__class__.__name__} does not support RL-Tuning."
 | 
			
		||||
        )
 | 
			
		||||
							
								
								
									
										17
									
								
								python/nn4k/nn4k/executor/deepke.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								python/nn4k/nn4k/executor/deepke.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,17 @@
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
 | 
			
		||||
from nn4k.executor import NNExecutor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DeepKeExecutor(NNExecutor):
 | 
			
		||||
 | 
			
		||||
    pass
 | 
			
		||||
							
								
								
									
										104
									
								
								python/nn4k/nn4k/executor/hugging_face.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										104
									
								
								python/nn4k/nn4k/executor/hugging_face.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,104 @@
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
 | 
			
		||||
from typing import Union
 | 
			
		||||
from nn4k.executor import LLMExecutor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HfLLMExecutor(LLMExecutor):
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_config(cls, nn_config: dict) -> "HfLLMExecutor":
 | 
			
		||||
        """
 | 
			
		||||
        Create an HfLLMExecutor instance from `nn_config`.
 | 
			
		||||
        """
 | 
			
		||||
        executor = cls(nn_config)
 | 
			
		||||
        return executor
 | 
			
		||||
 | 
			
		||||
    def execute_sft(self, args=None, callbacks=None, **kwargs):
 | 
			
		||||
        raise NotImplementedError(
 | 
			
		||||
            f"{self.__class__.__name__} will support SFT in the next version."
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def load_model(self, args=None, **kwargs):
 | 
			
		||||
        import torch
 | 
			
		||||
        from transformers import AutoTokenizer
 | 
			
		||||
        from transformers import AutoModelForCausalLM
 | 
			
		||||
        from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
 | 
			
		||||
        from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
 | 
			
		||||
        from nn4k.consts import NN_DEVICE_KEY, NN_TRUST_REMOTE_CODE_KEY
 | 
			
		||||
        from nn4k.utils.config_parsing import get_string_field
 | 
			
		||||
 | 
			
		||||
        nn_config: dict = args or self.init_args
 | 
			
		||||
        if self._model is None:
 | 
			
		||||
            nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
 | 
			
		||||
            nn_version = nn_config.get(NN_VERSION_KEY)
 | 
			
		||||
            if nn_version is not None:
 | 
			
		||||
                nn_version = get_string_field(
 | 
			
		||||
                    nn_config, NN_VERSION_KEY, NN_VERSION_TEXT
 | 
			
		||||
                )
 | 
			
		||||
            model_path = nn_name
 | 
			
		||||
            revision = nn_version
 | 
			
		||||
            use_fast_tokenizer = False
 | 
			
		||||
            device = nn_config.get(NN_DEVICE_KEY)
 | 
			
		||||
            trust_remote_code = nn_config.get(NN_TRUST_REMOTE_CODE_KEY, False)
 | 
			
		||||
            if device is None:
 | 
			
		||||
                device = "cuda" if torch.cuda.is_available() else "cpu"
 | 
			
		||||
            tokenizer = AutoTokenizer.from_pretrained(
 | 
			
		||||
                model_path,
 | 
			
		||||
                use_fast=use_fast_tokenizer,
 | 
			
		||||
                revision=revision,
 | 
			
		||||
                trust_remote_code=trust_remote_code,
 | 
			
		||||
            )
 | 
			
		||||
            model = AutoModelForCausalLM.from_pretrained(
 | 
			
		||||
                model_path,
 | 
			
		||||
                low_cpu_mem_usage=True,
 | 
			
		||||
                torch_dtype=torch.float16,
 | 
			
		||||
                revision=revision,
 | 
			
		||||
                trust_remote_code=trust_remote_code,
 | 
			
		||||
            )
 | 
			
		||||
            model.to(device)
 | 
			
		||||
            self._tokenizer = tokenizer
 | 
			
		||||
            self._model = model
 | 
			
		||||
 | 
			
		||||
    def inference(
 | 
			
		||||
        self,
 | 
			
		||||
        data,
 | 
			
		||||
        max_input_length: int = 1024,
 | 
			
		||||
        max_output_length: int = 1024,
 | 
			
		||||
        do_sample: bool = False,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ):
 | 
			
		||||
        model = self.model
 | 
			
		||||
        tokenizer = self.tokenizer
 | 
			
		||||
        input_ids = tokenizer(
 | 
			
		||||
            data,
 | 
			
		||||
            padding=True,
 | 
			
		||||
            return_token_type_ids=False,
 | 
			
		||||
            return_tensors="pt",
 | 
			
		||||
            truncation=True,
 | 
			
		||||
            max_length=max_input_length,
 | 
			
		||||
        ).to(model.device)
 | 
			
		||||
        output_ids = model.generate(
 | 
			
		||||
            **input_ids,
 | 
			
		||||
            max_new_tokens=max_output_length,
 | 
			
		||||
            do_sample=do_sample,
 | 
			
		||||
            eos_token_id=tokenizer.eos_token_id,
 | 
			
		||||
            pad_token_id=tokenizer.pad_token_id,
 | 
			
		||||
            **kwargs,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        outputs = [
 | 
			
		||||
            tokenizer.decode(
 | 
			
		||||
                output_id[len(input_ids["input_ids"][idx]) :], skip_special_tokens=True
 | 
			
		||||
            )
 | 
			
		||||
            for idx, output_id in enumerate(output_ids)
 | 
			
		||||
        ]
 | 
			
		||||
        return outputs
 | 
			
		||||
							
								
								
									
										12
									
								
								python/nn4k/nn4k/invoker/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								python/nn4k/nn4k/invoker/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,12 @@
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
 | 
			
		||||
from nn4k.invoker.base import NNInvoker, LLMInvoker
 | 
			
		||||
							
								
								
									
										186
									
								
								python/nn4k/nn4k/invoker/base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										186
									
								
								python/nn4k/nn4k/invoker/base.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,186 @@
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
 | 
			
		||||
from abc import ABC, abstractmethod
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from typing import Union
 | 
			
		||||
 | 
			
		||||
from nn4k.executor import LLMExecutor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SubmitMode(Enum):
 | 
			
		||||
    K8s = "k8s"
 | 
			
		||||
    Docker = "docker"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NNInvoker(ABC):
 | 
			
		||||
    """
 | 
			
		||||
    Invoking Entry Interfaces for NN Models.
 | 
			
		||||
    One NNInvoker object is for one NN Model.
 | 
			
		||||
    - Interfaces starting with "submit_" means submitting a batch task to a remote execution engine.
 | 
			
		||||
    - Interfaces starting with "remote_" means querying a remote service for some results.
 | 
			
		||||
    - Interfaces starting with "local_"  means running something locally.
 | 
			
		||||
            Must call `warmup_local_model` before calling any local_xxx interface.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, init_args: dict, **kwargs):
 | 
			
		||||
        self._init_args = init_args
 | 
			
		||||
        self._kwargs = kwargs
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def init_args(self):
 | 
			
		||||
        """
 | 
			
		||||
        Return the `init_args` passed to the invoker constructor.
 | 
			
		||||
        """
 | 
			
		||||
        return self._init_args
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def kwargs(self):
 | 
			
		||||
        """
 | 
			
		||||
        Return the `kwargs` passed to the invoker constructor.
 | 
			
		||||
        """
 | 
			
		||||
        return self._kwargs
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def from_config(cls, nn_config: Union[str, dict]) -> "NNInvoker":
 | 
			
		||||
        """
 | 
			
		||||
        Create an NN invoker instance from `nn_config`.
 | 
			
		||||
 | 
			
		||||
        This method is abstract, derived class must override it by either
 | 
			
		||||
        creating invoker instances or implementating dispatch logic.
 | 
			
		||||
 | 
			
		||||
        :param nn_config: config to use, can be dictionary or path to a JSON file
 | 
			
		||||
        :type nn_config: str or dict
 | 
			
		||||
        :rtype: NNInvoker
 | 
			
		||||
        :raises RuntimeError: if the NN config is not recognized
 | 
			
		||||
        """
 | 
			
		||||
        from nn4k.nnhub import NNHub
 | 
			
		||||
        from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
 | 
			
		||||
        from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
 | 
			
		||||
        from nn4k.consts import NN_INVOKER_KEY, NN_INVOKER_TEXT
 | 
			
		||||
        from nn4k.utils.config_parsing import preprocess_config
 | 
			
		||||
        from nn4k.utils.config_parsing import get_string_field
 | 
			
		||||
        from nn4k.utils.class_importing import dynamic_import_class
 | 
			
		||||
 | 
			
		||||
        nn_config = preprocess_config(nn_config)
 | 
			
		||||
        nn_invoker = nn_config.get(NN_INVOKER_KEY)
 | 
			
		||||
        if nn_invoker is not None:
 | 
			
		||||
            nn_invoker = get_string_field(nn_config, NN_INVOKER_KEY, NN_INVOKER_TEXT)
 | 
			
		||||
            invoker_class = dynamic_import_class(nn_invoker, NN_INVOKER_TEXT)
 | 
			
		||||
            if not issubclass(invoker_class, NNInvoker):
 | 
			
		||||
                message = "%r is not an %s class" % (nn_invoker, NN_INVOKER_TEXT)
 | 
			
		||||
                raise RuntimeError(message)
 | 
			
		||||
            invoker = invoker_class.from_config(nn_config)
 | 
			
		||||
            return invoker
 | 
			
		||||
 | 
			
		||||
        hub = NNHub.get_instance()
 | 
			
		||||
        invoker = hub.get_invoker(nn_config)
 | 
			
		||||
        if invoker is not None:
 | 
			
		||||
            return invoker
 | 
			
		||||
 | 
			
		||||
        nn_name = nn_config.get(NN_NAME_KEY)
 | 
			
		||||
        if nn_name is not None:
 | 
			
		||||
            nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
 | 
			
		||||
        nn_version = nn_config.get(NN_VERSION_KEY)
 | 
			
		||||
        if nn_version is not None:
 | 
			
		||||
            nn_version = get_string_field(nn_config, NN_VERSION_KEY, NN_VERSION_TEXT)
 | 
			
		||||
        message = "can not create invoker for NN config"
 | 
			
		||||
        if nn_name is not None:
 | 
			
		||||
            message += "; model: %r" % nn_name
 | 
			
		||||
            if nn_version is not None:
 | 
			
		||||
                message += ", version: %r" % nn_version
 | 
			
		||||
        raise RuntimeError(message)
 | 
			
		||||
 | 
			
		||||
    def submit_inference(self, submit_mode: SubmitMode = SubmitMode.K8s):
 | 
			
		||||
        """
 | 
			
		||||
        Submit remote batch inference execution.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError(
 | 
			
		||||
            f"{self.__class__.__name__} does not support batch inference."
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def remote_inference(self, input, **kwargs):
 | 
			
		||||
        """
 | 
			
		||||
        Inference via existing remote services.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError(
 | 
			
		||||
            f"{self.__class__.__name__} does not support remote inference."
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def local_inference(self, data, **kwargs):
 | 
			
		||||
        """
 | 
			
		||||
        Implement local inference in derived invoker classes.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError(
 | 
			
		||||
            f"{self.__class__.__name__} does not support local inference."
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def warmup_local_model(self):
 | 
			
		||||
        """
 | 
			
		||||
        Implement local model warming up logic in derived invoker classes.
 | 
			
		||||
        """
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LLMInvoker(NNInvoker):
 | 
			
		||||
    def submit_sft(self, submit_mode: SubmitMode = SubmitMode.K8s):
 | 
			
		||||
        """
 | 
			
		||||
        Submit remote SFT execution.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError(f"{self.__class__.__name__} does not support SFT.")
 | 
			
		||||
 | 
			
		||||
    def submit_rl_tuning(self, submit_mode: SubmitMode = SubmitMode.K8s):
 | 
			
		||||
        """
 | 
			
		||||
        Submit remote RL-Tuning execution.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError(
 | 
			
		||||
            f"{self.__class__.__name__} does not support RL-Tuning."
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def local_inference(self, data, **kwargs):
 | 
			
		||||
        """
 | 
			
		||||
        Implement local inference for local invoker.
 | 
			
		||||
        """
 | 
			
		||||
        return self._nn_executor.inference(data, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def warmup_local_model(self):
 | 
			
		||||
        """
 | 
			
		||||
        Implement local model warming up logic for local invoker.
 | 
			
		||||
        """
 | 
			
		||||
        from nn4k.nnhub import NNHub
 | 
			
		||||
        from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
 | 
			
		||||
        from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
 | 
			
		||||
        from nn4k.utils.config_parsing import get_string_field
 | 
			
		||||
 | 
			
		||||
        nn_name = get_string_field(self.init_args, NN_NAME_KEY, NN_NAME_TEXT)
 | 
			
		||||
        nn_version = self.init_args.get(NN_VERSION_KEY)
 | 
			
		||||
        if nn_version is not None:
 | 
			
		||||
            nn_version = get_string_field(
 | 
			
		||||
                self.init_args, NN_VERSION_KEY, NN_VERSION_TEXT
 | 
			
		||||
            )
 | 
			
		||||
        hub = NNHub.get_instance()
 | 
			
		||||
        executor = hub.get_model_executor(nn_name, nn_version)
 | 
			
		||||
        if executor is None:
 | 
			
		||||
            message = "model %r version %r " % (nn_name, nn_version)
 | 
			
		||||
            message += "is not found in the model hub"
 | 
			
		||||
            raise RuntimeError(message)
 | 
			
		||||
        self._nn_executor: LLMExecutor = executor
 | 
			
		||||
        self._nn_executor.load_model()
 | 
			
		||||
        self._nn_executor.warmup_inference()
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_config(cls, nn_config: dict) -> "LLMInvoker":
 | 
			
		||||
        """
 | 
			
		||||
        Create an LLMInvoker instance from `nn_config`.
 | 
			
		||||
        """
 | 
			
		||||
        invoker = cls(nn_config)
 | 
			
		||||
        return invoker
 | 
			
		||||
							
								
								
									
										75
									
								
								python/nn4k/nn4k/invoker/openai.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								python/nn4k/nn4k/invoker/openai.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,75 @@
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from nn4k.invoker import NNInvoker
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class OpenAIInvoker(NNInvoker):
 | 
			
		||||
    def __init__(self, nn_config: dict):
 | 
			
		||||
        super().__init__(nn_config)
 | 
			
		||||
 | 
			
		||||
        import openai
 | 
			
		||||
        from nn4k.consts import NN_OPENAI_MODEL_NAME_KEY, NN_OPENAI_MODEL_NAME_TEXT
 | 
			
		||||
        from nn4k.consts import NN_OPENAI_API_KEY_KEY, NN_OPENAI_API_KEY_TEXT
 | 
			
		||||
        from nn4k.consts import NN_OPENAI_API_BASE_KEY, NN_OPENAI_API_BASE_TEXT
 | 
			
		||||
        from nn4k.consts import NN_OPENAI_MAX_TOKENS_KEY, NN_OPENAI_MAX_TOKENS_TEXT
 | 
			
		||||
        from nn4k.utils.config_parsing import get_string_field
 | 
			
		||||
        from nn4k.utils.config_parsing import get_positive_int_field
 | 
			
		||||
 | 
			
		||||
        self.openai_model_name = get_string_field(
 | 
			
		||||
            self.init_args, NN_OPENAI_MODEL_NAME_KEY, NN_OPENAI_MODEL_NAME_TEXT
 | 
			
		||||
        )
 | 
			
		||||
        self.openai_api_key = get_string_field(
 | 
			
		||||
            self.init_args, NN_OPENAI_API_KEY_KEY, NN_OPENAI_API_KEY_TEXT
 | 
			
		||||
        )
 | 
			
		||||
        self.openai_api_base = get_string_field(
 | 
			
		||||
            self.init_args, NN_OPENAI_API_BASE_KEY, NN_OPENAI_API_BASE_TEXT
 | 
			
		||||
        )
 | 
			
		||||
        self.openai_max_tokens = get_positive_int_field(
 | 
			
		||||
            self.init_args, NN_OPENAI_MAX_TOKENS_KEY, NN_OPENAI_MAX_TOKENS_TEXT
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        openai.api_key = self.openai_api_key
 | 
			
		||||
        openai.api_base = self.openai_api_base
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_config(cls, nn_config: dict) -> "OpenAIInvoker":
 | 
			
		||||
        invoker = cls(nn_config)
 | 
			
		||||
        return invoker
 | 
			
		||||
 | 
			
		||||
    def _create_prompt(self, input, **kwargs):
 | 
			
		||||
        if isinstance(input, list):
 | 
			
		||||
            prompt = input
 | 
			
		||||
        else:
 | 
			
		||||
            prompt = [input]
 | 
			
		||||
        return prompt
 | 
			
		||||
 | 
			
		||||
    def _create_output(self, input, prompt, completion, **kwargs):
 | 
			
		||||
        output = [choice.text for choice in completion.choices]
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
    def remote_inference(
 | 
			
		||||
        self, input, max_output_length: Optional[int] = None, **kwargs
 | 
			
		||||
    ):
 | 
			
		||||
        import openai
 | 
			
		||||
 | 
			
		||||
        if max_output_length is None:
 | 
			
		||||
            max_output_length = self.openai_max_tokens
 | 
			
		||||
        prompt = self._create_prompt(input, **kwargs)
 | 
			
		||||
        completion = openai.Completion.create(
 | 
			
		||||
            model=self.openai_model_name,
 | 
			
		||||
            prompt=prompt,
 | 
			
		||||
            max_tokens=max_output_length,
 | 
			
		||||
        )
 | 
			
		||||
        output = self._create_output(input, prompt, completion, **kwargs)
 | 
			
		||||
        return output
 | 
			
		||||
							
								
								
									
										165
									
								
								python/nn4k/nn4k/nnhub/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										165
									
								
								python/nn4k/nn4k/nnhub/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,165 @@
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
 | 
			
		||||
from abc import ABC, abstractmethod
 | 
			
		||||
from typing import Optional, Union, Tuple, Type
 | 
			
		||||
 | 
			
		||||
from nn4k.executor import NNExecutor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NNHub(ABC):
 | 
			
		||||
    _hub_instance = None
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_instance() -> "NNHub":
 | 
			
		||||
        """
 | 
			
		||||
        Get the NNHub instance. If the instance is not initialized, create a stub `SimpleNNHub`.
 | 
			
		||||
        """
 | 
			
		||||
        if NNHub._hub_instance is None:
 | 
			
		||||
            NNHub._hub_instance = SimpleNNHub()
 | 
			
		||||
        return NNHub._hub_instance
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def publish(
 | 
			
		||||
        self,
 | 
			
		||||
        model_executor: Union[NNExecutor, Tuple[Type[NNExecutor], tuple, dict, tuple]],
 | 
			
		||||
        name: str,
 | 
			
		||||
        version: str = None,
 | 
			
		||||
    ) -> str:
 | 
			
		||||
        """
 | 
			
		||||
        Publish a model(executor) to hub.
 | 
			
		||||
 | 
			
		||||
        :param model_executor: An NNExecutor object, which is pickleable.
 | 
			
		||||
                               Or a tuple of (class, init_args, kwargs, weight_ids) for creating an NNExecutor,
 | 
			
		||||
                               while all these 4 augments are pickleable.
 | 
			
		||||
 | 
			
		||||
        :param str name: The name of a model, like `llama2`. We do not have a `namespace`.
 | 
			
		||||
                         Use a joined name like `alibaba/qwen` to support such features.
 | 
			
		||||
 | 
			
		||||
        :param str version: Optional. Auto generate a version if this param is not given.
 | 
			
		||||
 | 
			
		||||
        :return: The published model version.
 | 
			
		||||
        :rtype: str
 | 
			
		||||
        """
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def get_model_executor(
 | 
			
		||||
        self, name: str, version: str = None
 | 
			
		||||
    ) -> Optional[NNExecutor]:
 | 
			
		||||
        """
 | 
			
		||||
        Get an NNExecutor instance from Hub.
 | 
			
		||||
 | 
			
		||||
        :param str name: The name of a model.
 | 
			
		||||
        :param str version: The version of a model. Get default version of a model if this param is not given.
 | 
			
		||||
        :return: The ModelExecutor Instance. None for NotFound.
 | 
			
		||||
        :rtype: Optional[NNExecutor]
 | 
			
		||||
        """
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def get_invoker(self, nn_config: dict) -> Optional["NNInvoker"]:
 | 
			
		||||
        """
 | 
			
		||||
        Get an NNExecutor instance from Hub.
 | 
			
		||||
 | 
			
		||||
        :param dict nn_config: The config dictionary.
 | 
			
		||||
        :return: The NNExecutor Instance. None for NotFound.
 | 
			
		||||
        :rtype: Optional[NNInvoker]
 | 
			
		||||
        """
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def start_service(self, name: str, version: str, service_id: str = None, **kwargs):
 | 
			
		||||
        raise NotImplementedError("This Hub does not support starting model service.")
 | 
			
		||||
 | 
			
		||||
    def stop_service(self, name: str, version: str, service_id: str = None, **kwargs):
 | 
			
		||||
        raise NotImplementedError("This Hub does not support stopping model service.")
 | 
			
		||||
 | 
			
		||||
    def get_service(self, name: str, version: str, service_id: str = None):
 | 
			
		||||
        raise NotImplementedError("This Hub does not support model services.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SimpleNNHub(NNHub):
 | 
			
		||||
    def __init__(self) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self._model_executors = {}
 | 
			
		||||
 | 
			
		||||
    def _add_executor(
 | 
			
		||||
        self,
 | 
			
		||||
        executor: Union[NNExecutor, Tuple[Type[NNExecutor], tuple, dict, tuple]],
 | 
			
		||||
        name: str,
 | 
			
		||||
        version: str = None,
 | 
			
		||||
    ):
 | 
			
		||||
        from nn4k.consts import NN_VERSION_DEFAULT
 | 
			
		||||
 | 
			
		||||
        if version is None:
 | 
			
		||||
            version = NN_VERSION_DEFAULT
 | 
			
		||||
        if self._model_executors.get(name) is None:
 | 
			
		||||
            self._model_executors[name] = {version: executor}
 | 
			
		||||
        else:
 | 
			
		||||
            self._model_executors[name][version] = executor
 | 
			
		||||
 | 
			
		||||
    def publish(
 | 
			
		||||
        self, model_executor: NNExecutor, name: str, version: str = None
 | 
			
		||||
    ) -> str:
 | 
			
		||||
        from nn4k.consts import NN_VERSION_DEFAULT
 | 
			
		||||
 | 
			
		||||
        print(
 | 
			
		||||
            "WARNING: You are using SimpleNNHub which can only maintain models in memory without data persistence!"
 | 
			
		||||
        )
 | 
			
		||||
        if version is None:
 | 
			
		||||
            version = NN_VERSION_DEFAULT
 | 
			
		||||
        self._add_executor(model_executor, name, version)
 | 
			
		||||
        return version
 | 
			
		||||
 | 
			
		||||
    def _create_model_executor(self, cls, init_args, kwargs, weights):
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def get_model_executor(
 | 
			
		||||
        self, name: str, version: str = None
 | 
			
		||||
    ) -> Optional[NNExecutor]:
 | 
			
		||||
        if self._model_executors.get(name) is None:
 | 
			
		||||
            return None
 | 
			
		||||
        executor = self._model_executors.get(name).get(version)
 | 
			
		||||
        if isinstance(executor, NNExecutor):
 | 
			
		||||
            return executor
 | 
			
		||||
        cls, init_args, kwargs, weights = executor
 | 
			
		||||
        executor = self._create_model_executor(cls, init_args, kwargs, weights)
 | 
			
		||||
        return executor
 | 
			
		||||
 | 
			
		||||
    def _add_local_executor(self, nn_config):
 | 
			
		||||
        from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
 | 
			
		||||
        from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
 | 
			
		||||
        from nn4k.executor.hugging_face import HfLLMExecutor
 | 
			
		||||
        from nn4k.utils.config_parsing import get_string_field
 | 
			
		||||
 | 
			
		||||
        executor = HfLLMExecutor.from_config(nn_config)
 | 
			
		||||
        nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
 | 
			
		||||
        nn_version = nn_config.get(NN_VERSION_KEY)
 | 
			
		||||
        if nn_version is not None:
 | 
			
		||||
            nn_version = get_string_field(nn_config, NN_VERSION_KEY, NN_VERSION_TEXT)
 | 
			
		||||
        self.publish(executor, nn_name, nn_version)
 | 
			
		||||
 | 
			
		||||
    def get_invoker(self, nn_config: dict) -> Optional["NNInvoker"]:
 | 
			
		||||
        from nn4k.invoker import LLMInvoker
 | 
			
		||||
        from nn4k.invoker.openai import OpenAIInvoker
 | 
			
		||||
        from nn4k.utils.invoker_checking import is_openai_invoker
 | 
			
		||||
        from nn4k.utils.invoker_checking import is_local_invoker
 | 
			
		||||
 | 
			
		||||
        if is_openai_invoker(nn_config):
 | 
			
		||||
            invoker = OpenAIInvoker.from_config(nn_config)
 | 
			
		||||
            return invoker
 | 
			
		||||
 | 
			
		||||
        if is_local_invoker(nn_config):
 | 
			
		||||
            invoker = LLMInvoker.from_config(nn_config)
 | 
			
		||||
            self._add_local_executor(nn_config)
 | 
			
		||||
            return invoker
 | 
			
		||||
 | 
			
		||||
        return None
 | 
			
		||||
							
								
								
									
										10
									
								
								python/nn4k/nn4k/utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								python/nn4k/nn4k/utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,10 @@
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
							
								
								
									
										57
									
								
								python/nn4k/nn4k/utils/class_importing.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								python/nn4k/nn4k/utils/class_importing.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,57 @@
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
 | 
			
		||||
import importlib
 | 
			
		||||
from typing import Tuple
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def split_module_class_name(name: str, text: str) -> Tuple[str, str]:
 | 
			
		||||
    """
 | 
			
		||||
    Split `name` as module name and class name pair.
 | 
			
		||||
 | 
			
		||||
    :param name: fully qualified class name, e.g. ``foo.bar.MyClass``
 | 
			
		||||
    :type name: str
 | 
			
		||||
    :param text: describe the kind of the class, used in the exception message
 | 
			
		||||
    :type text: str
 | 
			
		||||
    :rtype: Tuple[str, str]
 | 
			
		||||
    :raises RuntimeError: if `name` is not a fully qualified class name
 | 
			
		||||
    """
 | 
			
		||||
    i = name.rfind(".")
 | 
			
		||||
    if i == -1:
 | 
			
		||||
        message = "invalid %s class name: %s" % (text, name)
 | 
			
		||||
        raise RuntimeError(message)
 | 
			
		||||
    module_name = name[:i]
 | 
			
		||||
    class_name = name[i + 1 :]
 | 
			
		||||
    return module_name, class_name
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def dynamic_import_class(name: str, text: str):
 | 
			
		||||
    """
 | 
			
		||||
    Import the class specified by `name` dyanmically.
 | 
			
		||||
 | 
			
		||||
    :param name: fully qualified class name, e.g. ``foo.bar.MyClass``
 | 
			
		||||
    :type name: str
 | 
			
		||||
    :param text: describe the kind of the class, use in the exception message
 | 
			
		||||
    :type text: str
 | 
			
		||||
    :raises RuntimeError: if `name` is not a fully qualified class name, or
 | 
			
		||||
                          the class is not in the module specified by `name`
 | 
			
		||||
    :raises ModuleNotFoundError: the module specified by `name` is not found
 | 
			
		||||
    """
 | 
			
		||||
    module_name, class_name = split_module_class_name(name, text)
 | 
			
		||||
    module = importlib.import_module(module_name)
 | 
			
		||||
    class_ = getattr(module, class_name, None)
 | 
			
		||||
    if class_ is None:
 | 
			
		||||
        message = "class %r not found in module %r" % (class_name, module_name)
 | 
			
		||||
        raise RuntimeError(message)
 | 
			
		||||
    if not isinstance(class_, type):
 | 
			
		||||
        message = "%r is not a class" % (name,)
 | 
			
		||||
        raise RuntimeError(message)
 | 
			
		||||
    return class_
 | 
			
		||||
							
								
								
									
										118
									
								
								python/nn4k/nn4k/utils/config_parsing.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										118
									
								
								python/nn4k/nn4k/utils/config_parsing.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,118 @@
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
 | 
			
		||||
from typing import Any
 | 
			
		||||
from typing import Union
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def preprocess_config(nn_config: Union[str, dict]) -> dict:
 | 
			
		||||
    """
 | 
			
		||||
    Preprocess config `nn_config` into a dictionary.
 | 
			
		||||
 | 
			
		||||
    * If `nn_config` is already a dictionary, return it as is.
 | 
			
		||||
 | 
			
		||||
    * If `nn_config` is a string, decode it as a JSON file.
 | 
			
		||||
 | 
			
		||||
    :param nn_config: config to be preprocessed
 | 
			
		||||
    :type nn_config: str or dict
 | 
			
		||||
    :return: `nn_config` or `nn_config` decoded as JSON
 | 
			
		||||
    :rtype: dict
 | 
			
		||||
    :raises ValueError: if cannot decode config file specified by
 | 
			
		||||
                        `nn_config` as JSON
 | 
			
		||||
    """
 | 
			
		||||
    try:
 | 
			
		||||
        if isinstance(nn_config, str):
 | 
			
		||||
            with open(nn_config, "r") as f:
 | 
			
		||||
                nn_config = json.load(f)
 | 
			
		||||
    except:
 | 
			
		||||
        raise ValueError("cannot decode config file")
 | 
			
		||||
    return nn_config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_field(nn_config: dict, name: str, text: str) -> Any:
 | 
			
		||||
    """
 | 
			
		||||
    Get the value of the field specified by `name` from the configuration
 | 
			
		||||
    dictionary `nn_config`.
 | 
			
		||||
 | 
			
		||||
    :param str name: name of the field
 | 
			
		||||
    :param str name: descriptive text of the name of the field
 | 
			
		||||
    :return: value of the field
 | 
			
		||||
    :rtype: Any
 | 
			
		||||
    :raises ValueError: if the field is not specified in `nn_config`
 | 
			
		||||
    """
 | 
			
		||||
    value = nn_config.get(name)
 | 
			
		||||
    if value is None:
 | 
			
		||||
        message = "%s %r not found" % (text, name)
 | 
			
		||||
        raise ValueError(message)
 | 
			
		||||
    return value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_string_field(nn_config: dict, name: str, text: str) -> str:
 | 
			
		||||
    """
 | 
			
		||||
    Get the value of the string field specified by `name` from the
 | 
			
		||||
    configuration dictionary `nn_config`.
 | 
			
		||||
 | 
			
		||||
    :param str name: name of the field
 | 
			
		||||
    :param str name: descriptive text of the name of the field
 | 
			
		||||
    :return: value of the field
 | 
			
		||||
    :rtype: str
 | 
			
		||||
    :raises ValueError: if the field is not specified in `nn_config`
 | 
			
		||||
    :raises TypeError: if the value of the field is not a string
 | 
			
		||||
    """
 | 
			
		||||
    value = get_field(nn_config, name, text)
 | 
			
		||||
    if not isinstance(value, str):
 | 
			
		||||
        message = "%s %r must be string; " % (text, name)
 | 
			
		||||
        message += "%r is invalid" % (value,)
 | 
			
		||||
        raise TypeError(message)
 | 
			
		||||
    return value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_int_field(nn_config: dict, name: str, text: str) -> int:
 | 
			
		||||
    """
 | 
			
		||||
    Get the value of the integer field specified by `name` from the
 | 
			
		||||
    configuration dictionary `nn_config`.
 | 
			
		||||
 | 
			
		||||
    :param str name: name of the field
 | 
			
		||||
    :param str name: descriptive text of the name of the field
 | 
			
		||||
    :return: value of the field
 | 
			
		||||
    :rtype: int
 | 
			
		||||
    :raises ValueError: if the field is not specified in `nn_config`
 | 
			
		||||
    :raises TypeError: if the value of the field is not an integer
 | 
			
		||||
    """
 | 
			
		||||
    value = get_field(nn_config, name, text)
 | 
			
		||||
    if not isinstance(value, int):
 | 
			
		||||
        message = "%s %r must be integer; " % (text, name)
 | 
			
		||||
        message += "%r is invalid" % (value,)
 | 
			
		||||
        raise TypeError(message)
 | 
			
		||||
    return value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_positive_int_field(nn_config: dict, name: str, text: str) -> int:
 | 
			
		||||
    """
 | 
			
		||||
    Get the value of the positive integer field specified by `name`
 | 
			
		||||
    from the configuration dictionary `nn_config`.
 | 
			
		||||
 | 
			
		||||
    :param str name: name of the field
 | 
			
		||||
    :param str name: descriptive text of the name of the field
 | 
			
		||||
    :return: value of the field
 | 
			
		||||
    :rtype: int
 | 
			
		||||
    :raises ValueError: if the field is not specified in `nn_config`, or the
 | 
			
		||||
                        value of the field is not a positive integer
 | 
			
		||||
    :raises TypeError: if the value of the field is not an integer
 | 
			
		||||
    """
 | 
			
		||||
    value = get_int_field(nn_config, name, text)
 | 
			
		||||
    if value <= 0:
 | 
			
		||||
        message = "%s %r must be positive integer; " % (text, name)
 | 
			
		||||
        message += "%r is invalid" % (value,)
 | 
			
		||||
        raise ValueError(message)
 | 
			
		||||
    return value
 | 
			
		||||
							
								
								
									
										62
									
								
								python/nn4k/nn4k/utils/invoker_checking.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								python/nn4k/nn4k/utils/invoker_checking.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,62 @@
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_openai_invoker(nn_config: dict) -> bool:
 | 
			
		||||
    """
 | 
			
		||||
    Check whether `nn_config` specifies OpenAI invoker.
 | 
			
		||||
 | 
			
		||||
    :type nn_config: dict
 | 
			
		||||
    :rtype: bool
 | 
			
		||||
    """
 | 
			
		||||
    from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
 | 
			
		||||
    from nn4k.consts import NN_OPENAI_API_KEY_KEY
 | 
			
		||||
    from nn4k.consts import NN_OPENAI_API_BASE_KEY
 | 
			
		||||
    from nn4k.consts import NN_OPENAI_MAX_TOKENS_KEY
 | 
			
		||||
    from nn4k.consts import NN_OPENAI_GPT4_PREFIX
 | 
			
		||||
    from nn4k.consts import NN_OPENAI_GPT35_PREFIX
 | 
			
		||||
    from nn4k.utils.config_parsing import get_string_field
 | 
			
		||||
 | 
			
		||||
    nn_name = nn_config.get(NN_NAME_KEY)
 | 
			
		||||
    if nn_name is not None:
 | 
			
		||||
        nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
 | 
			
		||||
        if nn_name.startswith(NN_OPENAI_GPT4_PREFIX) or nn_name.startswith(
 | 
			
		||||
            NN_OPENAI_GPT35_PREFIX
 | 
			
		||||
        ):
 | 
			
		||||
            return True
 | 
			
		||||
    keys = (NN_OPENAI_API_KEY_KEY, NN_OPENAI_API_BASE_KEY, NN_OPENAI_MAX_TOKENS_KEY)
 | 
			
		||||
    for key in keys:
 | 
			
		||||
        if key in nn_config:
 | 
			
		||||
            return True
 | 
			
		||||
    return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_local_invoker(nn_config: dict) -> bool:
 | 
			
		||||
    """
 | 
			
		||||
    Check whether `nn_config` specifies local invoker.
 | 
			
		||||
 | 
			
		||||
    :type nn_config: dict
 | 
			
		||||
    :rtype: bool
 | 
			
		||||
    """
 | 
			
		||||
    from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
 | 
			
		||||
    from nn4k.consts import NN_LOCAL_HF_MODEL_CONFIG_FILE
 | 
			
		||||
    from nn4k.utils.config_parsing import get_string_field
 | 
			
		||||
 | 
			
		||||
    nn_name = nn_config.get(NN_NAME_KEY)
 | 
			
		||||
    if nn_name is not None:
 | 
			
		||||
        nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
 | 
			
		||||
        if os.path.isdir(nn_name):
 | 
			
		||||
            file_path = os.path.join(nn_name, NN_LOCAL_HF_MODEL_CONFIG_FILE)
 | 
			
		||||
            if os.path.isfile(file_path):
 | 
			
		||||
                return True
 | 
			
		||||
    return False
 | 
			
		||||
							
								
								
									
										1
									
								
								python/nn4k/requirements.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								python/nn4k/requirements.txt
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1 @@
 | 
			
		||||
openai<1
 | 
			
		||||
							
								
								
									
										78
									
								
								python/nn4k/setup.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										78
									
								
								python/nn4k/setup.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,78 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
from setuptools import setup, find_packages
 | 
			
		||||
 | 
			
		||||
package_name = "openspg-nn4k"
 | 
			
		||||
 | 
			
		||||
# version
 | 
			
		||||
cwd = os.path.abspath(os.path.dirname(__file__))
 | 
			
		||||
with open(os.path.join(cwd, "NN4K_VERSION"), "r") as rf:
 | 
			
		||||
    version = rf.readline().strip("\n").strip()
 | 
			
		||||
 | 
			
		||||
# license
 | 
			
		||||
license = ""
 | 
			
		||||
with open(os.path.join(cwd, "LICENSE"), "r") as rf:
 | 
			
		||||
    line = rf.readline()
 | 
			
		||||
    while line:
 | 
			
		||||
        line = line.strip()
 | 
			
		||||
        if line:
 | 
			
		||||
            license += "# " + line + "\n"
 | 
			
		||||
        else:
 | 
			
		||||
            license += "#\n"
 | 
			
		||||
        line = rf.readline()
 | 
			
		||||
 | 
			
		||||
# Generate nn4k.__init__.py
 | 
			
		||||
with open(os.path.join(cwd, "nn4k/__init__.py"), "w") as wf:
 | 
			
		||||
    content = f"""{license}
 | 
			
		||||
 | 
			
		||||
__package_name__ = "{package_name}"
 | 
			
		||||
__version__ = "{version}"
 | 
			
		||||
"""
 | 
			
		||||
    wf.write(content)
 | 
			
		||||
 | 
			
		||||
setup(
 | 
			
		||||
    name=package_name,
 | 
			
		||||
    version=version,
 | 
			
		||||
    description="nn4k",
 | 
			
		||||
    url="https://github.com/OpenSPG/openspg",
 | 
			
		||||
    packages=find_packages(
 | 
			
		||||
        where=".",
 | 
			
		||||
        exclude=[
 | 
			
		||||
            ".*test.py",
 | 
			
		||||
            "*_test.py",
 | 
			
		||||
            "*_debug.py",
 | 
			
		||||
            "*.txt",
 | 
			
		||||
            "tests",
 | 
			
		||||
            "tests.*",
 | 
			
		||||
            "configs",
 | 
			
		||||
            "configs.*",
 | 
			
		||||
            "test",
 | 
			
		||||
            "test.*",
 | 
			
		||||
            "*.tests",
 | 
			
		||||
            "*.tests.*",
 | 
			
		||||
            "*.pyc",
 | 
			
		||||
        ],
 | 
			
		||||
    ),
 | 
			
		||||
    python_requires=">=3.8",
 | 
			
		||||
    install_requires=[
 | 
			
		||||
        r.strip()
 | 
			
		||||
        for r in open("requirements.txt", "r")
 | 
			
		||||
        if not r.strip().startswith("#")
 | 
			
		||||
    ],
 | 
			
		||||
    include_package_data=True,
 | 
			
		||||
    package_data={
 | 
			
		||||
        "bin": ["*"],
 | 
			
		||||
    },
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										10
									
								
								python/nn4k/tests/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								python/nn4k/tests/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,10 @@
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
							
								
								
									
										10
									
								
								python/nn4k/tests/executor/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								python/nn4k/tests/executor/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,10 @@
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
							
								
								
									
										69
									
								
								python/nn4k/tests/executor/test_base_executor.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								python/nn4k/tests/executor/test_base_executor.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,69 @@
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import sys
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestBaseExecutor(unittest.TestCase):
 | 
			
		||||
    """
 | 
			
		||||
    NNExecutor and LLMExecutor unittest
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        # for importing test_stub.py
 | 
			
		||||
        dir_path = os.path.dirname(os.path.abspath(__file__))
 | 
			
		||||
        sys.path.insert(0, dir_path)
 | 
			
		||||
 | 
			
		||||
        from nn4k.nnhub import NNHub
 | 
			
		||||
        from test_stub import StubHub
 | 
			
		||||
 | 
			
		||||
        NNHub._hub_instance = StubHub()
 | 
			
		||||
 | 
			
		||||
    def tearDown(self):
 | 
			
		||||
        from nn4k.nnhub import NNHub
 | 
			
		||||
 | 
			
		||||
        sys.path.pop(0)
 | 
			
		||||
        NNHub._hub_instance = None
 | 
			
		||||
 | 
			
		||||
    def testCustomNNExecutor(self):
 | 
			
		||||
        from nn4k.executor import NNExecutor
 | 
			
		||||
        from test_stub import StubExecutor
 | 
			
		||||
 | 
			
		||||
        nn_config = {"nn_executor": "test_stub.StubExecutor"}
 | 
			
		||||
        executor = NNExecutor.from_config(nn_config)
 | 
			
		||||
        self.assertTrue(isinstance(executor, StubExecutor))
 | 
			
		||||
        self.assertEqual(executor.init_args, nn_config)
 | 
			
		||||
        self.assertEqual(executor.kwargs, {})
 | 
			
		||||
 | 
			
		||||
        with self.assertRaises(RuntimeError):
 | 
			
		||||
            executor = NNExecutor.from_config({"nn_executor": "test_stub.NotExecutor"})
 | 
			
		||||
 | 
			
		||||
    def testHubExecutor(self):
 | 
			
		||||
        from nn4k.executor import NNExecutor
 | 
			
		||||
        from test_stub import StubExecutor
 | 
			
		||||
 | 
			
		||||
        nn_config = {"nn_name": "test_stub", "nn_version": "default"}
 | 
			
		||||
        executor = NNExecutor.from_config(nn_config)
 | 
			
		||||
        self.assertTrue(isinstance(executor, StubExecutor))
 | 
			
		||||
        self.assertEqual(executor.init_args, nn_config)
 | 
			
		||||
        self.assertEqual(executor.kwargs, {"test_stub_executor": True})
 | 
			
		||||
 | 
			
		||||
    def testExecutorNotExists(self):
 | 
			
		||||
        from nn4k.executor import NNExecutor
 | 
			
		||||
 | 
			
		||||
        with self.assertRaises(RuntimeError):
 | 
			
		||||
            executor = NNExecutor.from_config({"nn_name": "not_exists"})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
							
								
								
									
										57
									
								
								python/nn4k/tests/executor/test_hf_llm_executor.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								python/nn4k/tests/executor/test_hf_llm_executor.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,57 @@
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
 | 
			
		||||
import sys
 | 
			
		||||
import unittest
 | 
			
		||||
import unittest.mock
 | 
			
		||||
 | 
			
		||||
from nn4k.executor.hugging_face import HfLLMExecutor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestHfLLMExecutor(unittest.TestCase):
 | 
			
		||||
    """
 | 
			
		||||
    HfLLMExecutor unittest
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        self._saved_torch = sys.modules.get("torch")
 | 
			
		||||
        self._mocked_torch = unittest.mock.MagicMock()
 | 
			
		||||
        sys.modules["torch"] = self._mocked_torch
 | 
			
		||||
 | 
			
		||||
        self._saved_transformers = sys.modules.get("transformers")
 | 
			
		||||
        self._mocked_transformers = unittest.mock.MagicMock()
 | 
			
		||||
        sys.modules["transformers"] = self._mocked_transformers
 | 
			
		||||
 | 
			
		||||
    def tearDown(self):
 | 
			
		||||
        del sys.modules["torch"]
 | 
			
		||||
        if self._saved_torch is not None:
 | 
			
		||||
            sys.modules["torch"] = self._saved_torch
 | 
			
		||||
 | 
			
		||||
        del sys.modules["transformers"]
 | 
			
		||||
        if self._saved_transformers is not None:
 | 
			
		||||
            sys.modules["transformers"] = self._saved_transformers
 | 
			
		||||
 | 
			
		||||
    def testHfLLMExecutor(self):
 | 
			
		||||
        nn_config = {
 | 
			
		||||
            "nn_name": "/opt/test_model_dir",
 | 
			
		||||
            "nn_version": "default",
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        executor = HfLLMExecutor.from_config(nn_config)
 | 
			
		||||
        executor.load_model()
 | 
			
		||||
        executor.inference("input")
 | 
			
		||||
 | 
			
		||||
        self._mocked_transformers.AutoTokenizer.from_pretrained.assert_called()
 | 
			
		||||
        self._mocked_transformers.AutoModelForCausalLM.from_pretrained.assert_called()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
							
								
								
									
										52
									
								
								python/nn4k/tests/executor/test_stub.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								python/nn4k/tests/executor/test_stub.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,52 @@
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from nn4k.executor import NNExecutor, LLMExecutor
 | 
			
		||||
from nn4k.nnhub import SimpleNNHub
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class StubExecutor(LLMExecutor):
 | 
			
		||||
    def load_model(self, args=None, mode=None, **kwargs):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def warmup_inference(self, args=None, **kwargs):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def inference(self, data, args=None, **kwargs):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_config(cls, nn_config: dict) -> "StubExecutor":
 | 
			
		||||
        """
 | 
			
		||||
        Create a StubExecutor instance from `nn_config`.
 | 
			
		||||
        """
 | 
			
		||||
        executor = cls(nn_config)
 | 
			
		||||
        return executor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NotExecutor:
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class StubHub(SimpleNNHub):
 | 
			
		||||
    def get_model_executor(
 | 
			
		||||
        self, name: str, version: str = None
 | 
			
		||||
    ) -> Optional[NNExecutor]:
 | 
			
		||||
        if name == "test_stub":
 | 
			
		||||
            if version is None:
 | 
			
		||||
                version = "default"
 | 
			
		||||
            executor = StubExecutor(
 | 
			
		||||
                {"nn_name": name, "nn_version": version}, test_stub_executor=True
 | 
			
		||||
            )
 | 
			
		||||
            return executor
 | 
			
		||||
        return super().get_model_executor(name, version)
 | 
			
		||||
							
								
								
									
										10
									
								
								python/nn4k/tests/invoker/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								python/nn4k/tests/invoker/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,10 @@
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
							
								
								
									
										84
									
								
								python/nn4k/tests/invoker/test_base_invoker.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										84
									
								
								python/nn4k/tests/invoker/test_base_invoker.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,84 @@
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import sys
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestBaseInvoker(unittest.TestCase):
 | 
			
		||||
    """
 | 
			
		||||
    NNInvoker and LLMInvoker unittest
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        # for importing test_stub.py
 | 
			
		||||
        dir_path = os.path.dirname(os.path.abspath(__file__))
 | 
			
		||||
        sys.path.insert(0, dir_path)
 | 
			
		||||
 | 
			
		||||
        from nn4k.nnhub import NNHub
 | 
			
		||||
        from test_stub import StubHub
 | 
			
		||||
 | 
			
		||||
        NNHub._hub_instance = StubHub()
 | 
			
		||||
 | 
			
		||||
    def tearDown(self):
 | 
			
		||||
        from nn4k.nnhub import NNHub
 | 
			
		||||
 | 
			
		||||
        sys.path.pop(0)
 | 
			
		||||
        NNHub._hub_instance = None
 | 
			
		||||
 | 
			
		||||
    def testCustomNNInvoker(self):
 | 
			
		||||
        from nn4k.invoker import NNInvoker
 | 
			
		||||
        from test_stub import StubInvoker
 | 
			
		||||
 | 
			
		||||
        nn_config = {"nn_invoker": "test_stub.StubInvoker"}
 | 
			
		||||
        invoker = NNInvoker.from_config(nn_config)
 | 
			
		||||
        self.assertTrue(isinstance(invoker, StubInvoker))
 | 
			
		||||
        self.assertEqual(invoker.init_args, nn_config)
 | 
			
		||||
        self.assertEqual(invoker.kwargs, {})
 | 
			
		||||
 | 
			
		||||
        with self.assertRaises(RuntimeError):
 | 
			
		||||
            invoker = NNInvoker.from_config({"nn_invoker": "test_stub.NotInvoker"})
 | 
			
		||||
 | 
			
		||||
    def testHubInvoker(self):
 | 
			
		||||
        from nn4k.invoker import NNInvoker
 | 
			
		||||
        from test_stub import StubInvoker
 | 
			
		||||
 | 
			
		||||
        nn_config = {"nn_name": "test_stub"}
 | 
			
		||||
        invoker = NNInvoker.from_config(nn_config)
 | 
			
		||||
        self.assertTrue(isinstance(invoker, StubInvoker))
 | 
			
		||||
        self.assertEqual(invoker.init_args, nn_config)
 | 
			
		||||
        self.assertEqual(invoker.kwargs, {"test_stub_invoker": True})
 | 
			
		||||
 | 
			
		||||
    def testInvokerNotExists(self):
 | 
			
		||||
        from nn4k.invoker import NNInvoker
 | 
			
		||||
 | 
			
		||||
        with self.assertRaises(RuntimeError):
 | 
			
		||||
            invoker = NNInvoker.from_config({"nn_name": "not_exists"})
 | 
			
		||||
 | 
			
		||||
    def testLocalInvoker(self):
 | 
			
		||||
        from nn4k.invoker import NNInvoker
 | 
			
		||||
        from test_stub import StubInvoker
 | 
			
		||||
 | 
			
		||||
        nn_config = {"nn_name": "test_stub"}
 | 
			
		||||
        invoker = NNInvoker.from_config(nn_config)
 | 
			
		||||
        self.assertTrue(isinstance(invoker, StubInvoker))
 | 
			
		||||
        self.assertEqual(invoker.init_args, nn_config)
 | 
			
		||||
        self.assertEqual(invoker.kwargs, {"test_stub_invoker": True})
 | 
			
		||||
 | 
			
		||||
        invoker.warmup_local_model()
 | 
			
		||||
        invoker._nn_executor.inference_result = "inference result"
 | 
			
		||||
        result = invoker.local_inference("input")
 | 
			
		||||
        self.assertEqual(result, invoker._nn_executor.inference_result)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
							
								
								
									
										70
									
								
								python/nn4k/tests/invoker/test_openai_invoker.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										70
									
								
								python/nn4k/tests/invoker/test_openai_invoker.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,70 @@
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
 | 
			
		||||
import sys
 | 
			
		||||
import unittest
 | 
			
		||||
import unittest.mock
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
 | 
			
		||||
from nn4k.invoker import NNInvoker
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class MockCompletion:
 | 
			
		||||
    choices: list
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class MockChoice:
 | 
			
		||||
    text: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestOpenAIInvoker(unittest.TestCase):
 | 
			
		||||
    """
 | 
			
		||||
    OpenAIInvoker unittest
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        self._saved_openai = sys.modules.get("openai")
 | 
			
		||||
        self._mocked_openai = unittest.mock.MagicMock()
 | 
			
		||||
        sys.modules["openai"] = self._mocked_openai
 | 
			
		||||
 | 
			
		||||
    def tearDown(self):
 | 
			
		||||
        del sys.modules["openai"]
 | 
			
		||||
        if self._saved_openai is not None:
 | 
			
		||||
            sys.modules["openai"] = self._saved_openai
 | 
			
		||||
 | 
			
		||||
    def testOpenAIInvoker(self):
 | 
			
		||||
        nn_config = {
 | 
			
		||||
            "nn_name": "gpt-3.5-turbo",
 | 
			
		||||
            "openai_api_key": "EMPTY",
 | 
			
		||||
            "openai_api_base": "http://localhost:38080/v1",
 | 
			
		||||
            "openai_max_tokens": 2000,
 | 
			
		||||
        }
 | 
			
		||||
        invoker = NNInvoker.from_config(nn_config)
 | 
			
		||||
        self.assertEqual(invoker.init_args, nn_config)
 | 
			
		||||
        self.assertEqual(self._mocked_openai.api_key, nn_config["openai_api_key"])
 | 
			
		||||
        self.assertEqual(self._mocked_openai.api_base, nn_config["openai_api_base"])
 | 
			
		||||
 | 
			
		||||
        mock_completion = MockCompletion(choices=[MockChoice("a dog named Bolt ...")])
 | 
			
		||||
        self._mocked_openai.Completion.create.return_value = mock_completion
 | 
			
		||||
 | 
			
		||||
        result = invoker.remote_inference("Long long ago, ")
 | 
			
		||||
        self._mocked_openai.Completion.create.assert_called_with(
 | 
			
		||||
            prompt=["Long long ago, "],
 | 
			
		||||
            model=nn_config["nn_name"],
 | 
			
		||||
            max_tokens=nn_config["openai_max_tokens"],
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(result, [mock_completion.choices[0].text])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
							
								
								
									
										70
									
								
								python/nn4k/tests/invoker/test_stub.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										70
									
								
								python/nn4k/tests/invoker/test_stub.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,70 @@
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from nn4k.invoker import NNInvoker, LLMInvoker
 | 
			
		||||
from nn4k.executor import NNExecutor
 | 
			
		||||
from nn4k.nnhub import SimpleNNHub
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class StubInvoker(LLMInvoker):
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_config(cls, nn_config: dict) -> "StubInvoker":
 | 
			
		||||
        """
 | 
			
		||||
        Create a StubInvoker instance from `nn_config`.
 | 
			
		||||
        """
 | 
			
		||||
        invoker = cls(nn_config)
 | 
			
		||||
        return invoker
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NotInvoker:
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class StubExecutor(NNExecutor):
 | 
			
		||||
    def load_model(self, args=None, mode=None, **kwargs):
 | 
			
		||||
        self.load_model_called = True
 | 
			
		||||
 | 
			
		||||
    def warmup_inference(self, args=None, **kwargs):
 | 
			
		||||
        self.warmup_inference_called = True
 | 
			
		||||
 | 
			
		||||
    def inference(self, data, args=None, **kwargs):
 | 
			
		||||
        return self.inference_result
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_config(cls, nn_config: dict) -> "StubExecutor":
 | 
			
		||||
        """
 | 
			
		||||
        Create a StubExecutor instance from `nn_config`.
 | 
			
		||||
        """
 | 
			
		||||
        executor = cls(nn_config)
 | 
			
		||||
        return executor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class StubHub(SimpleNNHub):
 | 
			
		||||
    def get_invoker(self, nn_config: dict) -> Optional[NNInvoker]:
 | 
			
		||||
        nn_name = nn_config.get("nn_name")
 | 
			
		||||
        if nn_name is not None and nn_name == "test_stub":
 | 
			
		||||
            invoker = StubInvoker(nn_config, test_stub_invoker=True)
 | 
			
		||||
            return invoker
 | 
			
		||||
        return super().get_invoker(nn_config)
 | 
			
		||||
 | 
			
		||||
    def get_model_executor(
 | 
			
		||||
        self, name: str, version: str = None
 | 
			
		||||
    ) -> Optional[NNExecutor]:
 | 
			
		||||
        if name == "test_stub":
 | 
			
		||||
            if version is None:
 | 
			
		||||
                version = "default"
 | 
			
		||||
            executor = StubExecutor(
 | 
			
		||||
                {"nn_name": name, "nn_version": version}, test_stub_executor=True
 | 
			
		||||
            )
 | 
			
		||||
            return executor
 | 
			
		||||
        return super().get_model_executor(name, version)
 | 
			
		||||
							
								
								
									
										10
									
								
								python/nn4k/tests/nnhub/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								python/nn4k/tests/nnhub/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,10 @@
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
							
								
								
									
										34
									
								
								python/nn4k/tests/nnhub/test_base_nnhub.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								python/nn4k/tests/nnhub/test_base_nnhub.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,34 @@
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestBaseNNHub(unittest.TestCase):
 | 
			
		||||
    """
 | 
			
		||||
    NNHub and SimpleNNHub unittest
 | 
			
		||||
 | 
			
		||||
    The interface and implementation of NNHub may be revised later,
 | 
			
		||||
    then unittests will be added.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def tearDown(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def testBaseNNHub(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
							
								
								
									
										10
									
								
								python/nn4k/tests/utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								python/nn4k/tests/utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,10 @@
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
							
								
								
									
										58
									
								
								python/nn4k/tests/utils/test_class_importing.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								python/nn4k/tests/utils/test_class_importing.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,58 @@
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestClassImporting(unittest.TestCase):
 | 
			
		||||
    """
 | 
			
		||||
    module nn4k.utils.class_importing unittest
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def testSplitModuleClassName(self):
 | 
			
		||||
        from nn4k.utils.class_importing import split_module_class_name
 | 
			
		||||
 | 
			
		||||
        pair = split_module_class_name("foo.bar.Baz", "test")
 | 
			
		||||
        self.assertEqual(pair, ("foo.bar", "Baz"))
 | 
			
		||||
 | 
			
		||||
    def testSplitModuleClassNameInvalid(self):
 | 
			
		||||
        from nn4k.utils.class_importing import split_module_class_name
 | 
			
		||||
 | 
			
		||||
        with self.assertRaises(RuntimeError):
 | 
			
		||||
            pair = split_module_class_name("foo", "test")
 | 
			
		||||
 | 
			
		||||
    def testDynamicImportClass(self):
 | 
			
		||||
        from nn4k.utils.class_importing import dynamic_import_class
 | 
			
		||||
 | 
			
		||||
        class_ = dynamic_import_class("unittest.TestCase", "test")
 | 
			
		||||
        self.assertEqual(class_, unittest.TestCase)
 | 
			
		||||
 | 
			
		||||
    def testDynamicImportClassModuleNotFound(self):
 | 
			
		||||
        from nn4k.utils.class_importing import dynamic_import_class
 | 
			
		||||
 | 
			
		||||
        with self.assertRaises(ModuleNotFoundError):
 | 
			
		||||
            class_ = dynamic_import_class("not_exists.ClassName", "test")
 | 
			
		||||
 | 
			
		||||
    def testDynamicImportClassClassNotFound(self):
 | 
			
		||||
        from nn4k.utils.class_importing import dynamic_import_class
 | 
			
		||||
 | 
			
		||||
        with self.assertRaises(RuntimeError):
 | 
			
		||||
            class_ = dynamic_import_class("unittest.NotExists", "test")
 | 
			
		||||
 | 
			
		||||
    def testDynamicImportClassNotClass(self):
 | 
			
		||||
        from nn4k.utils.class_importing import dynamic_import_class
 | 
			
		||||
 | 
			
		||||
        with self.assertRaises(RuntimeError):
 | 
			
		||||
            class_ = dynamic_import_class("unittest.mock", "test")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
							
								
								
									
										3
									
								
								python/nn4k/tests/utils/test_config.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								python/nn4k/tests/utils/test_config.json
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,3 @@
 | 
			
		||||
{
 | 
			
		||||
    "foo": "bar"
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										103
									
								
								python/nn4k/tests/utils/test_config_parsing.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										103
									
								
								python/nn4k/tests/utils/test_config_parsing.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,103 @@
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestConfigParsing(unittest.TestCase):
 | 
			
		||||
    """
 | 
			
		||||
    module nn4k.utils.config_parsing unittest
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def testPreprocessConfigFile(self):
 | 
			
		||||
        import os
 | 
			
		||||
        from nn4k.utils.config_parsing import preprocess_config
 | 
			
		||||
 | 
			
		||||
        dir_path = os.path.dirname(os.path.abspath(__file__))
 | 
			
		||||
        file_path = os.path.join(dir_path, "test_config.json")
 | 
			
		||||
        nn_config = preprocess_config(file_path)
 | 
			
		||||
        self.assertEqual(nn_config, {"foo": "bar"})
 | 
			
		||||
 | 
			
		||||
    def testPreprocessConfigFileNotExists(self):
 | 
			
		||||
        import os
 | 
			
		||||
        from nn4k.utils.config_parsing import preprocess_config
 | 
			
		||||
 | 
			
		||||
        dir_path = os.path.dirname(os.path.abspath(__file__))
 | 
			
		||||
        file_path = os.path.join(dir_path, "not_exists.json")
 | 
			
		||||
        with self.assertRaises(ValueError):
 | 
			
		||||
            nn_config = preprocess_config(file_path)
 | 
			
		||||
 | 
			
		||||
    def testPreprocessConfigDict(self):
 | 
			
		||||
        from nn4k.utils.config_parsing import preprocess_config
 | 
			
		||||
 | 
			
		||||
        conf = {"foo": "bar"}
 | 
			
		||||
        nn_config = preprocess_config(conf)
 | 
			
		||||
        self.assertEqual(nn_config, conf)
 | 
			
		||||
 | 
			
		||||
    def testGetField(self):
 | 
			
		||||
        from nn4k.utils.config_parsing import get_field
 | 
			
		||||
 | 
			
		||||
        nn_config = {"foo": "bar"}
 | 
			
		||||
        value = get_field(nn_config, "foo", "Foo")
 | 
			
		||||
        self.assertEqual(value, "bar")
 | 
			
		||||
 | 
			
		||||
    def testGetFieldNotExists(self):
 | 
			
		||||
        from nn4k.utils.config_parsing import get_field
 | 
			
		||||
 | 
			
		||||
        nn_config = {"foo": "bar"}
 | 
			
		||||
        with self.assertRaises(ValueError):
 | 
			
		||||
            value = get_field(nn_config, "not_exists", "not exists")
 | 
			
		||||
 | 
			
		||||
    def testGetStringField(self):
 | 
			
		||||
        from nn4k.utils.config_parsing import get_string_field
 | 
			
		||||
 | 
			
		||||
        nn_config = {"foo": "bar"}
 | 
			
		||||
        value = get_string_field(nn_config, "foo", "Foo")
 | 
			
		||||
        self.assertEqual(value, "bar")
 | 
			
		||||
 | 
			
		||||
    def testGetStringFieldNotString(self):
 | 
			
		||||
        from nn4k.utils.config_parsing import get_string_field
 | 
			
		||||
 | 
			
		||||
        nn_config = {"foo": "bar", "baz": True}
 | 
			
		||||
        with self.assertRaises(TypeError):
 | 
			
		||||
            value = get_string_field(nn_config, "baz", "Baz")
 | 
			
		||||
 | 
			
		||||
    def testGetIntField(self):
 | 
			
		||||
        from nn4k.utils.config_parsing import get_int_field
 | 
			
		||||
 | 
			
		||||
        nn_config = {"foo": "bar", "baz": 1000}
 | 
			
		||||
        value = get_int_field(nn_config, "baz", "Baz")
 | 
			
		||||
        self.assertEqual(value, 1000)
 | 
			
		||||
 | 
			
		||||
    def testGetIntFieldNotInteger(self):
 | 
			
		||||
        from nn4k.utils.config_parsing import get_int_field
 | 
			
		||||
 | 
			
		||||
        nn_config = {"foo": "bar", "baz": "quux"}
 | 
			
		||||
        with self.assertRaises(TypeError):
 | 
			
		||||
            value = get_int_field(nn_config, "baz", "Baz")
 | 
			
		||||
 | 
			
		||||
    def testGetPositiveIntField(self):
 | 
			
		||||
        from nn4k.utils.config_parsing import get_positive_int_field
 | 
			
		||||
 | 
			
		||||
        nn_config = {"foo": "bar", "baz": 1000}
 | 
			
		||||
        value = get_positive_int_field(nn_config, "baz", "Baz")
 | 
			
		||||
        self.assertEqual(value, 1000)
 | 
			
		||||
 | 
			
		||||
    def testGetPositiveIntFieldNotPositive(self):
 | 
			
		||||
        from nn4k.utils.config_parsing import get_positive_int_field
 | 
			
		||||
 | 
			
		||||
        nn_config = {"foo": "bar", "baz": 0}
 | 
			
		||||
        with self.assertRaises(ValueError):
 | 
			
		||||
            value = get_positive_int_field(nn_config, "baz", "Baz")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
							
								
								
									
										49
									
								
								python/nn4k/tests/utils/test_invoker_checking.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								python/nn4k/tests/utils/test_invoker_checking.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,49 @@
 | 
			
		||||
# Copyright 2023 Ant Group CO., Ltd.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 | 
			
		||||
# in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
# http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
 | 
			
		||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 | 
			
		||||
# or implied.
 | 
			
		||||
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestInvokerChecking(unittest.TestCase):
 | 
			
		||||
    """
 | 
			
		||||
    module nn4k.utils.invoker_checking unittest
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def testIsOpenAIInvoker(self):
 | 
			
		||||
        from nn4k.utils.invoker_checking import is_openai_invoker
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(is_openai_invoker({"nn_name": "gpt-3.5-turbo"}))
 | 
			
		||||
        self.assertTrue(is_openai_invoker({"nn_name": "gpt-4"}))
 | 
			
		||||
        self.assertFalse(is_openai_invoker({"nn_name": "dummy"}))
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(is_openai_invoker({"openai_api_key": "EMPTY"}))
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            is_openai_invoker({"openai_api_base": "http://localhost:38000/v1"})
 | 
			
		||||
        )
 | 
			
		||||
        self.assertTrue(is_openai_invoker({"openai_max_tokens": 1000}))
 | 
			
		||||
        self.assertFalse(is_openai_invoker({"foo": "bar"}))
 | 
			
		||||
 | 
			
		||||
    def testIsLocalInvoker(self):
 | 
			
		||||
        import os
 | 
			
		||||
        from nn4k.utils.invoker_checking import is_local_invoker
 | 
			
		||||
 | 
			
		||||
        dir_path = os.path.dirname(os.path.abspath(__file__))
 | 
			
		||||
        self.assertFalse(is_local_invoker({"nn_name": dir_path}))
 | 
			
		||||
 | 
			
		||||
        model_dir_path = os.path.join(dir_path, "test_model_dir")
 | 
			
		||||
        self.assertTrue(is_local_invoker({"nn_name": model_dir_path}))
 | 
			
		||||
 | 
			
		||||
        self.assertFalse(is_local_invoker({"nn_name": "/not_exists"}))
 | 
			
		||||
        self.assertFalse(is_local_invoker({"foo": "bar"}))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
							
								
								
									
										1
									
								
								python/nn4k/tests/utils/test_model_dir/config.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								python/nn4k/tests/utils/test_model_dir/config.json
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1 @@
 | 
			
		||||
{}
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user