openspg/python/nn4k/tests/invoker/test_hub_invoker.py
2024-03-29 17:18:26 +08:00

109 lines
3.6 KiB
Python

# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
import requests
import unittest
import unittest.mock
from nn4k.invoker import NNInvoker
from nn4k.invoker.hub_invoker import NN4KException
class TestHubInvoker(unittest.TestCase):
"""
HubInvoker unittest.
"""
def setUp(self):
self._save_response = requests.Response
self._mock_response = unittest.mock.MagicMock()
self._mock_response.status_code = 200
self._mock_response.json.return_value = {
"model_name": "vllm_model",
"model_version": "1",
"text_output": "Long long ago, 100 years ago, ",
}
self._nn_config = {
"hub_infer_url": "hub_infer_url",
"generate_config": {
"ignore_eos": True,
"skip_special_tokens": True,
"use_beam_search": True,
"best_of": 3,
"max_tokens": 127,
"n": 1,
"top_k": -1,
"frequency_penalty": 0.0,
"length_penalty": 1.0,
"presence_penalty": 0.0,
"temperature": 0.0,
"top_p": 1.0,
},
}
self._invoker = NNInvoker.from_config(self._nn_config)
def tearDown(self):
requests.Response = self._save_response
def testHubInvokerCallService(self):
self.assertEqual(self._invoker.init_args, self._nn_config)
self.assertEqual(
self._invoker.invoker_args.generate_config,
self._nn_config["generate_config"],
)
with unittest.mock.patch(
"requests.post", return_value=self._mock_response
) as mock_post:
result = self._invoker.remote_inference("Long long ago, ")
data = {
"text_input": "Long long ago, ",
"parameters": self._invoker.invoker_args.generate_config,
}
mock_post.assert_called_with(
url=self._invoker.invoker_args.hub_infer_url,
headers={"Content-Type": "application/json"},
data=json.dumps(data),
)
index = self._mock_response.json()["text_output"].find("\n")
self.assertEqual(
result, [self._mock_response.json()["text_output"][index + 1 :]]
)
def testHubInvokerWithNN4KException(self):
with self.assertRaises(NN4KException):
with unittest.mock.patch(
"requests.post", return_value=self._mock_response
) as mock_post:
self._mock_response.status_code = 500
result = self._invoker.remote_inference("Long long ago, ")
def testHubInvokerWithNotImplementException(self):
with self.assertRaises(NotImplementedError):
with unittest.mock.patch(
"requests.post", return_value=self._mock_response
) as mock_post:
self._mock_response.status_code = 200
result = self._invoker.remote_inference(
[
"Long long ago, ",
]
)
if __name__ == "__main__":
unittest.main()