mirror of
https://github.com/OpenSPG/openspg.git
synced 2025-08-31 20:35:58 +00:00
feat(nn4k): implement text embeddings (#104)
This commit is contained in:
parent
1b3b78b02e
commit
945cf8fbbd
@ -1 +1 @@
|
|||||||
0.0.2-beta2
|
0.0.2-beta3
|
||||||
|
@ -11,4 +11,4 @@
|
|||||||
|
|
||||||
|
|
||||||
__package_name__ = "openspg-nn4k"
|
__package_name__ = "openspg-nn4k"
|
||||||
__version__ = "0.0.2-beta2"
|
__version__ = "0.0.2-beta3"
|
||||||
|
@ -37,7 +37,12 @@ NN_OPENAI_API_BASE_TEXT = "openai api base"
|
|||||||
NN_OPENAI_MAX_TOKENS_KEY = "openai_max_tokens"
|
NN_OPENAI_MAX_TOKENS_KEY = "openai_max_tokens"
|
||||||
NN_OPENAI_MAX_TOKENS_TEXT = "openai max tokens"
|
NN_OPENAI_MAX_TOKENS_TEXT = "openai max tokens"
|
||||||
|
|
||||||
|
NN_OPENAI_ORGANIZATION_KEY = "openai_organization"
|
||||||
|
NN_OPENAI_ORGANIZATION_TEXT = "openai organization"
|
||||||
|
|
||||||
NN_OPENAI_GPT4_PREFIX = "gpt-4"
|
NN_OPENAI_GPT4_PREFIX = "gpt-4"
|
||||||
NN_OPENAI_GPT35_PREFIX = "gpt-3.5"
|
NN_OPENAI_GPT35_PREFIX = "gpt-3.5"
|
||||||
|
NN_OPENAI_EMBEDDING_PREFIX = "text-embedding"
|
||||||
|
|
||||||
NN_LOCAL_HF_MODEL_CONFIG_FILE = "config.json"
|
NN_LOCAL_HF_MODEL_CONFIG_FILE = "config.json"
|
||||||
|
NN_LOCAL_SENTENCE_TRANSFORMERS_CONFIG_FILE = "config_sentence_transformers.json"
|
||||||
|
29
python/nn4k/nn4k/examples/openai-embeddings/README.md
Normal file
29
python/nn4k/nn4k/examples/openai-embeddings/README.md
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# NN4K example: text embedding with OpenAI
|
||||||
|
|
||||||
|
## Install dependencies
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -m venv .env
|
||||||
|
source .env/bin/activate
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
python -m pip install openspg-nn4k
|
||||||
|
```
|
||||||
|
|
||||||
|
## Edit configurations
|
||||||
|
|
||||||
|
Edit configurations in [openai_emb.json](./openai_emb.json).
|
||||||
|
|
||||||
|
* Set ``openai_api_base`` to an OpenAI api compatible base url.
|
||||||
|
To invoke the official OpenAI api service, set this field to
|
||||||
|
``https://api.openai.com/v1``.
|
||||||
|
|
||||||
|
* Set ``openai_api_key`` to a valid OpenAI api key.
|
||||||
|
For the official OpenAI api service, you can find your api key
|
||||||
|
``sk-xxx`` in the [API keys](https://platform.openai.com/api-keys) page.
|
||||||
|
|
||||||
|
## Run the example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python openai_emb.py
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"nn_name": "text-embedding-ada-002",
|
||||||
|
"openai_api_key": "EMPTY",
|
||||||
|
"openai_api_base": "http://127.0.0.1:38080/v1"
|
||||||
|
}
|
19
python/nn4k/nn4k/examples/openai-embeddings/openai_emb.py
Normal file
19
python/nn4k/nn4k/examples/openai-embeddings/openai_emb.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# Copyright 2023 OpenSPG Authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
from nn4k.invoker import NNInvoker
|
||||||
|
|
||||||
|
invoker = NNInvoker.from_config("openai_emb.json")
|
||||||
|
vecs = invoker.remote_inference(
|
||||||
|
["How old are you?", "What is your age?"], type="Embedding"
|
||||||
|
)
|
||||||
|
similarity = sum(x * y for x, y in zip(*vecs))
|
||||||
|
print("similarity: %g" % similarity)
|
29
python/nn4k/nn4k/examples/openai/README.md
Normal file
29
python/nn4k/nn4k/examples/openai/README.md
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# NN4K example: inference with OpenAI
|
||||||
|
|
||||||
|
## Install dependencies
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -m venv .env
|
||||||
|
source .env/bin/activate
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
python -m pip install openspg-nn4k
|
||||||
|
```
|
||||||
|
|
||||||
|
## Edit configurations
|
||||||
|
|
||||||
|
Edit configurations in [openai_infer.json](./openai_infer.json).
|
||||||
|
|
||||||
|
* Set ``openai_api_base`` to an OpenAI api compatible base url.
|
||||||
|
To invoke the official OpenAI api service, set this field to
|
||||||
|
``https://api.openai.com/v1``.
|
||||||
|
|
||||||
|
* Set ``openai_api_key`` to a valid OpenAI api key.
|
||||||
|
For the official OpenAI api service, you can find your api key
|
||||||
|
``sk-xxx`` in the [API keys](https://platform.openai.com/api-keys) page.
|
||||||
|
|
||||||
|
## Run the example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python openai_infer.py
|
||||||
|
```
|
||||||
|
|
6
python/nn4k/nn4k/examples/openai/openai_infer.json
Normal file
6
python/nn4k/nn4k/examples/openai/openai_infer.json
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"nn_name": "gpt-3.5-turbo",
|
||||||
|
"openai_api_key": "EMPTY",
|
||||||
|
"openai_api_base": "http://127.0.0.1:38080/v1",
|
||||||
|
"openai_max_tokens": 64
|
||||||
|
}
|
16
python/nn4k/nn4k/examples/openai/openai_infer.py
Normal file
16
python/nn4k/nn4k/examples/openai/openai_infer.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
# Copyright 2023 OpenSPG Authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
from nn4k.invoker import NNInvoker
|
||||||
|
|
||||||
|
invoker = NNInvoker.from_config("openai_infer.json")
|
||||||
|
output = invoker.remote_inference("Say this is a test.")
|
||||||
|
print(output)
|
@ -13,11 +13,11 @@ from typing import Union
|
|||||||
from nn4k.executor import LLMExecutor
|
from nn4k.executor import LLMExecutor
|
||||||
|
|
||||||
|
|
||||||
class HfLLMExecutor(LLMExecutor):
|
class HFLLMExecutor(LLMExecutor):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, nn_config: dict) -> "HfLLMExecutor":
|
def from_config(cls, nn_config: dict) -> "HFLLMExecutor":
|
||||||
"""
|
"""
|
||||||
Create an HfLLMExecutor instance from `nn_config`.
|
Create an HFLLMExecutor instance from `nn_config`.
|
||||||
"""
|
"""
|
||||||
executor = cls(nn_config)
|
executor = cls(nn_config)
|
||||||
return executor
|
return executor
|
||||||
@ -102,3 +102,51 @@ class HfLLMExecutor(LLMExecutor):
|
|||||||
for idx, output_id in enumerate(output_ids)
|
for idx, output_id in enumerate(output_ids)
|
||||||
]
|
]
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class HFEmbeddingExecutor(LLMExecutor):
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, nn_config: dict) -> "HFEmbeddingExecutor":
|
||||||
|
"""
|
||||||
|
Create an HFEmbeddingExecutor instance from `nn_config`.
|
||||||
|
"""
|
||||||
|
executor = cls(nn_config)
|
||||||
|
return executor
|
||||||
|
|
||||||
|
def load_model(self, args=None, **kwargs):
|
||||||
|
import torch
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||||
|
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
|
||||||
|
from nn4k.consts import NN_DEVICE_KEY
|
||||||
|
from nn4k.utils.config_parsing import get_string_field
|
||||||
|
|
||||||
|
nn_config: dict = args or self.init_args
|
||||||
|
if self._model is None:
|
||||||
|
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
|
||||||
|
nn_version = nn_config.get(NN_VERSION_KEY)
|
||||||
|
if nn_version is not None:
|
||||||
|
nn_version = get_string_field(
|
||||||
|
nn_config, NN_VERSION_KEY, NN_VERSION_TEXT
|
||||||
|
)
|
||||||
|
model_path = nn_name
|
||||||
|
revision = nn_version
|
||||||
|
use_fast_tokenizer = False
|
||||||
|
device = nn_config.get(NN_DEVICE_KEY)
|
||||||
|
if device is None:
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
#
|
||||||
|
# SentenceTransformer will support `revision` soon. See:
|
||||||
|
#
|
||||||
|
# https://github.com/UKPLab/sentence-transformers/pull/2419
|
||||||
|
#
|
||||||
|
model = SentenceTransformer(
|
||||||
|
model_path,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self._model = model
|
||||||
|
|
||||||
|
def inference(self, data, args=None, **kwargs):
|
||||||
|
model = self.model
|
||||||
|
embeddings = model.encode(data)
|
||||||
|
return embeddings
|
||||||
|
@ -13,7 +13,7 @@ from abc import ABC, abstractmethod
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from nn4k.executor import LLMExecutor
|
from nn4k.executor import NNExecutor
|
||||||
|
|
||||||
|
|
||||||
class SubmitMode(Enum):
|
class SubmitMode(Enum):
|
||||||
@ -157,23 +157,36 @@ class LLMInvoker(NNInvoker):
|
|||||||
Implement local model warming up logic for local invoker.
|
Implement local model warming up logic for local invoker.
|
||||||
"""
|
"""
|
||||||
from nn4k.nnhub import NNHub
|
from nn4k.nnhub import NNHub
|
||||||
|
from nn4k.consts import NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT
|
||||||
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||||
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
|
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
|
||||||
from nn4k.utils.config_parsing import get_string_field
|
from nn4k.utils.config_parsing import get_string_field
|
||||||
|
from nn4k.utils.class_importing import dynamic_import_class
|
||||||
|
|
||||||
nn_name = get_string_field(self.init_args, NN_NAME_KEY, NN_NAME_TEXT)
|
nn_executor = self.init_args.get(NN_EXECUTOR_KEY)
|
||||||
nn_version = self.init_args.get(NN_VERSION_KEY)
|
if nn_executor is not None:
|
||||||
if nn_version is not None:
|
nn_executor = get_string_field(
|
||||||
nn_version = get_string_field(
|
self.init_args, NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT
|
||||||
self.init_args, NN_VERSION_KEY, NN_VERSION_TEXT
|
|
||||||
)
|
)
|
||||||
hub = NNHub.get_instance()
|
executor_class = dynamic_import_class(nn_executor, NN_EXECUTOR_TEXT)
|
||||||
executor = hub.get_model_executor(nn_name, nn_version)
|
if not issubclass(executor_class, NNExecutor):
|
||||||
if executor is None:
|
message = "%r is not an %s class" % (nn_executor, NN_EXECUTOR_TEXT)
|
||||||
message = "model %r version %r " % (nn_name, nn_version)
|
raise RuntimeError(message)
|
||||||
message += "is not found in the model hub"
|
executor = executor_class.from_config(self.init_args)
|
||||||
raise RuntimeError(message)
|
else:
|
||||||
self._nn_executor: LLMExecutor = executor
|
nn_name = get_string_field(self.init_args, NN_NAME_KEY, NN_NAME_TEXT)
|
||||||
|
nn_version = self.init_args.get(NN_VERSION_KEY)
|
||||||
|
if nn_version is not None:
|
||||||
|
nn_version = get_string_field(
|
||||||
|
self.init_args, NN_VERSION_KEY, NN_VERSION_TEXT
|
||||||
|
)
|
||||||
|
hub = NNHub.get_instance()
|
||||||
|
executor = hub.get_model_executor(nn_name, nn_version)
|
||||||
|
if executor is None:
|
||||||
|
message = "model %r version %r " % (nn_name, nn_version)
|
||||||
|
message += "is not found in the model hub"
|
||||||
|
raise RuntimeError(message)
|
||||||
|
self._nn_executor: NNExecutor = executor
|
||||||
self._nn_executor.load_model()
|
self._nn_executor.load_model()
|
||||||
self._nn_executor.warmup_inference()
|
self._nn_executor.warmup_inference()
|
||||||
|
|
||||||
|
@ -23,6 +23,7 @@ class OpenAIInvoker(NNInvoker):
|
|||||||
from nn4k.consts import NN_OPENAI_API_KEY_KEY, NN_OPENAI_API_KEY_TEXT
|
from nn4k.consts import NN_OPENAI_API_KEY_KEY, NN_OPENAI_API_KEY_TEXT
|
||||||
from nn4k.consts import NN_OPENAI_API_BASE_KEY, NN_OPENAI_API_BASE_TEXT
|
from nn4k.consts import NN_OPENAI_API_BASE_KEY, NN_OPENAI_API_BASE_TEXT
|
||||||
from nn4k.consts import NN_OPENAI_MAX_TOKENS_KEY, NN_OPENAI_MAX_TOKENS_TEXT
|
from nn4k.consts import NN_OPENAI_MAX_TOKENS_KEY, NN_OPENAI_MAX_TOKENS_TEXT
|
||||||
|
from nn4k.consts import NN_OPENAI_ORGANIZATION_KEY, NN_OPENAI_ORGANIZATION_TEXT
|
||||||
from nn4k.utils.config_parsing import get_string_field
|
from nn4k.utils.config_parsing import get_string_field
|
||||||
from nn4k.utils.config_parsing import get_positive_int_field
|
from nn4k.utils.config_parsing import get_positive_int_field
|
||||||
|
|
||||||
@ -35,41 +36,107 @@ class OpenAIInvoker(NNInvoker):
|
|||||||
self.openai_api_base = get_string_field(
|
self.openai_api_base = get_string_field(
|
||||||
self.init_args, NN_OPENAI_API_BASE_KEY, NN_OPENAI_API_BASE_TEXT
|
self.init_args, NN_OPENAI_API_BASE_KEY, NN_OPENAI_API_BASE_TEXT
|
||||||
)
|
)
|
||||||
self.openai_max_tokens = get_positive_int_field(
|
self.openai_max_tokens = self.init_args.get(NN_OPENAI_MAX_TOKENS_KEY)
|
||||||
self.init_args, NN_OPENAI_MAX_TOKENS_KEY, NN_OPENAI_MAX_TOKENS_TEXT
|
if self.openai_max_tokens is not None:
|
||||||
)
|
self.openai_max_tokens = get_positive_int_field(
|
||||||
|
self.init_args, NN_OPENAI_MAX_TOKENS_KEY, NN_OPENAI_MAX_TOKENS_TEXT
|
||||||
|
)
|
||||||
|
self.openai_organization = self.init_args.get(NN_OPENAI_ORGANIZATION_KEY)
|
||||||
|
if self.openai_organization is not None:
|
||||||
|
self.openai_organization = get_string_field(
|
||||||
|
self.init_args, NN_OPENAI_ORGANIZATION_KEY, NN_OPENAI_ORGANIZATION_TEXT
|
||||||
|
)
|
||||||
|
|
||||||
openai.api_key = self.openai_api_key
|
if self._is_legacy_openai_api:
|
||||||
openai.api_base = self.openai_api_base
|
self.client = None
|
||||||
|
openai.api_key = self.openai_api_key
|
||||||
|
openai.api_base = self.openai_api_base
|
||||||
|
openai.organization = self.openai_organization
|
||||||
|
else:
|
||||||
|
self.client = openai.OpenAI(
|
||||||
|
api_key=self.openai_api_key,
|
||||||
|
base_url=self.openai_api_base,
|
||||||
|
organization=self.openai_organization,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, nn_config: dict) -> "OpenAIInvoker":
|
def from_config(cls, nn_config: dict) -> "OpenAIInvoker":
|
||||||
invoker = cls(nn_config)
|
invoker = cls(nn_config)
|
||||||
return invoker
|
return invoker
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _is_legacy_openai_api(self):
|
||||||
|
import openai
|
||||||
|
|
||||||
|
return openai.__version__.startswith("0.")
|
||||||
|
|
||||||
def _create_prompt(self, input, **kwargs):
|
def _create_prompt(self, input, **kwargs):
|
||||||
if isinstance(input, list):
|
if isinstance(input, list):
|
||||||
prompt = input
|
prompt = input
|
||||||
else:
|
else:
|
||||||
prompt = [input]
|
prompt = [{"role": "user", "content": input}]
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def _create_output(self, input, prompt, completion, **kwargs):
|
def _create_completion(self, input, prompt, max_output_length, **kwargs):
|
||||||
output = [choice.text for choice in completion.choices]
|
|
||||||
return output
|
|
||||||
|
|
||||||
def remote_inference(
|
|
||||||
self, input, max_output_length: Optional[int] = None, **kwargs
|
|
||||||
):
|
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
if max_output_length is None:
|
if self._is_legacy_openai_api:
|
||||||
max_output_length = self.openai_max_tokens
|
completion = openai.ChatCompletion.create(
|
||||||
|
model=self.openai_model_name,
|
||||||
|
messages=prompt,
|
||||||
|
max_tokens=max_output_length,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
completion = self.client.chat.completions.create(
|
||||||
|
model=self.openai_model_name,
|
||||||
|
messages=prompt,
|
||||||
|
max_tokens=max_output_length,
|
||||||
|
)
|
||||||
|
return completion
|
||||||
|
|
||||||
|
def _create_output(self, input, prompt, completion, **kwargs):
|
||||||
|
output = [choice.message.content for choice in completion.choices]
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _get_completion(self, input, max_output_length, **kwargs):
|
||||||
prompt = self._create_prompt(input, **kwargs)
|
prompt = self._create_prompt(input, **kwargs)
|
||||||
completion = openai.Completion.create(
|
completion = self._create_completion(input, prompt, max_output_length, **kwargs)
|
||||||
model=self.openai_model_name,
|
|
||||||
prompt=prompt,
|
|
||||||
max_tokens=max_output_length,
|
|
||||||
)
|
|
||||||
output = self._create_output(input, prompt, completion, **kwargs)
|
output = self._create_output(input, prompt, completion, **kwargs)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def _get_embeddings(self, input, **kwargs):
|
||||||
|
import openai
|
||||||
|
|
||||||
|
if isinstance(input, list):
|
||||||
|
inputs = input
|
||||||
|
else:
|
||||||
|
inputs = [input]
|
||||||
|
|
||||||
|
if self._is_legacy_openai_api:
|
||||||
|
response = openai.Embedding.create(
|
||||||
|
model=self.openai_model_name,
|
||||||
|
input=inputs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = self.client.embeddings.create(
|
||||||
|
model=self.openai_model_name,
|
||||||
|
input=inputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = [emb.embedding for emb in response.data]
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def remote_inference(
|
||||||
|
self,
|
||||||
|
input,
|
||||||
|
type: Optional[str] = None,
|
||||||
|
max_output_length: Optional[int] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if type == "Embedding":
|
||||||
|
output = self._get_embeddings(input, **kwargs)
|
||||||
|
else:
|
||||||
|
if max_output_length is None:
|
||||||
|
max_output_length = self.openai_max_tokens
|
||||||
|
output = self._get_completion(input, max_output_length, **kwargs)
|
||||||
|
return output
|
||||||
|
@ -9,6 +9,8 @@
|
|||||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
# or implied.
|
# or implied.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional, Union, Tuple, Type
|
from typing import Optional, Union, Tuple, Type
|
||||||
|
|
||||||
@ -125,6 +127,10 @@ class SimpleNNHub(NNHub):
|
|||||||
def get_model_executor(
|
def get_model_executor(
|
||||||
self, name: str, version: str = None
|
self, name: str, version: str = None
|
||||||
) -> Optional[NNExecutor]:
|
) -> Optional[NNExecutor]:
|
||||||
|
from nn4k.consts import NN_VERSION_DEFAULT
|
||||||
|
|
||||||
|
if version is None:
|
||||||
|
version = NN_VERSION_DEFAULT
|
||||||
if self._model_executors.get(name) is None:
|
if self._model_executors.get(name) is None:
|
||||||
return None
|
return None
|
||||||
executor = self._model_executors.get(name).get(version)
|
executor = self._model_executors.get(name).get(version)
|
||||||
@ -134,13 +140,59 @@ class SimpleNNHub(NNHub):
|
|||||||
executor = self._create_model_executor(cls, init_args, kwargs, weights)
|
executor = self._create_model_executor(cls, init_args, kwargs, weights)
|
||||||
return executor
|
return executor
|
||||||
|
|
||||||
def _add_local_executor(self, nn_config):
|
def _get_local_executor_class(self, nn_config: dict) -> Type[NNExecutor]:
|
||||||
|
from nn4k.consts import NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT
|
||||||
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||||
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
|
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
|
||||||
from nn4k.executor.hugging_face import HfLLMExecutor
|
from nn4k.consts import NN_LOCAL_HF_MODEL_CONFIG_FILE
|
||||||
|
from nn4k.consts import NN_LOCAL_SENTENCE_TRANSFORMERS_CONFIG_FILE
|
||||||
|
from nn4k.executor.hugging_face import HFLLMExecutor
|
||||||
|
from nn4k.executor.hugging_face import HFEmbeddingExecutor
|
||||||
from nn4k.utils.config_parsing import get_string_field
|
from nn4k.utils.config_parsing import get_string_field
|
||||||
|
|
||||||
executor = HfLLMExecutor.from_config(nn_config)
|
nn_executor = nn_config.get(NN_EXECUTOR_KEY)
|
||||||
|
if nn_executor is not None:
|
||||||
|
nn_executor = get_string_field(
|
||||||
|
self.init_args, NN_EXECUTOR_KEY, NN_EXECUTOR_TEXT
|
||||||
|
)
|
||||||
|
executor_class = dynamic_import_class(nn_executor, NN_EXECUTOR_TEXT)
|
||||||
|
if not issubclass(executor_class, NNExecutor):
|
||||||
|
message = "%r is not an %s class" % (nn_executor, NN_EXECUTOR_TEXT)
|
||||||
|
raise RuntimeError(message)
|
||||||
|
return executor_class
|
||||||
|
|
||||||
|
nn_name = nn_config.get(NN_NAME_KEY)
|
||||||
|
if nn_name is not None:
|
||||||
|
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
|
||||||
|
if os.path.isdir(nn_name):
|
||||||
|
file_path = os.path.join(nn_name, NN_LOCAL_HF_MODEL_CONFIG_FILE)
|
||||||
|
if os.path.isfile(file_path):
|
||||||
|
file_path = os.path.join(
|
||||||
|
nn_name, NN_LOCAL_SENTENCE_TRANSFORMERS_CONFIG_FILE
|
||||||
|
)
|
||||||
|
if os.path.isfile(file_path):
|
||||||
|
executor_class = HFEmbeddingExecutor
|
||||||
|
else:
|
||||||
|
executor_class = HFLLMExecutor
|
||||||
|
return executor_class
|
||||||
|
|
||||||
|
nn_version = nn_config.get(NN_VERSION_KEY)
|
||||||
|
if nn_version is not None:
|
||||||
|
nn_version = get_string_field(nn_config, NN_VERSION_KEY, NN_VERSION_TEXT)
|
||||||
|
message = "can not determine local executor class for NN config"
|
||||||
|
if nn_name is not None:
|
||||||
|
message += "; model: %r" % nn_name
|
||||||
|
if nn_version is not None:
|
||||||
|
message += ", version: %r" % nn_version
|
||||||
|
raise RuntimeError(message)
|
||||||
|
|
||||||
|
def _add_local_executor(self, nn_config: dict):
|
||||||
|
from nn4k.consts import NN_NAME_KEY, NN_NAME_TEXT
|
||||||
|
from nn4k.consts import NN_VERSION_KEY, NN_VERSION_TEXT
|
||||||
|
from nn4k.utils.config_parsing import get_string_field
|
||||||
|
|
||||||
|
executor_class = self._get_local_executor_class(nn_config)
|
||||||
|
executor = executor_class.from_config(nn_config)
|
||||||
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
|
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
|
||||||
nn_version = nn_config.get(NN_VERSION_KEY)
|
nn_version = nn_config.get(NN_VERSION_KEY)
|
||||||
if nn_version is not None:
|
if nn_version is not None:
|
||||||
|
@ -23,18 +23,27 @@ def is_openai_invoker(nn_config: dict) -> bool:
|
|||||||
from nn4k.consts import NN_OPENAI_API_KEY_KEY
|
from nn4k.consts import NN_OPENAI_API_KEY_KEY
|
||||||
from nn4k.consts import NN_OPENAI_API_BASE_KEY
|
from nn4k.consts import NN_OPENAI_API_BASE_KEY
|
||||||
from nn4k.consts import NN_OPENAI_MAX_TOKENS_KEY
|
from nn4k.consts import NN_OPENAI_MAX_TOKENS_KEY
|
||||||
|
from nn4k.consts import NN_OPENAI_ORGANIZATION_KEY
|
||||||
from nn4k.consts import NN_OPENAI_GPT4_PREFIX
|
from nn4k.consts import NN_OPENAI_GPT4_PREFIX
|
||||||
from nn4k.consts import NN_OPENAI_GPT35_PREFIX
|
from nn4k.consts import NN_OPENAI_GPT35_PREFIX
|
||||||
|
from nn4k.consts import NN_OPENAI_EMBEDDING_PREFIX
|
||||||
from nn4k.utils.config_parsing import get_string_field
|
from nn4k.utils.config_parsing import get_string_field
|
||||||
|
|
||||||
nn_name = nn_config.get(NN_NAME_KEY)
|
nn_name = nn_config.get(NN_NAME_KEY)
|
||||||
if nn_name is not None:
|
if nn_name is not None:
|
||||||
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
|
nn_name = get_string_field(nn_config, NN_NAME_KEY, NN_NAME_TEXT)
|
||||||
if nn_name.startswith(NN_OPENAI_GPT4_PREFIX) or nn_name.startswith(
|
if (
|
||||||
NN_OPENAI_GPT35_PREFIX
|
nn_name.startswith(NN_OPENAI_GPT4_PREFIX)
|
||||||
|
or nn_name.startswith(NN_OPENAI_GPT35_PREFIX)
|
||||||
|
or nn_name.startswith(NN_OPENAI_EMBEDDING_PREFIX)
|
||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
keys = (NN_OPENAI_API_KEY_KEY, NN_OPENAI_API_BASE_KEY, NN_OPENAI_MAX_TOKENS_KEY)
|
keys = (
|
||||||
|
NN_OPENAI_API_KEY_KEY,
|
||||||
|
NN_OPENAI_API_BASE_KEY,
|
||||||
|
NN_OPENAI_MAX_TOKENS_KEY,
|
||||||
|
NN_OPENAI_ORGANIZATION_KEY,
|
||||||
|
)
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key in nn_config:
|
if key in nn_config:
|
||||||
return True
|
return True
|
||||||
|
@ -1 +1 @@
|
|||||||
openai<1
|
openai
|
||||||
|
@ -42,7 +42,7 @@ class StubHub(SimpleNNHub):
|
|||||||
def get_model_executor(
|
def get_model_executor(
|
||||||
self, name: str, version: str = None
|
self, name: str, version: str = None
|
||||||
) -> Optional[NNExecutor]:
|
) -> Optional[NNExecutor]:
|
||||||
if name == "test_stub":
|
if name == "executor_test_stub":
|
||||||
if version is None:
|
if version is None:
|
||||||
version = "default"
|
version = "default"
|
||||||
executor = StubExecutor(
|
executor = StubExecutor(
|
@ -20,12 +20,12 @@ class TestBaseExecutor(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
# for importing test_stub.py
|
# for importing executor_test_stub.py
|
||||||
dir_path = os.path.dirname(os.path.abspath(__file__))
|
dir_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
sys.path.insert(0, dir_path)
|
sys.path.insert(0, dir_path)
|
||||||
|
|
||||||
from nn4k.nnhub import NNHub
|
from nn4k.nnhub import NNHub
|
||||||
from test_stub import StubHub
|
from executor_test_stub import StubHub
|
||||||
|
|
||||||
NNHub._hub_instance = StubHub()
|
NNHub._hub_instance = StubHub()
|
||||||
|
|
||||||
@ -37,22 +37,24 @@ class TestBaseExecutor(unittest.TestCase):
|
|||||||
|
|
||||||
def testCustomNNExecutor(self):
|
def testCustomNNExecutor(self):
|
||||||
from nn4k.executor import NNExecutor
|
from nn4k.executor import NNExecutor
|
||||||
from test_stub import StubExecutor
|
from executor_test_stub import StubExecutor
|
||||||
|
|
||||||
nn_config = {"nn_executor": "test_stub.StubExecutor"}
|
nn_config = {"nn_executor": "executor_test_stub.StubExecutor"}
|
||||||
executor = NNExecutor.from_config(nn_config)
|
executor = NNExecutor.from_config(nn_config)
|
||||||
self.assertTrue(isinstance(executor, StubExecutor))
|
self.assertTrue(isinstance(executor, StubExecutor))
|
||||||
self.assertEqual(executor.init_args, nn_config)
|
self.assertEqual(executor.init_args, nn_config)
|
||||||
self.assertEqual(executor.kwargs, {})
|
self.assertEqual(executor.kwargs, {})
|
||||||
|
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
executor = NNExecutor.from_config({"nn_executor": "test_stub.NotExecutor"})
|
executor = NNExecutor.from_config(
|
||||||
|
{"nn_executor": "executor_test_stub.NotExecutor"}
|
||||||
|
)
|
||||||
|
|
||||||
def testHubExecutor(self):
|
def testHubExecutor(self):
|
||||||
from nn4k.executor import NNExecutor
|
from nn4k.executor import NNExecutor
|
||||||
from test_stub import StubExecutor
|
from executor_test_stub import StubExecutor
|
||||||
|
|
||||||
nn_config = {"nn_name": "test_stub", "nn_version": "default"}
|
nn_config = {"nn_name": "executor_test_stub", "nn_version": "default"}
|
||||||
executor = NNExecutor.from_config(nn_config)
|
executor = NNExecutor.from_config(nn_config)
|
||||||
self.assertTrue(isinstance(executor, StubExecutor))
|
self.assertTrue(isinstance(executor, StubExecutor))
|
||||||
self.assertEqual(executor.init_args, nn_config)
|
self.assertEqual(executor.init_args, nn_config)
|
||||||
|
57
python/nn4k/tests/executor/test_hf_embedding_executor.py
Normal file
57
python/nn4k/tests/executor/test_hf_embedding_executor.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
# Copyright 2023 OpenSPG Authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
import unittest.mock
|
||||||
|
|
||||||
|
from nn4k.executor.hugging_face import HFEmbeddingExecutor
|
||||||
|
|
||||||
|
|
||||||
|
class TestHFEmbeddingExecutor(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
HFEmbeddingExecutor unittest
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self._saved_torch = sys.modules.get("torch")
|
||||||
|
self._mocked_torch = unittest.mock.MagicMock()
|
||||||
|
sys.modules["torch"] = self._mocked_torch
|
||||||
|
|
||||||
|
self._saved_sentence_transformers = sys.modules.get("sentence_transformers")
|
||||||
|
self._mocked_sentence_transformers = unittest.mock.MagicMock()
|
||||||
|
sys.modules["sentence_transformers"] = self._mocked_sentence_transformers
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
del sys.modules["torch"]
|
||||||
|
if self._saved_torch is not None:
|
||||||
|
sys.modules["torch"] = self._saved_torch
|
||||||
|
|
||||||
|
del sys.modules["sentence_transformers"]
|
||||||
|
if self._saved_sentence_transformers is not None:
|
||||||
|
sys.modules["sentence_transformers"] = self._saved_sentence_transformers
|
||||||
|
|
||||||
|
def testHFEmbeddingExecutor(self):
|
||||||
|
nn_config = {
|
||||||
|
"nn_name": "/opt/test_model_dir",
|
||||||
|
"nn_version": "default",
|
||||||
|
}
|
||||||
|
|
||||||
|
executor = HFEmbeddingExecutor.from_config(nn_config)
|
||||||
|
executor.load_model()
|
||||||
|
executor.inference("input")
|
||||||
|
|
||||||
|
self._mocked_sentence_transformers.SentenceTransformer.assert_called()
|
||||||
|
executor.model.encode.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
@ -13,12 +13,12 @@ import sys
|
|||||||
import unittest
|
import unittest
|
||||||
import unittest.mock
|
import unittest.mock
|
||||||
|
|
||||||
from nn4k.executor.hugging_face import HfLLMExecutor
|
from nn4k.executor.hugging_face import HFLLMExecutor
|
||||||
|
|
||||||
|
|
||||||
class TestHfLLMExecutor(unittest.TestCase):
|
class TestHFLLMExecutor(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
HfLLMExecutor unittest
|
HFLLMExecutor unittest
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -39,18 +39,20 @@ class TestHfLLMExecutor(unittest.TestCase):
|
|||||||
if self._saved_transformers is not None:
|
if self._saved_transformers is not None:
|
||||||
sys.modules["transformers"] = self._saved_transformers
|
sys.modules["transformers"] = self._saved_transformers
|
||||||
|
|
||||||
def testHfLLMExecutor(self):
|
def testHFLLMExecutor(self):
|
||||||
nn_config = {
|
nn_config = {
|
||||||
"nn_name": "/opt/test_model_dir",
|
"nn_name": "/opt/test_model_dir",
|
||||||
"nn_version": "default",
|
"nn_version": "default",
|
||||||
}
|
}
|
||||||
|
|
||||||
executor = HfLLMExecutor.from_config(nn_config)
|
executor = HFLLMExecutor.from_config(nn_config)
|
||||||
executor.load_model()
|
executor.load_model()
|
||||||
executor.inference("input")
|
executor.inference("input")
|
||||||
|
|
||||||
self._mocked_transformers.AutoTokenizer.from_pretrained.assert_called()
|
self._mocked_transformers.AutoTokenizer.from_pretrained.assert_called()
|
||||||
self._mocked_transformers.AutoModelForCausalLM.from_pretrained.assert_called()
|
self._mocked_transformers.AutoModelForCausalLM.from_pretrained.assert_called()
|
||||||
|
executor.tokenizer.assert_called()
|
||||||
|
executor.model.generate.assert_called()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -52,7 +52,7 @@ class StubExecutor(NNExecutor):
|
|||||||
class StubHub(SimpleNNHub):
|
class StubHub(SimpleNNHub):
|
||||||
def get_invoker(self, nn_config: dict) -> Optional[NNInvoker]:
|
def get_invoker(self, nn_config: dict) -> Optional[NNInvoker]:
|
||||||
nn_name = nn_config.get("nn_name")
|
nn_name = nn_config.get("nn_name")
|
||||||
if nn_name is not None and nn_name == "test_stub":
|
if nn_name is not None and nn_name == "invoker_test_stub":
|
||||||
invoker = StubInvoker(nn_config, test_stub_invoker=True)
|
invoker = StubInvoker(nn_config, test_stub_invoker=True)
|
||||||
return invoker
|
return invoker
|
||||||
return super().get_invoker(nn_config)
|
return super().get_invoker(nn_config)
|
||||||
@ -60,7 +60,7 @@ class StubHub(SimpleNNHub):
|
|||||||
def get_model_executor(
|
def get_model_executor(
|
||||||
self, name: str, version: str = None
|
self, name: str, version: str = None
|
||||||
) -> Optional[NNExecutor]:
|
) -> Optional[NNExecutor]:
|
||||||
if name == "test_stub":
|
if name == "invoker_test_stub":
|
||||||
if version is None:
|
if version is None:
|
||||||
version = "default"
|
version = "default"
|
||||||
executor = StubExecutor(
|
executor = StubExecutor(
|
@ -20,12 +20,12 @@ class TestBaseInvoker(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
# for importing test_stub.py
|
# for importing invoker_test_stub.py
|
||||||
dir_path = os.path.dirname(os.path.abspath(__file__))
|
dir_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
sys.path.insert(0, dir_path)
|
sys.path.insert(0, dir_path)
|
||||||
|
|
||||||
from nn4k.nnhub import NNHub
|
from nn4k.nnhub import NNHub
|
||||||
from test_stub import StubHub
|
from invoker_test_stub import StubHub
|
||||||
|
|
||||||
NNHub._hub_instance = StubHub()
|
NNHub._hub_instance = StubHub()
|
||||||
|
|
||||||
@ -37,22 +37,24 @@ class TestBaseInvoker(unittest.TestCase):
|
|||||||
|
|
||||||
def testCustomNNInvoker(self):
|
def testCustomNNInvoker(self):
|
||||||
from nn4k.invoker import NNInvoker
|
from nn4k.invoker import NNInvoker
|
||||||
from test_stub import StubInvoker
|
from invoker_test_stub import StubInvoker
|
||||||
|
|
||||||
nn_config = {"nn_invoker": "test_stub.StubInvoker"}
|
nn_config = {"nn_invoker": "invoker_test_stub.StubInvoker"}
|
||||||
invoker = NNInvoker.from_config(nn_config)
|
invoker = NNInvoker.from_config(nn_config)
|
||||||
self.assertTrue(isinstance(invoker, StubInvoker))
|
self.assertTrue(isinstance(invoker, StubInvoker))
|
||||||
self.assertEqual(invoker.init_args, nn_config)
|
self.assertEqual(invoker.init_args, nn_config)
|
||||||
self.assertEqual(invoker.kwargs, {})
|
self.assertEqual(invoker.kwargs, {})
|
||||||
|
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
invoker = NNInvoker.from_config({"nn_invoker": "test_stub.NotInvoker"})
|
invoker = NNInvoker.from_config(
|
||||||
|
{"nn_invoker": "invoker_test_stub.NotInvoker"}
|
||||||
|
)
|
||||||
|
|
||||||
def testHubInvoker(self):
|
def testHubInvoker(self):
|
||||||
from nn4k.invoker import NNInvoker
|
from nn4k.invoker import NNInvoker
|
||||||
from test_stub import StubInvoker
|
from invoker_test_stub import StubInvoker
|
||||||
|
|
||||||
nn_config = {"nn_name": "test_stub"}
|
nn_config = {"nn_name": "invoker_test_stub"}
|
||||||
invoker = NNInvoker.from_config(nn_config)
|
invoker = NNInvoker.from_config(nn_config)
|
||||||
self.assertTrue(isinstance(invoker, StubInvoker))
|
self.assertTrue(isinstance(invoker, StubInvoker))
|
||||||
self.assertEqual(invoker.init_args, nn_config)
|
self.assertEqual(invoker.init_args, nn_config)
|
||||||
@ -66,9 +68,9 @@ class TestBaseInvoker(unittest.TestCase):
|
|||||||
|
|
||||||
def testLocalInvoker(self):
|
def testLocalInvoker(self):
|
||||||
from nn4k.invoker import NNInvoker
|
from nn4k.invoker import NNInvoker
|
||||||
from test_stub import StubInvoker
|
from invoker_test_stub import StubInvoker
|
||||||
|
|
||||||
nn_config = {"nn_name": "test_stub"}
|
nn_config = {"nn_name": "invoker_test_stub"}
|
||||||
invoker = NNInvoker.from_config(nn_config)
|
invoker = NNInvoker.from_config(nn_config)
|
||||||
self.assertTrue(isinstance(invoker, StubInvoker))
|
self.assertTrue(isinstance(invoker, StubInvoker))
|
||||||
self.assertEqual(invoker.init_args, nn_config)
|
self.assertEqual(invoker.init_args, nn_config)
|
||||||
@ -79,6 +81,19 @@ class TestBaseInvoker(unittest.TestCase):
|
|||||||
result = invoker.local_inference("input")
|
result = invoker.local_inference("input")
|
||||||
self.assertEqual(result, invoker._nn_executor.inference_result)
|
self.assertEqual(result, invoker._nn_executor.inference_result)
|
||||||
|
|
||||||
|
def testLocalLLMInvokerWithCustomExecutor(self):
|
||||||
|
from nn4k.invoker import LLMInvoker
|
||||||
|
|
||||||
|
nn_config = {"nn_executor": "invoker_test_stub.StubExecutor"}
|
||||||
|
invoker = LLMInvoker.from_config(nn_config)
|
||||||
|
self.assertTrue(isinstance(invoker, LLMInvoker))
|
||||||
|
self.assertEqual(invoker.init_args, nn_config)
|
||||||
|
|
||||||
|
invoker.warmup_local_model()
|
||||||
|
invoker._nn_executor.inference_result = "inference result"
|
||||||
|
result = invoker.local_inference("input")
|
||||||
|
self.assertEqual(result, invoker._nn_executor.inference_result)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -22,9 +22,24 @@ class MockCompletion:
|
|||||||
choices: list
|
choices: list
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockMessage:
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MockChoice:
|
class MockChoice:
|
||||||
text: str
|
message: MockMessage
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockEmbeddings:
|
||||||
|
data: list
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockEmbedding:
|
||||||
|
embedding: list
|
||||||
|
|
||||||
|
|
||||||
class TestOpenAIInvoker(unittest.TestCase):
|
class TestOpenAIInvoker(unittest.TestCase):
|
||||||
@ -42,7 +57,10 @@ class TestOpenAIInvoker(unittest.TestCase):
|
|||||||
if self._saved_openai is not None:
|
if self._saved_openai is not None:
|
||||||
sys.modules["openai"] = self._saved_openai
|
sys.modules["openai"] = self._saved_openai
|
||||||
|
|
||||||
def testOpenAIInvoker(self):
|
def testOpenAICompletion(self):
|
||||||
|
self._mocked_openai.__version__ = "1.7.0"
|
||||||
|
self._mocked_openai.OpenAI = unittest.mock.MagicMock
|
||||||
|
|
||||||
nn_config = {
|
nn_config = {
|
||||||
"nn_name": "gpt-3.5-turbo",
|
"nn_name": "gpt-3.5-turbo",
|
||||||
"openai_api_key": "EMPTY",
|
"openai_api_key": "EMPTY",
|
||||||
@ -51,19 +69,45 @@ class TestOpenAIInvoker(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
invoker = NNInvoker.from_config(nn_config)
|
invoker = NNInvoker.from_config(nn_config)
|
||||||
self.assertEqual(invoker.init_args, nn_config)
|
self.assertEqual(invoker.init_args, nn_config)
|
||||||
self.assertEqual(self._mocked_openai.api_key, nn_config["openai_api_key"])
|
self.assertEqual(invoker.client.api_key, nn_config["openai_api_key"])
|
||||||
self.assertEqual(self._mocked_openai.api_base, nn_config["openai_api_base"])
|
self.assertEqual(invoker.client.base_url, nn_config["openai_api_base"])
|
||||||
|
|
||||||
mock_completion = MockCompletion(choices=[MockChoice("a dog named Bolt ...")])
|
mock_completion = MockCompletion(
|
||||||
self._mocked_openai.Completion.create.return_value = mock_completion
|
choices=[MockChoice(message=MockMessage(content="a dog named Bolt ..."))]
|
||||||
|
)
|
||||||
|
invoker.client.chat.completions.create.return_value = mock_completion
|
||||||
|
|
||||||
result = invoker.remote_inference("Long long ago, ")
|
result = invoker.remote_inference("Long long ago, ")
|
||||||
self._mocked_openai.Completion.create.assert_called_with(
|
invoker.client.chat.completions.create.assert_called_with(
|
||||||
prompt=["Long long ago, "],
|
|
||||||
model=nn_config["nn_name"],
|
model=nn_config["nn_name"],
|
||||||
|
messages=[{"role": "user", "content": "Long long ago, "}],
|
||||||
max_tokens=nn_config["openai_max_tokens"],
|
max_tokens=nn_config["openai_max_tokens"],
|
||||||
)
|
)
|
||||||
self.assertEqual(result, [mock_completion.choices[0].text])
|
self.assertEqual(result, [mock_completion.choices[0].message.content])
|
||||||
|
|
||||||
|
def testOpenAIEmbedding(self):
|
||||||
|
self._mocked_openai.__version__ = "1.7.0"
|
||||||
|
self._mocked_openai.OpenAI = unittest.mock.MagicMock
|
||||||
|
|
||||||
|
nn_config = {
|
||||||
|
"nn_name": "text-embedding-ada-002",
|
||||||
|
"openai_api_key": "EMPTY",
|
||||||
|
"openai_api_base": "http://localhost:38080/v1",
|
||||||
|
}
|
||||||
|
invoker = NNInvoker.from_config(nn_config)
|
||||||
|
self.assertEqual(invoker.init_args, nn_config)
|
||||||
|
self.assertEqual(invoker.client.api_key, nn_config["openai_api_key"])
|
||||||
|
self.assertEqual(invoker.client.base_url, nn_config["openai_api_base"])
|
||||||
|
|
||||||
|
mock_embeddings = MockEmbeddings(data=[MockEmbedding(embedding=[0.1, 0.2])])
|
||||||
|
invoker.client.embeddings.create.return_value = mock_embeddings
|
||||||
|
|
||||||
|
result = invoker.remote_inference("How old are you?", type="Embedding")
|
||||||
|
invoker.client.embeddings.create.assert_called_with(
|
||||||
|
model=nn_config["nn_name"],
|
||||||
|
input=["How old are you?"],
|
||||||
|
)
|
||||||
|
self.assertEqual(result, [mock_embeddings.data[0].embedding])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
113
python/nn4k/tests/invoker/test_openai_invoker_legacy_api.py
Normal file
113
python/nn4k/tests/invoker/test_openai_invoker_legacy_api.py
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
# Copyright 2023 OpenSPG Authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
import unittest.mock
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from nn4k.invoker import NNInvoker
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockCompletion:
|
||||||
|
choices: list
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockMessage:
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockChoice:
|
||||||
|
message: MockMessage
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockEmbeddings:
|
||||||
|
data: list
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockEmbedding:
|
||||||
|
embedding: list
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIInvokerLegacyAPI(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
OpenAIInvoker unittest for legacy OpenAI api
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self._saved_openai = sys.modules.get("openai")
|
||||||
|
self._mocked_openai = unittest.mock.MagicMock()
|
||||||
|
sys.modules["openai"] = self._mocked_openai
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
del sys.modules["openai"]
|
||||||
|
if self._saved_openai is not None:
|
||||||
|
sys.modules["openai"] = self._saved_openai
|
||||||
|
|
||||||
|
def testOpenAICompletion(self):
|
||||||
|
self._mocked_openai.__version__ = "0.28.1"
|
||||||
|
|
||||||
|
nn_config = {
|
||||||
|
"nn_name": "gpt-3.5-turbo",
|
||||||
|
"openai_api_key": "EMPTY",
|
||||||
|
"openai_api_base": "http://localhost:38080/v1",
|
||||||
|
"openai_max_tokens": 2000,
|
||||||
|
}
|
||||||
|
invoker = NNInvoker.from_config(nn_config)
|
||||||
|
self.assertEqual(invoker.init_args, nn_config)
|
||||||
|
self.assertEqual(self._mocked_openai.api_key, nn_config["openai_api_key"])
|
||||||
|
self.assertEqual(self._mocked_openai.api_base, nn_config["openai_api_base"])
|
||||||
|
|
||||||
|
mock_completion = MockCompletion(
|
||||||
|
choices=[MockChoice(message=MockMessage(content="a dog named Bolt ..."))]
|
||||||
|
)
|
||||||
|
self._mocked_openai.ChatCompletion.create.return_value = mock_completion
|
||||||
|
|
||||||
|
result = invoker.remote_inference("Long long ago, ")
|
||||||
|
self._mocked_openai.ChatCompletion.create.assert_called_with(
|
||||||
|
model=nn_config["nn_name"],
|
||||||
|
messages=[{"role": "user", "content": "Long long ago, "}],
|
||||||
|
max_tokens=nn_config["openai_max_tokens"],
|
||||||
|
)
|
||||||
|
self.assertEqual(result, [mock_completion.choices[0].message.content])
|
||||||
|
|
||||||
|
def testOpenAIEmbedding(self):
|
||||||
|
self._mocked_openai.__version__ = "0.28.1"
|
||||||
|
self._mocked_openai.OpenAI = unittest.mock.MagicMock
|
||||||
|
|
||||||
|
nn_config = {
|
||||||
|
"nn_name": "text-embedding-ada-002",
|
||||||
|
"openai_api_key": "EMPTY",
|
||||||
|
"openai_api_base": "http://localhost:38080/v1",
|
||||||
|
}
|
||||||
|
invoker = NNInvoker.from_config(nn_config)
|
||||||
|
self.assertEqual(invoker.init_args, nn_config)
|
||||||
|
self.assertEqual(self._mocked_openai.api_key, nn_config["openai_api_key"])
|
||||||
|
self.assertEqual(self._mocked_openai.api_base, nn_config["openai_api_base"])
|
||||||
|
|
||||||
|
mock_embeddings = MockEmbeddings(data=[MockEmbedding(embedding=[0.1, 0.2])])
|
||||||
|
self._mocked_openai.Embedding.create.return_value = mock_embeddings
|
||||||
|
|
||||||
|
result = invoker.remote_inference("How old are you?", type="Embedding")
|
||||||
|
self._mocked_openai.Embedding.create.assert_called_with(
|
||||||
|
model=nn_config["nn_name"],
|
||||||
|
input=["How old are you?"],
|
||||||
|
)
|
||||||
|
self.assertEqual(result, [mock_embeddings.data[0].embedding])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
36
python/nn4k/tests/python-env/.env.restore.sh
Executable file
36
python/nn4k/tests/python-env/.env.restore.sh
Executable file
@ -0,0 +1,36 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Copyright 2023 OpenSPG Authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
_SCRIPT_FILE_PATH="${BASH_SOURCE[0]}"
|
||||||
|
while [ -h ${_SCRIPT_FILE_PATH} ]
|
||||||
|
do
|
||||||
|
_SCRIPT_DIR_PATH=$(cd -P "$(dirname ${_SCRIPT_FILE_PATH})" && pwd)
|
||||||
|
_SCRIPT_FILE_PATH=$(readlink ${_SCRIPT_FILE_PATH})
|
||||||
|
case ${_SCRIPT_FILE_PATH} in
|
||||||
|
/*) ;;
|
||||||
|
*) _SCRIPT_FILE_PATH=${_SCRIPT_DIR_PATH}/${_SCRIPT_FILE_PATH} ;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
_SCRIPT_DIR_PATH=$(cd -P "$(dirname ${_SCRIPT_FILE_PATH})" && pwd)
|
||||||
|
|
||||||
|
if [ -f ${_SCRIPT_DIR_PATH}/.env/requirements.txt ]
|
||||||
|
then
|
||||||
|
exit
|
||||||
|
fi
|
||||||
|
|
||||||
|
set -e
|
||||||
|
rm -rf ${_SCRIPT_DIR_PATH}/.env
|
||||||
|
python3 -m venv ${_SCRIPT_DIR_PATH}/.env
|
||||||
|
source ${_SCRIPT_DIR_PATH}/.env/bin/activate
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
python -m pip freeze > ${_SCRIPT_DIR_PATH}/.env/requirements.txt
|
1
python/nn4k/tests/python-env/.gitignore
vendored
Normal file
1
python/nn4k/tests/python-env/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
/.env/
|
25
python/nn4k/tests/python-env/env.sh
Normal file
25
python/nn4k/tests/python-env/env.sh
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# Copyright 2023 OpenSPG Authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
_SCRIPT_FILE_PATH="${BASH_SOURCE[0]}"
|
||||||
|
while [ -h ${_SCRIPT_FILE_PATH} ]
|
||||||
|
do
|
||||||
|
_SCRIPT_DIR_PATH=$(cd -P "$(dirname ${_SCRIPT_FILE_PATH})" && pwd)
|
||||||
|
_SCRIPT_FILE_PATH=$(readlink ${_SCRIPT_FILE_PATH})
|
||||||
|
case ${_SCRIPT_FILE_PATH} in
|
||||||
|
/*) ;;
|
||||||
|
*) _SCRIPT_FILE_PATH=${_SCRIPT_DIR_PATH}/${_SCRIPT_FILE_PATH} ;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
_SCRIPT_DIR_PATH=$(cd -P "$(dirname ${_SCRIPT_FILE_PATH})" && pwd)
|
||||||
|
|
||||||
|
${_SCRIPT_DIR_PATH}/.env.restore.sh
|
||||||
|
source ${_SCRIPT_DIR_PATH}/.env/bin/activate
|
60
python/nn4k/tests/run_all_tests.py
Executable file
60
python/nn4k/tests/run_all_tests.py
Executable file
@ -0,0 +1,60 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- encoding: utf-8 -*-
|
||||||
|
# -*- mode: python -*-
|
||||||
|
|
||||||
|
# Copyright 2023 OpenSPG Authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||||
|
# in compliance with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||||
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
|
# or implied.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
|
||||||
|
|
||||||
|
class NN4KTestsRunner(object):
|
||||||
|
def _run_all_tests(self):
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
dir_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
restore_script_path = os.path.join(dir_path, "python-env", ".env.restore.sh")
|
||||||
|
args = [restore_script_path]
|
||||||
|
subprocess.check_call(args)
|
||||||
|
|
||||||
|
nn4k_dir_path = os.path.dirname(dir_path)
|
||||||
|
python_executable_path = os.path.join(
|
||||||
|
dir_path, "python-env", ".env", "bin", "python"
|
||||||
|
)
|
||||||
|
saved_dir_path = os.getcwd()
|
||||||
|
os.chdir(dir_path)
|
||||||
|
|
||||||
|
args = ["env", "PYTHONPATH=%s" % nn4k_dir_path]
|
||||||
|
args += [python_executable_path]
|
||||||
|
args += ["-m", "unittest"]
|
||||||
|
try:
|
||||||
|
subprocess.check_call(args)
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
raise SystemExit(1)
|
||||||
|
finally:
|
||||||
|
os.chdir(saved_dir_path)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
self._run_all_tests()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
runner = NN4KTestsRunner()
|
||||||
|
runner.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -22,6 +22,7 @@ class TestInvokerChecking(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertTrue(is_openai_invoker({"nn_name": "gpt-3.5-turbo"}))
|
self.assertTrue(is_openai_invoker({"nn_name": "gpt-3.5-turbo"}))
|
||||||
self.assertTrue(is_openai_invoker({"nn_name": "gpt-4"}))
|
self.assertTrue(is_openai_invoker({"nn_name": "gpt-4"}))
|
||||||
|
self.assertTrue(is_openai_invoker({"nn_name": "text-embedding-ada-002"}))
|
||||||
self.assertFalse(is_openai_invoker({"nn_name": "dummy"}))
|
self.assertFalse(is_openai_invoker({"nn_name": "dummy"}))
|
||||||
|
|
||||||
self.assertTrue(is_openai_invoker({"openai_api_key": "EMPTY"}))
|
self.assertTrue(is_openai_invoker({"openai_api_key": "EMPTY"}))
|
||||||
@ -29,6 +30,7 @@ class TestInvokerChecking(unittest.TestCase):
|
|||||||
is_openai_invoker({"openai_api_base": "http://localhost:38000/v1"})
|
is_openai_invoker({"openai_api_base": "http://localhost:38000/v1"})
|
||||||
)
|
)
|
||||||
self.assertTrue(is_openai_invoker({"openai_max_tokens": 1000}))
|
self.assertTrue(is_openai_invoker({"openai_max_tokens": 1000}))
|
||||||
|
self.assertTrue(is_openai_invoker({"openai_organization": "test_org"}))
|
||||||
self.assertFalse(is_openai_invoker({"foo": "bar"}))
|
self.assertFalse(is_openai_invoker({"foo": "bar"}))
|
||||||
|
|
||||||
def testIsLocalInvoker(self):
|
def testIsLocalInvoker(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user