mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-10-23 22:09:12 +00:00

### What problem does this PR solve? ### Type of change - [x] New Feature (non-breaking change which adds functionality)
386 lines
11 KiB
Python
386 lines
11 KiB
Python
#
|
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import base64
|
|
import datetime
|
|
import io
|
|
import json
|
|
import os
|
|
import pickle
|
|
import socket
|
|
import time
|
|
import uuid
|
|
import requests
|
|
import logging
|
|
import copy
|
|
from enum import Enum, IntEnum
|
|
import importlib
|
|
from Cryptodome.PublicKey import RSA
|
|
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
|
from filelock import FileLock
|
|
from api.constants import SERVICE_CONF
|
|
|
|
from . import file_utils
|
|
|
|
|
|
def conf_realpath(conf_name):
|
|
conf_path = f"conf/{conf_name}"
|
|
return os.path.join(file_utils.get_project_base_directory(), conf_path)
|
|
|
|
|
|
def read_config(conf_name=SERVICE_CONF):
|
|
local_config = {}
|
|
local_path = conf_realpath(f'local.{conf_name}')
|
|
|
|
# load local config file
|
|
if os.path.exists(local_path):
|
|
local_config = file_utils.load_yaml_conf(local_path)
|
|
if not isinstance(local_config, dict):
|
|
raise ValueError(f'Invalid config file: "{local_path}".')
|
|
|
|
global_config_path = conf_realpath(conf_name)
|
|
global_config = file_utils.load_yaml_conf(global_config_path)
|
|
|
|
if not isinstance(global_config, dict):
|
|
raise ValueError(f'Invalid config file: "{global_config_path}".')
|
|
|
|
global_config.update(local_config)
|
|
return global_config
|
|
|
|
|
|
CONFIGS = read_config()
|
|
|
|
|
|
def show_configs():
|
|
msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:"
|
|
for k, v in CONFIGS.items():
|
|
if isinstance(v, dict):
|
|
if "password" in v:
|
|
v = copy.deepcopy(v)
|
|
v["password"] = "*" * 8
|
|
msg += f"\n\t{k}: {v}"
|
|
logging.info(msg)
|
|
|
|
|
|
def get_base_config(key, default=None):
|
|
if key is None:
|
|
return None
|
|
if default is None:
|
|
default = os.environ.get(key.upper())
|
|
return CONFIGS.get(key, default)
|
|
|
|
|
|
use_deserialize_safe_module = get_base_config(
|
|
'use_deserialize_safe_module', False)
|
|
|
|
|
|
class BaseType:
|
|
def to_dict(self):
|
|
return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
|
|
|
|
def to_dict_with_type(self):
|
|
def _dict(obj):
|
|
module = None
|
|
if issubclass(obj.__class__, BaseType):
|
|
data = {}
|
|
for attr, v in obj.__dict__.items():
|
|
k = attr.lstrip("_")
|
|
data[k] = _dict(v)
|
|
module = obj.__module__
|
|
elif isinstance(obj, (list, tuple)):
|
|
data = []
|
|
for i, vv in enumerate(obj):
|
|
data.append(_dict(vv))
|
|
elif isinstance(obj, dict):
|
|
data = {}
|
|
for _k, vv in obj.items():
|
|
data[_k] = _dict(vv)
|
|
else:
|
|
data = obj
|
|
return {"type": obj.__class__.__name__,
|
|
"data": data, "module": module}
|
|
|
|
return _dict(self)
|
|
|
|
|
|
class CustomJSONEncoder(json.JSONEncoder):
|
|
def __init__(self, **kwargs):
|
|
self._with_type = kwargs.pop("with_type", False)
|
|
super().__init__(**kwargs)
|
|
|
|
def default(self, obj):
|
|
if isinstance(obj, datetime.datetime):
|
|
return obj.strftime('%Y-%m-%d %H:%M:%S')
|
|
elif isinstance(obj, datetime.date):
|
|
return obj.strftime('%Y-%m-%d')
|
|
elif isinstance(obj, datetime.timedelta):
|
|
return str(obj)
|
|
elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
|
|
return obj.value
|
|
elif isinstance(obj, set):
|
|
return list(obj)
|
|
elif issubclass(type(obj), BaseType):
|
|
if not self._with_type:
|
|
return obj.to_dict()
|
|
else:
|
|
return obj.to_dict_with_type()
|
|
elif isinstance(obj, type):
|
|
return obj.__name__
|
|
else:
|
|
return json.JSONEncoder.default(self, obj)
|
|
|
|
|
|
def rag_uuid():
|
|
return uuid.uuid1().hex
|
|
|
|
|
|
def string_to_bytes(string):
|
|
return string if isinstance(
|
|
string, bytes) else string.encode(encoding="utf-8")
|
|
|
|
|
|
def bytes_to_string(byte):
|
|
return byte.decode(encoding="utf-8")
|
|
|
|
|
|
def json_dumps(src, byte=False, indent=None, with_type=False):
|
|
dest = json.dumps(
|
|
src,
|
|
indent=indent,
|
|
cls=CustomJSONEncoder,
|
|
with_type=with_type)
|
|
if byte:
|
|
dest = string_to_bytes(dest)
|
|
return dest
|
|
|
|
|
|
def json_loads(src, object_hook=None, object_pairs_hook=None):
|
|
if isinstance(src, bytes):
|
|
src = bytes_to_string(src)
|
|
return json.loads(src, object_hook=object_hook,
|
|
object_pairs_hook=object_pairs_hook)
|
|
|
|
|
|
def current_timestamp():
|
|
return int(time.time() * 1000)
|
|
|
|
|
|
def timestamp_to_date(timestamp, format_string="%Y-%m-%d %H:%M:%S"):
|
|
if not timestamp:
|
|
timestamp = time.time()
|
|
timestamp = int(timestamp) / 1000
|
|
time_array = time.localtime(timestamp)
|
|
str_date = time.strftime(format_string, time_array)
|
|
return str_date
|
|
|
|
|
|
def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"):
|
|
time_array = time.strptime(time_str, format_string)
|
|
time_stamp = int(time.mktime(time_array) * 1000)
|
|
return time_stamp
|
|
|
|
|
|
def serialize_b64(src, to_str=False):
|
|
dest = base64.b64encode(pickle.dumps(src))
|
|
if not to_str:
|
|
return dest
|
|
else:
|
|
return bytes_to_string(dest)
|
|
|
|
|
|
def deserialize_b64(src):
|
|
src = base64.b64decode(
|
|
string_to_bytes(src) if isinstance(
|
|
src, str) else src)
|
|
if use_deserialize_safe_module:
|
|
return restricted_loads(src)
|
|
return pickle.loads(src)
|
|
|
|
|
|
safe_module = {
|
|
'numpy',
|
|
'rag_flow'
|
|
}
|
|
|
|
|
|
class RestrictedUnpickler(pickle.Unpickler):
|
|
def find_class(self, module, name):
|
|
import importlib
|
|
if module.split('.')[0] in safe_module:
|
|
_module = importlib.import_module(module)
|
|
return getattr(_module, name)
|
|
# Forbid everything else.
|
|
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
|
(module, name))
|
|
|
|
|
|
def restricted_loads(src):
|
|
"""Helper function analogous to pickle.loads()."""
|
|
return RestrictedUnpickler(io.BytesIO(src)).load()
|
|
|
|
|
|
def get_lan_ip():
|
|
if os.name != "nt":
|
|
import fcntl
|
|
import struct
|
|
|
|
def get_interface_ip(ifname):
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
return socket.inet_ntoa(
|
|
fcntl.ioctl(s.fileno(), 0x8915, struct.pack('256s', string_to_bytes(ifname[:15])))[20:24])
|
|
|
|
ip = socket.gethostbyname(socket.getfqdn())
|
|
if ip.startswith("127.") and os.name != "nt":
|
|
interfaces = [
|
|
"bond1",
|
|
"eth0",
|
|
"eth1",
|
|
"eth2",
|
|
"wlan0",
|
|
"wlan1",
|
|
"wifi0",
|
|
"ath0",
|
|
"ath1",
|
|
"ppp0",
|
|
]
|
|
for ifname in interfaces:
|
|
try:
|
|
ip = get_interface_ip(ifname)
|
|
break
|
|
except IOError:
|
|
pass
|
|
return ip or ''
|
|
|
|
|
|
def from_dict_hook(in_dict: dict):
|
|
if "type" in in_dict and "data" in in_dict:
|
|
if in_dict["module"] is None:
|
|
return in_dict["data"]
|
|
else:
|
|
return getattr(importlib.import_module(
|
|
in_dict["module"]), in_dict["type"])(**in_dict["data"])
|
|
else:
|
|
return in_dict
|
|
|
|
|
|
def decrypt_database_password(password):
|
|
encrypt_password = get_base_config("encrypt_password", False)
|
|
encrypt_module = get_base_config("encrypt_module", False)
|
|
private_key = get_base_config("private_key", None)
|
|
|
|
if not password or not encrypt_password:
|
|
return password
|
|
|
|
if not private_key:
|
|
raise ValueError("No private key")
|
|
|
|
module_fun = encrypt_module.split("#")
|
|
pwdecrypt_fun = getattr(
|
|
importlib.import_module(
|
|
module_fun[0]),
|
|
module_fun[1])
|
|
|
|
return pwdecrypt_fun(private_key, password)
|
|
|
|
|
|
def decrypt_database_config(
|
|
database=None, passwd_key="password", name="database"):
|
|
if not database:
|
|
database = get_base_config(name, {})
|
|
|
|
database[passwd_key] = decrypt_database_password(database[passwd_key])
|
|
return database
|
|
|
|
|
|
def update_config(key, value, conf_name=SERVICE_CONF):
|
|
conf_path = conf_realpath(conf_name=conf_name)
|
|
if not os.path.isabs(conf_path):
|
|
conf_path = os.path.join(
|
|
file_utils.get_project_base_directory(), conf_path)
|
|
|
|
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
|
|
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
|
|
config[key] = value
|
|
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
|
|
|
|
|
|
def get_uuid():
|
|
return uuid.uuid1().hex
|
|
|
|
|
|
def datetime_format(date_time: datetime.datetime) -> datetime.datetime:
|
|
return datetime.datetime(date_time.year, date_time.month, date_time.day,
|
|
date_time.hour, date_time.minute, date_time.second)
|
|
|
|
|
|
def get_format_time() -> datetime.datetime:
|
|
return datetime_format(datetime.datetime.now())
|
|
|
|
|
|
def str2date(date_time: str):
|
|
return datetime.datetime.strptime(date_time, '%Y-%m-%d')
|
|
|
|
|
|
def elapsed2time(elapsed):
|
|
seconds = elapsed / 1000
|
|
minuter, second = divmod(seconds, 60)
|
|
hour, minuter = divmod(minuter, 60)
|
|
return '%02d:%02d:%02d' % (hour, minuter, second)
|
|
|
|
|
|
def decrypt(line):
|
|
file_path = os.path.join(
|
|
file_utils.get_project_base_directory(),
|
|
"conf",
|
|
"private.pem")
|
|
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
|
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
|
return cipher.decrypt(base64.b64decode(
|
|
line), "Fail to decrypt password!").decode('utf-8')
|
|
|
|
|
|
def decrypt2(crypt_text):
|
|
from base64 import b64decode, b16decode
|
|
from Crypto.Cipher import PKCS1_v1_5 as Cipher_PKCS1_v1_5
|
|
from Crypto.PublicKey import RSA
|
|
decode_data = b64decode(crypt_text)
|
|
if len(decode_data) == 127:
|
|
hex_fixed = '00' + decode_data.hex()
|
|
decode_data = b16decode(hex_fixed.upper())
|
|
|
|
file_path = os.path.join(
|
|
file_utils.get_project_base_directory(),
|
|
"conf",
|
|
"private.pem")
|
|
pem = open(file_path).read()
|
|
rsa_key = RSA.importKey(pem, "Welcome")
|
|
cipher = Cipher_PKCS1_v1_5.new(rsa_key)
|
|
decrypt_text = cipher.decrypt(decode_data, None)
|
|
return (b64decode(decrypt_text)).decode()
|
|
|
|
|
|
def download_img(url):
|
|
if not url:
|
|
return ""
|
|
response = requests.get(url)
|
|
return "data:" + \
|
|
response.headers.get('Content-Type', 'image/jpg') + ";" + \
|
|
"base64," + base64.b64encode(response.content).decode("utf-8")
|
|
|
|
|
|
def delta_seconds(date_string: str):
|
|
dt = datetime.datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S")
|
|
return (datetime.datetime.now() - dt).total_seconds()
|