mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-06-26 22:19:57 +00:00
Feat: add MCP treamable-http transport (#8449)
### What problem does this PR solve? Add MCP treamable-http transport. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
8d9d2cc0a9
commit
f21827bc28
36
mcp/client/streamable_http_client.py
Normal file
36
mcp/client/streamable_http_client.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
from mcp import ClientSession
|
||||||
|
from mcp.client.streamable_http import streamablehttp_client
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
try:
|
||||||
|
async with streamablehttp_client("http://localhost:9382/mcp/") as (read_stream, write_stream, _):
|
||||||
|
async with ClientSession(read_stream, write_stream) as session:
|
||||||
|
await session.initialize()
|
||||||
|
tools = await session.list_tools()
|
||||||
|
print(f"{tools.tools=}")
|
||||||
|
response = await session.call_tool(name="ragflow_retrieval", arguments={"dataset_ids": ["bc4177924a7a11f09eff238aa5c10c94"], "document_ids": [], "question": "How to install neovim?"})
|
||||||
|
print(f"Tool response: {response.model_dump()}")
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from anyio import run
|
||||||
|
|
||||||
|
run(main)
|
@ -15,6 +15,7 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
@ -29,7 +30,6 @@ from strenum import StrEnum
|
|||||||
|
|
||||||
import mcp.types as types
|
import mcp.types as types
|
||||||
from mcp.server.lowlevel import Server
|
from mcp.server.lowlevel import Server
|
||||||
from mcp.server.sse import SseServerTransport
|
|
||||||
|
|
||||||
|
|
||||||
class LaunchMode(StrEnum):
|
class LaunchMode(StrEnum):
|
||||||
@ -37,11 +37,19 @@ class LaunchMode(StrEnum):
|
|||||||
HOST = "host"
|
HOST = "host"
|
||||||
|
|
||||||
|
|
||||||
|
class Transport(StrEnum):
|
||||||
|
SSE = "sse"
|
||||||
|
STEAMABLE_HTTP = "streamable-http"
|
||||||
|
|
||||||
|
|
||||||
BASE_URL = "http://127.0.0.1:9380"
|
BASE_URL = "http://127.0.0.1:9380"
|
||||||
HOST = "127.0.0.1"
|
HOST = "127.0.0.1"
|
||||||
PORT = "9382"
|
PORT = "9382"
|
||||||
HOST_API_KEY = ""
|
HOST_API_KEY = ""
|
||||||
MODE = ""
|
MODE = ""
|
||||||
|
TRANSPORT_SSE_ENABLED = True
|
||||||
|
TRANSPORT_STREAMABLE_HTTP_ENABLED = True
|
||||||
|
JSON_RESPONSE = True
|
||||||
|
|
||||||
|
|
||||||
class RAGFlowConnector:
|
class RAGFlowConnector:
|
||||||
@ -115,17 +123,17 @@ class RAGFlowCtx:
|
|||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def server_lifespan(server: Server) -> AsyncIterator[dict]:
|
async def sse_lifespan(server: Server) -> AsyncIterator[dict]:
|
||||||
ctx = RAGFlowCtx(RAGFlowConnector(base_url=BASE_URL))
|
ctx = RAGFlowCtx(RAGFlowConnector(base_url=BASE_URL))
|
||||||
|
|
||||||
|
logging.info("Legacy SSE application started with StreamableHTTP session manager!")
|
||||||
try:
|
try:
|
||||||
yield {"ragflow_ctx": ctx}
|
yield {"ragflow_ctx": ctx}
|
||||||
finally:
|
finally:
|
||||||
pass
|
logging.info("Legacy SSE application shutting down...")
|
||||||
|
|
||||||
|
|
||||||
app = Server("ragflow-server", lifespan=server_lifespan)
|
app = Server("ragflow-mcp-server", lifespan=sse_lifespan)
|
||||||
sse = SseServerTransport("/messages/")
|
|
||||||
|
|
||||||
|
|
||||||
def with_api_key(required=True):
|
def with_api_key(required=True):
|
||||||
@ -206,13 +214,8 @@ async def call_tool(name: str, arguments: dict, *, connector) -> list[types.Text
|
|||||||
raise ValueError(f"Tool not found: {name}")
|
raise ValueError(f"Tool not found: {name}")
|
||||||
|
|
||||||
|
|
||||||
async def handle_sse(request):
|
|
||||||
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
|
|
||||||
await app.run(streams[0], streams[1], app.create_initialization_options(experimental_capabilities={"headers": dict(request.headers)}))
|
|
||||||
return Response()
|
|
||||||
|
|
||||||
|
|
||||||
def create_starlette_app():
|
def create_starlette_app():
|
||||||
|
routes = []
|
||||||
middleware = None
|
middleware = None
|
||||||
if MODE == LaunchMode.HOST:
|
if MODE == LaunchMode.HOST:
|
||||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||||
@ -227,7 +230,7 @@ def create_starlette_app():
|
|||||||
return
|
return
|
||||||
|
|
||||||
path = scope["path"]
|
path = scope["path"]
|
||||||
if path.startswith("/messages/") or path.startswith("/sse"):
|
if path.startswith("/messages/") or path.startswith("/sse") or path.startswith("/mcp"):
|
||||||
headers = dict(scope["headers"])
|
headers = dict(scope["headers"])
|
||||||
token = None
|
token = None
|
||||||
auth_header = headers.get(b"authorization")
|
auth_header = headers.get(b"authorization")
|
||||||
@ -245,13 +248,57 @@ def create_starlette_app():
|
|||||||
|
|
||||||
middleware = [Middleware(AuthMiddleware)]
|
middleware = [Middleware(AuthMiddleware)]
|
||||||
|
|
||||||
return Starlette(
|
# Add SSE routes if enabled
|
||||||
debug=True,
|
if TRANSPORT_SSE_ENABLED:
|
||||||
routes=[
|
from mcp.server.sse import SseServerTransport
|
||||||
|
|
||||||
|
sse = SseServerTransport("/messages/")
|
||||||
|
|
||||||
|
async def handle_sse(request):
|
||||||
|
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
|
||||||
|
await app.run(streams[0], streams[1], app.create_initialization_options(experimental_capabilities={"headers": dict(request.headers)}))
|
||||||
|
return Response()
|
||||||
|
|
||||||
|
routes.extend(
|
||||||
|
[
|
||||||
Route("/sse", endpoint=handle_sse, methods=["GET"]),
|
Route("/sse", endpoint=handle_sse, methods=["GET"]),
|
||||||
Mount("/messages/", app=sse.handle_post_message),
|
Mount("/messages/", app=sse.handle_post_message),
|
||||||
],
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add streamable HTTP route if enabled
|
||||||
|
streamablehttp_lifespan = None
|
||||||
|
if TRANSPORT_STREAMABLE_HTTP_ENABLED:
|
||||||
|
from starlette.types import Receive, Scope, Send
|
||||||
|
|
||||||
|
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
||||||
|
|
||||||
|
session_manager = StreamableHTTPSessionManager(
|
||||||
|
app=app,
|
||||||
|
event_store=None,
|
||||||
|
json_response=JSON_RESPONSE,
|
||||||
|
stateless=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
|
await session_manager.handle_request(scope, receive, send)
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def streamablehttp_lifespan(app: Starlette) -> AsyncIterator[None]:
|
||||||
|
async with session_manager.run():
|
||||||
|
logging.info("StreamableHTTP application started with StreamableHTTP session manager!")
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
logging.info("StreamableHTTP application shutting down...")
|
||||||
|
|
||||||
|
routes.append(Mount("/mcp", app=handle_streamable_http))
|
||||||
|
|
||||||
|
return Starlette(
|
||||||
|
debug=True,
|
||||||
|
routes=routes,
|
||||||
middleware=middleware,
|
middleware=middleware,
|
||||||
|
lifespan=streamablehttp_lifespan,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -266,7 +313,22 @@ def create_starlette_app():
|
|||||||
help=("Launch mode:\n self-host: run MCP for a single tenant (requires --api-key)\n host: multi-tenant mode, users must provide Authorization headers"),
|
help=("Launch mode:\n self-host: run MCP for a single tenant (requires --api-key)\n host: multi-tenant mode, users must provide Authorization headers"),
|
||||||
)
|
)
|
||||||
@click.option("--api-key", type=str, default="", help="API key to use when in self-host mode")
|
@click.option("--api-key", type=str, default="", help="API key to use when in self-host mode")
|
||||||
def main(base_url, host, port, mode, api_key):
|
@click.option(
|
||||||
|
"--transport-sse-enabled/--no-transport-sse-enabled",
|
||||||
|
default=True,
|
||||||
|
help="Enable or disable legacy SSE transport mode (default: enabled)",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--transport-streamable-http-enabled/--no-transport-streamable-http-enabled",
|
||||||
|
default=True,
|
||||||
|
help="Enable or disable streamable-http transport mode (default: enabled)",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--json-response/--no-json-response",
|
||||||
|
default=True,
|
||||||
|
help="Enable or disable JSON response mode for streamable-http (default: enabled)",
|
||||||
|
)
|
||||||
|
def main(base_url, host, port, mode, api_key, transport_sse_enabled, transport_streamable_http_enabled, json_response):
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
@ -274,16 +336,29 @@ def main(base_url, host, port, mode, api_key):
|
|||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
global BASE_URL, HOST, PORT, MODE, HOST_API_KEY
|
def parse_bool_flag(key: str, default: bool) -> bool:
|
||||||
|
val = os.environ.get(key, str(default))
|
||||||
|
return str(val).strip().lower() in ("1", "true", "yes", "on")
|
||||||
|
|
||||||
|
global BASE_URL, HOST, PORT, MODE, HOST_API_KEY, TRANSPORT_SSE_ENABLED, TRANSPORT_STREAMABLE_HTTP_ENABLED, JSON_RESPONSE
|
||||||
BASE_URL = os.environ.get("RAGFLOW_MCP_BASE_URL", base_url)
|
BASE_URL = os.environ.get("RAGFLOW_MCP_BASE_URL", base_url)
|
||||||
HOST = os.environ.get("RAGFLOW_MCP_HOST", host)
|
HOST = os.environ.get("RAGFLOW_MCP_HOST", host)
|
||||||
PORT = os.environ.get("RAGFLOW_MCP_PORT", str(port))
|
PORT = os.environ.get("RAGFLOW_MCP_PORT", str(port))
|
||||||
MODE = os.environ.get("RAGFLOW_MCP_LAUNCH_MODE", mode)
|
MODE = os.environ.get("RAGFLOW_MCP_LAUNCH_MODE", mode)
|
||||||
HOST_API_KEY = os.environ.get("RAGFLOW_MCP_HOST_API_KEY", api_key)
|
HOST_API_KEY = os.environ.get("RAGFLOW_MCP_HOST_API_KEY", api_key)
|
||||||
|
TRANSPORT_SSE_ENABLED = parse_bool_flag("RAGFLOW_MCP_TRANSPORT_SSE_ENABLED", transport_sse_enabled)
|
||||||
|
TRANSPORT_STREAMABLE_HTTP_ENABLED = parse_bool_flag("RAGFLOW_MCP_TRANSPORT_STREAMABLE_ENABLED", transport_streamable_http_enabled)
|
||||||
|
JSON_RESPONSE = parse_bool_flag("RAGFLOW_MCP_JSON_RESPONSE", json_response)
|
||||||
|
|
||||||
if MODE == "self-host" and not HOST_API_KEY:
|
if MODE == LaunchMode.SELF_HOST and not HOST_API_KEY:
|
||||||
raise click.UsageError("--api-key is required when --mode is 'self-host'")
|
raise click.UsageError("--api-key is required when --mode is 'self-host'")
|
||||||
|
|
||||||
|
if TRANSPORT_STREAMABLE_HTTP_ENABLED and MODE == LaunchMode.HOST:
|
||||||
|
raise click.UsageError("The --host mode is not supported with streamable-http transport yet.")
|
||||||
|
|
||||||
|
if not TRANSPORT_STREAMABLE_HTTP_ENABLED and JSON_RESPONSE:
|
||||||
|
JSON_RESPONSE = False
|
||||||
|
|
||||||
print(
|
print(
|
||||||
r"""
|
r"""
|
||||||
__ __ ____ ____ ____ _____ ______ _______ ____
|
__ __ ____ ____ ____ _____ ______ _______ ____
|
||||||
@ -299,6 +374,24 @@ __ __ ____ ____ ____ _____ ______ _______ ____
|
|||||||
print(f"MCP port: {PORT}", flush=True)
|
print(f"MCP port: {PORT}", flush=True)
|
||||||
print(f"MCP base_url: {BASE_URL}", flush=True)
|
print(f"MCP base_url: {BASE_URL}", flush=True)
|
||||||
|
|
||||||
|
if TRANSPORT_SSE_ENABLED:
|
||||||
|
print("SSE transport enabled: yes", flush=True)
|
||||||
|
print("SSE endpoint available at /sse", flush=True)
|
||||||
|
else:
|
||||||
|
print("SSE transport enabled: no", flush=True)
|
||||||
|
|
||||||
|
if TRANSPORT_STREAMABLE_HTTP_ENABLED:
|
||||||
|
print("Streamable HTTP transport enabled: yes", flush=True)
|
||||||
|
print("Streamable HTTP endpoint available at /mcp", flush=True)
|
||||||
|
if JSON_RESPONSE:
|
||||||
|
print("Streamable HTTP mode: JSON response enabled", flush=True)
|
||||||
|
else:
|
||||||
|
print("Streamable HTTP mode: SSE over HTTP enabled", flush=True)
|
||||||
|
else:
|
||||||
|
print("Streamable HTTP transport enabled: no", flush=True)
|
||||||
|
if JSON_RESPONSE:
|
||||||
|
print("Warning: --json-response ignored because streamable transport is disabled.", flush=True)
|
||||||
|
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
create_starlette_app(),
|
create_starlette_app(),
|
||||||
host=HOST,
|
host=HOST,
|
||||||
@ -308,10 +401,32 @@ __ __ ____ ____ ____ _____ ______ _______ ____
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
"""
|
"""
|
||||||
Launch example:
|
Launch examples:
|
||||||
self-host:
|
|
||||||
uv run mcp/server/server.py --host=127.0.0.1 --port=9382 --base-url=http://127.0.0.1:9380 --mode=self-host --api-key=ragflow-xxxxx
|
1. Self-host mode with both SSE and Streamable HTTP (in JSON response mode) enabled (default):
|
||||||
host:
|
uv run mcp/server/server.py --host=127.0.0.1 --port=9382 \
|
||||||
uv run mcp/server/server.py --host=127.0.0.1 --port=9382 --base-url=http://127.0.0.1:9380 --mode=host
|
--base-url=http://127.0.0.1:9380 \
|
||||||
|
--mode=self-host --api-key=ragflow-xxxxx
|
||||||
|
|
||||||
|
2. Host mode (multi-tenant, self-host only, clients must provide Authorization headers):
|
||||||
|
uv run mcp/server/server.py --host=127.0.0.1 --port=9382 \
|
||||||
|
--base-url=http://127.0.0.1:9380 \
|
||||||
|
--mode=host
|
||||||
|
|
||||||
|
3. Disable legacy SSE (only streamable HTTP will be active):
|
||||||
|
uv run mcp/server/server.py --no-transport-sse-enabled \
|
||||||
|
--mode=self-host --api-key=ragflow-xxxxx
|
||||||
|
|
||||||
|
4. Disable streamable HTTP (only legacy SSE will be active):
|
||||||
|
uv run mcp/server/server.py --no-transport-streamable-http-enabled \
|
||||||
|
--mode=self-host --api-key=ragflow-xxxxx
|
||||||
|
|
||||||
|
5. Use streamable HTTP with SSE-style events (disable JSON response):
|
||||||
|
uv run mcp/server/server.py --transport-streamable-http-enabled --no-json-response \
|
||||||
|
--mode=self-host --api-key=ragflow-xxxxx
|
||||||
|
|
||||||
|
6. Disable both transports (for testing):
|
||||||
|
uv run mcp/server/server.py --no-transport-sse-enabled --no-transport-streamable-http-enabled \
|
||||||
|
--mode=self-host --api-key=ragflow-xxxxx
|
||||||
"""
|
"""
|
||||||
main()
|
main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user