openspg/python/nn4k/tests/utils/test_config_parsing.py

144 lines
4.7 KiB
Python
Raw Normal View History

2024-01-16 14:38:37 +08:00
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import unittest
from dataclasses import dataclass, field
from typing import List, Optional
@dataclass
class TestArgs:
input_columns: Optional[List[str]] = field(
default=None,
metadata={"help": ""},
)
is_bool: Optional[bool] = field(
default=None,
metadata={"help": ""},
)
max_input_length: int = field(
default=1024,
metadata={"help": ""},
)
lora_config: Optional[dict] = field(default=None)
class TestConfigParsing(unittest.TestCase):
"""
module nn4k.utils.config_parsing unittest
"""
def testPreprocessConfigFile(self):
import os
from nn4k.utils.config_parsing import preprocess_config
dir_path = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(dir_path, "test_config.json")
nn_config = preprocess_config(file_path)
self.assertEqual(nn_config, {"foo": "bar"})
def testPreprocessConfigFileNotExists(self):
import os
from nn4k.utils.config_parsing import preprocess_config
dir_path = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(dir_path, "not_exists.json")
with self.assertRaises(ValueError):
nn_config = preprocess_config(file_path)
def testPreprocessConfigDict(self):
from nn4k.utils.config_parsing import preprocess_config
conf = {"foo": "bar"}
nn_config = preprocess_config(conf)
self.assertEqual(nn_config, conf)
def testGetField(self):
from nn4k.utils.config_parsing import get_field
nn_config = {"foo": "bar"}
value = get_field(nn_config, "foo", "Foo")
self.assertEqual(value, "bar")
def testGetFieldNotExists(self):
from nn4k.utils.config_parsing import get_field
nn_config = {"foo": "bar"}
with self.assertRaises(ValueError):
value = get_field(nn_config, "not_exists", "not exists")
def testGetStringField(self):
from nn4k.utils.config_parsing import get_string_field
nn_config = {"foo": "bar"}
value = get_string_field(nn_config, "foo", "Foo")
self.assertEqual(value, "bar")
def testGetStringFieldNotString(self):
from nn4k.utils.config_parsing import get_string_field
nn_config = {"foo": "bar", "baz": True}
with self.assertRaises(TypeError):
value = get_string_field(nn_config, "baz", "Baz")
def testGetIntField(self):
from nn4k.utils.config_parsing import get_int_field
nn_config = {"foo": "bar", "baz": 1000}
value = get_int_field(nn_config, "baz", "Baz")
self.assertEqual(value, 1000)
def testGetIntFieldNotInteger(self):
from nn4k.utils.config_parsing import get_int_field
nn_config = {"foo": "bar", "baz": "quux"}
with self.assertRaises(TypeError):
value = get_int_field(nn_config, "baz", "Baz")
def testGetPositiveIntField(self):
from nn4k.utils.config_parsing import get_positive_int_field
nn_config = {"foo": "bar", "baz": 1000}
value = get_positive_int_field(nn_config, "baz", "Baz")
self.assertEqual(value, 1000)
def testGetPositiveIntFieldNotPositive(self):
from nn4k.utils.config_parsing import get_positive_int_field
nn_config = {"foo": "bar", "baz": 0}
with self.assertRaises(ValueError):
value = get_positive_int_field(nn_config, "baz", "Baz")
def testTransformerArgsParseDict(self):
from transformers import HfArgumentParser
args = {
"input_columns": ["column1", "column2"],
"is_bool": False,
"max_input_length": 256,
"lora_config": {"r": 1, "type": "lora"},
"is_bool_int": 1,
"extra_arg": "extra_configs",
}
parser = HfArgumentParser(TestArgs)
parsed_args: TestArgs
parsed_args, *rest = parser.parse_dict(args, allow_extra_keys=True)
self.assertEqual(parsed_args.input_columns, ["column1", "column2"])
self.assertEqual(parsed_args.is_bool, False)
self.assertEqual(parsed_args.lora_config, {"type": "lora", "r": 1})
self.assertEqual(parsed_args.max_input_length, 256)
if __name__ == "__main__":
unittest.main()