mirror of
				https://github.com/langgenius/dify.git
				synced 2025-10-31 19:03:09 +00:00 
			
		
		
		
	
		
			
	
	
		
			109 lines
		
	
	
		
			3.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			109 lines
		
	
	
		
			3.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | import os | ||
|  | import shutil | ||
|  | from contextlib import closing | ||
|  | 
 | ||
|  | import boto3 | ||
|  | from botocore.exceptions import ClientError | ||
|  | from flask import Flask | ||
|  | 
 | ||
|  | 
 | ||
|  | class Storage: | ||
|  |     def __init__(self): | ||
|  |         self.storage_type = None | ||
|  |         self.bucket_name = None | ||
|  |         self.client = None | ||
|  |         self.folder = None | ||
|  | 
 | ||
|  |     def init_app(self, app: Flask): | ||
|  |         self.storage_type = app.config.get('STORAGE_TYPE') | ||
|  |         if self.storage_type == 's3': | ||
|  |             self.bucket_name = app.config.get('S3_BUCKET_NAME') | ||
|  |             self.client = boto3.client( | ||
|  |                 's3', | ||
|  |                 aws_secret_access_key=app.config.get('S3_SECRET_KEY'), | ||
|  |                 aws_access_key_id=app.config.get('S3_ACCESS_KEY'), | ||
|  |                 endpoint_url=app.config.get('S3_ENDPOINT'), | ||
|  |                 region_name=app.config.get('S3_REGION') | ||
|  |             ) | ||
|  |         else: | ||
|  |             self.folder = app.config.get('STORAGE_LOCAL_PATH') | ||
|  |             if not os.path.isabs(self.folder): | ||
|  |                 self.folder = os.path.join(app.root_path, self.folder) | ||
|  | 
 | ||
|  |     def save(self, filename, data): | ||
|  |         if self.storage_type == 's3': | ||
|  |             self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data) | ||
|  |         else: | ||
|  |             if not self.folder or self.folder.endswith('/'): | ||
|  |                 filename = self.folder + filename | ||
|  |             else: | ||
|  |                 filename = self.folder + '/' + filename | ||
|  | 
 | ||
|  |             folder = os.path.dirname(filename) | ||
|  |             os.makedirs(folder, exist_ok=True) | ||
|  | 
 | ||
|  |             with open(os.path.join(os.getcwd(), filename), "wb") as f: | ||
|  |                 f.write(data) | ||
|  | 
 | ||
|  |     def load(self, filename): | ||
|  |         if self.storage_type == 's3': | ||
|  |             try: | ||
|  |                 with closing(self.client) as client: | ||
|  |                     data = client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].read() | ||
|  |             except ClientError as ex: | ||
|  |                 if ex.response['Error']['Code'] == 'NoSuchKey': | ||
|  |                     raise FileNotFoundError("File not found") | ||
|  |                 else: | ||
|  |                     raise | ||
|  |         else: | ||
|  |             if not self.folder or self.folder.endswith('/'): | ||
|  |                 filename = self.folder + filename | ||
|  |             else: | ||
|  |                 filename = self.folder + '/' + filename | ||
|  | 
 | ||
|  |             if not os.path.exists(filename): | ||
|  |                 raise FileNotFoundError("File not found") | ||
|  | 
 | ||
|  |             with open(filename, "rb") as f: | ||
|  |                 data = f.read() | ||
|  | 
 | ||
|  |         return data | ||
|  | 
 | ||
|  |     def download(self, filename, target_filepath): | ||
|  |         if self.storage_type == 's3': | ||
|  |             with closing(self.client) as client: | ||
|  |                 client.download_file(self.bucket_name, filename, target_filepath) | ||
|  |         else: | ||
|  |             if not self.folder or self.folder.endswith('/'): | ||
|  |                 filename = self.folder + filename | ||
|  |             else: | ||
|  |                 filename = self.folder + '/' + filename | ||
|  | 
 | ||
|  |             if not os.path.exists(filename): | ||
|  |                 raise FileNotFoundError("File not found") | ||
|  | 
 | ||
|  |             shutil.copyfile(filename, target_filepath) | ||
|  | 
 | ||
|  |     def exists(self, filename): | ||
|  |         if self.storage_type == 's3': | ||
|  |             with closing(self.client) as client: | ||
|  |                 try: | ||
|  |                     client.head_object(Bucket=self.bucket_name, Key=filename) | ||
|  |                     return True | ||
|  |                 except: | ||
|  |                     return False | ||
|  |         else: | ||
|  |             if not self.folder or self.folder.endswith('/'): | ||
|  |                 filename = self.folder + filename | ||
|  |             else: | ||
|  |                 filename = self.folder + '/' + filename | ||
|  | 
 | ||
|  |             return os.path.exists(filename) | ||
|  | 
 | ||
|  | 
 | ||
|  | storage = Storage() | ||
|  | 
 | ||
|  | 
 | ||
|  | def init_app(app: Flask): | ||
|  |     storage.init_app(app) |