mirror of
https://github.com/langgenius/dify.git
synced 2025-07-18 23:02:25 +00:00
feat(oauth): refactor proxy context (#21483)
This commit is contained in:
parent
164e5481c5
commit
1dd2607dfd
@ -8,9 +8,10 @@ from extensions.ext_redis import redis_client
|
|||||||
class OAuthProxyService(BasePluginClient):
|
class OAuthProxyService(BasePluginClient):
|
||||||
# Default max age for proxy context parameter in seconds
|
# Default max age for proxy context parameter in seconds
|
||||||
__MAX_AGE__ = 5 * 60 # 5 minutes
|
__MAX_AGE__ = 5 * 60 # 5 minutes
|
||||||
|
__KEY_PREFIX__ = "oauth_proxy_context:"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_proxy_context(user_id, tenant_id, plugin_id, provider):
|
def create_proxy_context(user_id: str, tenant_id: str, plugin_id: str, provider: str):
|
||||||
"""
|
"""
|
||||||
Create a proxy context for an OAuth 2.0 authorization request.
|
Create a proxy context for an OAuth 2.0 authorization request.
|
||||||
|
|
||||||
@ -23,26 +24,22 @@ class OAuthProxyService(BasePluginClient):
|
|||||||
is used to verify the state, ensuring the request's integrity and authenticity,
|
is used to verify the state, ensuring the request's integrity and authenticity,
|
||||||
and mitigating replay attacks.
|
and mitigating replay attacks.
|
||||||
"""
|
"""
|
||||||
seconds, _ = redis_client.time()
|
|
||||||
context_id = str(uuid.uuid4())
|
context_id = str(uuid.uuid4())
|
||||||
data = {
|
data = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"plugin_id": plugin_id,
|
"plugin_id": plugin_id,
|
||||||
"tenant_id": tenant_id,
|
"tenant_id": tenant_id,
|
||||||
"provider": provider,
|
"provider": provider,
|
||||||
# encode redis time to avoid distribution time skew
|
|
||||||
"timestamp": seconds,
|
|
||||||
}
|
}
|
||||||
# ignore nonce collision
|
|
||||||
redis_client.setex(
|
redis_client.setex(
|
||||||
f"oauth_proxy_context:{context_id}",
|
f"{OAuthProxyService.__KEY_PREFIX__}{context_id}",
|
||||||
OAuthProxyService.__MAX_AGE__,
|
OAuthProxyService.__MAX_AGE__,
|
||||||
json.dumps(data),
|
json.dumps(data),
|
||||||
)
|
)
|
||||||
return context_id
|
return context_id
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def use_proxy_context(context_id, max_age=__MAX_AGE__):
|
def use_proxy_context(context_id: str):
|
||||||
"""
|
"""
|
||||||
Validate the proxy context parameter.
|
Validate the proxy context parameter.
|
||||||
This checks if the context_id is valid and not expired.
|
This checks if the context_id is valid and not expired.
|
||||||
@ -50,12 +47,7 @@ class OAuthProxyService(BasePluginClient):
|
|||||||
if not context_id:
|
if not context_id:
|
||||||
raise ValueError("context_id is required")
|
raise ValueError("context_id is required")
|
||||||
# get data from redis
|
# get data from redis
|
||||||
data = redis_client.getdel(f"oauth_proxy_context:{context_id}")
|
data = redis_client.getdel(f"{OAuthProxyService.__KEY_PREFIX__}{context_id}")
|
||||||
if not data:
|
if not data:
|
||||||
raise ValueError("context_id is invalid")
|
raise ValueError("context_id is invalid")
|
||||||
# check if data is expired
|
return json.loads(data)
|
||||||
seconds, _ = redis_client.time()
|
|
||||||
state = json.loads(data)
|
|
||||||
if state.get("timestamp") < seconds - max_age:
|
|
||||||
raise ValueError("context_id is expired")
|
|
||||||
return state
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user