dify/api/core/mcp/auth/auth_flow.py
Novice 6be013e072
feat: implement RFC-compliant OAuth discovery with dynamic scope selection for MCP providers (#28294)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-11-20 11:18:16 +08:00

706 lines
25 KiB
Python

import base64
import hashlib
import json
import os
import secrets
import urllib.parse
from urllib.parse import urljoin, urlparse
import httpx
from httpx import RequestError
from pydantic import ValidationError
from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType
from core.helper import ssrf_proxy
from core.mcp.entities import AuthAction, AuthActionType, AuthResult, OAuthCallbackState
from core.mcp.error import MCPRefreshTokenError
from core.mcp.types import (
LATEST_PROTOCOL_VERSION,
OAuthClientInformation,
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthMetadata,
OAuthTokens,
ProtectedResourceMetadata,
)
from extensions.ext_redis import redis_client
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
def generate_pkce_challenge() -> tuple[str, str]:
"""Generate PKCE challenge and verifier."""
code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
code_verifier = code_verifier.replace("=", "").replace("+", "-").replace("/", "_")
code_challenge_hash = hashlib.sha256(code_verifier.encode("utf-8")).digest()
code_challenge = base64.urlsafe_b64encode(code_challenge_hash).decode("utf-8")
code_challenge = code_challenge.replace("=", "").replace("+", "-").replace("/", "_")
return code_verifier, code_challenge
def build_protected_resource_metadata_discovery_urls(
www_auth_resource_metadata_url: str | None, server_url: str
) -> list[str]:
"""
Build a list of URLs to try for Protected Resource Metadata discovery.
Per SEP-985, supports fallback when discovery fails at one URL.
"""
urls = []
# First priority: URL from WWW-Authenticate header
if www_auth_resource_metadata_url:
urls.append(www_auth_resource_metadata_url)
# Fallback: construct from server URL
parsed = urlparse(server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
fallback_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
if fallback_url not in urls:
urls.append(fallback_url)
return urls
def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: str | None, server_url: str) -> list[str]:
"""
Build a list of URLs to try for OAuth Authorization Server Metadata discovery.
Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery.
Per RFC 8414 section 3:
- If issuer has no path: https://example.com/.well-known/oauth-authorization-server
- If issuer has path: https://example.com/.well-known/oauth-authorization-server{path}
Example:
- issuer: https://example.com/oauth
- metadata: https://example.com/.well-known/oauth-authorization-server/oauth
"""
urls = []
base_url = auth_server_url or server_url
parsed = urlparse(base_url)
base = f"{parsed.scheme}://{parsed.netloc}"
path = parsed.path.rstrip("/") # Remove trailing slash
# Try OpenID Connect discovery first (more common)
urls.append(urljoin(base + "/", ".well-known/openid-configuration"))
# OAuth 2.0 Authorization Server Metadata (RFC 8414)
# Include the path component if present in the issuer URL
if path:
urls.append(urljoin(base, f".well-known/oauth-authorization-server{path}"))
else:
urls.append(urljoin(base, ".well-known/oauth-authorization-server"))
return urls
def discover_protected_resource_metadata(
prm_url: str | None, server_url: str, protocol_version: str | None = None
) -> ProtectedResourceMetadata | None:
"""Discover OAuth 2.0 Protected Resource Metadata (RFC 9470)."""
urls = build_protected_resource_metadata_discovery_urls(prm_url, server_url)
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
for url in urls:
try:
response = ssrf_proxy.get(url, headers=headers)
if response.status_code == 200:
return ProtectedResourceMetadata.model_validate(response.json())
elif response.status_code == 404:
continue # Try next URL
except (RequestError, ValidationError):
continue # Try next URL
return None
def discover_oauth_authorization_server_metadata(
auth_server_url: str | None, server_url: str, protocol_version: str | None = None
) -> OAuthMetadata | None:
"""Discover OAuth 2.0 Authorization Server Metadata (RFC 8414)."""
urls = build_oauth_authorization_server_metadata_discovery_urls(auth_server_url, server_url)
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
for url in urls:
try:
response = ssrf_proxy.get(url, headers=headers)
if response.status_code == 200:
return OAuthMetadata.model_validate(response.json())
elif response.status_code == 404:
continue # Try next URL
except (RequestError, ValidationError):
continue # Try next URL
return None
def get_effective_scope(
scope_from_www_auth: str | None,
prm: ProtectedResourceMetadata | None,
asm: OAuthMetadata | None,
client_scope: str | None,
) -> str | None:
"""
Determine effective scope using priority-based selection strategy.
Priority order:
1. WWW-Authenticate header scope (server explicit requirement)
2. Protected Resource Metadata scopes
3. OAuth Authorization Server Metadata scopes
4. Client configured scope
"""
if scope_from_www_auth:
return scope_from_www_auth
if prm and prm.scopes_supported:
return " ".join(prm.scopes_supported)
if asm and asm.scopes_supported:
return " ".join(asm.scopes_supported)
return client_scope
def _create_secure_redis_state(state_data: OAuthCallbackState) -> str:
"""Create a secure state parameter by storing state data in Redis and returning a random state key."""
# Generate a secure random state key
state_key = secrets.token_urlsafe(32)
# Store the state data in Redis with expiration
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
redis_client.setex(redis_key, OAUTH_STATE_EXPIRY_SECONDS, state_data.model_dump_json())
return state_key
def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
"""Retrieve and decode OAuth state data from Redis using the state key, then delete it."""
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
# Get state data from Redis
state_data = redis_client.get(redis_key)
if not state_data:
raise ValueError("State parameter has expired or does not exist")
# Delete the state data from Redis immediately after retrieval to prevent reuse
redis_client.delete(redis_key)
try:
# Parse and validate the state data
oauth_state = OAuthCallbackState.model_validate_json(state_data)
return oauth_state
except ValidationError as e:
raise ValueError(f"Invalid state parameter: {str(e)}")
def handle_callback(state_key: str, authorization_code: str) -> tuple[OAuthCallbackState, OAuthTokens]:
"""
Handle the callback from the OAuth provider.
Returns:
A tuple of (callback_state, tokens) that can be used by the caller to save data.
"""
# Retrieve state data from Redis (state is automatically deleted after retrieval)
full_state_data = _retrieve_redis_state(state_key)
tokens = exchange_authorization(
full_state_data.server_url,
full_state_data.metadata,
full_state_data.client_information,
authorization_code,
full_state_data.code_verifier,
full_state_data.redirect_uri,
)
return full_state_data, tokens
def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
"""Check if the server supports OAuth 2.0 Resource Discovery."""
b_scheme, b_netloc, _, _, b_query, b_fragment = urlparse(server_url, "", True)
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource"
if b_query:
url_for_resource_discovery += f"?{b_query}"
if b_fragment:
url_for_resource_discovery += f"#{b_fragment}"
try:
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
response = ssrf_proxy.get(url_for_resource_discovery, headers=headers)
if 200 <= response.status_code < 300:
body = response.json()
# Support both singular and plural forms
if body.get("authorization_servers"):
return True, body["authorization_servers"][0]
elif body.get("authorization_server_url"):
return True, body["authorization_server_url"][0]
else:
return False, ""
return False, ""
except RequestError:
# Not support resource discovery, fall back to well-known OAuth metadata
return False, ""
def discover_oauth_metadata(
server_url: str,
resource_metadata_url: str | None = None,
scope_hint: str | None = None,
protocol_version: str | None = None,
) -> tuple[OAuthMetadata | None, ProtectedResourceMetadata | None, str | None]:
"""
Discover OAuth metadata using RFC 8414/9470 standards.
Args:
server_url: The MCP server URL
resource_metadata_url: Protected Resource Metadata URL from WWW-Authenticate header
scope_hint: Scope hint from WWW-Authenticate header
protocol_version: MCP protocol version
Returns:
(oauth_metadata, protected_resource_metadata, scope_hint)
"""
# Discover Protected Resource Metadata
prm = discover_protected_resource_metadata(resource_metadata_url, server_url, protocol_version)
# Get authorization server URL from PRM or use server URL
auth_server_url = None
if prm and prm.authorization_servers:
auth_server_url = prm.authorization_servers[0]
# Discover OAuth Authorization Server Metadata
asm = discover_oauth_authorization_server_metadata(auth_server_url, server_url, protocol_version)
return asm, prm, scope_hint
def start_authorization(
server_url: str,
metadata: OAuthMetadata | None,
client_information: OAuthClientInformation,
redirect_url: str,
provider_id: str,
tenant_id: str,
scope: str | None = None,
) -> tuple[str, str]:
"""Begins the authorization flow with secure Redis state storage."""
response_type = "code"
code_challenge_method = "S256"
if metadata:
authorization_url = metadata.authorization_endpoint
if response_type not in metadata.response_types_supported:
raise ValueError(f"Incompatible auth server: does not support response type {response_type}")
else:
authorization_url = urljoin(server_url, "/authorize")
code_verifier, code_challenge = generate_pkce_challenge()
# Prepare state data with all necessary information
state_data = OAuthCallbackState(
provider_id=provider_id,
tenant_id=tenant_id,
server_url=server_url,
metadata=metadata,
client_information=client_information,
code_verifier=code_verifier,
redirect_uri=redirect_url,
)
# Store state data in Redis and generate secure state key
state_key = _create_secure_redis_state(state_data)
params = {
"response_type": response_type,
"client_id": client_information.client_id,
"code_challenge": code_challenge,
"code_challenge_method": code_challenge_method,
"redirect_uri": redirect_url,
"state": state_key,
}
# Add scope if provided
if scope:
params["scope"] = scope
authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
return authorization_url, code_verifier
def _parse_token_response(response: httpx.Response) -> OAuthTokens:
"""
Parse OAuth token response supporting both JSON and form-urlencoded formats.
Per RFC 6749 Section 5.1, the standard format is JSON.
However, some legacy OAuth providers (e.g., early GitHub OAuth Apps) return
application/x-www-form-urlencoded format for backwards compatibility.
Args:
response: The HTTP response from token endpoint
Returns:
Parsed OAuth tokens
Raises:
ValueError: If response cannot be parsed
"""
content_type = response.headers.get("content-type", "").lower()
if "application/json" in content_type:
# Standard OAuth 2.0 JSON response (RFC 6749)
return OAuthTokens.model_validate(response.json())
elif "application/x-www-form-urlencoded" in content_type:
# Legacy form-urlencoded response (non-standard but used by some providers)
token_data = dict(urllib.parse.parse_qsl(response.text))
return OAuthTokens.model_validate(token_data)
else:
# No content-type or unknown - try JSON first, fallback to form-urlencoded
try:
return OAuthTokens.model_validate(response.json())
except (ValidationError, json.JSONDecodeError):
token_data = dict(urllib.parse.parse_qsl(response.text))
return OAuthTokens.model_validate(token_data)
def exchange_authorization(
server_url: str,
metadata: OAuthMetadata | None,
client_information: OAuthClientInformation,
authorization_code: str,
code_verifier: str,
redirect_uri: str,
) -> OAuthTokens:
"""Exchanges an authorization code for an access token."""
grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value
if metadata:
token_url = metadata.token_endpoint
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
else:
token_url = urljoin(server_url, "/token")
params = {
"grant_type": grant_type,
"client_id": client_information.client_id,
"code": authorization_code,
"code_verifier": code_verifier,
"redirect_uri": redirect_uri,
}
if client_information.client_secret:
params["client_secret"] = client_information.client_secret
response = ssrf_proxy.post(token_url, data=params)
if not response.is_success:
raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
return _parse_token_response(response)
def refresh_authorization(
server_url: str,
metadata: OAuthMetadata | None,
client_information: OAuthClientInformation,
refresh_token: str,
) -> OAuthTokens:
"""Exchange a refresh token for an updated access token."""
grant_type = MCPSupportGrantType.REFRESH_TOKEN.value
if metadata:
token_url = metadata.token_endpoint
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
else:
token_url = urljoin(server_url, "/token")
params = {
"grant_type": grant_type,
"client_id": client_information.client_id,
"refresh_token": refresh_token,
}
if client_information.client_secret:
params["client_secret"] = client_information.client_secret
try:
response = ssrf_proxy.post(token_url, data=params)
except ssrf_proxy.MaxRetriesExceededError as e:
raise MCPRefreshTokenError(e) from e
if not response.is_success:
raise MCPRefreshTokenError(response.text)
return _parse_token_response(response)
def client_credentials_flow(
server_url: str,
metadata: OAuthMetadata | None,
client_information: OAuthClientInformation,
scope: str | None = None,
) -> OAuthTokens:
"""Execute Client Credentials Flow to get access token."""
grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
if metadata:
token_url = metadata.token_endpoint
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
else:
token_url = urljoin(server_url, "/token")
# Support both Basic Auth and body parameters for client authentication
headers = {"Content-Type": "application/x-www-form-urlencoded"}
data = {"grant_type": grant_type}
if scope:
data["scope"] = scope
# If client_secret is provided, use Basic Auth (preferred method)
if client_information.client_secret:
credentials = f"{client_information.client_id}:{client_information.client_secret}"
encoded_credentials = base64.b64encode(credentials.encode()).decode()
headers["Authorization"] = f"Basic {encoded_credentials}"
else:
# Fall back to including credentials in the body
data["client_id"] = client_information.client_id
if client_information.client_secret:
data["client_secret"] = client_information.client_secret
response = ssrf_proxy.post(token_url, headers=headers, data=data)
if not response.is_success:
raise ValueError(
f"Client credentials token request failed: HTTP {response.status_code}, Response: {response.text}"
)
return _parse_token_response(response)
def register_client(
server_url: str,
metadata: OAuthMetadata | None,
client_metadata: OAuthClientMetadata,
) -> OAuthClientInformationFull:
"""Performs OAuth 2.0 Dynamic Client Registration."""
if metadata:
if not metadata.registration_endpoint:
raise ValueError("Incompatible auth server: does not support dynamic client registration")
registration_url = metadata.registration_endpoint
else:
registration_url = urljoin(server_url, "/register")
response = ssrf_proxy.post(
registration_url,
json=client_metadata.model_dump(),
headers={"Content-Type": "application/json"},
)
if not response.is_success:
response.raise_for_status()
return OAuthClientInformationFull.model_validate(response.json())
def auth(
provider: MCPProviderEntity,
authorization_code: str | None = None,
state_param: str | None = None,
resource_metadata_url: str | None = None,
scope_hint: str | None = None,
) -> AuthResult:
"""
Orchestrates the full auth flow with a server using secure Redis state storage.
This function performs only network operations and returns actions that need
to be performed by the caller (such as saving data to database).
Args:
provider: The MCP provider entity
authorization_code: Optional authorization code from OAuth callback
state_param: Optional state parameter from OAuth callback
resource_metadata_url: Optional Protected Resource Metadata URL from WWW-Authenticate
scope_hint: Optional scope hint from WWW-Authenticate header
Returns:
AuthResult containing actions to be performed and response data
"""
actions: list[AuthAction] = []
server_url = provider.decrypt_server_url()
# Discover OAuth metadata using RFC 8414/9470 standards
server_metadata, prm, scope_from_www_auth = discover_oauth_metadata(
server_url, resource_metadata_url, scope_hint, LATEST_PROTOCOL_VERSION
)
client_metadata = provider.client_metadata
provider_id = provider.id
tenant_id = provider.tenant_id
client_information = provider.retrieve_client_information()
redirect_url = provider.redirect_url
credentials = provider.decrypt_credentials()
# Determine grant type based on server metadata
if not server_metadata:
raise ValueError("Failed to discover OAuth metadata from server")
supported_grant_types = server_metadata.grant_types_supported or []
# Convert to lowercase for comparison
supported_grant_types_lower = [gt.lower() for gt in supported_grant_types]
# Determine which grant type to use
effective_grant_type = None
if MCPSupportGrantType.AUTHORIZATION_CODE.value in supported_grant_types_lower:
effective_grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value
else:
effective_grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
# Determine effective scope using priority-based strategy
effective_scope = get_effective_scope(scope_from_www_auth, prm, server_metadata, credentials.get("scope"))
if not client_information:
if authorization_code is not None:
raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
# For client credentials flow, we don't need to register client dynamically
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
# Client should provide client_id and client_secret directly
raise ValueError("Client credentials flow requires client_id and client_secret to be provided")
try:
full_information = register_client(server_url, server_metadata, client_metadata)
except RequestError as e:
raise ValueError(f"Could not register OAuth client: {e}")
# Return action to save client information
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_CLIENT_INFO,
data={"client_information": full_information.model_dump()},
provider_id=provider_id,
tenant_id=tenant_id,
)
)
client_information = full_information
# Handle client credentials flow
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
# Direct token request without user interaction
try:
tokens = client_credentials_flow(
server_url,
server_metadata,
client_information,
effective_scope,
)
# Return action to save tokens and grant type
token_data = tokens.model_dump()
token_data["grant_type"] = MCPSupportGrantType.CLIENT_CREDENTIALS.value
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_TOKENS,
data=token_data,
provider_id=provider_id,
tenant_id=tenant_id,
)
)
return AuthResult(actions=actions, response={"result": "success"})
except (RequestError, ValueError, KeyError) as e:
# RequestError: HTTP request failed
# ValueError: Invalid response data
# KeyError: Missing required fields in response
raise ValueError(f"Client credentials flow failed: {e}")
# Exchange authorization code for tokens (Authorization Code flow)
if authorization_code is not None:
if not state_param:
raise ValueError("State parameter is required when exchanging authorization code")
try:
# Retrieve state data from Redis using state key
full_state_data = _retrieve_redis_state(state_param)
code_verifier = full_state_data.code_verifier
redirect_uri = full_state_data.redirect_uri
if not code_verifier or not redirect_uri:
raise ValueError("Missing code_verifier or redirect_uri in state data")
except (json.JSONDecodeError, ValueError) as e:
raise ValueError(f"Invalid state parameter: {e}")
tokens = exchange_authorization(
server_url,
server_metadata,
client_information,
authorization_code,
code_verifier,
redirect_uri,
)
# Return action to save tokens
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_TOKENS,
data=tokens.model_dump(),
provider_id=provider_id,
tenant_id=tenant_id,
)
)
return AuthResult(actions=actions, response={"result": "success"})
provider_tokens = provider.retrieve_tokens()
# Handle token refresh or new authorization
if provider_tokens and provider_tokens.refresh_token:
try:
new_tokens = refresh_authorization(
server_url, server_metadata, client_information, provider_tokens.refresh_token
)
# Return action to save new tokens
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_TOKENS,
data=new_tokens.model_dump(),
provider_id=provider_id,
tenant_id=tenant_id,
)
)
return AuthResult(actions=actions, response={"result": "success"})
except (RequestError, ValueError, KeyError) as e:
# RequestError: HTTP request failed
# ValueError: Invalid response data
# KeyError: Missing required fields in response
raise ValueError(f"Could not refresh OAuth tokens: {e}")
# Start new authorization flow (only for authorization code flow)
authorization_url, code_verifier = start_authorization(
server_url,
server_metadata,
client_information,
redirect_url,
provider_id,
tenant_id,
effective_scope,
)
# Return action to save code verifier
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_CODE_VERIFIER,
data={"code_verifier": code_verifier},
provider_id=provider_id,
tenant_id=tenant_id,
)
)
return AuthResult(actions=actions, response={"authorization_url": authorization_url})