mirror of
https://github.com/OpenSPG/openspg.git
synced 2025-08-08 08:52:35 +00:00
feat(nn4k): add nn4k.utils.config_parsing
This commit is contained in:
parent
4e11000e2f
commit
0a9c4976d0
@ -1,4 +1,3 @@
|
||||
# coding: utf-8
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
@ -10,8 +9,6 @@
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
# Copyright (c) Antfin, Inc. All rights reserved.
|
||||
import json
|
||||
import os
|
||||
from abc import ABC
|
||||
from typing import Union
|
||||
@ -85,13 +82,9 @@ class LLMInvoker(NNInvoker):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, nn_config: Union[str, dict]):
|
||||
try:
|
||||
if isinstance(nn_config, str):
|
||||
with open(nn_config, "r") as f:
|
||||
nn_config = json.load(f)
|
||||
except:
|
||||
raise ValueError("cannot decode config file")
|
||||
from nn4k.utils.config_parsing import preprocess_config
|
||||
|
||||
nn_config = preprocess_config(nn_config)
|
||||
if nn_config.get("invoker_type", "LLM") == "LLM":
|
||||
|
||||
o = cls.__new__(cls)
|
||||
|
@ -9,70 +9,27 @@
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
from typing import Any
|
||||
from typing import Union
|
||||
|
||||
from nn4k.invoker import NNInvoker
|
||||
|
||||
|
||||
class OpenAIInvoker(NNInvoker):
|
||||
@classmethod
|
||||
def _preprocess_config(cls, nn_config: Union[str, dict]) -> dict:
|
||||
try:
|
||||
if isinstance(nn_config, str):
|
||||
with open(nn_config, "r") as f:
|
||||
nn_config = json.load(f)
|
||||
except:
|
||||
raise ValueError("cannot decode config file")
|
||||
return nn_config
|
||||
|
||||
@classmethod
|
||||
def _get_field(cls, nn_config: dict, name: str, text: str) -> Any:
|
||||
value = nn_config.get(name)
|
||||
if value is None:
|
||||
message = "%s %r not found" % (text, name)
|
||||
raise ValueError(message)
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _get_string_field(cls, nn_config: dict, name: str, text: str) -> str:
|
||||
value = cls._get_field(nn_config, name, text)
|
||||
if not isinstance(value, str):
|
||||
message = "%s %r must be string; " % (text, name)
|
||||
message += "%r is invalid" % (value,)
|
||||
raise TypeError(message)
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _get_int_field(cls, nn_config: dict, name: str, text: str) -> int:
|
||||
value = cls._get_field(nn_config, name, text)
|
||||
if not isinstance(value, int):
|
||||
message = "%s %r must be integer; " % (text, name)
|
||||
message += "%r is invalid" % (value,)
|
||||
raise TypeError(message)
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _get_positive_int_field(cls, nn_config: dict, name: str, text: str) -> int:
|
||||
value = cls._get_int_field(nn_config, name, text)
|
||||
if value <= 0:
|
||||
message = "%s %r must be positive integer; " % (text, name)
|
||||
message += "%r is invalid" % (value,)
|
||||
raise ValueError(message)
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _parse_config(cls, nn_config: dict) -> dict:
|
||||
openai_api_key = cls._get_string_field(
|
||||
from nn4k.utils.config_parsing import get_string_field
|
||||
from nn4k.utils.config_parsing import get_positive_int_field
|
||||
|
||||
openai_api_key = get_string_field(
|
||||
nn_config, "openai_api_key", "openai api key"
|
||||
)
|
||||
openai_api_base = cls._get_string_field(
|
||||
openai_api_base = get_string_field(
|
||||
nn_config, "openai_api_base", "openai api base"
|
||||
)
|
||||
openai_model_name = cls._get_string_field(
|
||||
openai_model_name = get_string_field(
|
||||
nn_config, "openai_model_name", "openai model name"
|
||||
)
|
||||
openai_max_tokens = cls._get_positive_int_field(
|
||||
openai_max_tokens = get_positive_int_field(
|
||||
nn_config, "openai_max_tokens", "openai max tokens"
|
||||
)
|
||||
config = dict(
|
||||
@ -86,8 +43,9 @@ class OpenAIInvoker(NNInvoker):
|
||||
@classmethod
|
||||
def from_config(cls, nn_config: Union[str, dict]):
|
||||
import openai
|
||||
from nn4k.utils.config_parsing import preprocess_config
|
||||
|
||||
nn_config = cls._preprocess_config(nn_config)
|
||||
nn_config = preprocess_config(nn_config)
|
||||
config = cls._parse_config(nn_config)
|
||||
|
||||
o = cls.__new__(cls)
|
||||
|
10
python/nn4k/utils/__init__.py
Normal file
10
python/nn4k/utils/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
56
python/nn4k/utils/config_parsing.py
Normal file
56
python/nn4k/utils/config_parsing.py
Normal file
@ -0,0 +1,56 @@
|
||||
# Copyright 2023 Ant Group CO., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
|
||||
import json
|
||||
|
||||
from typing import Any
|
||||
from typing import Union
|
||||
|
||||
|
||||
def preprocess_config(nn_config: Union[str, dict]) -> dict:
|
||||
try:
|
||||
if isinstance(nn_config, str):
|
||||
with open(nn_config, "r") as f:
|
||||
nn_config = json.load(f)
|
||||
except:
|
||||
raise ValueError("cannot decode config file")
|
||||
return nn_config
|
||||
|
||||
def get_field(nn_config: dict, name: str, text: str) -> Any:
|
||||
value = nn_config.get(name)
|
||||
if value is None:
|
||||
message = "%s %r not found" % (text, name)
|
||||
raise ValueError(message)
|
||||
return value
|
||||
|
||||
def get_string_field(nn_config: dict, name: str, text: str) -> str:
|
||||
value = get_field(nn_config, name, text)
|
||||
if not isinstance(value, str):
|
||||
message = "%s %r must be string; " % (text, name)
|
||||
message += "%r is invalid" % (value,)
|
||||
raise TypeError(message)
|
||||
return value
|
||||
|
||||
def get_int_field(nn_config: dict, name: str, text: str) -> int:
|
||||
value = get_field(nn_config, name, text)
|
||||
if not isinstance(value, int):
|
||||
message = "%s %r must be integer; " % (text, name)
|
||||
message += "%r is invalid" % (value,)
|
||||
raise TypeError(message)
|
||||
return value
|
||||
|
||||
def get_positive_int_field(nn_config: dict, name: str, text: str) -> int:
|
||||
value = get_int_field(nn_config, name, text)
|
||||
if value <= 0:
|
||||
message = "%s %r must be positive integer; " % (text, name)
|
||||
message += "%r is invalid" % (value,)
|
||||
raise ValueError(message)
|
||||
return value
|
Loading…
x
Reference in New Issue
Block a user