mirror of
https://github.com/langgenius/dify.git
synced 2025-07-29 20:40:19 +00:00
82 lines
3.3 KiB
Python
82 lines
3.3 KiB
Python
from typing import Optional
|
|
|
|
from configs import dify_config
|
|
from core.mcp.types import (
|
|
OAuthClientInformation,
|
|
OAuthClientInformationFull,
|
|
OAuthClientMetadata,
|
|
OAuthTokens,
|
|
)
|
|
from models.tools import MCPToolProvider
|
|
from services.tools.mcp_tools_mange_service import MCPToolManageService
|
|
|
|
LATEST_PROTOCOL_VERSION = "1.0"
|
|
|
|
|
|
class OAuthClientProvider:
|
|
mcp_provider: MCPToolProvider
|
|
|
|
def __init__(self, provider_id: str, tenant_id: str, for_list: bool = False):
|
|
if for_list:
|
|
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
|
else:
|
|
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_server_identifier(provider_id, tenant_id)
|
|
|
|
@property
|
|
def redirect_url(self) -> str:
|
|
"""The URL to redirect the user agent to after authorization."""
|
|
return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
|
|
|
|
@property
|
|
def client_metadata(self) -> OAuthClientMetadata:
|
|
"""Metadata about this OAuth client."""
|
|
return OAuthClientMetadata(
|
|
redirect_uris=[self.redirect_url],
|
|
token_endpoint_auth_method="none",
|
|
grant_types=["authorization_code", "refresh_token"],
|
|
response_types=["code"],
|
|
client_name="Dify",
|
|
client_uri="https://github.com/langgenius/dify",
|
|
)
|
|
|
|
def client_information(self) -> Optional[OAuthClientInformation]:
|
|
"""Loads information about this OAuth client."""
|
|
client_information = self.mcp_provider.decrypted_credentials.get("client_information", {})
|
|
if not client_information:
|
|
return None
|
|
return OAuthClientInformation.model_validate(client_information)
|
|
|
|
def save_client_information(self, client_information: OAuthClientInformationFull) -> None:
|
|
"""Saves client information after dynamic registration."""
|
|
MCPToolManageService.update_mcp_provider_credentials(
|
|
self.mcp_provider,
|
|
{"client_information": client_information.model_dump()},
|
|
)
|
|
|
|
def tokens(self) -> Optional[OAuthTokens]:
|
|
"""Loads any existing OAuth tokens for the current session."""
|
|
credentials = self.mcp_provider.decrypted_credentials
|
|
if not credentials:
|
|
return None
|
|
return OAuthTokens(
|
|
access_token=credentials.get("access_token", ""),
|
|
token_type=credentials.get("token_type", "Bearer"),
|
|
expires_in=int(credentials.get("expires_in", "3600") or 3600),
|
|
refresh_token=credentials.get("refresh_token", ""),
|
|
)
|
|
|
|
def save_tokens(self, tokens: OAuthTokens) -> None:
|
|
"""Stores new OAuth tokens for the current session."""
|
|
# update mcp provider credentials
|
|
token_dict = tokens.model_dump()
|
|
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True)
|
|
|
|
def save_code_verifier(self, code_verifier: str) -> None:
|
|
"""Saves a PKCE code verifier for the current session."""
|
|
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier})
|
|
|
|
def code_verifier(self) -> str:
|
|
"""Loads the PKCE code verifier for the current session."""
|
|
# get code verifier from mcp provider credentials
|
|
return str(self.mcp_provider.decrypted_credentials.get("code_verifier", ""))
|