mirror of
				https://github.com/langgenius/dify.git
				synced 2025-10-31 10:53:02 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			137 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			137 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import base64
 | |
| from abc import ABC, abstractmethod
 | |
| from typing import Optional, Union
 | |
| 
 | |
| from core import hosted_llm_credentials
 | |
| from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError
 | |
| from extensions.ext_database import db
 | |
| from libs import rsa
 | |
| from models.account import Tenant
 | |
| from models.provider import Provider, ProviderType, ProviderName
 | |
| 
 | |
| 
 | |
| class BaseProvider(ABC):
 | |
|     def __init__(self, tenant_id: str):
 | |
|         self.tenant_id = tenant_id
 | |
| 
 | |
|     def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> Union[str | dict]:
 | |
|         """
 | |
|         Returns the decrypted API key for the given tenant_id and provider_name.
 | |
|         If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
 | |
|         If the provider is not found or not valid, raises a ProviderTokenNotInitError.
 | |
|         """
 | |
|         provider = self.get_provider(prefer_custom)
 | |
|         if not provider:
 | |
|             raise ProviderTokenNotInitError()
 | |
| 
 | |
|         if provider.provider_type == ProviderType.SYSTEM.value:
 | |
|             quota_used = provider.quota_used if provider.quota_used is not None else 0
 | |
|             quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
 | |
| 
 | |
|             if model_id and model_id == 'gpt-4':
 | |
|                 raise ModelCurrentlyNotSupportError()
 | |
| 
 | |
|             if quota_used >= quota_limit:
 | |
|                 raise QuotaExceededError()
 | |
| 
 | |
|             return self.get_hosted_credentials()
 | |
|         else:
 | |
|             return self.get_decrypted_token(provider.encrypted_config)
 | |
| 
 | |
|     def get_provider(self, prefer_custom: bool) -> Optional[Provider]:
 | |
|         """
 | |
|         Returns the Provider instance for the given tenant_id and provider_name.
 | |
|         If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
 | |
|         """
 | |
|         return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, prefer_custom)
 | |
| 
 | |
|     @classmethod
 | |
|     def get_valid_provider(cls, tenant_id: str, provider_name: str = None, prefer_custom: bool = False) -> Optional[Provider]:
 | |
|         """
 | |
|         Returns the Provider instance for the given tenant_id and provider_name.
 | |
|         If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
 | |
|         """
 | |
|         query = db.session.query(Provider).filter(
 | |
|             Provider.tenant_id == tenant_id
 | |
|         )
 | |
| 
 | |
|         if provider_name:
 | |
|             query = query.filter(Provider.provider_name == provider_name)
 | |
| 
 | |
|         providers = query.order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all()
 | |
| 
 | |
|         custom_provider = None
 | |
|         system_provider = None
 | |
| 
 | |
|         for provider in providers:
 | |
|             if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config:
 | |
|                 custom_provider = provider
 | |
|             elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid:
 | |
|                 system_provider = provider
 | |
| 
 | |
|         if custom_provider:
 | |
|             return custom_provider
 | |
|         elif system_provider:
 | |
|             return system_provider
 | |
|         else:
 | |
|             return None
 | |
| 
 | |
|     def get_hosted_credentials(self) -> str:
 | |
|         if self.get_provider_name() != ProviderName.OPENAI:
 | |
|             raise ProviderTokenNotInitError()
 | |
| 
 | |
|         if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
 | |
|             raise ProviderTokenNotInitError()
 | |
| 
 | |
|         return hosted_llm_credentials.openai.api_key
 | |
| 
 | |
|     def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
 | |
|         """
 | |
|         Returns the provider configs.
 | |
|         """
 | |
|         try:
 | |
|             config = self.get_provider_api_key()
 | |
|         except:
 | |
|             config = ''
 | |
| 
 | |
|         if obfuscated:
 | |
|             return self.obfuscated_token(config)
 | |
| 
 | |
|         return config
 | |
| 
 | |
|     def obfuscated_token(self, token: str):
 | |
|         return token[:6] + '*' * (len(token) - 8) + token[-2:]
 | |
| 
 | |
|     def get_token_type(self):
 | |
|         return str
 | |
| 
 | |
|     def get_encrypted_token(self, config: Union[dict | str]):
 | |
|         return self.encrypt_token(config)
 | |
| 
 | |
|     def get_decrypted_token(self, token: str):
 | |
|         return self.decrypt_token(token)
 | |
| 
 | |
|     def encrypt_token(self, token):
 | |
|         tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
 | |
|         encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
 | |
|         return base64.b64encode(encrypted_token).decode()
 | |
| 
 | |
|     def decrypt_token(self, token):
 | |
|         return rsa.decrypt(base64.b64decode(token), self.tenant_id)
 | |
| 
 | |
|     @abstractmethod
 | |
|     def get_provider_name(self):
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @abstractmethod
 | |
|     def get_credentials(self, model_id: Optional[str] = None) -> dict:
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @abstractmethod
 | |
|     def get_models(self, model_id: Optional[str] = None) -> list[dict]:
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @abstractmethod
 | |
|     def config_validate(self, config: str):
 | |
|         raise NotImplementedError
 | 
