dify/api/services/provider_service.py

89 lines
3.5 KiB
Python
Raw Normal View History

2023-05-15 08:51:32 +08:00
from typing import Union
from flask import current_app
from core.llm.provider.llm_provider_service import LLMProviderService
from models.account import Tenant
from models.provider import *
class ProviderService:
@staticmethod
2023-07-17 00:14:19 +08:00
def init_supported_provider(tenant):
2023-05-15 08:51:32 +08:00
"""Initialize the model provider, check whether the supported provider has a record"""
2023-07-17 00:14:19 +08:00
need_init_provider_names = [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value, ProviderName.ANTHROPIC.value]
2023-05-15 08:51:32 +08:00
2023-07-17 00:14:19 +08:00
providers = db.session.query(Provider).filter(
Provider.tenant_id == tenant.id,
Provider.provider_type == ProviderType.CUSTOM.value,
Provider.provider_name.in_(need_init_provider_names)
).all()
2023-05-15 08:51:32 +08:00
2023-07-17 00:14:19 +08:00
exists_provider_names = []
2023-05-15 08:51:32 +08:00
for provider in providers:
2023-07-17 00:14:19 +08:00
exists_provider_names.append(provider.provider_name)
2023-05-15 08:51:32 +08:00
2023-07-17 00:14:19 +08:00
not_exists_provider_names = list(set(need_init_provider_names) - set(exists_provider_names))
2023-05-15 08:51:32 +08:00
2023-07-17 00:14:19 +08:00
if not_exists_provider_names:
# Initialize the model provider, check whether the supported provider has a record
for provider_name in not_exists_provider_names:
provider = Provider(
tenant_id=tenant.id,
provider_name=provider_name,
provider_type=ProviderType.CUSTOM.value,
is_valid=False
)
db.session.add(provider)
2023-05-15 08:51:32 +08:00
db.session.commit()
@staticmethod
2023-07-17 00:14:19 +08:00
def get_obfuscated_api_key(tenant, provider_name: ProviderName, only_custom: bool = False):
2023-05-15 08:51:32 +08:00
llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
2023-07-17 00:14:19 +08:00
return llm_provider_service.get_provider_configs(obfuscated=True, only_custom=only_custom)
2023-05-15 08:51:32 +08:00
@staticmethod
def get_token_type(tenant, provider_name: ProviderName):
llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
return llm_provider_service.get_token_type()
@staticmethod
def validate_provider_configs(tenant, provider_name: ProviderName, configs: Union[dict | str]):
if current_app.config['DISABLE_PROVIDER_CONFIG_VALIDATION']:
return
2023-05-15 08:51:32 +08:00
llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
return llm_provider_service.config_validate(configs)
@staticmethod
def get_encrypted_token(tenant, provider_name: ProviderName, configs: Union[dict | str]):
llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
return llm_provider_service.get_encrypted_token(configs)
@staticmethod
2023-07-17 00:14:19 +08:00
def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value, quota_limit: int = 200,
2023-05-15 08:51:32 +08:00
is_valid: bool = True):
if current_app.config['EDITION'] != 'CLOUD':
return
provider = db.session.query(Provider).filter(
Provider.tenant_id == tenant.id,
Provider.provider_name == provider_name,
Provider.provider_type == ProviderType.SYSTEM.value
).one_or_none()
if not provider:
provider = Provider(
tenant_id=tenant.id,
provider_name=provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=ProviderQuotaType.TRIAL.value,
2023-07-17 00:14:19 +08:00
quota_limit=quota_limit,
2023-05-15 08:51:32 +08:00
encrypted_config='',
is_valid=is_valid,
)
db.session.add(provider)
db.session.commit()