openspg/python/nn4k/tests/invoker/test_base_invoker.py
chenbin11200 bf57b3319f
feat(nn4k): support huggingface decode only model local inference (#128)
Co-authored-by: xionghuaidong <huaidong.xhd@antgroup.com>
2024-03-08 13:54:15 +08:00

108 lines
3.6 KiB
Python

# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
import sys
import unittest
class TestBaseInvoker(unittest.TestCase):
"""
NNInvoker and LLMInvoker unittest
"""
def setUp(self):
# for importing invoker_test_stub.py
dir_path = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, dir_path)
from nn4k.nnhub import NNHub
from invoker_test_stub import StubHub
NNHub._hub_instance = StubHub()
def tearDown(self):
from nn4k.nnhub import NNHub
sys.path.pop(0)
NNHub._hub_instance = None
def testCustomNNInvoker(self):
from nn4k.invoker import NNInvoker
from invoker_test_stub import StubInvoker
nn_config = {"nn_invoker": "invoker_test_stub.StubInvoker"}
invoker = NNInvoker.from_config(nn_config)
self.assertTrue(isinstance(invoker, StubInvoker))
self.assertEqual(invoker.init_args, nn_config)
self.assertEqual(invoker.kwargs, {})
with self.assertRaises(RuntimeError):
invoker = NNInvoker.from_config(
{"nn_invoker": "invoker_test_stub.NotInvoker"}
)
def testHubInvoker(self):
from nn4k.invoker import NNInvoker
from invoker_test_stub import StubInvoker
nn_config = {"nn_name": "invoker_test_stub"}
invoker = NNInvoker.from_config(nn_config)
self.assertTrue(isinstance(invoker, StubInvoker))
self.assertEqual(invoker.init_args, nn_config)
self.assertEqual(invoker.kwargs, {"test_stub_invoker": True})
def testInvokerNotExists(self):
"""
now the default invoker is LLMInvoker
"""
from nn4k.invoker import NNInvoker
invoker = NNInvoker.from_config({"nn_name": "not_exists"})
from nn4k.invoker.base import LLMInvoker
assert type(invoker) == LLMInvoker
def testLocalInvoker(self):
from nn4k.invoker import NNInvoker
from invoker_test_stub import StubInvoker
nn_config = {"nn_name": "invoker_test_stub"}
invoker = NNInvoker.from_config(nn_config)
self.assertTrue(isinstance(invoker, StubInvoker))
self.assertEqual(invoker.init_args, nn_config)
self.assertEqual(invoker.kwargs, {"test_stub_invoker": True})
invoker.warmup_local_model()
invoker._nn_executor.inference_result = "inference result"
result = invoker.local_inference("input")
self.assertEqual(result, invoker._nn_executor.inference_result)
def testLocalLLMInvokerWithCustomExecutor(self):
from nn4k.invoker import LLMInvoker
nn_config = {
"nn_model_path": "/path/to/model",
"nn_executor": "invoker_test_stub.StubExecutor",
}
invoker = LLMInvoker.from_config(nn_config)
self.assertTrue(isinstance(invoker, LLMInvoker))
self.assertEqual(invoker.init_args, nn_config)
invoker.warmup_local_model()
invoker._nn_executor.inference_result = "inference result"
result = invoker.local_inference("input")
self.assertEqual(result, invoker._nn_executor.inference_result)
if __name__ == "__main__":
unittest.main()