mirror of
https://github.com/OpenSPG/openspg.git
synced 2025-08-31 12:25:36 +00:00
feat(nn4k): implement text embeddings (#104)
This commit is contained in:
parent
1b3b78b02e
commit
945cf8fbbd
@ -1 +1 @@
|
||||
0.0.2-beta2
|
||||
0.0.2-beta3
|
||||
|
@ -11,4 +11,4 @@
|
||||
|
||||
|
||||
__package_name__ = "openspg-nn4k"
|
||||
__version__ = "0.0.2-beta2"
|
||||
__version__ = "0.0.2-beta3"
|
||||
|
@ -37,7 +37,12 @@ NN_OPENAI_API_BASE_TEXT = "openai api base"
|
||||
NN_OPENAI_MAX_TOKENS_KEY = "openai_max_tokens"
|
||||
NN_OPENAI_MAX_TOKENS_TEXT = "openai max tokens"
|
||||
|
||||
NN_OPENAI_ORGANIZATION_KEY = "openai_organization"
|
||||
NN_OPENAI_ORGANIZATION_TEXT = "openai organization"
|
||||
|
||||
NN_OPENAI_GPT4_PREFIX = "gpt-4"
|
||||
NN_OPENAI_GPT35_PREFIX = "gpt-3.5"
|
||||
NN_OPENAI_EMBEDDING_PREFIX = "text-embedding"
|
||||
|
||||
NN_LOCAL_HF_MODEL_CONFIG_FILE = "config.json"
|
||||
NN_LOCAL_SENTENCE_TRANSFORMERS_CONFIG_FILE = "config_sentence_transformers.json"
|
||||
|
29
python/nn4k/nn4k/examples/openai-embeddings/README.md
Normal file
29
python/nn4k/nn4k/examples/openai-embeddings/README.md
Normal file
@ -0,0 +1,29 @@
|
||||
# NN4K example: text embedding with OpenAI
|
||||
|
||||
## Install dependencies
|
||||
|
||||
```bash
|
||||
python3 -m venv .env
|
||||
source .env/bin/activate
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install openspg-nn4k
|
||||
```
|
||||
|
||||
## Edit configurations
|
||||
|
||||
Edit configurations in [openai_emb.json](./openai_emb.json).
|
||||
|
||||
* Set ``openai_api_base`` to an OpenAI api compatible base url.
|
||||
To invoke the official OpenAI api service, set this field to
|
||||
``https://api.openai.com/v1``.
|
||||
|
||||
* Set ``openai_api_key`` to a valid OpenAI api key.
|
||||
For the official OpenAI api service, you can find your api key
|
||||
``sk-xxx`` in the [API keys](https://platform.openai.com/api-keys) page.
|
||||
|
||||
## Run the example
|
||||
|
||||
```bash
|
||||
python openai_emb.py
|
||||
```
|
||||
|
@ -0,0 +1,5 @@
|
||||
{
|
||||
"nn_name": "text-embedding-ada-002",
|
||||
"openai_api_key": "EMPTY",
|
||||
"openai_api_base": "http://127.0.0.1:38080/v1"
|
||||
}
|
19
python/nn4k/nn4k/examples/openai-embeddings/openai_emb.py
Normal file
19
python/nn4k/nn4k/examples/openai-embeddings/openai_emb.py
Normal file
@ -0,0 +1,19 @@
|
||||
# Copyright 2023 OpenSPG Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
from nn4k.invoker import NNInvoker
|
||||
|
||||
invoker = NNInvoker.from_config("openai_emb.json")
|
||||
vecs = invoker.remote_inference(
|
||||
["How old are you?", "What is your age?"], type="Embedding"
|
||||
)
|
||||
similarity = sum(x * y for x, y in zip(*vecs))
|
||||
print("similarity: %g" % similarity)
|
29
python/nn4k/nn4k/examples/openai/README.md
Normal file
29
python/nn4k/nn4k/examples/openai/README.md
Normal file
@ -0,0 +1,29 @@
|
||||
# NN4K example: inference with OpenAI
|
||||
|
||||
## Install dependencies
|
||||
|
||||
```bash
|
||||
python3 -m venv .env
|
||||
source .env/bin/activate
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install openspg-nn4k
|
||||
```
|
||||
|
||||
## Edit configurations
|
||||
|
||||
Edit configurations in [openai_infer.json](./openai_infer.json).
|
||||
|
||||
* Set ``openai_api_base`` to an OpenAI api compatible base url.
|
||||
To invoke the official OpenAI api service, set this field to
|
||||
``https://api.openai.com/v1``.
|
||||
|
||||
* Set ``openai_api_key`` to a valid OpenAI api key.
|
||||
For the official OpenAI api service, you can find your api key
|
||||
``sk-xxx`` in the [API keys](https://platform.openai.com/api-keys) page.
|
||||
|
||||
## Run the example
|
||||
|
||||
```bash
|
||||
python openai_infer.py
|
||||
```
|
||||
|
6
python/nn4k/nn4k/examples/openai/openai_infer.json
Normal file
6
python/nn4k/nn4k/examples/openai/openai_infer.json
Normal file
@ -0,0 +1,6 @@
|
||||
{
|
||||
"nn_name": "gpt-3.5-turbo",
|
||||
"openai_api_key": "EMPTY",
|
||||
"openai_api_base": "http://127.0.0.1:38080/v1",
|
||||
"openai_max_tokens": 64
|
||||
}
|
16
python/nn4k/nn4k/examples/openai/openai_infer.py
Normal file
16
python/nn4k/nn4k/examples/openai/openai_infer.py
Normal file
@ -0,0 +1,16 @@
|
||||
# Copyright 2023 OpenSPG Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
from nn4k.invoker import NNInvoker
|
||||
|
||||
invoker = NNInvoker.from_config("openai_infer.json")
|
||||
output = invoker.remote_inference("Say this is a test.")
|
||||
print(output)
|
@ -13,11 +13,11 @@ from typing import Union
|
||||
from nn4k.executor import LLMExecutor
|
||||
|
||||
|
||||
class HfLLMExecutor(LLMExecutor):
|
||||
class HFLLMExecutor(LLMExecutor):
|
||||
@classmethod
|
||||
def from_config(cls, nn_config: dict) -> "HfLLMExecutor":
|
||||
def from_config(cls, nn_config: dict) -> "HFLLMExecutor":
|
||||
"""
|
||||
Create an HfLLMExecutor instance from `nn_config`.
|
||||
Create an HFLLMExecutor instance from `nn_config`.
|
||||
"""
|
||||
executor = cls(nn_config)
|
||||
return executor
|
||||
@ -102,3 +102,51 @@ class HfLLMExecutor(LLMExecutor):
|
||||
for idx, output_id in enumerate(output_ids)
|
||||
]
|
||||
return outputs
|
||||
|
||||
|
||||
class HFEmbeddingExecutor(LLMExecutor):
|
||||
@classmethod
|
||||
def from_config(cls, nn_config: dict) -> "HFEmbeddingExecutor":
|
||||
"""
|
||||
Create an HFEmbeddingExecutor instance from `nn_config`.
|
||||
"""
|
||||
executor = cls(nn_config)
|
||||
return executor
|
||||
|
||||
def load_model(self, args=None, **kwargs):
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
|
||||
from nn4k.consts import NN_DEVICE_KEY
|
||||
from nn4k.utils.config_parsing import get_string_field
|
||||
|
||||
nn_config: dict = args or self.init_args
|
||||
if self._model is None:
|
||||
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
|
||||
nn_version = nn_config.get(NN_VERSION_KEY)
|
||||
if nn_version is not None:
|
||||
nn_version = get_string_field(
|
||||
nn_config, NN_VERSION_KEY, NN_VERSION_TEXT
|
||||
)
|
||||
model_path = nn_name
|
||||
revision = nn_version
|
||||
use_fast_tokenizer = False
|
||||
device = nn_config.get(NN_DEVICE_KEY)
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
#
|
||||
# SentenceTransformer will support `revision` soon. See:
|
||||
#
|
||||
# https://github.com/UKPLab/sentence-transformers/pull/2419
|
||||
#
|
||||
model = SentenceTransformer(
|
||||
model_path,
|
||||
device=device,
|
||||
)
|
||||
self._model = model
|
||||
|
||||
def inference(self, data, args=None, **kwargs):
|
||||
model = self.model
|
||||
embeddings = model.encode(data)
|
||||
return embeddings
|
||||
|
@ -13,7 +13,7 @@ from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Union
|
||||
|
||||
from nn4k.executor import LLMExecutor
|
||||
from nn4k.executor import NNExecutor
|
||||
|
||||
|
||||
class SubmitMode(Enum):
|
||||
@ -157,23 +157,36 @@ class LLMInvoker(NNInvoker):
|
||||
Implement local model warming up logic for local invoker.
|
||||
"""
|
||||
from nn4k.nnhub import NNHub
|
||||
from nn4k.consts import NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT
|
||||
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
|
||||
from nn4k.utils.config_parsing import get_string_field
|
||||
from nn4k.utils.class_importing import dynamic_import_class
|
||||
|
||||
nn_name = get_string_field(self.init_args, NN_NAME_KEY, NN_NAME_TEXT)
|
||||
nn_version = self.init_args.get(NN_VERSION_KEY)
|
||||
if nn_version is not None:
|
||||
nn_version = get_string_field(
|
||||
self.init_args, NN_VERSION_KEY, NN_VERSION_TEXT
|
||||
nn_executor = self.init_args.get(NN_EXECUTOR_KEY)
|
||||
if nn_executor is not None:
|
||||
nn_executor = get_string_field(
|
||||
self.init_args, NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT
|
||||
)
|
||||
hub = NNHub.get_instance()
|
||||
executor = hub.get_model_executor(nn_name, nn_version)
|
||||
if executor is None:
|
||||
message = "model %r version %r " % (nn_name, nn_version)
|
||||
message += "is not found in the model hub"
|
||||
raise RuntimeError(message)
|
||||
self._nn_executor: LLMExecutor = executor
|
||||
executor_class = dynamic_import_class(nn_executor, NN_EXECUTOR_TEXT)
|
||||
if not issubclass(executor_class, NNExecutor):
|
||||
message = "%r is not an %s class" % (nn_executor, NN_EXECUTOR_TEXT)
|
||||
raise RuntimeError(message)
|
||||
executor = executor_class.from_config(self.init_args)
|
||||
else:
|
||||
nn_name = get_string_field(self.init_args, NN_NAME_KEY, NN_NAME_TEXT)
|
||||
nn_version = self.init_args.get(NN_VERSION_KEY)
|
||||
if nn_version is not None:
|
||||
nn_version = get_string_field(
|
||||
self.init_args, NN_VERSION_KEY, NN_VERSION_TEXT
|
||||
)
|
||||
hub = NNHub.get_instance()
|
||||
executor = hub.get_model_executor(nn_name, nn_version)
|
||||
if executor is None:
|
||||
message = "model %r version %r " % (nn_name, nn_version)
|
||||
message += "is not found in the model hub"
|
||||
raise RuntimeError(message)
|
||||
self._nn_executor: NNExecutor = executor
|
||||
self._nn_executor.load_model()
|
||||
self._nn_executor.warmup_inference()
|
||||
|
||||
|
@ -23,6 +23,7 @@ class OpenAIInvoker(NNInvoker):
|
||||
from nn4k.consts import NN_OPENAI_API_KEY_KEY, NN_OPENAI_API_KEY_TEXT
|
||||
from nn4k.consts import NN_OPENAI_API_BASE_KEY, NN_OPENAI_API_BASE_TEXT
|
||||
from nn4k.consts import NN_OPENAI_MAX_TOKENS_KEY, NN_OPENAI_MAX_TOKENS_TEXT
|
||||
from nn4k.consts import NN_OPENAI_ORGANIZATION_KEY, NN_OPENAI_ORGANIZATION_TEXT
|
||||
from nn4k.utils.config_parsing import get_string_field
|
||||
from nn4k.utils.config_parsing import get_positive_int_field
|
||||
|
||||
@ -35,41 +36,107 @@ class OpenAIInvoker(NNInvoker):
|
||||
self.openai_api_base = get_string_field(
|
||||
self.init_args, NN_OPENAI_API_BASE_KEY, NN_OPENAI_API_BASE_TEXT
|
||||
)
|
||||
self.openai_max_tokens = get_positive_int_field(
|
||||
self.init_args, NN_OPENAI_MAX_TOKENS_KEY, NN_OPENAI_MAX_TOKENS_TEXT
|
||||
)
|
||||
self.openai_max_tokens = self.init_args.get(NN_OPENAI_MAX_TOKENS_KEY)
|
||||
if self.openai_max_tokens is not None:
|
||||
self.openai_max_tokens = get_positive_int_field(
|
||||
self.init_args, NN_OPENAI_MAX_TOKENS_KEY, NN_OPENAI_MAX_TOKENS_TEXT
|
||||
)
|
||||
self.openai_organization = self.init_args.get(NN_OPENAI_ORGANIZATION_KEY)
|
||||
if self.openai_organization is not None:
|
||||
self.openai_organization = get_string_field(
|
||||
self.init_args, NN_OPENAI_ORGANIZATION_KEY, NN_OPENAI_ORGANIZATION_TEXT
|
||||
)
|
||||
|
||||
openai.api_key = self.openai_api_key
|
||||
openai.api_base = self.openai_api_base
|
||||
if self._is_legacy_openai_api:
|
||||
self.client = None
|
||||
openai.api_key = self.openai_api_key
|
||||
openai.api_base = self.openai_api_base
|
||||
openai.organization = self.openai_organization
|
||||
else:
|
||||
self.client = openai.OpenAI(
|
||||
api_key=self.openai_api_key,
|
||||
base_url=self.openai_api_base,
|
||||
organization=self.openai_organization,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, nn_config: dict) -> "OpenAIInvoker":
|
||||
invoker = cls(nn_config)
|
||||
return invoker
|
||||
|
||||
@property
|
||||
def _is_legacy_openai_api(self):
|
||||
import openai
|
||||
|
||||
return openai.__version__.startswith("0.")
|
||||
|
||||
def _create_prompt(self, input, **kwargs):
|
||||
if isinstance(input, list):
|
||||
prompt = input
|
||||
else:
|
||||
prompt = [input]
|
||||
prompt = [{"role": "user", "content": input}]
|
||||
return prompt
|
||||
|
||||
def _create_output(self, input, prompt, completion, **kwargs):
|
||||
output = [choice.text for choice in completion.choices]
|
||||
return output
|
||||
|
||||
def remote_inference(
|
||||
self, input, max_output_length: Optional[int] = None, **kwargs
|
||||
):
|
||||
def _create_completion(self, input, prompt, max_output_length, **kwargs):
|
||||
import openai
|
||||
|
||||
if max_output_length is None:
|
||||
max_output_length = self.openai_max_tokens
|
||||
if self._is_legacy_openai_api:
|
||||
completion = openai.ChatCompletion.create(
|
||||
model=self.openai_model_name,
|
||||
messages=prompt,
|
||||
max_tokens=max_output_length,
|
||||
)
|
||||
else:
|
||||
completion = self.client.chat.completions.create(
|
||||
model=self.openai_model_name,
|
||||
messages=prompt,
|
||||
max_tokens=max_output_length,
|
||||
)
|
||||
return completion
|
||||
|
||||
def _create_output(self, input, prompt, completion, **kwargs):
|
||||
output = [choice.message.content for choice in completion.choices]
|
||||
return output
|
||||
|
||||
def _get_completion(self, input, max_output_length, **kwargs):
|
||||
prompt = self._create_prompt(input, **kwargs)
|
||||
completion = openai.Completion.create(
|
||||
model=self.openai_model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=max_output_length,
|
||||
)
|
||||
completion = self._create_completion(input, prompt, max_output_length, **kwargs)
|
||||
output = self._create_output(input, prompt, completion, **kwargs)
|
||||
return output
|
||||
|
||||
def _get_embeddings(self, input, **kwargs):
|
||||
import openai
|
||||
|
||||
if isinstance(input, list):
|
||||
inputs = input
|
||||
else:
|
||||
inputs = [input]
|
||||
|
||||
if self._is_legacy_openai_api:
|
||||
response = openai.Embedding.create(
|
||||
model=self.openai_model_name,
|
||||
input=inputs,
|
||||
)
|
||||
else:
|
||||
response = self.client.embeddings.create(
|
||||
model=self.openai_model_name,
|
||||
input=inputs,
|
||||
)
|
||||
|
||||
embeddings = [emb.embedding for emb in response.data]
|
||||
return embeddings
|
||||
|
||||
def remote_inference(
|
||||
self,
|
||||
input,
|
||||
type: Optional[str] = None,
|
||||
max_output_length: Optional[int] = None,
|
||||
**kwargs
|
||||
):
|
||||
if type == "Embedding":
|
||||
output = self._get_embeddings(input, **kwargs)
|
||||
else:
|
||||
if max_output_length is None:
|
||||
max_output_length = self.openai_max_tokens
|
||||
output = self._get_completion(input, max_output_length, **kwargs)
|
||||
return output
|
||||
|
@ -9,6 +9,8 @@
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
import os
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union, Tuple, Type
|
||||
|
||||
@ -125,6 +127,10 @@ class SimpleNNHub(NNHub):
|
||||
def get_model_executor(
|
||||
self, name: str, version: str = None
|
||||
) -> Optional[NNExecutor]:
|
||||
from nn4k.consts import NN_VERSION_DEFAULT
|
||||
|
||||
if version is None:
|
||||
version = NN_VERSION_DEFAULT
|
||||
if self._model_executors.get(name) is None:
|
||||
return None
|
||||
executor = self._model_executors.get(name).get(version)
|
||||
@ -134,13 +140,59 @@ class SimpleNNHub(NNHub):
|
||||
executor = self._create_model_executor(cls, init_args, kwargs, weights)
|
||||
return executor
|
||||
|
||||
def _add_local_executor(self, nn_config):
|
||||
def _get_local_executor_class(self, nn_config: dict) -> Type[NNExecutor]:
|
||||
from nn4k.consts import NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT
|
||||
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
|
||||
from nn4k.executor.hugging_face import HfLLMExecutor
|
||||
from nn4k.consts import NN_LOCAL_HF_MODEL_CONFIG_FILE
|
||||
from nn4k.consts import NN_LOCAL_SENTENCE_TRANSFORMERS_CONFIG_FILE
|
||||
from nn4k.executor.hugging_face import HFLLMExecutor
|
||||
from nn4k.executor.hugging_face import HFEmbeddingExecutor
|
||||
from nn4k.utils.config_parsing import get_string_field
|
||||
|
||||
executor = HfLLMExecutor.from_config(nn_config)
|
||||
nn_executor = nn_config.get(NN_EXECUTOR_KEY)
|
||||
if nn_executor is not None:
|
||||
nn_executor = get_string_field(
|
||||
self.init_args, NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT
|
||||
)
|
||||
executor_class = dynamic_import_class(nn_executor, NN_EXECUTOR_TEXT)
|
||||
if not issubclass(executor_class, NNExecutor):
|
||||
message = "%r is not an %s class" % (nn_executor, NN_EXECUTOR_TEXT)
|
||||
raise RuntimeError(message)
|
||||
return executor_class
|
||||
|
||||
nn_name = nn_config.get(NN_NAME_KEY)
|
||||
if nn_name is not None:
|
||||
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
|
||||
if os.path.isdir(nn_name):
|
||||
file_path = os.path.join(nn_name, NN_LOCAL_HF_MODEL_CONFIG_FILE)
|
||||
if os.path.isfile(file_path):
|
||||
file_path = os.path.join(
|
||||
nn_name, NN_LOCAL_SENTENCE_TRANSFORMERS_CONFIG_FILE
|
||||
)
|
||||
if os.path.isfile(file_path):
|
||||
executor_class = HFEmbeddingExecutor
|
||||
else:
|
||||
executor_class = HFLLMExecutor
|
||||
return executor_class
|
||||
|
||||
nn_version = nn_config.get(NN_VERSION_KEY)
|
||||
if nn_version is not None:
|
||||
nn_version = get_string_field(nn_config, NN_VERSION_KEY, NN_VERSION_TEXT)
|
||||
message = "can not determine local executor class for NN config"
|
||||
if nn_name is not None:
|
||||
message += "; model: %r" % nn_name
|
||||
if nn_version is not None:
|
||||
message += ", version: %r" % nn_version
|
||||
raise RuntimeError(message)
|
||||
|
||||
def _add_local_executor(self, nn_config: dict):
|
||||
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
|
||||
from nn4k.utils.config_parsing import get_string_field
|
||||
|
||||
executor_class = self._get_local_executor_class(nn_config)
|
||||
executor = executor_class.from_config(nn_config)
|
||||
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
|
||||
nn_version = nn_config.get(NN_VERSION_KEY)
|
||||
if nn_version is not None:
|
||||
|
@ -23,18 +23,27 @@ def is_openai_invoker(nn_config: dict) -> bool:
|
||||
from nn4k.consts import NN_OPENAI_API_KEY_KEY
|
||||
from nn4k.consts import NN_OPENAI_API_BASE_KEY
|
||||
from nn4k.consts import NN_OPENAI_MAX_TOKENS_KEY
|
||||
from nn4k.consts import NN_OPENAI_ORGANIZATION_KEY
|
||||
from nn4k.consts import NN_OPENAI_GPT4_PREFIX
|
||||
from nn4k.consts import NN_OPENAI_GPT35_PREFIX
|
||||
from nn4k.consts import NN_OPENAI_EMBEDDING_PREFIX
|
||||
from nn4k.utils.config_parsing import get_string_field
|
||||
|
||||
nn_name = nn_config.get(NN_NAME_KEY)
|
||||
if nn_name is not None:
|
||||
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
|
||||
if nn_name.startswith(NN_OPENAI_GPT4_PREFIX) or nn_name.startswith(
|
||||
NN_OPENAI_GPT35_PREFIX
|
||||
if (
|
||||
nn_name.startswith(NN_OPENAI_GPT4_PREFIX)
|
||||
or nn_name.startswith(NN_OPENAI_GPT35_PREFIX)
|
||||
or nn_name.startswith(NN_OPENAI_EMBEDDING_PREFIX)
|
||||
):
|
||||
return True
|
||||
keys = (NN_OPENAI_API_KEY_KEY, NN_OPENAI_API_BASE_KEY, NN_OPENAI_MAX_TOKENS_KEY)
|
||||
keys = (
|
||||
NN_OPENAI_API_KEY_KEY,
|
||||
NN_OPENAI_API_BASE_KEY,
|
||||
NN_OPENAI_MAX_TOKENS_KEY,
|
||||
NN_OPENAI_ORGANIZATION_KEY,
|
||||
)
|
||||
for key in keys:
|
||||
if key in nn_config:
|
||||
return True
|
||||
|
@ -1 +1 @@
|
||||
openai<1
|
||||
openai
|
||||
|
@ -42,7 +42,7 @@ class StubHub(SimpleNNHub):
|
||||
def get_model_executor(
|
||||
self, name: str, version: str = None
|
||||
) -> Optional[NNExecutor]:
|
||||
if name == "test_stub":
|
||||
if name == "executor_test_stub":
|
||||
if version is None:
|
||||
version = "default"
|
||||
executor = StubExecutor(
|
@ -20,12 +20,12 @@ class TestBaseExecutor(unittest.TestCase):
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
# for importing test_stub.py
|
||||
# for importing executor_test_stub.py
|
||||
dir_path = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, dir_path)
|
||||
|
||||
from nn4k.nnhub import NNHub
|
||||
from test_stub import StubHub
|
||||
from executor_test_stub import StubHub
|
||||
|
||||
NNHub._hub_instance = StubHub()
|
||||
|
||||
@ -37,22 +37,24 @@ class TestBaseExecutor(unittest.TestCase):
|
||||
|
||||
def testCustomNNExecutor(self):
|
||||
from nn4k.executor import NNExecutor
|
||||
from test_stub import StubExecutor
|
||||
from executor_test_stub import StubExecutor
|
||||
|
||||
nn_config = {"nn_executor": "test_stub.StubExecutor"}
|
||||
nn_config = {"nn_executor": "executor_test_stub.StubExecutor"}
|
||||
executor = NNExecutor.from_config(nn_config)
|
||||
self.assertTrue(isinstance(executor, StubExecutor))
|
||||
self.assertEqual(executor.init_args, nn_config)
|
||||
self.assertEqual(executor.kwargs, {})
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
executor = NNExecutor.from_config({"nn_executor": "test_stub.NotExecutor"})
|
||||
executor = NNExecutor.from_config(
|
||||
{"nn_executor": "executor_test_stub.NotExecutor"}
|
||||
)
|
||||
|
||||
def testHubExecutor(self):
|
||||
from nn4k.executor import NNExecutor
|
||||
from test_stub import StubExecutor
|
||||
from executor_test_stub import StubExecutor
|
||||
|
||||
nn_config = {"nn_name": "test_stub", "nn_version": "default"}
|
||||
nn_config = {"nn_name": "executor_test_stub", "nn_version": "default"}
|
||||
executor = NNExecutor.from_config(nn_config)
|
||||
self.assertTrue(isinstance(executor, StubExecutor))
|
||||
self.assertEqual(executor.init_args, nn_config)
|
||||
|
57
python/nn4k/tests/executor/test_hf_embedding_executor.py
Normal file
57
python/nn4k/tests/executor/test_hf_embedding_executor.py
Normal file
@ -0,0 +1,57 @@
|
||||
# Copyright 2023 OpenSPG Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
import unittest.mock
|
||||
|
||||
from nn4k.executor.hugging_face import HFEmbeddingExecutor
|
||||
|
||||
|
||||
class TestHFEmbeddingExecutor(unittest.TestCase):
|
||||
"""
|
||||
HFEmbeddingExecutor unittest
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self._saved_torch = sys.modules.get("torch")
|
||||
self._mocked_torch = unittest.mock.MagicMock()
|
||||
sys.modules["torch"] = self._mocked_torch
|
||||
|
||||
self._saved_sentence_transformers = sys.modules.get("sentence_transformers")
|
||||
self._mocked_sentence_transformers = unittest.mock.MagicMock()
|
||||
sys.modules["sentence_transformers"] = self._mocked_sentence_transformers
|
||||
|
||||
def tearDown(self):
|
||||
del sys.modules["torch"]
|
||||
if self._saved_torch is not None:
|
||||
sys.modules["torch"] = self._saved_torch
|
||||
|
||||
del sys.modules["sentence_transformers"]
|
||||
if self._saved_sentence_transformers is not None:
|
||||
sys.modules["sentence_transformers"] = self._saved_sentence_transformers
|
||||
|
||||
def testHFEmbeddingExecutor(self):
|
||||
nn_config = {
|
||||
"nn_name": "/opt/test_model_dir",
|
||||
"nn_version": "default",
|
||||
}
|
||||
|
||||
executor = HFEmbeddingExecutor.from_config(nn_config)
|
||||
executor.load_model()
|
||||
executor.inference("input")
|
||||
|
||||
self._mocked_sentence_transformers.SentenceTransformer.assert_called()
|
||||
executor.model.encode.assert_called()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -13,12 +13,12 @@ import sys
|
||||
import unittest
|
||||
import unittest.mock
|
||||
|
||||
from nn4k.executor.hugging_face import HfLLMExecutor
|
||||
from nn4k.executor.hugging_face import HFLLMExecutor
|
||||
|
||||
|
||||
class TestHfLLMExecutor(unittest.TestCase):
|
||||
class TestHFLLMExecutor(unittest.TestCase):
|
||||
"""
|
||||
HfLLMExecutor unittest
|
||||
HFLLMExecutor unittest
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
@ -39,18 +39,20 @@ class TestHfLLMExecutor(unittest.TestCase):
|
||||
if self._saved_transformers is not None:
|
||||
sys.modules["transformers"] = self._saved_transformers
|
||||
|
||||
def testHfLLMExecutor(self):
|
||||
def testHFLLMExecutor(self):
|
||||
nn_config = {
|
||||
"nn_name": "/opt/test_model_dir",
|
||||
"nn_version": "default",
|
||||
}
|
||||
|
||||
executor = HfLLMExecutor.from_config(nn_config)
|
||||
executor = HFLLMExecutor.from_config(nn_config)
|
||||
executor.load_model()
|
||||
executor.inference("input")
|
||||
|
||||
self._mocked_transformers.AutoTokenizer.from_pretrained.assert_called()
|
||||
self._mocked_transformers.AutoModelForCausalLM.from_pretrained.assert_called()
|
||||
executor.tokenizer.assert_called()
|
||||
executor.model.generate.assert_called()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -52,7 +52,7 @@ class StubExecutor(NNExecutor):
|
||||
class StubHub(SimpleNNHub):
|
||||
def get_invoker(self, nn_config: dict) -> Optional[NNInvoker]:
|
||||
nn_name = nn_config.get("nn_name")
|
||||
if nn_name is not None and nn_name == "test_stub":
|
||||
if nn_name is not None and nn_name == "invoker_test_stub":
|
||||
invoker = StubInvoker(nn_config, test_stub_invoker=True)
|
||||
return invoker
|
||||
return super().get_invoker(nn_config)
|
||||
@ -60,7 +60,7 @@ class StubHub(SimpleNNHub):
|
||||
def get_model_executor(
|
||||
self, name: str, version: str = None
|
||||
) -> Optional[NNExecutor]:
|
||||
if name == "test_stub":
|
||||
if name == "invoker_test_stub":
|
||||
if version is None:
|
||||
version = "default"
|
||||
executor = StubExecutor(
|
@ -20,12 +20,12 @@ class TestBaseInvoker(unittest.TestCase):
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
# for importing test_stub.py
|
||||
# for importing invoker_test_stub.py
|
||||
dir_path = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, dir_path)
|
||||
|
||||
from nn4k.nnhub import NNHub
|
||||
from test_stub import StubHub
|
||||
from invoker_test_stub import StubHub
|
||||
|
||||
NNHub._hub_instance = StubHub()
|
||||
|
||||
@ -37,22 +37,24 @@ class TestBaseInvoker(unittest.TestCase):
|
||||
|
||||
def testCustomNNInvoker(self):
|
||||
from nn4k.invoker import NNInvoker
|
||||
from test_stub import StubInvoker
|
||||
from invoker_test_stub import StubInvoker
|
||||
|
||||
nn_config = {"nn_invoker": "test_stub.StubInvoker"}
|
||||
nn_config = {"nn_invoker": "invoker_test_stub.StubInvoker"}
|
||||
invoker = NNInvoker.from_config(nn_config)
|
||||
self.assertTrue(isinstance(invoker, StubInvoker))
|
||||
self.assertEqual(invoker.init_args, nn_config)
|
||||
self.assertEqual(invoker.kwargs, {})
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
invoker = NNInvoker.from_config({"nn_invoker": "test_stub.NotInvoker"})
|
||||
invoker = NNInvoker.from_config(
|
||||
{"nn_invoker": "invoker_test_stub.NotInvoker"}
|
||||
)
|
||||
|
||||
def testHubInvoker(self):
|
||||
from nn4k.invoker import NNInvoker
|
||||
from test_stub import StubInvoker
|
||||
from invoker_test_stub import StubInvoker
|
||||
|
||||
nn_config = {"nn_name": "test_stub"}
|
||||
nn_config = {"nn_name": "invoker_test_stub"}
|
||||
invoker = NNInvoker.from_config(nn_config)
|
||||
self.assertTrue(isinstance(invoker, StubInvoker))
|
||||
self.assertEqual(invoker.init_args, nn_config)
|
||||
@ -66,9 +68,9 @@ class TestBaseInvoker(unittest.TestCase):
|
||||
|
||||
def testLocalInvoker(self):
|
||||
from nn4k.invoker import NNInvoker
|
||||
from test_stub import StubInvoker
|
||||
from invoker_test_stub import StubInvoker
|
||||
|
||||
nn_config = {"nn_name": "test_stub"}
|
||||
nn_config = {"nn_name": "invoker_test_stub"}
|
||||
invoker = NNInvoker.from_config(nn_config)
|
||||
self.assertTrue(isinstance(invoker, StubInvoker))
|
||||
self.assertEqual(invoker.init_args, nn_config)
|
||||
@ -79,6 +81,19 @@ class TestBaseInvoker(unittest.TestCase):
|
||||
result = invoker.local_inference("input")
|
||||
self.assertEqual(result, invoker._nn_executor.inference_result)
|
||||
|
||||
def testLocalLLMInvokerWithCustomExecutor(self):
|
||||
from nn4k.invoker import LLMInvoker
|
||||
|
||||
nn_config = {"nn_executor": "invoker_test_stub.StubExecutor"}
|
||||
invoker = LLMInvoker.from_config(nn_config)
|
||||
self.assertTrue(isinstance(invoker, LLMInvoker))
|
||||
self.assertEqual(invoker.init_args, nn_config)
|
||||
|
||||
invoker.warmup_local_model()
|
||||
invoker._nn_executor.inference_result = "inference result"
|
||||
result = invoker.local_inference("input")
|
||||
self.assertEqual(result, invoker._nn_executor.inference_result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -22,9 +22,24 @@ class MockCompletion:
|
||||
choices: list
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockMessage:
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockChoice:
|
||||
text: str
|
||||
message: MockMessage
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockEmbeddings:
|
||||
data: list
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockEmbedding:
|
||||
embedding: list
|
||||
|
||||
|
||||
class TestOpenAIInvoker(unittest.TestCase):
|
||||
@ -42,7 +57,10 @@ class TestOpenAIInvoker(unittest.TestCase):
|
||||
if self._saved_openai is not None:
|
||||
sys.modules["openai"] = self._saved_openai
|
||||
|
||||
def testOpenAIInvoker(self):
|
||||
def testOpenAICompletion(self):
|
||||
self._mocked_openai.__version__ = "1.7.0"
|
||||
self._mocked_openai.OpenAI = unittest.mock.MagicMock
|
||||
|
||||
nn_config = {
|
||||
"nn_name": "gpt-3.5-turbo",
|
||||
"openai_api_key": "EMPTY",
|
||||
@ -51,19 +69,45 @@ class TestOpenAIInvoker(unittest.TestCase):
|
||||
}
|
||||
invoker = NNInvoker.from_config(nn_config)
|
||||
self.assertEqual(invoker.init_args, nn_config)
|
||||
self.assertEqual(self._mocked_openai.api_key, nn_config["openai_api_key"])
|
||||
self.assertEqual(self._mocked_openai.api_base, nn_config["openai_api_base"])
|
||||
self.assertEqual(invoker.client.api_key, nn_config["openai_api_key"])
|
||||
self.assertEqual(invoker.client.base_url, nn_config["openai_api_base"])
|
||||
|
||||
mock_completion = MockCompletion(choices=[MockChoice("a dog named Bolt ...")])
|
||||
self._mocked_openai.Completion.create.return_value = mock_completion
|
||||
mock_completion = MockCompletion(
|
||||
choices=[MockChoice(message=MockMessage(content="a dog named Bolt ..."))]
|
||||
)
|
||||
invoker.client.chat.completions.create.return_value = mock_completion
|
||||
|
||||
result = invoker.remote_inference("Long long ago, ")
|
||||
self._mocked_openai.Completion.create.assert_called_with(
|
||||
prompt=["Long long ago, "],
|
||||
invoker.client.chat.completions.create.assert_called_with(
|
||||
model=nn_config["nn_name"],
|
||||
messages=[{"role": "user", "content": "Long long ago, "}],
|
||||
max_tokens=nn_config["openai_max_tokens"],
|
||||
)
|
||||
self.assertEqual(result, [mock_completion.choices[0].text])
|
||||
self.assertEqual(result, [mock_completion.choices[0].message.content])
|
||||
|
||||
def testOpenAIEmbedding(self):
|
||||
self._mocked_openai.__version__ = "1.7.0"
|
||||
self._mocked_openai.OpenAI = unittest.mock.MagicMock
|
||||
|
||||
nn_config = {
|
||||
"nn_name": "text-embedding-ada-002",
|
||||
"openai_api_key": "EMPTY",
|
||||
"openai_api_base": "http://localhost:38080/v1",
|
||||
}
|
||||
invoker = NNInvoker.from_config(nn_config)
|
||||
self.assertEqual(invoker.init_args, nn_config)
|
||||
self.assertEqual(invoker.client.api_key, nn_config["openai_api_key"])
|
||||
self.assertEqual(invoker.client.base_url, nn_config["openai_api_base"])
|
||||
|
||||
mock_embeddings = MockEmbeddings(data=[MockEmbedding(embedding=[0.1, 0.2])])
|
||||
invoker.client.embeddings.create.return_value = mock_embeddings
|
||||
|
||||
result = invoker.remote_inference("How old are you?", type="Embedding")
|
||||
invoker.client.embeddings.create.assert_called_with(
|
||||
model=nn_config["nn_name"],
|
||||
input=["How old are you?"],
|
||||
)
|
||||
self.assertEqual(result, [mock_embeddings.data[0].embedding])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
113
python/nn4k/tests/invoker/test_openai_invoker_legacy_api.py
Normal file
113
python/nn4k/tests/invoker/test_openai_invoker_legacy_api.py
Normal file
@ -0,0 +1,113 @@
|
||||
# Copyright 2023 OpenSPG Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
import unittest.mock
|
||||
from dataclasses import dataclass
|
||||
|
||||
from nn4k.invoker import NNInvoker
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockCompletion:
|
||||
choices: list
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockMessage:
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockChoice:
|
||||
message: MockMessage
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockEmbeddings:
|
||||
data: list
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockEmbedding:
|
||||
embedding: list
|
||||
|
||||
|
||||
class TestOpenAIInvokerLegacyAPI(unittest.TestCase):
|
||||
"""
|
||||
OpenAIInvoker unittest for legacy OpenAI api
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self._saved_openai = sys.modules.get("openai")
|
||||
self._mocked_openai = unittest.mock.MagicMock()
|
||||
sys.modules["openai"] = self._mocked_openai
|
||||
|
||||
def tearDown(self):
|
||||
del sys.modules["openai"]
|
||||
if self._saved_openai is not None:
|
||||
sys.modules["openai"] = self._saved_openai
|
||||
|
||||
def testOpenAICompletion(self):
|
||||
self._mocked_openai.__version__ = "0.28.1"
|
||||
|
||||
nn_config = {
|
||||
"nn_name": "gpt-3.5-turbo",
|
||||
"openai_api_key": "EMPTY",
|
||||
"openai_api_base": "http://localhost:38080/v1",
|
||||
"openai_max_tokens": 2000,
|
||||
}
|
||||
invoker = NNInvoker.from_config(nn_config)
|
||||
self.assertEqual(invoker.init_args, nn_config)
|
||||
self.assertEqual(self._mocked_openai.api_key, nn_config["openai_api_key"])
|
||||
self.assertEqual(self._mocked_openai.api_base, nn_config["openai_api_base"])
|
||||
|
||||
mock_completion = MockCompletion(
|
||||
choices=[MockChoice(message=MockMessage(content="a dog named Bolt ..."))]
|
||||
)
|
||||
self._mocked_openai.ChatCompletion.create.return_value = mock_completion
|
||||
|
||||
result = invoker.remote_inference("Long long ago, ")
|
||||
self._mocked_openai.ChatCompletion.create.assert_called_with(
|
||||
model=nn_config["nn_name"],
|
||||
messages=[{"role": "user", "content": "Long long ago, "}],
|
||||
max_tokens=nn_config["openai_max_tokens"],
|
||||
)
|
||||
self.assertEqual(result, [mock_completion.choices[0].message.content])
|
||||
|
||||
def testOpenAIEmbedding(self):
|
||||
self._mocked_openai.__version__ = "0.28.1"
|
||||
self._mocked_openai.OpenAI = unittest.mock.MagicMock
|
||||
|
||||
nn_config = {
|
||||
"nn_name": "text-embedding-ada-002",
|
||||
"openai_api_key": "EMPTY",
|
||||
"openai_api_base": "http://localhost:38080/v1",
|
||||
}
|
||||
invoker = NNInvoker.from_config(nn_config)
|
||||
self.assertEqual(invoker.init_args, nn_config)
|
||||
self.assertEqual(self._mocked_openai.api_key, nn_config["openai_api_key"])
|
||||
self.assertEqual(self._mocked_openai.api_base, nn_config["openai_api_base"])
|
||||
|
||||
mock_embeddings = MockEmbeddings(data=[MockEmbedding(embedding=[0.1, 0.2])])
|
||||
self._mocked_openai.Embedding.create.return_value = mock_embeddings
|
||||
|
||||
result = invoker.remote_inference("How old are you?", type="Embedding")
|
||||
self._mocked_openai.Embedding.create.assert_called_with(
|
||||
model=nn_config["nn_name"],
|
||||
input=["How old are you?"],
|
||||
)
|
||||
self.assertEqual(result, [mock_embeddings.data[0].embedding])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
36
python/nn4k/tests/python-env/.env.restore.sh
Executable file
36
python/nn4k/tests/python-env/.env.restore.sh
Executable file
@ -0,0 +1,36 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Copyright 2023 OpenSPG Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
_SCRIPT_FILE_PATH="${BASH_SOURCE[0]}"
|
||||
while [ -h ${_SCRIPT_FILE_PATH} ]
|
||||
do
|
||||
_SCRIPT_DIR_PATH=$(cd -P "$(dirname ${_SCRIPT_FILE_PATH})" && pwd)
|
||||
_SCRIPT_FILE_PATH=$(readlink ${_SCRIPT_FILE_PATH})
|
||||
case ${_SCRIPT_FILE_PATH} in
|
||||
/*) ;;
|
||||
*) _SCRIPT_FILE_PATH=${_SCRIPT_DIR_PATH}/${_SCRIPT_FILE_PATH} ;;
|
||||
esac
|
||||
done
|
||||
_SCRIPT_DIR_PATH=$(cd -P "$(dirname ${_SCRIPT_FILE_PATH})" && pwd)
|
||||
|
||||
if [ -f ${_SCRIPT_DIR_PATH}/.env/requirements.txt ]
|
||||
then
|
||||
exit
|
||||
fi
|
||||
|
||||
set -e
|
||||
rm -rf ${_SCRIPT_DIR_PATH}/.env
|
||||
python3 -m venv ${_SCRIPT_DIR_PATH}/.env
|
||||
source ${_SCRIPT_DIR_PATH}/.env/bin/activate
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip freeze > ${_SCRIPT_DIR_PATH}/.env/requirements.txt
|
1
python/nn4k/tests/python-env/.gitignore
vendored
Normal file
1
python/nn4k/tests/python-env/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
/.env/
|
25
python/nn4k/tests/python-env/env.sh
Normal file
25
python/nn4k/tests/python-env/env.sh
Normal file
@ -0,0 +1,25 @@
|
||||
# Copyright 2023 OpenSPG Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
_SCRIPT_FILE_PATH="${BASH_SOURCE[0]}"
|
||||
while [ -h ${_SCRIPT_FILE_PATH} ]
|
||||
do
|
||||
_SCRIPT_DIR_PATH=$(cd -P "$(dirname ${_SCRIPT_FILE_PATH})" && pwd)
|
||||
_SCRIPT_FILE_PATH=$(readlink ${_SCRIPT_FILE_PATH})
|
||||
case ${_SCRIPT_FILE_PATH} in
|
||||
/*) ;;
|
||||
*) _SCRIPT_FILE_PATH=${_SCRIPT_DIR_PATH}/${_SCRIPT_FILE_PATH} ;;
|
||||
esac
|
||||
done
|
||||
_SCRIPT_DIR_PATH=$(cd -P "$(dirname ${_SCRIPT_FILE_PATH})" && pwd)
|
||||
|
||||
${_SCRIPT_DIR_PATH}/.env.restore.sh
|
||||
source ${_SCRIPT_DIR_PATH}/.env/bin/activate
|
60
python/nn4k/tests/run_all_tests.py
Executable file
60
python/nn4k/tests/run_all_tests.py
Executable file
@ -0,0 +1,60 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
# -*- mode: python -*-
|
||||
|
||||
# Copyright 2023 OpenSPG Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
from __future__ import print_function
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
|
||||
class NN4KTestsRunner(object):
|
||||
def _run_all_tests(self):
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
dir_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
restore_script_path = os.path.join(dir_path, "python-env", ".env.restore.sh")
|
||||
args = [restore_script_path]
|
||||
subprocess.check_call(args)
|
||||
|
||||
nn4k_dir_path = os.path.dirname(dir_path)
|
||||
python_executable_path = os.path.join(
|
||||
dir_path, "python-env", ".env", "bin", "python"
|
||||
)
|
||||
saved_dir_path = os.getcwd()
|
||||
os.chdir(dir_path)
|
||||
|
||||
args = ["env", "PYTHONPATH=%s" % nn4k_dir_path]
|
||||
args += [python_executable_path]
|
||||
args += ["-m", "unittest"]
|
||||
try:
|
||||
subprocess.check_call(args)
|
||||
except subprocess.CalledProcessError:
|
||||
raise SystemExit(1)
|
||||
finally:
|
||||
os.chdir(saved_dir_path)
|
||||
|
||||
def run(self):
|
||||
self._run_all_tests()
|
||||
|
||||
|
||||
def main():
|
||||
runner = NN4KTestsRunner()
|
||||
runner.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -22,6 +22,7 @@ class TestInvokerChecking(unittest.TestCase):
|
||||
|
||||
self.assertTrue(is_openai_invoker({"nn_name": "gpt-3.5-turbo"}))
|
||||
self.assertTrue(is_openai_invoker({"nn_name": "gpt-4"}))
|
||||
self.assertTrue(is_openai_invoker({"nn_name": "text-embedding-ada-002"}))
|
||||
self.assertFalse(is_openai_invoker({"nn_name": "dummy"}))
|
||||
|
||||
self.assertTrue(is_openai_invoker({"openai_api_key": "EMPTY"}))
|
||||
@ -29,6 +30,7 @@ class TestInvokerChecking(unittest.TestCase):
|
||||
is_openai_invoker({"openai_api_base": "http://localhost:38000/v1"})
|
||||
)
|
||||
self.assertTrue(is_openai_invoker({"openai_max_tokens": 1000}))
|
||||
self.assertTrue(is_openai_invoker({"openai_organization": "test_org"}))
|
||||
self.assertFalse(is_openai_invoker({"foo": "bar"}))
|
||||
|
||||
def testIsLocalInvoker(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user