feat(oauth): refactor proxy context (#21483)

This commit is contained in:
Maries 2025-06-25 15:10:45 +08:00 committed by GitHub
parent 164e5481c5
commit 1dd2607dfd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -8,9 +8,10 @@ from extensions.ext_redis import redis_client
class OAuthProxyService(BasePluginClient):
# Default max age for proxy context parameter in seconds
__MAX_AGE__ = 5 * 60 # 5 minutes
__KEY_PREFIX__ = "oauth_proxy_context:"
@staticmethod
def create_proxy_context(user_id, tenant_id, plugin_id, provider):
def create_proxy_context(user_id: str, tenant_id: str, plugin_id: str, provider: str):
"""
Create a proxy context for an OAuth 2.0 authorization request.
@ -23,26 +24,22 @@ class OAuthProxyService(BasePluginClient):
is used to verify the state, ensuring the request's integrity and authenticity,
and mitigating replay attacks.
"""
seconds, _ = redis_client.time()
context_id = str(uuid.uuid4())
data = {
"user_id": user_id,
"plugin_id": plugin_id,
"tenant_id": tenant_id,
"provider": provider,
# encode redis time to avoid distribution time skew
"timestamp": seconds,
}
# ignore nonce collision
redis_client.setex(
f"oauth_proxy_context:{context_id}",
f"{OAuthProxyService.__KEY_PREFIX__}{context_id}",
OAuthProxyService.__MAX_AGE__,
json.dumps(data),
)
return context_id
@staticmethod
def use_proxy_context(context_id, max_age=__MAX_AGE__):
def use_proxy_context(context_id: str):
"""
Validate the proxy context parameter.
This checks if the context_id is valid and not expired.
@ -50,12 +47,7 @@ class OAuthProxyService(BasePluginClient):
if not context_id:
raise ValueError("context_id is required")
# get data from redis
data = redis_client.getdel(f"oauth_proxy_context:{context_id}")
data = redis_client.getdel(f"{OAuthProxyService.__KEY_PREFIX__}{context_id}")
if not data:
raise ValueError("context_id is invalid")
# check if data is expired
seconds, _ = redis_client.time()
state = json.loads(data)
if state.get("timestamp") < seconds - max_age:
raise ValueError("context_id is expired")
return state
return json.loads(data)