diff --git a/python/nn4k/invoker/base.py b/python/nn4k/invoker/base.py index 617fae27..62087782 100644 --- a/python/nn4k/invoker/base.py +++ b/python/nn4k/invoker/base.py @@ -1,4 +1,3 @@ -# coding: utf-8 # Copyright 2023 Ant Group CO., Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except @@ -10,8 +9,6 @@ # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. -# Copyright (c) Antfin, Inc. All rights reserved. -import json import os from abc import ABC from typing import Union @@ -85,13 +82,9 @@ class LLMInvoker(NNInvoker): @classmethod def from_config(cls, nn_config: Union[str, dict]): - try: - if isinstance(nn_config, str): - with open(nn_config, "r") as f: - nn_config = json.load(f) - except: - raise ValueError("cannot decode config file") + from nn4k.utils.config_parsing import preprocess_config + nn_config = preprocess_config(nn_config) if nn_config.get("invoker_type", "LLM") == "LLM": o = cls.__new__(cls) diff --git a/python/nn4k/invoker/openai_invoker.py b/python/nn4k/invoker/openai_invoker.py index b5a422b0..034da765 100644 --- a/python/nn4k/invoker/openai_invoker.py +++ b/python/nn4k/invoker/openai_invoker.py @@ -9,70 +9,27 @@ # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. -from typing import Any from typing import Union from nn4k.invoker import NNInvoker class OpenAIInvoker(NNInvoker): - @classmethod - def _preprocess_config(cls, nn_config: Union[str, dict]) -> dict: - try: - if isinstance(nn_config, str): - with open(nn_config, "r") as f: - nn_config = json.load(f) - except: - raise ValueError("cannot decode config file") - return nn_config - - @classmethod - def _get_field(cls, nn_config: dict, name: str, text: str) -> Any: - value = nn_config.get(name) - if value is None: - message = "%s %r not found" % (text, name) - raise ValueError(message) - return value - - @classmethod - def _get_string_field(cls, nn_config: dict, name: str, text: str) -> str: - value = cls._get_field(nn_config, name, text) - if not isinstance(value, str): - message = "%s %r must be string; " % (text, name) - message += "%r is invalid" % (value,) - raise TypeError(message) - return value - - @classmethod - def _get_int_field(cls, nn_config: dict, name: str, text: str) -> int: - value = cls._get_field(nn_config, name, text) - if not isinstance(value, int): - message = "%s %r must be integer; " % (text, name) - message += "%r is invalid" % (value,) - raise TypeError(message) - return value - - @classmethod - def _get_positive_int_field(cls, nn_config: dict, name: str, text: str) -> int: - value = cls._get_int_field(nn_config, name, text) - if value <= 0: - message = "%s %r must be positive integer; " % (text, name) - message += "%r is invalid" % (value,) - raise ValueError(message) - return value - @classmethod def _parse_config(cls, nn_config: dict) -> dict: - openai_api_key = cls._get_string_field( + from nn4k.utils.config_parsing import get_string_field + from nn4k.utils.config_parsing import get_positive_int_field + + openai_api_key = get_string_field( nn_config, "openai_api_key", "openai api key" ) - openai_api_base = cls._get_string_field( + openai_api_base = get_string_field( nn_config, "openai_api_base", "openai api base" ) - openai_model_name = cls._get_string_field( + openai_model_name = get_string_field( nn_config, "openai_model_name", "openai model name" ) - openai_max_tokens = cls._get_positive_int_field( + openai_max_tokens = get_positive_int_field( nn_config, "openai_max_tokens", "openai max tokens" ) config = dict( @@ -86,8 +43,9 @@ class OpenAIInvoker(NNInvoker): @classmethod def from_config(cls, nn_config: Union[str, dict]): import openai + from nn4k.utils.config_parsing import preprocess_config - nn_config = cls._preprocess_config(nn_config) + nn_config = preprocess_config(nn_config) config = cls._parse_config(nn_config) o = cls.__new__(cls) diff --git a/python/nn4k/utils/__init__.py b/python/nn4k/utils/__init__.py new file mode 100644 index 00000000..19913efa --- /dev/null +++ b/python/nn4k/utils/__init__.py @@ -0,0 +1,10 @@ +# Copyright 2023 Ant Group CO., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. diff --git a/python/nn4k/utils/config_parsing.py b/python/nn4k/utils/config_parsing.py new file mode 100644 index 00000000..abc53849 --- /dev/null +++ b/python/nn4k/utils/config_parsing.py @@ -0,0 +1,56 @@ +# Copyright 2023 Ant Group CO., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. + +import json + +from typing import Any +from typing import Union + + +def preprocess_config(nn_config: Union[str, dict]) -> dict: + try: + if isinstance(nn_config, str): + with open(nn_config, "r") as f: + nn_config = json.load(f) + except: + raise ValueError("cannot decode config file") + return nn_config + +def get_field(nn_config: dict, name: str, text: str) -> Any: + value = nn_config.get(name) + if value is None: + message = "%s %r not found" % (text, name) + raise ValueError(message) + return value + +def get_string_field(nn_config: dict, name: str, text: str) -> str: + value = get_field(nn_config, name, text) + if not isinstance(value, str): + message = "%s %r must be string; " % (text, name) + message += "%r is invalid" % (value,) + raise TypeError(message) + return value + +def get_int_field(nn_config: dict, name: str, text: str) -> int: + value = get_field(nn_config, name, text) + if not isinstance(value, int): + message = "%s %r must be integer; " % (text, name) + message += "%r is invalid" % (value,) + raise TypeError(message) + return value + +def get_positive_int_field(nn_config: dict, name: str, text: str) -> int: + value = get_int_field(nn_config, name, text) + if value <= 0: + message = "%s %r must be positive integer; " % (text, name) + message += "%r is invalid" % (value,) + raise ValueError(message) + return value