mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-06-26 22:19:57 +00:00
Feat: add authorization header for MCP server based on OAuth 2.1 (#8292)
### What problem does this PR solve? Add authorization header for MCP server based on [OAuth 2.1](https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-5). ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
efc3caf702
commit
a9532cb9e7
@ -23,6 +23,9 @@ async def main():
|
||||
try:
|
||||
# To access RAGFlow server in `host` mode, you need to attach `api_key` for each request to indicate identification.
|
||||
# async with sse_client("http://localhost:9382/sse", headers={"api_key": "ragflow-IyMGI1ZDhjMTA2ZTExZjBiYTMyMGQ4Zm"}) as streams:
|
||||
# Or follow the requirements of OAuth 2.1 Section 5 with Authorization header
|
||||
# async with sse_client("http://localhost:9382/sse", headers={"Authorization": "Bearer ragflow-IyMGI1ZDhjMTA2ZTExZjBiYTMyMGQ4Zm"}) as streams:
|
||||
|
||||
async with sse_client("http://localhost:9382/sse") as streams:
|
||||
async with ClientSession(
|
||||
streams[0],
|
||||
|
@ -17,6 +17,7 @@
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import wraps
|
||||
|
||||
import requests
|
||||
from starlette.applications import Starlette
|
||||
@ -127,22 +128,45 @@ app = Server("ragflow-server", lifespan=server_lifespan)
|
||||
sse = SseServerTransport("/messages/")
|
||||
|
||||
|
||||
def with_api_key(required=True):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
ctx = app.request_context
|
||||
ragflow_ctx = ctx.lifespan_context.get("ragflow_ctx")
|
||||
if not ragflow_ctx:
|
||||
raise ValueError("Get RAGFlow Context failed")
|
||||
|
||||
connector = ragflow_ctx.conn
|
||||
|
||||
if MODE == LaunchMode.HOST:
|
||||
headers = ctx.session._init_options.capabilities.experimental.get("headers", {})
|
||||
token = None
|
||||
|
||||
# lower case here, because of Starlette conversion
|
||||
auth = headers.get("authorization", "")
|
||||
if auth.startswith("Bearer "):
|
||||
token = auth.removeprefix("Bearer ").strip()
|
||||
elif "api_key" in headers:
|
||||
token = headers["api_key"]
|
||||
|
||||
if required and not token:
|
||||
raise ValueError("RAGFlow API key or Bearer token is required.")
|
||||
|
||||
connector.bind_api_key(token)
|
||||
else:
|
||||
connector.bind_api_key(HOST_API_KEY)
|
||||
|
||||
return await func(*args, connector=connector, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@app.list_tools()
|
||||
async def list_tools() -> list[types.Tool]:
|
||||
ctx = app.request_context
|
||||
ragflow_ctx = ctx.lifespan_context["ragflow_ctx"]
|
||||
if not ragflow_ctx:
|
||||
raise ValueError("Get RAGFlow Context failed")
|
||||
connector = ragflow_ctx.conn
|
||||
|
||||
if MODE == LaunchMode.HOST:
|
||||
api_key = ctx.session._init_options.capabilities.experimental["headers"]["api_key"]
|
||||
if not api_key:
|
||||
raise ValueError("RAGFlow API_KEY is required.")
|
||||
else:
|
||||
api_key = HOST_API_KEY
|
||||
connector.bind_api_key(api_key)
|
||||
|
||||
@with_api_key(required=True)
|
||||
async def list_tools(*, connector) -> list[types.Tool]:
|
||||
dataset_description = connector.list_datasets()
|
||||
|
||||
return [
|
||||
@ -152,7 +176,17 @@ async def list_tools() -> list[types.Tool]:
|
||||
+ dataset_description,
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {"dataset_ids": {"type": "array", "items": {"type": "string"}}, "document_ids": {"type": "array", "items": {"type": "string"}}, "question": {"type": "string"}},
|
||||
"properties": {
|
||||
"dataset_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
"document_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
"question": {"type": "string"},
|
||||
},
|
||||
"required": ["dataset_ids", "question"],
|
||||
},
|
||||
),
|
||||
@ -160,24 +194,15 @@ async def list_tools() -> list[types.Tool]:
|
||||
|
||||
|
||||
@app.call_tool()
|
||||
async def call_tool(name: str, arguments: dict) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||
ctx = app.request_context
|
||||
ragflow_ctx = ctx.lifespan_context["ragflow_ctx"]
|
||||
if not ragflow_ctx:
|
||||
raise ValueError("Get RAGFlow Context failed")
|
||||
connector = ragflow_ctx.conn
|
||||
|
||||
if MODE == LaunchMode.HOST:
|
||||
api_key = ctx.session._init_options.capabilities.experimental["headers"]["api_key"]
|
||||
if not api_key:
|
||||
raise ValueError("RAGFlow API_KEY is required.")
|
||||
else:
|
||||
api_key = HOST_API_KEY
|
||||
connector.bind_api_key(api_key)
|
||||
|
||||
@with_api_key(required=True)
|
||||
async def call_tool(name: str, arguments: dict, *, connector) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||
if name == "ragflow_retrieval":
|
||||
document_ids = arguments.get("document_ids", [])
|
||||
return connector.retrieval(dataset_ids=arguments["dataset_ids"], document_ids=document_ids, question=arguments["question"])
|
||||
return connector.retrieval(
|
||||
dataset_ids=arguments["dataset_ids"],
|
||||
document_ids=document_ids,
|
||||
question=arguments["question"],
|
||||
)
|
||||
raise ValueError(f"Tool not found: {name}")
|
||||
|
||||
|
||||
@ -188,25 +213,34 @@ async def handle_sse(request):
|
||||
|
||||
class AuthMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request, call_next):
|
||||
# Authentication is deferred, will be handled by RAGFlow core service.
|
||||
if request.url.path.startswith("/sse") or request.url.path.startswith("/messages"):
|
||||
api_key = request.headers.get("api_key")
|
||||
if not api_key:
|
||||
return JSONResponse({"error": "Missing unauthorization header"}, status_code=401)
|
||||
token = None
|
||||
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
token = auth_header.removeprefix("Bearer ").strip()
|
||||
elif request.headers.get("api_key"):
|
||||
token = request.headers["api_key"]
|
||||
|
||||
if not token:
|
||||
return JSONResponse({"error": "Missing or invalid authorization header"}, status_code=401)
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
middleware = None
|
||||
if MODE == LaunchMode.HOST:
|
||||
middleware = [Middleware(AuthMiddleware)]
|
||||
def create_starlette_app():
|
||||
middleware = None
|
||||
if MODE == LaunchMode.HOST:
|
||||
middleware = [Middleware(AuthMiddleware)]
|
||||
|
||||
starlette_app = Starlette(
|
||||
debug=True,
|
||||
routes=[
|
||||
Route("/sse", endpoint=handle_sse),
|
||||
Mount("/messages/", app=sse.handle_post_message),
|
||||
],
|
||||
middleware=middleware,
|
||||
)
|
||||
return Starlette(
|
||||
debug=True,
|
||||
routes=[
|
||||
Route("/sse", endpoint=handle_sse),
|
||||
Mount("/messages/", app=sse.handle_post_message),
|
||||
],
|
||||
middleware=middleware,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -236,7 +270,7 @@ if __name__ == "__main__":
|
||||
default="self-host",
|
||||
help="Launch mode options:\n"
|
||||
" * self-host: Launches an MCP server to access a specific tenant space. The 'api_key' argument is required.\n"
|
||||
" * host: Launches an MCP server that allows users to access their own spaces. Each request must include a header "
|
||||
" * host: Launches an MCP server that allows users to access their own spaces. Each request must include a Authorization header "
|
||||
"indicating the user's identification.",
|
||||
)
|
||||
parser.add_argument("--api_key", type=str, default="", help="RAGFlow MCP SERVER HOST API KEY")
|
||||
@ -268,7 +302,7 @@ __ __ ____ ____ ____ _____ ______ _______ ____
|
||||
print(f"MCP base_url: {BASE_URL}", flush=True)
|
||||
|
||||
uvicorn.run(
|
||||
starlette_app,
|
||||
create_starlette_app(),
|
||||
host=HOST,
|
||||
port=int(PORT),
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user