mirror of
				https://github.com/infiniflow/ragflow.git
				synced 2025-11-04 03:39:41 +00:00 
			
		
		
		
	### What problem does this PR solve? Bump infinity to v0.5.0-dev5 ### Type of change - [x] Refactoring
		
			
				
	
	
		
			361 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			361 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
#
 | 
						|
#  Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
 | 
						|
#
 | 
						|
#  Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
#  you may not use this file except in compliance with the License.
 | 
						|
#  You may obtain a copy of the License at
 | 
						|
#
 | 
						|
#      http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#
 | 
						|
#  Unless required by applicable law or agreed to in writing, software
 | 
						|
#  distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
#  See the License for the specific language governing permissions and
 | 
						|
#  limitations under the License.
 | 
						|
#
 | 
						|
import base64
 | 
						|
import datetime
 | 
						|
import io
 | 
						|
import json
 | 
						|
import os
 | 
						|
import pickle
 | 
						|
import socket
 | 
						|
import time
 | 
						|
import uuid
 | 
						|
import requests
 | 
						|
import logging
 | 
						|
from enum import Enum, IntEnum
 | 
						|
import importlib
 | 
						|
from Cryptodome.PublicKey import RSA
 | 
						|
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
 | 
						|
from filelock import FileLock
 | 
						|
from api.constants import SERVICE_CONF
 | 
						|
 | 
						|
from . import file_utils
 | 
						|
 | 
						|
 | 
						|
def conf_realpath(conf_name):
 | 
						|
    conf_path = f"conf/{conf_name}"
 | 
						|
    return os.path.join(file_utils.get_project_base_directory(), conf_path)
 | 
						|
 | 
						|
 | 
						|
def read_config(conf_name=SERVICE_CONF):
 | 
						|
    local_config = {}
 | 
						|
    local_path = conf_realpath(f'local.{conf_name}')
 | 
						|
 | 
						|
    # load local config file
 | 
						|
    if os.path.exists(local_path):
 | 
						|
        local_config = file_utils.load_yaml_conf(local_path)
 | 
						|
        if not isinstance(local_config, dict):
 | 
						|
            raise ValueError(f'Invalid config file: "{local_path}".')
 | 
						|
 | 
						|
    global_config_path = conf_realpath(conf_name)
 | 
						|
    global_config = file_utils.load_yaml_conf(global_config_path)
 | 
						|
 | 
						|
    if not isinstance(global_config, dict):
 | 
						|
        raise ValueError(f'Invalid config file: "{global_config_path}".')
 | 
						|
 | 
						|
    global_config.update(local_config)
 | 
						|
    return global_config
 | 
						|
 | 
						|
 | 
						|
CONFIGS = read_config()
 | 
						|
 | 
						|
 | 
						|
def show_configs():
 | 
						|
    msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:"
 | 
						|
    for k, v in CONFIGS.items():
 | 
						|
        msg += f"\n\t{k}: {v}"
 | 
						|
    logging.info(msg)
 | 
						|
 | 
						|
 | 
						|
def get_base_config(key, default=None):
 | 
						|
    if key is None:
 | 
						|
        return None
 | 
						|
    if default is None:
 | 
						|
        default = os.environ.get(key.upper())
 | 
						|
    return CONFIGS.get(key, default)
 | 
						|
 | 
						|
 | 
						|
use_deserialize_safe_module = get_base_config(
 | 
						|
    'use_deserialize_safe_module', False)
 | 
						|
 | 
						|
 | 
						|
class BaseType:
 | 
						|
    def to_dict(self):
 | 
						|
        return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
 | 
						|
 | 
						|
    def to_dict_with_type(self):
 | 
						|
        def _dict(obj):
 | 
						|
            module = None
 | 
						|
            if issubclass(obj.__class__, BaseType):
 | 
						|
                data = {}
 | 
						|
                for attr, v in obj.__dict__.items():
 | 
						|
                    k = attr.lstrip("_")
 | 
						|
                    data[k] = _dict(v)
 | 
						|
                module = obj.__module__
 | 
						|
            elif isinstance(obj, (list, tuple)):
 | 
						|
                data = []
 | 
						|
                for i, vv in enumerate(obj):
 | 
						|
                    data.append(_dict(vv))
 | 
						|
            elif isinstance(obj, dict):
 | 
						|
                data = {}
 | 
						|
                for _k, vv in obj.items():
 | 
						|
                    data[_k] = _dict(vv)
 | 
						|
            else:
 | 
						|
                data = obj
 | 
						|
            return {"type": obj.__class__.__name__,
 | 
						|
                    "data": data, "module": module}
 | 
						|
 | 
						|
        return _dict(self)
 | 
						|
 | 
						|
 | 
						|
class CustomJSONEncoder(json.JSONEncoder):
 | 
						|
    def __init__(self, **kwargs):
 | 
						|
        self._with_type = kwargs.pop("with_type", False)
 | 
						|
        super().__init__(**kwargs)
 | 
						|
 | 
						|
    def default(self, obj):
 | 
						|
        if isinstance(obj, datetime.datetime):
 | 
						|
            return obj.strftime('%Y-%m-%d %H:%M:%S')
 | 
						|
        elif isinstance(obj, datetime.date):
 | 
						|
            return obj.strftime('%Y-%m-%d')
 | 
						|
        elif isinstance(obj, datetime.timedelta):
 | 
						|
            return str(obj)
 | 
						|
        elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
 | 
						|
            return obj.value
 | 
						|
        elif isinstance(obj, set):
 | 
						|
            return list(obj)
 | 
						|
        elif issubclass(type(obj), BaseType):
 | 
						|
            if not self._with_type:
 | 
						|
                return obj.to_dict()
 | 
						|
            else:
 | 
						|
                return obj.to_dict_with_type()
 | 
						|
        elif isinstance(obj, type):
 | 
						|
            return obj.__name__
 | 
						|
        else:
 | 
						|
            return json.JSONEncoder.default(self, obj)
 | 
						|
 | 
						|
 | 
						|
def rag_uuid():
 | 
						|
    return uuid.uuid1().hex
 | 
						|
 | 
						|
 | 
						|
def string_to_bytes(string):
 | 
						|
    return string if isinstance(
 | 
						|
        string, bytes) else string.encode(encoding="utf-8")
 | 
						|
 | 
						|
 | 
						|
def bytes_to_string(byte):
 | 
						|
    return byte.decode(encoding="utf-8")
 | 
						|
 | 
						|
 | 
						|
def json_dumps(src, byte=False, indent=None, with_type=False):
 | 
						|
    dest = json.dumps(
 | 
						|
        src,
 | 
						|
        indent=indent,
 | 
						|
        cls=CustomJSONEncoder,
 | 
						|
        with_type=with_type)
 | 
						|
    if byte:
 | 
						|
        dest = string_to_bytes(dest)
 | 
						|
    return dest
 | 
						|
 | 
						|
 | 
						|
def json_loads(src, object_hook=None, object_pairs_hook=None):
 | 
						|
    if isinstance(src, bytes):
 | 
						|
        src = bytes_to_string(src)
 | 
						|
    return json.loads(src, object_hook=object_hook,
 | 
						|
                      object_pairs_hook=object_pairs_hook)
 | 
						|
 | 
						|
 | 
						|
def current_timestamp():
 | 
						|
    return int(time.time() * 1000)
 | 
						|
 | 
						|
 | 
						|
def timestamp_to_date(timestamp, format_string="%Y-%m-%d %H:%M:%S"):
 | 
						|
    if not timestamp:
 | 
						|
        timestamp = time.time()
 | 
						|
    timestamp = int(timestamp) / 1000
 | 
						|
    time_array = time.localtime(timestamp)
 | 
						|
    str_date = time.strftime(format_string, time_array)
 | 
						|
    return str_date
 | 
						|
 | 
						|
 | 
						|
def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"):
 | 
						|
    time_array = time.strptime(time_str, format_string)
 | 
						|
    time_stamp = int(time.mktime(time_array) * 1000)
 | 
						|
    return time_stamp
 | 
						|
 | 
						|
 | 
						|
def serialize_b64(src, to_str=False):
 | 
						|
    dest = base64.b64encode(pickle.dumps(src))
 | 
						|
    if not to_str:
 | 
						|
        return dest
 | 
						|
    else:
 | 
						|
        return bytes_to_string(dest)
 | 
						|
 | 
						|
 | 
						|
def deserialize_b64(src):
 | 
						|
    src = base64.b64decode(
 | 
						|
        string_to_bytes(src) if isinstance(
 | 
						|
            src, str) else src)
 | 
						|
    if use_deserialize_safe_module:
 | 
						|
        return restricted_loads(src)
 | 
						|
    return pickle.loads(src)
 | 
						|
 | 
						|
 | 
						|
safe_module = {
 | 
						|
    'numpy',
 | 
						|
    'rag_flow'
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
class RestrictedUnpickler(pickle.Unpickler):
 | 
						|
    def find_class(self, module, name):
 | 
						|
        import importlib
 | 
						|
        if module.split('.')[0] in safe_module:
 | 
						|
            _module = importlib.import_module(module)
 | 
						|
            return getattr(_module, name)
 | 
						|
        # Forbid everything else.
 | 
						|
        raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
 | 
						|
                                     (module, name))
 | 
						|
 | 
						|
 | 
						|
def restricted_loads(src):
 | 
						|
    """Helper function analogous to pickle.loads()."""
 | 
						|
    return RestrictedUnpickler(io.BytesIO(src)).load()
 | 
						|
 | 
						|
 | 
						|
def get_lan_ip():
 | 
						|
    if os.name != "nt":
 | 
						|
        import fcntl
 | 
						|
        import struct
 | 
						|
 | 
						|
        def get_interface_ip(ifname):
 | 
						|
            s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
 | 
						|
            return socket.inet_ntoa(
 | 
						|
                fcntl.ioctl(s.fileno(), 0x8915, struct.pack('256s', string_to_bytes(ifname[:15])))[20:24])
 | 
						|
 | 
						|
    ip = socket.gethostbyname(socket.getfqdn())
 | 
						|
    if ip.startswith("127.") and os.name != "nt":
 | 
						|
        interfaces = [
 | 
						|
            "bond1",
 | 
						|
            "eth0",
 | 
						|
            "eth1",
 | 
						|
            "eth2",
 | 
						|
            "wlan0",
 | 
						|
            "wlan1",
 | 
						|
            "wifi0",
 | 
						|
            "ath0",
 | 
						|
            "ath1",
 | 
						|
            "ppp0",
 | 
						|
        ]
 | 
						|
        for ifname in interfaces:
 | 
						|
            try:
 | 
						|
                ip = get_interface_ip(ifname)
 | 
						|
                break
 | 
						|
            except IOError:
 | 
						|
                pass
 | 
						|
    return ip or ''
 | 
						|
 | 
						|
 | 
						|
def from_dict_hook(in_dict: dict):
 | 
						|
    if "type" in in_dict and "data" in in_dict:
 | 
						|
        if in_dict["module"] is None:
 | 
						|
            return in_dict["data"]
 | 
						|
        else:
 | 
						|
            return getattr(importlib.import_module(
 | 
						|
                in_dict["module"]), in_dict["type"])(**in_dict["data"])
 | 
						|
    else:
 | 
						|
        return in_dict
 | 
						|
 | 
						|
 | 
						|
def decrypt_database_password(password):
 | 
						|
    encrypt_password = get_base_config("encrypt_password", False)
 | 
						|
    encrypt_module = get_base_config("encrypt_module", False)
 | 
						|
    private_key = get_base_config("private_key", None)
 | 
						|
 | 
						|
    if not password or not encrypt_password:
 | 
						|
        return password
 | 
						|
 | 
						|
    if not private_key:
 | 
						|
        raise ValueError("No private key")
 | 
						|
 | 
						|
    module_fun = encrypt_module.split("#")
 | 
						|
    pwdecrypt_fun = getattr(
 | 
						|
        importlib.import_module(
 | 
						|
            module_fun[0]),
 | 
						|
        module_fun[1])
 | 
						|
 | 
						|
    return pwdecrypt_fun(private_key, password)
 | 
						|
 | 
						|
 | 
						|
def decrypt_database_config(
 | 
						|
        database=None, passwd_key="password", name="database"):
 | 
						|
    if not database:
 | 
						|
        database = get_base_config(name, {})
 | 
						|
 | 
						|
    database[passwd_key] = decrypt_database_password(database[passwd_key])
 | 
						|
    return database
 | 
						|
 | 
						|
 | 
						|
def update_config(key, value, conf_name=SERVICE_CONF):
 | 
						|
    conf_path = conf_realpath(conf_name=conf_name)
 | 
						|
    if not os.path.isabs(conf_path):
 | 
						|
        conf_path = os.path.join(
 | 
						|
            file_utils.get_project_base_directory(), conf_path)
 | 
						|
 | 
						|
    with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
 | 
						|
        config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
 | 
						|
        config[key] = value
 | 
						|
        file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
 | 
						|
 | 
						|
 | 
						|
def get_uuid():
 | 
						|
    return uuid.uuid1().hex
 | 
						|
 | 
						|
 | 
						|
def datetime_format(date_time: datetime.datetime) -> datetime.datetime:
 | 
						|
    return datetime.datetime(date_time.year, date_time.month, date_time.day,
 | 
						|
                             date_time.hour, date_time.minute, date_time.second)
 | 
						|
 | 
						|
 | 
						|
def get_format_time() -> datetime.datetime:
 | 
						|
    return datetime_format(datetime.datetime.now())
 | 
						|
 | 
						|
 | 
						|
def str2date(date_time: str):
 | 
						|
    return datetime.datetime.strptime(date_time, '%Y-%m-%d')
 | 
						|
 | 
						|
 | 
						|
def elapsed2time(elapsed):
 | 
						|
    seconds = elapsed / 1000
 | 
						|
    minuter, second = divmod(seconds, 60)
 | 
						|
    hour, minuter = divmod(minuter, 60)
 | 
						|
    return '%02d:%02d:%02d' % (hour, minuter, second)
 | 
						|
 | 
						|
 | 
						|
def decrypt(line):
 | 
						|
    file_path = os.path.join(
 | 
						|
        file_utils.get_project_base_directory(),
 | 
						|
        "conf",
 | 
						|
        "private.pem")
 | 
						|
    rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
 | 
						|
    cipher = Cipher_pkcs1_v1_5.new(rsa_key)
 | 
						|
    return cipher.decrypt(base64.b64decode(
 | 
						|
        line), "Fail to decrypt password!").decode('utf-8')
 | 
						|
 | 
						|
 | 
						|
def download_img(url):
 | 
						|
    if not url:
 | 
						|
        return ""
 | 
						|
    response = requests.get(url)
 | 
						|
    return "data:" + \
 | 
						|
        response.headers.get('Content-Type', 'image/jpg') + ";" + \
 | 
						|
        "base64," + base64.b64encode(response.content).decode("utf-8")
 | 
						|
 | 
						|
 | 
						|
def delta_seconds(date_string: str):
 | 
						|
    dt = datetime.datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S")
 | 
						|
    return (datetime.datetime.now() - dt).total_seconds()
 |