mirror of
				https://github.com/infiniflow/ragflow.git
				synced 2025-10-31 09:50:00 +00:00 
			
		
		
		
	 e38e3bcc3b
			
		
	
	
		e38e3bcc3b
		
			
		
	
	
	
	
		
			
			### What problem does this PR solve? Mask password in log ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
		
			
				
	
	
		
			366 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			366 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #
 | |
| #  Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
 | |
| #
 | |
| #  Licensed under the Apache License, Version 2.0 (the "License");
 | |
| #  you may not use this file except in compliance with the License.
 | |
| #  You may obtain a copy of the License at
 | |
| #
 | |
| #      http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| #  Unless required by applicable law or agreed to in writing, software
 | |
| #  distributed under the License is distributed on an "AS IS" BASIS,
 | |
| #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| #  See the License for the specific language governing permissions and
 | |
| #  limitations under the License.
 | |
| #
 | |
| import base64
 | |
| import datetime
 | |
| import io
 | |
| import json
 | |
| import os
 | |
| import pickle
 | |
| import socket
 | |
| import time
 | |
| import uuid
 | |
| import requests
 | |
| import logging
 | |
| import copy
 | |
| from enum import Enum, IntEnum
 | |
| import importlib
 | |
| from Cryptodome.PublicKey import RSA
 | |
| from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
 | |
| from filelock import FileLock
 | |
| from api.constants import SERVICE_CONF
 | |
| 
 | |
| from . import file_utils
 | |
| 
 | |
| 
 | |
| def conf_realpath(conf_name):
 | |
|     conf_path = f"conf/{conf_name}"
 | |
|     return os.path.join(file_utils.get_project_base_directory(), conf_path)
 | |
| 
 | |
| 
 | |
| def read_config(conf_name=SERVICE_CONF):
 | |
|     local_config = {}
 | |
|     local_path = conf_realpath(f'local.{conf_name}')
 | |
| 
 | |
|     # load local config file
 | |
|     if os.path.exists(local_path):
 | |
|         local_config = file_utils.load_yaml_conf(local_path)
 | |
|         if not isinstance(local_config, dict):
 | |
|             raise ValueError(f'Invalid config file: "{local_path}".')
 | |
| 
 | |
|     global_config_path = conf_realpath(conf_name)
 | |
|     global_config = file_utils.load_yaml_conf(global_config_path)
 | |
| 
 | |
|     if not isinstance(global_config, dict):
 | |
|         raise ValueError(f'Invalid config file: "{global_config_path}".')
 | |
| 
 | |
|     global_config.update(local_config)
 | |
|     return global_config
 | |
| 
 | |
| 
 | |
| CONFIGS = read_config()
 | |
| 
 | |
| 
 | |
| def show_configs():
 | |
|     msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:"
 | |
|     for k, v in CONFIGS.items():
 | |
|         if isinstance(v, dict):
 | |
|             if "password" in v:
 | |
|                 v = copy.deepcopy(v)
 | |
|                 v["password"] = "*" * 8
 | |
|         msg += f"\n\t{k}: {v}"
 | |
|     logging.info(msg)
 | |
| 
 | |
| 
 | |
| def get_base_config(key, default=None):
 | |
|     if key is None:
 | |
|         return None
 | |
|     if default is None:
 | |
|         default = os.environ.get(key.upper())
 | |
|     return CONFIGS.get(key, default)
 | |
| 
 | |
| 
 | |
| use_deserialize_safe_module = get_base_config(
 | |
|     'use_deserialize_safe_module', False)
 | |
| 
 | |
| 
 | |
| class BaseType:
 | |
|     def to_dict(self):
 | |
|         return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
 | |
| 
 | |
|     def to_dict_with_type(self):
 | |
|         def _dict(obj):
 | |
|             module = None
 | |
|             if issubclass(obj.__class__, BaseType):
 | |
|                 data = {}
 | |
|                 for attr, v in obj.__dict__.items():
 | |
|                     k = attr.lstrip("_")
 | |
|                     data[k] = _dict(v)
 | |
|                 module = obj.__module__
 | |
|             elif isinstance(obj, (list, tuple)):
 | |
|                 data = []
 | |
|                 for i, vv in enumerate(obj):
 | |
|                     data.append(_dict(vv))
 | |
|             elif isinstance(obj, dict):
 | |
|                 data = {}
 | |
|                 for _k, vv in obj.items():
 | |
|                     data[_k] = _dict(vv)
 | |
|             else:
 | |
|                 data = obj
 | |
|             return {"type": obj.__class__.__name__,
 | |
|                     "data": data, "module": module}
 | |
| 
 | |
|         return _dict(self)
 | |
| 
 | |
| 
 | |
| class CustomJSONEncoder(json.JSONEncoder):
 | |
|     def __init__(self, **kwargs):
 | |
|         self._with_type = kwargs.pop("with_type", False)
 | |
|         super().__init__(**kwargs)
 | |
| 
 | |
|     def default(self, obj):
 | |
|         if isinstance(obj, datetime.datetime):
 | |
|             return obj.strftime('%Y-%m-%d %H:%M:%S')
 | |
|         elif isinstance(obj, datetime.date):
 | |
|             return obj.strftime('%Y-%m-%d')
 | |
|         elif isinstance(obj, datetime.timedelta):
 | |
|             return str(obj)
 | |
|         elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
 | |
|             return obj.value
 | |
|         elif isinstance(obj, set):
 | |
|             return list(obj)
 | |
|         elif issubclass(type(obj), BaseType):
 | |
|             if not self._with_type:
 | |
|                 return obj.to_dict()
 | |
|             else:
 | |
|                 return obj.to_dict_with_type()
 | |
|         elif isinstance(obj, type):
 | |
|             return obj.__name__
 | |
|         else:
 | |
|             return json.JSONEncoder.default(self, obj)
 | |
| 
 | |
| 
 | |
| def rag_uuid():
 | |
|     return uuid.uuid1().hex
 | |
| 
 | |
| 
 | |
| def string_to_bytes(string):
 | |
|     return string if isinstance(
 | |
|         string, bytes) else string.encode(encoding="utf-8")
 | |
| 
 | |
| 
 | |
| def bytes_to_string(byte):
 | |
|     return byte.decode(encoding="utf-8")
 | |
| 
 | |
| 
 | |
| def json_dumps(src, byte=False, indent=None, with_type=False):
 | |
|     dest = json.dumps(
 | |
|         src,
 | |
|         indent=indent,
 | |
|         cls=CustomJSONEncoder,
 | |
|         with_type=with_type)
 | |
|     if byte:
 | |
|         dest = string_to_bytes(dest)
 | |
|     return dest
 | |
| 
 | |
| 
 | |
| def json_loads(src, object_hook=None, object_pairs_hook=None):
 | |
|     if isinstance(src, bytes):
 | |
|         src = bytes_to_string(src)
 | |
|     return json.loads(src, object_hook=object_hook,
 | |
|                       object_pairs_hook=object_pairs_hook)
 | |
| 
 | |
| 
 | |
| def current_timestamp():
 | |
|     return int(time.time() * 1000)
 | |
| 
 | |
| 
 | |
| def timestamp_to_date(timestamp, format_string="%Y-%m-%d %H:%M:%S"):
 | |
|     if not timestamp:
 | |
|         timestamp = time.time()
 | |
|     timestamp = int(timestamp) / 1000
 | |
|     time_array = time.localtime(timestamp)
 | |
|     str_date = time.strftime(format_string, time_array)
 | |
|     return str_date
 | |
| 
 | |
| 
 | |
| def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"):
 | |
|     time_array = time.strptime(time_str, format_string)
 | |
|     time_stamp = int(time.mktime(time_array) * 1000)
 | |
|     return time_stamp
 | |
| 
 | |
| 
 | |
| def serialize_b64(src, to_str=False):
 | |
|     dest = base64.b64encode(pickle.dumps(src))
 | |
|     if not to_str:
 | |
|         return dest
 | |
|     else:
 | |
|         return bytes_to_string(dest)
 | |
| 
 | |
| 
 | |
| def deserialize_b64(src):
 | |
|     src = base64.b64decode(
 | |
|         string_to_bytes(src) if isinstance(
 | |
|             src, str) else src)
 | |
|     if use_deserialize_safe_module:
 | |
|         return restricted_loads(src)
 | |
|     return pickle.loads(src)
 | |
| 
 | |
| 
 | |
| safe_module = {
 | |
|     'numpy',
 | |
|     'rag_flow'
 | |
| }
 | |
| 
 | |
| 
 | |
| class RestrictedUnpickler(pickle.Unpickler):
 | |
|     def find_class(self, module, name):
 | |
|         import importlib
 | |
|         if module.split('.')[0] in safe_module:
 | |
|             _module = importlib.import_module(module)
 | |
|             return getattr(_module, name)
 | |
|         # Forbid everything else.
 | |
|         raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
 | |
|                                      (module, name))
 | |
| 
 | |
| 
 | |
| def restricted_loads(src):
 | |
|     """Helper function analogous to pickle.loads()."""
 | |
|     return RestrictedUnpickler(io.BytesIO(src)).load()
 | |
| 
 | |
| 
 | |
| def get_lan_ip():
 | |
|     if os.name != "nt":
 | |
|         import fcntl
 | |
|         import struct
 | |
| 
 | |
|         def get_interface_ip(ifname):
 | |
|             s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
 | |
|             return socket.inet_ntoa(
 | |
|                 fcntl.ioctl(s.fileno(), 0x8915, struct.pack('256s', string_to_bytes(ifname[:15])))[20:24])
 | |
| 
 | |
|     ip = socket.gethostbyname(socket.getfqdn())
 | |
|     if ip.startswith("127.") and os.name != "nt":
 | |
|         interfaces = [
 | |
|             "bond1",
 | |
|             "eth0",
 | |
|             "eth1",
 | |
|             "eth2",
 | |
|             "wlan0",
 | |
|             "wlan1",
 | |
|             "wifi0",
 | |
|             "ath0",
 | |
|             "ath1",
 | |
|             "ppp0",
 | |
|         ]
 | |
|         for ifname in interfaces:
 | |
|             try:
 | |
|                 ip = get_interface_ip(ifname)
 | |
|                 break
 | |
|             except IOError:
 | |
|                 pass
 | |
|     return ip or ''
 | |
| 
 | |
| 
 | |
| def from_dict_hook(in_dict: dict):
 | |
|     if "type" in in_dict and "data" in in_dict:
 | |
|         if in_dict["module"] is None:
 | |
|             return in_dict["data"]
 | |
|         else:
 | |
|             return getattr(importlib.import_module(
 | |
|                 in_dict["module"]), in_dict["type"])(**in_dict["data"])
 | |
|     else:
 | |
|         return in_dict
 | |
| 
 | |
| 
 | |
| def decrypt_database_password(password):
 | |
|     encrypt_password = get_base_config("encrypt_password", False)
 | |
|     encrypt_module = get_base_config("encrypt_module", False)
 | |
|     private_key = get_base_config("private_key", None)
 | |
| 
 | |
|     if not password or not encrypt_password:
 | |
|         return password
 | |
| 
 | |
|     if not private_key:
 | |
|         raise ValueError("No private key")
 | |
| 
 | |
|     module_fun = encrypt_module.split("#")
 | |
|     pwdecrypt_fun = getattr(
 | |
|         importlib.import_module(
 | |
|             module_fun[0]),
 | |
|         module_fun[1])
 | |
| 
 | |
|     return pwdecrypt_fun(private_key, password)
 | |
| 
 | |
| 
 | |
| def decrypt_database_config(
 | |
|         database=None, passwd_key="password", name="database"):
 | |
|     if not database:
 | |
|         database = get_base_config(name, {})
 | |
| 
 | |
|     database[passwd_key] = decrypt_database_password(database[passwd_key])
 | |
|     return database
 | |
| 
 | |
| 
 | |
| def update_config(key, value, conf_name=SERVICE_CONF):
 | |
|     conf_path = conf_realpath(conf_name=conf_name)
 | |
|     if not os.path.isabs(conf_path):
 | |
|         conf_path = os.path.join(
 | |
|             file_utils.get_project_base_directory(), conf_path)
 | |
| 
 | |
|     with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
 | |
|         config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
 | |
|         config[key] = value
 | |
|         file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
 | |
| 
 | |
| 
 | |
| def get_uuid():
 | |
|     return uuid.uuid1().hex
 | |
| 
 | |
| 
 | |
| def datetime_format(date_time: datetime.datetime) -> datetime.datetime:
 | |
|     return datetime.datetime(date_time.year, date_time.month, date_time.day,
 | |
|                              date_time.hour, date_time.minute, date_time.second)
 | |
| 
 | |
| 
 | |
| def get_format_time() -> datetime.datetime:
 | |
|     return datetime_format(datetime.datetime.now())
 | |
| 
 | |
| 
 | |
| def str2date(date_time: str):
 | |
|     return datetime.datetime.strptime(date_time, '%Y-%m-%d')
 | |
| 
 | |
| 
 | |
| def elapsed2time(elapsed):
 | |
|     seconds = elapsed / 1000
 | |
|     minuter, second = divmod(seconds, 60)
 | |
|     hour, minuter = divmod(minuter, 60)
 | |
|     return '%02d:%02d:%02d' % (hour, minuter, second)
 | |
| 
 | |
| 
 | |
| def decrypt(line):
 | |
|     file_path = os.path.join(
 | |
|         file_utils.get_project_base_directory(),
 | |
|         "conf",
 | |
|         "private.pem")
 | |
|     rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
 | |
|     cipher = Cipher_pkcs1_v1_5.new(rsa_key)
 | |
|     return cipher.decrypt(base64.b64decode(
 | |
|         line), "Fail to decrypt password!").decode('utf-8')
 | |
| 
 | |
| 
 | |
| def download_img(url):
 | |
|     if not url:
 | |
|         return ""
 | |
|     response = requests.get(url)
 | |
|     return "data:" + \
 | |
|         response.headers.get('Content-Type', 'image/jpg') + ";" + \
 | |
|         "base64," + base64.b64encode(response.content).decode("utf-8")
 | |
| 
 | |
| 
 | |
| def delta_seconds(date_string: str):
 | |
|     dt = datetime.datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S")
 | |
|     return (datetime.datetime.now() - dt).total_seconds()
 |