mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-11-24 14:07:04 +00:00
### What problem does this PR solve? Refactor import modules. ### Type of change - [x] Refactoring --------- Signed-off-by: jinhai <haijin.chn@gmail.com> Signed-off-by: Jin Hai <haijin.chn@gmail.com>
180 lines
5.4 KiB
Python
180 lines
5.4 KiB
Python
#
|
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import os
|
|
import io
|
|
import copy
|
|
import logging
|
|
import base64
|
|
import pickle
|
|
import importlib
|
|
|
|
from api.utils import file_utils
|
|
from filelock import FileLock
|
|
from api.utils.common import bytes_to_string, string_to_bytes
|
|
from api.constants import SERVICE_CONF
|
|
|
|
|
|
def conf_realpath(conf_name):
|
|
conf_path = f"conf/{conf_name}"
|
|
return os.path.join(file_utils.get_project_base_directory(), conf_path)
|
|
|
|
|
|
def read_config(conf_name=SERVICE_CONF):
|
|
local_config = {}
|
|
local_path = conf_realpath(f'local.{conf_name}')
|
|
|
|
# load local config file
|
|
if os.path.exists(local_path):
|
|
local_config = file_utils.load_yaml_conf(local_path)
|
|
if not isinstance(local_config, dict):
|
|
raise ValueError(f'Invalid config file: "{local_path}".')
|
|
|
|
global_config_path = conf_realpath(conf_name)
|
|
global_config = file_utils.load_yaml_conf(global_config_path)
|
|
|
|
if not isinstance(global_config, dict):
|
|
raise ValueError(f'Invalid config file: "{global_config_path}".')
|
|
|
|
global_config.update(local_config)
|
|
return global_config
|
|
|
|
|
|
CONFIGS = read_config()
|
|
|
|
|
|
def show_configs():
|
|
msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:"
|
|
for k, v in CONFIGS.items():
|
|
if isinstance(v, dict):
|
|
if "password" in v:
|
|
v = copy.deepcopy(v)
|
|
v["password"] = "*" * 8
|
|
if "access_key" in v:
|
|
v = copy.deepcopy(v)
|
|
v["access_key"] = "*" * 8
|
|
if "secret_key" in v:
|
|
v = copy.deepcopy(v)
|
|
v["secret_key"] = "*" * 8
|
|
if "secret" in v:
|
|
v = copy.deepcopy(v)
|
|
v["secret"] = "*" * 8
|
|
if "sas_token" in v:
|
|
v = copy.deepcopy(v)
|
|
v["sas_token"] = "*" * 8
|
|
if "oauth" in k:
|
|
v = copy.deepcopy(v)
|
|
for key, val in v.items():
|
|
if "client_secret" in val:
|
|
val["client_secret"] = "*" * 8
|
|
if "authentication" in k:
|
|
v = copy.deepcopy(v)
|
|
for key, val in v.items():
|
|
if "http_secret_key" in val:
|
|
val["http_secret_key"] = "*" * 8
|
|
msg += f"\n\t{k}: {v}"
|
|
logging.info(msg)
|
|
|
|
|
|
def get_base_config(key, default=None):
|
|
if key is None:
|
|
return None
|
|
if default is None:
|
|
default = os.environ.get(key.upper())
|
|
return CONFIGS.get(key, default)
|
|
|
|
|
|
def decrypt_database_password(password):
|
|
encrypt_password = get_base_config("encrypt_password", False)
|
|
encrypt_module = get_base_config("encrypt_module", False)
|
|
private_key = get_base_config("private_key", None)
|
|
|
|
if not password or not encrypt_password:
|
|
return password
|
|
|
|
if not private_key:
|
|
raise ValueError("No private key")
|
|
|
|
module_fun = encrypt_module.split("#")
|
|
pwdecrypt_fun = getattr(
|
|
importlib.import_module(
|
|
module_fun[0]),
|
|
module_fun[1])
|
|
|
|
return pwdecrypt_fun(private_key, password)
|
|
|
|
|
|
def decrypt_database_config(
|
|
database=None, passwd_key="password", name="database"):
|
|
if not database:
|
|
database = get_base_config(name, {})
|
|
|
|
database[passwd_key] = decrypt_database_password(database[passwd_key])
|
|
return database
|
|
|
|
|
|
def update_config(key, value, conf_name=SERVICE_CONF):
|
|
conf_path = conf_realpath(conf_name=conf_name)
|
|
if not os.path.isabs(conf_path):
|
|
conf_path = os.path.join(
|
|
file_utils.get_project_base_directory(), conf_path)
|
|
|
|
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
|
|
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
|
|
config[key] = value
|
|
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
|
|
|
|
|
|
safe_module = {
|
|
'numpy',
|
|
'rag_flow'
|
|
}
|
|
|
|
|
|
class RestrictedUnpickler(pickle.Unpickler):
|
|
def find_class(self, module, name):
|
|
import importlib
|
|
if module.split('.')[0] in safe_module:
|
|
_module = importlib.import_module(module)
|
|
return getattr(_module, name)
|
|
# Forbid everything else.
|
|
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
|
(module, name))
|
|
|
|
|
|
def restricted_loads(src):
|
|
"""Helper function analogous to pickle.loads()."""
|
|
return RestrictedUnpickler(io.BytesIO(src)).load()
|
|
|
|
|
|
def serialize_b64(src, to_str=False):
|
|
dest = base64.b64encode(pickle.dumps(src))
|
|
if not to_str:
|
|
return dest
|
|
else:
|
|
return bytes_to_string(dest)
|
|
|
|
|
|
def deserialize_b64(src):
|
|
src = base64.b64decode(
|
|
string_to_bytes(src) if isinstance(
|
|
src, str) else src)
|
|
use_deserialize_safe_module = get_base_config(
|
|
'use_deserialize_safe_module', False)
|
|
if use_deserialize_safe_module:
|
|
return restricted_loads(src)
|
|
return pickle.loads(src)
|