Feat: add authorization header for MCP server based on OAuth 2.1 (#8292)

### What problem does this PR solve?

Add authorization header for MCP server based on [OAuth
2.1](https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-5).

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei 2025-06-17 09:29:12 +08:00 committed by GitHub
parent efc3caf702
commit a9532cb9e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 85 additions and 48 deletions

View File

@ -23,6 +23,9 @@ async def main():
try:
# To access RAGFlow server in `host` mode, you need to attach `api_key` for each request to indicate identification.
# async with sse_client("http://localhost:9382/sse", headers={"api_key": "ragflow-IyMGI1ZDhjMTA2ZTExZjBiYTMyMGQ4Zm"}) as streams:
# Or follow the requirements of OAuth 2.1 Section 5 with Authorization header
# async with sse_client("http://localhost:9382/sse", headers={"Authorization": "Bearer ragflow-IyMGI1ZDhjMTA2ZTExZjBiYTMyMGQ4Zm"}) as streams:
async with sse_client("http://localhost:9382/sse") as streams:
async with ClientSession(
streams[0],

View File

@ -17,6 +17,7 @@
import json
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from functools import wraps
import requests
from starlette.applications import Starlette
@ -127,22 +128,45 @@ app = Server("ragflow-server", lifespan=server_lifespan)
sse = SseServerTransport("/messages/")
def with_api_key(required=True):
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
ctx = app.request_context
ragflow_ctx = ctx.lifespan_context.get("ragflow_ctx")
if not ragflow_ctx:
raise ValueError("Get RAGFlow Context failed")
connector = ragflow_ctx.conn
if MODE == LaunchMode.HOST:
headers = ctx.session._init_options.capabilities.experimental.get("headers", {})
token = None
# lower case here, because of Starlette conversion
auth = headers.get("authorization", "")
if auth.startswith("Bearer "):
token = auth.removeprefix("Bearer ").strip()
elif "api_key" in headers:
token = headers["api_key"]
if required and not token:
raise ValueError("RAGFlow API key or Bearer token is required.")
connector.bind_api_key(token)
else:
connector.bind_api_key(HOST_API_KEY)
return await func(*args, connector=connector, **kwargs)
return wrapper
return decorator
@app.list_tools()
async def list_tools() -> list[types.Tool]:
ctx = app.request_context
ragflow_ctx = ctx.lifespan_context["ragflow_ctx"]
if not ragflow_ctx:
raise ValueError("Get RAGFlow Context failed")
connector = ragflow_ctx.conn
if MODE == LaunchMode.HOST:
api_key = ctx.session._init_options.capabilities.experimental["headers"]["api_key"]
if not api_key:
raise ValueError("RAGFlow API_KEY is required.")
else:
api_key = HOST_API_KEY
connector.bind_api_key(api_key)
@with_api_key(required=True)
async def list_tools(*, connector) -> list[types.Tool]:
dataset_description = connector.list_datasets()
return [
@ -152,7 +176,17 @@ async def list_tools() -> list[types.Tool]:
+ dataset_description,
inputSchema={
"type": "object",
"properties": {"dataset_ids": {"type": "array", "items": {"type": "string"}}, "document_ids": {"type": "array", "items": {"type": "string"}}, "question": {"type": "string"}},
"properties": {
"dataset_ids": {
"type": "array",
"items": {"type": "string"},
},
"document_ids": {
"type": "array",
"items": {"type": "string"},
},
"question": {"type": "string"},
},
"required": ["dataset_ids", "question"],
},
),
@ -160,24 +194,15 @@ async def list_tools() -> list[types.Tool]:
@app.call_tool()
async def call_tool(name: str, arguments: dict) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
ctx = app.request_context
ragflow_ctx = ctx.lifespan_context["ragflow_ctx"]
if not ragflow_ctx:
raise ValueError("Get RAGFlow Context failed")
connector = ragflow_ctx.conn
if MODE == LaunchMode.HOST:
api_key = ctx.session._init_options.capabilities.experimental["headers"]["api_key"]
if not api_key:
raise ValueError("RAGFlow API_KEY is required.")
else:
api_key = HOST_API_KEY
connector.bind_api_key(api_key)
@with_api_key(required=True)
async def call_tool(name: str, arguments: dict, *, connector) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
if name == "ragflow_retrieval":
document_ids = arguments.get("document_ids", [])
return connector.retrieval(dataset_ids=arguments["dataset_ids"], document_ids=document_ids, question=arguments["question"])
return connector.retrieval(
dataset_ids=arguments["dataset_ids"],
document_ids=document_ids,
question=arguments["question"],
)
raise ValueError(f"Tool not found: {name}")
@ -188,25 +213,34 @@ async def handle_sse(request):
class AuthMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
# Authentication is deferred, will be handled by RAGFlow core service.
if request.url.path.startswith("/sse") or request.url.path.startswith("/messages"):
api_key = request.headers.get("api_key")
if not api_key:
return JSONResponse({"error": "Missing unauthorization header"}, status_code=401)
token = None
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header.removeprefix("Bearer ").strip()
elif request.headers.get("api_key"):
token = request.headers["api_key"]
if not token:
return JSONResponse({"error": "Missing or invalid authorization header"}, status_code=401)
return await call_next(request)
middleware = None
if MODE == LaunchMode.HOST:
middleware = [Middleware(AuthMiddleware)]
def create_starlette_app():
middleware = None
if MODE == LaunchMode.HOST:
middleware = [Middleware(AuthMiddleware)]
starlette_app = Starlette(
debug=True,
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
middleware=middleware,
)
return Starlette(
debug=True,
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
middleware=middleware,
)
if __name__ == "__main__":
@ -236,7 +270,7 @@ if __name__ == "__main__":
default="self-host",
help="Launch mode options:\n"
" * self-host: Launches an MCP server to access a specific tenant space. The 'api_key' argument is required.\n"
" * host: Launches an MCP server that allows users to access their own spaces. Each request must include a header "
" * host: Launches an MCP server that allows users to access their own spaces. Each request must include a Authorization header "
"indicating the user's identification.",
)
parser.add_argument("--api_key", type=str, default="", help="RAGFlow MCP SERVER HOST API KEY")
@ -268,7 +302,7 @@ __ __ ____ ____ ____ _____ ______ _______ ____
print(f"MCP base_url: {BASE_URL}", flush=True)
uvicorn.run(
starlette_app,
create_starlette_app(),
host=HOST,
port=int(PORT),
)