feat(nn4k): implement text embeddings (#104)

This commit is contained in:
xionghuaidong 2024-02-02 16:29:24 +08:00 committed by GitHub
parent 1b3b78b02e
commit 945cf8fbbd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 733 additions and 78 deletions

View File

@ -1 +1 @@
0.0.2-beta2
0.0.2-beta3

View File

@ -11,4 +11,4 @@
__package_name__ = "openspg-nn4k"
__version__ = "0.0.2-beta2"
__version__ = "0.0.2-beta3"

View File

@ -37,7 +37,12 @@ NN_OPENAI_API_BASE_TEXT = "openai api base"
NN_OPENAI_MAX_TOKENS_KEY = "openai_max_tokens"
NN_OPENAI_MAX_TOKENS_TEXT = "openai max tokens"
NN_OPENAI_ORGANIZATION_KEY = "openai_organization"
NN_OPENAI_ORGANIZATION_TEXT = "openai organization"
NN_OPENAI_GPT4_PREFIX = "gpt-4"
NN_OPENAI_GPT35_PREFIX = "gpt-3.5"
NN_OPENAI_EMBEDDING_PREFIX = "text-embedding"
NN_LOCAL_HF_MODEL_CONFIG_FILE = "config.json"
NN_LOCAL_SENTENCE_TRANSFORMERS_CONFIG_FILE = "config_sentence_transformers.json"

View File

@ -0,0 +1,29 @@
# NN4K example: text embedding with OpenAI
## Install dependencies
```bash
python3 -m venv .env
source .env/bin/activate
python -m pip install --upgrade pip
python -m pip install openspg-nn4k
```
## Edit configurations
Edit configurations in [openai_emb.json](./openai_emb.json).
* Set ``openai_api_base`` to an OpenAI api compatible base url.
To invoke the official OpenAI api service, set this field to
``https://api.openai.com/v1``.
* Set ``openai_api_key`` to a valid OpenAI api key.
For the official OpenAI api service, you can find your api key
``sk-xxx`` in the [API keys](https://platform.openai.com/api-keys) page.
## Run the example
```bash
python openai_emb.py
```

View File

@ -0,0 +1,5 @@
{
"nn_name": "text-embedding-ada-002",
"openai_api_key": "EMPTY",
"openai_api_base": "http://127.0.0.1:38080/v1"
}

View File

@ -0,0 +1,19 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from nn4k.invoker import NNInvoker
invoker = NNInvoker.from_config("openai_emb.json")
vecs = invoker.remote_inference(
["How old are you?", "What is your age?"], type="Embedding"
)
similarity = sum(x * y for x, y in zip(*vecs))
print("similarity: %g" % similarity)

View File

@ -0,0 +1,29 @@
# NN4K example: inference with OpenAI
## Install dependencies
```bash
python3 -m venv .env
source .env/bin/activate
python -m pip install --upgrade pip
python -m pip install openspg-nn4k
```
## Edit configurations
Edit configurations in [openai_infer.json](./openai_infer.json).
* Set ``openai_api_base`` to an OpenAI api compatible base url.
To invoke the official OpenAI api service, set this field to
``https://api.openai.com/v1``.
* Set ``openai_api_key`` to a valid OpenAI api key.
For the official OpenAI api service, you can find your api key
``sk-xxx`` in the [API keys](https://platform.openai.com/api-keys) page.
## Run the example
```bash
python openai_infer.py
```

View File

@ -0,0 +1,6 @@
{
"nn_name": "gpt-3.5-turbo",
"openai_api_key": "EMPTY",
"openai_api_base": "http://127.0.0.1:38080/v1",
"openai_max_tokens": 64
}

View File

@ -0,0 +1,16 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from nn4k.invoker import NNInvoker
invoker = NNInvoker.from_config("openai_infer.json")
output = invoker.remote_inference("Say this is a test.")
print(output)

View File

@ -13,11 +13,11 @@ from typing import Union
from nn4k.executor import LLMExecutor
class HfLLMExecutor(LLMExecutor):
class HFLLMExecutor(LLMExecutor):
@classmethod
def from_config(cls, nn_config: dict) -> "HfLLMExecutor":
def from_config(cls, nn_config: dict) -> "HFLLMExecutor":
"""
Create an HfLLMExecutor instance from `nn_config`.
Create an HFLLMExecutor instance from `nn_config`.
"""
executor = cls(nn_config)
return executor
@ -102,3 +102,51 @@ class HfLLMExecutor(LLMExecutor):
for idx, output_id in enumerate(output_ids)
]
return outputs
class HFEmbeddingExecutor(LLMExecutor):
@classmethod
def from_config(cls, nn_config: dict) -> "HFEmbeddingExecutor":
"""
Create an HFEmbeddingExecutor instance from `nn_config`.
"""
executor = cls(nn_config)
return executor
def load_model(self, args=None, **kwargs):
import torch
from sentence_transformers import SentenceTransformer
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
from nn4k.consts import NN_DEVICE_KEY
from nn4k.utils.config_parsing import get_string_field
nn_config: dict = args or self.init_args
if self._model is None:
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
nn_version = nn_config.get(NN_VERSION_KEY)
if nn_version is not None:
nn_version = get_string_field(
nn_config, NN_VERSION_KEY, NN_VERSION_TEXT
)
model_path = nn_name
revision = nn_version
use_fast_tokenizer = False
device = nn_config.get(NN_DEVICE_KEY)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
#
# SentenceTransformer will support `revision` soon. See:
#
# https://github.com/UKPLab/sentence-transformers/pull/2419
#
model = SentenceTransformer(
model_path,
device=device,
)
self._model = model
def inference(self, data, args=None, **kwargs):
model = self.model
embeddings = model.encode(data)
return embeddings

View File

@ -13,7 +13,7 @@ from abc import ABC, abstractmethod
from enum import Enum
from typing import Union
from nn4k.executor import LLMExecutor
from nn4k.executor import NNExecutor
class SubmitMode(Enum):
@ -157,23 +157,36 @@ class LLMInvoker(NNInvoker):
Implement local model warming up logic for local invoker.
"""
from nn4k.nnhub import NNHub
from nn4k.consts import NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
from nn4k.utils.config_parsing import get_string_field
from nn4k.utils.class_importing import dynamic_import_class
nn_name = get_string_field(self.init_args, NN_NAME_KEY, NN_NAME_TEXT)
nn_version = self.init_args.get(NN_VERSION_KEY)
if nn_version is not None:
nn_version = get_string_field(
self.init_args, NN_VERSION_KEY, NN_VERSION_TEXT
nn_executor = self.init_args.get(NN_EXECUTOR_KEY)
if nn_executor is not None:
nn_executor = get_string_field(
self.init_args, NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT
)
hub = NNHub.get_instance()
executor = hub.get_model_executor(nn_name, nn_version)
if executor is None:
message = "model %r version %r " % (nn_name, nn_version)
message += "is not found in the model hub"
raise RuntimeError(message)
self._nn_executor: LLMExecutor = executor
executor_class = dynamic_import_class(nn_executor, NN_EXECUTOR_TEXT)
if not issubclass(executor_class, NNExecutor):
message = "%r is not an %s class" % (nn_executor, NN_EXECUTOR_TEXT)
raise RuntimeError(message)
executor = executor_class.from_config(self.init_args)
else:
nn_name = get_string_field(self.init_args, NN_NAME_KEY, NN_NAME_TEXT)
nn_version = self.init_args.get(NN_VERSION_KEY)
if nn_version is not None:
nn_version = get_string_field(
self.init_args, NN_VERSION_KEY, NN_VERSION_TEXT
)
hub = NNHub.get_instance()
executor = hub.get_model_executor(nn_name, nn_version)
if executor is None:
message = "model %r version %r " % (nn_name, nn_version)
message += "is not found in the model hub"
raise RuntimeError(message)
self._nn_executor: NNExecutor = executor
self._nn_executor.load_model()
self._nn_executor.warmup_inference()

View File

@ -23,6 +23,7 @@ class OpenAIInvoker(NNInvoker):
from nn4k.consts import NN_OPENAI_API_KEY_KEY, NN_OPENAI_API_KEY_TEXT
from nn4k.consts import NN_OPENAI_API_BASE_KEY, NN_OPENAI_API_BASE_TEXT
from nn4k.consts import NN_OPENAI_MAX_TOKENS_KEY, NN_OPENAI_MAX_TOKENS_TEXT
from nn4k.consts import NN_OPENAI_ORGANIZATION_KEY, NN_OPENAI_ORGANIZATION_TEXT
from nn4k.utils.config_parsing import get_string_field
from nn4k.utils.config_parsing import get_positive_int_field
@ -35,41 +36,107 @@ class OpenAIInvoker(NNInvoker):
self.openai_api_base = get_string_field(
self.init_args, NN_OPENAI_API_BASE_KEY, NN_OPENAI_API_BASE_TEXT
)
self.openai_max_tokens = get_positive_int_field(
self.init_args, NN_OPENAI_MAX_TOKENS_KEY, NN_OPENAI_MAX_TOKENS_TEXT
)
self.openai_max_tokens = self.init_args.get(NN_OPENAI_MAX_TOKENS_KEY)
if self.openai_max_tokens is not None:
self.openai_max_tokens = get_positive_int_field(
self.init_args, NN_OPENAI_MAX_TOKENS_KEY, NN_OPENAI_MAX_TOKENS_TEXT
)
self.openai_organization = self.init_args.get(NN_OPENAI_ORGANIZATION_KEY)
if self.openai_organization is not None:
self.openai_organization = get_string_field(
self.init_args, NN_OPENAI_ORGANIZATION_KEY, NN_OPENAI_ORGANIZATION_TEXT
)
openai.api_key = self.openai_api_key
openai.api_base = self.openai_api_base
if self._is_legacy_openai_api:
self.client = None
openai.api_key = self.openai_api_key
openai.api_base = self.openai_api_base
openai.organization = self.openai_organization
else:
self.client = openai.OpenAI(
api_key=self.openai_api_key,
base_url=self.openai_api_base,
organization=self.openai_organization,
)
@classmethod
def from_config(cls, nn_config: dict) -> "OpenAIInvoker":
invoker = cls(nn_config)
return invoker
@property
def _is_legacy_openai_api(self):
import openai
return openai.__version__.startswith("0.")
def _create_prompt(self, input, **kwargs):
if isinstance(input, list):
prompt = input
else:
prompt = [input]
prompt = [{"role": "user", "content": input}]
return prompt
def _create_output(self, input, prompt, completion, **kwargs):
output = [choice.text for choice in completion.choices]
return output
def remote_inference(
self, input, max_output_length: Optional[int] = None, **kwargs
):
def _create_completion(self, input, prompt, max_output_length, **kwargs):
import openai
if max_output_length is None:
max_output_length = self.openai_max_tokens
if self._is_legacy_openai_api:
completion = openai.ChatCompletion.create(
model=self.openai_model_name,
messages=prompt,
max_tokens=max_output_length,
)
else:
completion = self.client.chat.completions.create(
model=self.openai_model_name,
messages=prompt,
max_tokens=max_output_length,
)
return completion
def _create_output(self, input, prompt, completion, **kwargs):
output = [choice.message.content for choice in completion.choices]
return output
def _get_completion(self, input, max_output_length, **kwargs):
prompt = self._create_prompt(input, **kwargs)
completion = openai.Completion.create(
model=self.openai_model_name,
prompt=prompt,
max_tokens=max_output_length,
)
completion = self._create_completion(input, prompt, max_output_length, **kwargs)
output = self._create_output(input, prompt, completion, **kwargs)
return output
def _get_embeddings(self, input, **kwargs):
import openai
if isinstance(input, list):
inputs = input
else:
inputs = [input]
if self._is_legacy_openai_api:
response = openai.Embedding.create(
model=self.openai_model_name,
input=inputs,
)
else:
response = self.client.embeddings.create(
model=self.openai_model_name,
input=inputs,
)
embeddings = [emb.embedding for emb in response.data]
return embeddings
def remote_inference(
self,
input,
type: Optional[str] = None,
max_output_length: Optional[int] = None,
**kwargs
):
if type == "Embedding":
output = self._get_embeddings(input, **kwargs)
else:
if max_output_length is None:
max_output_length = self.openai_max_tokens
output = self._get_completion(input, max_output_length, **kwargs)
return output

View File

@ -9,6 +9,8 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
from abc import ABC, abstractmethod
from typing import Optional, Union, Tuple, Type
@ -125,6 +127,10 @@ class SimpleNNHub(NNHub):
def get_model_executor(
self, name: str, version: str = None
) -> Optional[NNExecutor]:
from nn4k.consts import NN_VERSION_DEFAULT
if version is None:
version = NN_VERSION_DEFAULT
if self._model_executors.get(name) is None:
return None
executor = self._model_executors.get(name).get(version)
@ -134,13 +140,59 @@ class SimpleNNHub(NNHub):
executor = self._create_model_executor(cls, init_args, kwargs, weights)
return executor
def _add_local_executor(self, nn_config):
def _get_local_executor_class(self, nn_config: dict) -> Type[NNExecutor]:
from nn4k.consts import NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
from nn4k.executor.hugging_face import HfLLMExecutor
from nn4k.consts import NN_LOCAL_HF_MODEL_CONFIG_FILE
from nn4k.consts import NN_LOCAL_SENTENCE_TRANSFORMERS_CONFIG_FILE
from nn4k.executor.hugging_face import HFLLMExecutor
from nn4k.executor.hugging_face import HFEmbeddingExecutor
from nn4k.utils.config_parsing import get_string_field
executor = HfLLMExecutor.from_config(nn_config)
nn_executor = nn_config.get(NN_EXECUTOR_KEY)
if nn_executor is not None:
nn_executor = get_string_field(
self.init_args, NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT
)
executor_class = dynamic_import_class(nn_executor, NN_EXECUTOR_TEXT)
if not issubclass(executor_class, NNExecutor):
message = "%r is not an %s class" % (nn_executor, NN_EXECUTOR_TEXT)
raise RuntimeError(message)
return executor_class
nn_name = nn_config.get(NN_NAME_KEY)
if nn_name is not None:
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
if os.path.isdir(nn_name):
file_path = os.path.join(nn_name, NN_LOCAL_HF_MODEL_CONFIG_FILE)
if os.path.isfile(file_path):
file_path = os.path.join(
nn_name, NN_LOCAL_SENTENCE_TRANSFORMERS_CONFIG_FILE
)
if os.path.isfile(file_path):
executor_class = HFEmbeddingExecutor
else:
executor_class = HFLLMExecutor
return executor_class
nn_version = nn_config.get(NN_VERSION_KEY)
if nn_version is not None:
nn_version = get_string_field(nn_config, NN_VERSION_KEY, NN_VERSION_TEXT)
message = "can not determine local executor class for NN config"
if nn_name is not None:
message += "; model: %r" % nn_name
if nn_version is not None:
message += ", version: %r" % nn_version
raise RuntimeError(message)
def _add_local_executor(self, nn_config: dict):
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
from nn4k.utils.config_parsing import get_string_field
executor_class = self._get_local_executor_class(nn_config)
executor = executor_class.from_config(nn_config)
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
nn_version = nn_config.get(NN_VERSION_KEY)
if nn_version is not None:

View File

@ -23,18 +23,27 @@ def is_openai_invoker(nn_config: dict) -> bool:
from nn4k.consts import NN_OPENAI_API_KEY_KEY
from nn4k.consts import NN_OPENAI_API_BASE_KEY
from nn4k.consts import NN_OPENAI_MAX_TOKENS_KEY
from nn4k.consts import NN_OPENAI_ORGANIZATION_KEY
from nn4k.consts import NN_OPENAI_GPT4_PREFIX
from nn4k.consts import NN_OPENAI_GPT35_PREFIX
from nn4k.consts import NN_OPENAI_EMBEDDING_PREFIX
from nn4k.utils.config_parsing import get_string_field
nn_name = nn_config.get(NN_NAME_KEY)
if nn_name is not None:
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
if nn_name.startswith(NN_OPENAI_GPT4_PREFIX) or nn_name.startswith(
NN_OPENAI_GPT35_PREFIX
if (
nn_name.startswith(NN_OPENAI_GPT4_PREFIX)
or nn_name.startswith(NN_OPENAI_GPT35_PREFIX)
or nn_name.startswith(NN_OPENAI_EMBEDDING_PREFIX)
):
return True
keys = (NN_OPENAI_API_KEY_KEY, NN_OPENAI_API_BASE_KEY, NN_OPENAI_MAX_TOKENS_KEY)
keys = (
NN_OPENAI_API_KEY_KEY,
NN_OPENAI_API_BASE_KEY,
NN_OPENAI_MAX_TOKENS_KEY,
NN_OPENAI_ORGANIZATION_KEY,
)
for key in keys:
if key in nn_config:
return True

View File

@ -1 +1 @@
openai<1
openai

View File

@ -42,7 +42,7 @@ class StubHub(SimpleNNHub):
def get_model_executor(
self, name: str, version: str = None
) -> Optional[NNExecutor]:
if name == "test_stub":
if name == "executor_test_stub":
if version is None:
version = "default"
executor = StubExecutor(

View File

@ -20,12 +20,12 @@ class TestBaseExecutor(unittest.TestCase):
"""
def setUp(self):
# for importing test_stub.py
# for importing executor_test_stub.py
dir_path = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, dir_path)
from nn4k.nnhub import NNHub
from test_stub import StubHub
from executor_test_stub import StubHub
NNHub._hub_instance = StubHub()
@ -37,22 +37,24 @@ class TestBaseExecutor(unittest.TestCase):
def testCustomNNExecutor(self):
from nn4k.executor import NNExecutor
from test_stub import StubExecutor
from executor_test_stub import StubExecutor
nn_config = {"nn_executor": "test_stub.StubExecutor"}
nn_config = {"nn_executor": "executor_test_stub.StubExecutor"}
executor = NNExecutor.from_config(nn_config)
self.assertTrue(isinstance(executor, StubExecutor))
self.assertEqual(executor.init_args, nn_config)
self.assertEqual(executor.kwargs, {})
with self.assertRaises(RuntimeError):
executor = NNExecutor.from_config({"nn_executor": "test_stub.NotExecutor"})
executor = NNExecutor.from_config(
{"nn_executor": "executor_test_stub.NotExecutor"}
)
def testHubExecutor(self):
from nn4k.executor import NNExecutor
from test_stub import StubExecutor
from executor_test_stub import StubExecutor
nn_config = {"nn_name": "test_stub", "nn_version": "default"}
nn_config = {"nn_name": "executor_test_stub", "nn_version": "default"}
executor = NNExecutor.from_config(nn_config)
self.assertTrue(isinstance(executor, StubExecutor))
self.assertEqual(executor.init_args, nn_config)

View File

@ -0,0 +1,57 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import sys
import unittest
import unittest.mock
from nn4k.executor.hugging_face import HFEmbeddingExecutor
class TestHFEmbeddingExecutor(unittest.TestCase):
"""
HFEmbeddingExecutor unittest
"""
def setUp(self):
self._saved_torch = sys.modules.get("torch")
self._mocked_torch = unittest.mock.MagicMock()
sys.modules["torch"] = self._mocked_torch
self._saved_sentence_transformers = sys.modules.get("sentence_transformers")
self._mocked_sentence_transformers = unittest.mock.MagicMock()
sys.modules["sentence_transformers"] = self._mocked_sentence_transformers
def tearDown(self):
del sys.modules["torch"]
if self._saved_torch is not None:
sys.modules["torch"] = self._saved_torch
del sys.modules["sentence_transformers"]
if self._saved_sentence_transformers is not None:
sys.modules["sentence_transformers"] = self._saved_sentence_transformers
def testHFEmbeddingExecutor(self):
nn_config = {
"nn_name": "/opt/test_model_dir",
"nn_version": "default",
}
executor = HFEmbeddingExecutor.from_config(nn_config)
executor.load_model()
executor.inference("input")
self._mocked_sentence_transformers.SentenceTransformer.assert_called()
executor.model.encode.assert_called()
if __name__ == "__main__":
unittest.main()

View File

@ -13,12 +13,12 @@ import sys
import unittest
import unittest.mock
from nn4k.executor.hugging_face import HfLLMExecutor
from nn4k.executor.hugging_face import HFLLMExecutor
class TestHfLLMExecutor(unittest.TestCase):
class TestHFLLMExecutor(unittest.TestCase):
"""
HfLLMExecutor unittest
HFLLMExecutor unittest
"""
def setUp(self):
@ -39,18 +39,20 @@ class TestHfLLMExecutor(unittest.TestCase):
if self._saved_transformers is not None:
sys.modules["transformers"] = self._saved_transformers
def testHfLLMExecutor(self):
def testHFLLMExecutor(self):
nn_config = {
"nn_name": "/opt/test_model_dir",
"nn_version": "default",
}
executor = HfLLMExecutor.from_config(nn_config)
executor = HFLLMExecutor.from_config(nn_config)
executor.load_model()
executor.inference("input")
self._mocked_transformers.AutoTokenizer.from_pretrained.assert_called()
self._mocked_transformers.AutoModelForCausalLM.from_pretrained.assert_called()
executor.tokenizer.assert_called()
executor.model.generate.assert_called()
if __name__ == "__main__":

View File

@ -52,7 +52,7 @@ class StubExecutor(NNExecutor):
class StubHub(SimpleNNHub):
def get_invoker(self, nn_config: dict) -> Optional[NNInvoker]:
nn_name = nn_config.get("nn_name")
if nn_name is not None and nn_name == "test_stub":
if nn_name is not None and nn_name == "invoker_test_stub":
invoker = StubInvoker(nn_config, test_stub_invoker=True)
return invoker
return super().get_invoker(nn_config)
@ -60,7 +60,7 @@ class StubHub(SimpleNNHub):
def get_model_executor(
self, name: str, version: str = None
) -> Optional[NNExecutor]:
if name == "test_stub":
if name == "invoker_test_stub":
if version is None:
version = "default"
executor = StubExecutor(

View File

@ -20,12 +20,12 @@ class TestBaseInvoker(unittest.TestCase):
"""
def setUp(self):
# for importing test_stub.py
# for importing invoker_test_stub.py
dir_path = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, dir_path)
from nn4k.nnhub import NNHub
from test_stub import StubHub
from invoker_test_stub import StubHub
NNHub._hub_instance = StubHub()
@ -37,22 +37,24 @@ class TestBaseInvoker(unittest.TestCase):
def testCustomNNInvoker(self):
from nn4k.invoker import NNInvoker
from test_stub import StubInvoker
from invoker_test_stub import StubInvoker
nn_config = {"nn_invoker": "test_stub.StubInvoker"}
nn_config = {"nn_invoker": "invoker_test_stub.StubInvoker"}
invoker = NNInvoker.from_config(nn_config)
self.assertTrue(isinstance(invoker, StubInvoker))
self.assertEqual(invoker.init_args, nn_config)
self.assertEqual(invoker.kwargs, {})
with self.assertRaises(RuntimeError):
invoker = NNInvoker.from_config({"nn_invoker": "test_stub.NotInvoker"})
invoker = NNInvoker.from_config(
{"nn_invoker": "invoker_test_stub.NotInvoker"}
)
def testHubInvoker(self):
from nn4k.invoker import NNInvoker
from test_stub import StubInvoker
from invoker_test_stub import StubInvoker
nn_config = {"nn_name": "test_stub"}
nn_config = {"nn_name": "invoker_test_stub"}
invoker = NNInvoker.from_config(nn_config)
self.assertTrue(isinstance(invoker, StubInvoker))
self.assertEqual(invoker.init_args, nn_config)
@ -66,9 +68,9 @@ class TestBaseInvoker(unittest.TestCase):
def testLocalInvoker(self):
from nn4k.invoker import NNInvoker
from test_stub import StubInvoker
from invoker_test_stub import StubInvoker
nn_config = {"nn_name": "test_stub"}
nn_config = {"nn_name": "invoker_test_stub"}
invoker = NNInvoker.from_config(nn_config)
self.assertTrue(isinstance(invoker, StubInvoker))
self.assertEqual(invoker.init_args, nn_config)
@ -79,6 +81,19 @@ class TestBaseInvoker(unittest.TestCase):
result = invoker.local_inference("input")
self.assertEqual(result, invoker._nn_executor.inference_result)
def testLocalLLMInvokerWithCustomExecutor(self):
from nn4k.invoker import LLMInvoker
nn_config = {"nn_executor": "invoker_test_stub.StubExecutor"}
invoker = LLMInvoker.from_config(nn_config)
self.assertTrue(isinstance(invoker, LLMInvoker))
self.assertEqual(invoker.init_args, nn_config)
invoker.warmup_local_model()
invoker._nn_executor.inference_result = "inference result"
result = invoker.local_inference("input")
self.assertEqual(result, invoker._nn_executor.inference_result)
if __name__ == "__main__":
unittest.main()

View File

@ -22,9 +22,24 @@ class MockCompletion:
choices: list
@dataclass
class MockMessage:
content: str
@dataclass
class MockChoice:
text: str
message: MockMessage
@dataclass
class MockEmbeddings:
data: list
@dataclass
class MockEmbedding:
embedding: list
class TestOpenAIInvoker(unittest.TestCase):
@ -42,7 +57,10 @@ class TestOpenAIInvoker(unittest.TestCase):
if self._saved_openai is not None:
sys.modules["openai"] = self._saved_openai
def testOpenAIInvoker(self):
def testOpenAICompletion(self):
self._mocked_openai.__version__ = "1.7.0"
self._mocked_openai.OpenAI = unittest.mock.MagicMock
nn_config = {
"nn_name": "gpt-3.5-turbo",
"openai_api_key": "EMPTY",
@ -51,19 +69,45 @@ class TestOpenAIInvoker(unittest.TestCase):
}
invoker = NNInvoker.from_config(nn_config)
self.assertEqual(invoker.init_args, nn_config)
self.assertEqual(self._mocked_openai.api_key, nn_config["openai_api_key"])
self.assertEqual(self._mocked_openai.api_base, nn_config["openai_api_base"])
self.assertEqual(invoker.client.api_key, nn_config["openai_api_key"])
self.assertEqual(invoker.client.base_url, nn_config["openai_api_base"])
mock_completion = MockCompletion(choices=[MockChoice("a dog named Bolt ...")])
self._mocked_openai.Completion.create.return_value = mock_completion
mock_completion = MockCompletion(
choices=[MockChoice(message=MockMessage(content="a dog named Bolt ..."))]
)
invoker.client.chat.completions.create.return_value = mock_completion
result = invoker.remote_inference("Long long ago, ")
self._mocked_openai.Completion.create.assert_called_with(
prompt=["Long long ago, "],
invoker.client.chat.completions.create.assert_called_with(
model=nn_config["nn_name"],
messages=[{"role": "user", "content": "Long long ago, "}],
max_tokens=nn_config["openai_max_tokens"],
)
self.assertEqual(result, [mock_completion.choices[0].text])
self.assertEqual(result, [mock_completion.choices[0].message.content])
def testOpenAIEmbedding(self):
self._mocked_openai.__version__ = "1.7.0"
self._mocked_openai.OpenAI = unittest.mock.MagicMock
nn_config = {
"nn_name": "text-embedding-ada-002",
"openai_api_key": "EMPTY",
"openai_api_base": "http://localhost:38080/v1",
}
invoker = NNInvoker.from_config(nn_config)
self.assertEqual(invoker.init_args, nn_config)
self.assertEqual(invoker.client.api_key, nn_config["openai_api_key"])
self.assertEqual(invoker.client.base_url, nn_config["openai_api_base"])
mock_embeddings = MockEmbeddings(data=[MockEmbedding(embedding=[0.1, 0.2])])
invoker.client.embeddings.create.return_value = mock_embeddings
result = invoker.remote_inference("How old are you?", type="Embedding")
invoker.client.embeddings.create.assert_called_with(
model=nn_config["nn_name"],
input=["How old are you?"],
)
self.assertEqual(result, [mock_embeddings.data[0].embedding])
if __name__ == "__main__":

View File

@ -0,0 +1,113 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import sys
import unittest
import unittest.mock
from dataclasses import dataclass
from nn4k.invoker import NNInvoker
@dataclass
class MockCompletion:
choices: list
@dataclass
class MockMessage:
content: str
@dataclass
class MockChoice:
message: MockMessage
@dataclass
class MockEmbeddings:
data: list
@dataclass
class MockEmbedding:
embedding: list
class TestOpenAIInvokerLegacyAPI(unittest.TestCase):
"""
OpenAIInvoker unittest for legacy OpenAI api
"""
def setUp(self):
self._saved_openai = sys.modules.get("openai")
self._mocked_openai = unittest.mock.MagicMock()
sys.modules["openai"] = self._mocked_openai
def tearDown(self):
del sys.modules["openai"]
if self._saved_openai is not None:
sys.modules["openai"] = self._saved_openai
def testOpenAICompletion(self):
self._mocked_openai.__version__ = "0.28.1"
nn_config = {
"nn_name": "gpt-3.5-turbo",
"openai_api_key": "EMPTY",
"openai_api_base": "http://localhost:38080/v1",
"openai_max_tokens": 2000,
}
invoker = NNInvoker.from_config(nn_config)
self.assertEqual(invoker.init_args, nn_config)
self.assertEqual(self._mocked_openai.api_key, nn_config["openai_api_key"])
self.assertEqual(self._mocked_openai.api_base, nn_config["openai_api_base"])
mock_completion = MockCompletion(
choices=[MockChoice(message=MockMessage(content="a dog named Bolt ..."))]
)
self._mocked_openai.ChatCompletion.create.return_value = mock_completion
result = invoker.remote_inference("Long long ago, ")
self._mocked_openai.ChatCompletion.create.assert_called_with(
model=nn_config["nn_name"],
messages=[{"role": "user", "content": "Long long ago, "}],
max_tokens=nn_config["openai_max_tokens"],
)
self.assertEqual(result, [mock_completion.choices[0].message.content])
def testOpenAIEmbedding(self):
self._mocked_openai.__version__ = "0.28.1"
self._mocked_openai.OpenAI = unittest.mock.MagicMock
nn_config = {
"nn_name": "text-embedding-ada-002",
"openai_api_key": "EMPTY",
"openai_api_base": "http://localhost:38080/v1",
}
invoker = NNInvoker.from_config(nn_config)
self.assertEqual(invoker.init_args, nn_config)
self.assertEqual(self._mocked_openai.api_key, nn_config["openai_api_key"])
self.assertEqual(self._mocked_openai.api_base, nn_config["openai_api_base"])
mock_embeddings = MockEmbeddings(data=[MockEmbedding(embedding=[0.1, 0.2])])
self._mocked_openai.Embedding.create.return_value = mock_embeddings
result = invoker.remote_inference("How old are you?", type="Embedding")
self._mocked_openai.Embedding.create.assert_called_with(
model=nn_config["nn_name"],
input=["How old are you?"],
)
self.assertEqual(result, [mock_embeddings.data[0].embedding])
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,36 @@
#!/bin/bash
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
_SCRIPT_FILE_PATH="${BASH_SOURCE[0]}"
while [ -h ${_SCRIPT_FILE_PATH} ]
do
_SCRIPT_DIR_PATH=$(cd -P "$(dirname ${_SCRIPT_FILE_PATH})" && pwd)
_SCRIPT_FILE_PATH=$(readlink ${_SCRIPT_FILE_PATH})
case ${_SCRIPT_FILE_PATH} in
/*) ;;
*) _SCRIPT_FILE_PATH=${_SCRIPT_DIR_PATH}/${_SCRIPT_FILE_PATH} ;;
esac
done
_SCRIPT_DIR_PATH=$(cd -P "$(dirname ${_SCRIPT_FILE_PATH})" && pwd)
if [ -f ${_SCRIPT_DIR_PATH}/.env/requirements.txt ]
then
exit
fi
set -e
rm -rf ${_SCRIPT_DIR_PATH}/.env
python3 -m venv ${_SCRIPT_DIR_PATH}/.env
source ${_SCRIPT_DIR_PATH}/.env/bin/activate
python -m pip install --upgrade pip
python -m pip freeze > ${_SCRIPT_DIR_PATH}/.env/requirements.txt

View File

@ -0,0 +1 @@
/.env/

View File

@ -0,0 +1,25 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
_SCRIPT_FILE_PATH="${BASH_SOURCE[0]}"
while [ -h ${_SCRIPT_FILE_PATH} ]
do
_SCRIPT_DIR_PATH=$(cd -P "$(dirname ${_SCRIPT_FILE_PATH})" && pwd)
_SCRIPT_FILE_PATH=$(readlink ${_SCRIPT_FILE_PATH})
case ${_SCRIPT_FILE_PATH} in
/*) ;;
*) _SCRIPT_FILE_PATH=${_SCRIPT_DIR_PATH}/${_SCRIPT_FILE_PATH} ;;
esac
done
_SCRIPT_DIR_PATH=$(cd -P "$(dirname ${_SCRIPT_FILE_PATH})" && pwd)
${_SCRIPT_DIR_PATH}/.env.restore.sh
source ${_SCRIPT_DIR_PATH}/.env/bin/activate

View File

@ -0,0 +1,60 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# -*- mode: python -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import division
class NN4KTestsRunner(object):
def _run_all_tests(self):
import os
import sys
import subprocess
dir_path = os.path.dirname(os.path.abspath(__file__))
restore_script_path = os.path.join(dir_path, "python-env", ".env.restore.sh")
args = [restore_script_path]
subprocess.check_call(args)
nn4k_dir_path = os.path.dirname(dir_path)
python_executable_path = os.path.join(
dir_path, "python-env", ".env", "bin", "python"
)
saved_dir_path = os.getcwd()
os.chdir(dir_path)
args = ["env", "PYTHONPATH=%s" % nn4k_dir_path]
args += [python_executable_path]
args += ["-m", "unittest"]
try:
subprocess.check_call(args)
except subprocess.CalledProcessError:
raise SystemExit(1)
finally:
os.chdir(saved_dir_path)
def run(self):
self._run_all_tests()
def main():
runner = NN4KTestsRunner()
runner.run()
if __name__ == "__main__":
main()

View File

@ -22,6 +22,7 @@ class TestInvokerChecking(unittest.TestCase):
self.assertTrue(is_openai_invoker({"nn_name": "gpt-3.5-turbo"}))
self.assertTrue(is_openai_invoker({"nn_name": "gpt-4"}))
self.assertTrue(is_openai_invoker({"nn_name": "text-embedding-ada-002"}))
self.assertFalse(is_openai_invoker({"nn_name": "dummy"}))
self.assertTrue(is_openai_invoker({"openai_api_key": "EMPTY"}))
@ -29,6 +30,7 @@ class TestInvokerChecking(unittest.TestCase):
is_openai_invoker({"openai_api_base": "http://localhost:38000/v1"})
)
self.assertTrue(is_openai_invoker({"openai_max_tokens": 1000}))
self.assertTrue(is_openai_invoker({"openai_organization": "test_org"}))
self.assertFalse(is_openai_invoker({"foo": "bar"}))
def testIsLocalInvoker(self):