Feat: add MCP dashboard functionalities list_tools and test_tool (#8505)

### What problem does this PR solve?

Add MCP dashboard functionalities list_tools and test_tool.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei 2025-06-26 13:52:01 +08:00 committed by GitHub
parent 6b1221d2f6
commit 0eb90e73a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 225 additions and 84 deletions

View File

@ -8,7 +8,8 @@ from api.db.services.user_service import TenantService
from api.settings import RetCode
from api.utils import get_uuid
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
from api.utils.web_utils import safe_json_parse
from api.utils.web_utils import get_float, safe_json_parse
from mcp_client.mcp_tool_call import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
@manager.route("/list", methods=["POST"]) # noqa: F821
@ -95,8 +96,13 @@ def update() -> Response:
if server_name and len(server_name.encode("utf-8")) > 255:
return get_data_error_result(message=f"Invaild MCP name or length is {len(server_name)} which is large than 255.")
req["headers"] = safe_json_parse(req.get("headers", {}))
req["variables"] = safe_json_parse(req.get("variables", {}))
mcp_id = req.get("id", "")
e, mcp_server = MCPServerService.get_by_id(mcp_id)
if not e or mcp_server.tenant_id != current_user.id:
return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}")
req["headers"] = safe_json_parse(req.get("headers", mcp_server.headers))
req["variables"] = safe_json_parse(req.get("variables", mcp_server.variables))
try:
req["tenant_id"] = current_user.id
@ -212,3 +218,69 @@ def export_multiple() -> Response:
return get_json_result(data={"mcpServers": exported_servers})
except Exception as e:
return server_error_response(e)
@manager.route("/list_tools", methods=["POST"]) # noqa: F821
@login_required
@validate_request("mcp_ids")
def list_tools() -> Response:
req = request.get_json()
mcp_ids = req.get("mcp_ids", [])
if not mcp_ids:
return get_data_error_result(message="No MCP server IDs provided.")
timeout = get_float(req, "timeout", 10)
results = {}
tool_call_sessions = []
try:
for mcp_id in mcp_ids:
e, mcp_server = MCPServerService.get_by_id(mcp_id)
if e and mcp_server.tenant_id == current_user.id:
server_key = mcp_server.id
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
tool_call_sessions.append(tool_call_session)
tools = tool_call_session.get_tools(timeout)
results[server_key] = [tool.model_dump() for tool in tools]
# PERF: blocking call to close sessions — consider moving to background thread or task queue
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
return get_json_result(data=results)
except Exception as e:
return server_error_response(e)
@manager.route("/test_tool", methods=["POST"]) # noqa: F821
@login_required
@validate_request("mcp_id", "tool_name", "arguments")
def test_tool() -> Response:
req = request.get_json()
mcp_id = req.get("mcp_id", "")
if not mcp_id:
return get_data_error_result(message="No MCP server ID provided.")
timeout = get_float(req, "timeout", 10)
tool_name = req.get("tool_name", "")
arguments = req.get("arguments", {})
if not all([tool_name, arguments]):
return get_data_error_result(message="Require provide tool name and arguments.")
tool_call_sessions = []
try:
e, mcp_server = MCPServerService.get_by_id(mcp_id)
if not e or mcp_server.tenant_id != current_user.id:
return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}")
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
tool_call_sessions.append(tool_call_session)
result = tool_call_session.tool_call(tool_name, arguments, timeout)
# PERF: blocking call to close sessions — consider moving to background thread or task queue
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
return get_json_result(data=result)
except Exception as e:
return server_error_response(e)

View File

@ -19,6 +19,7 @@
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
from api.utils.log_utils import init_root_logger
from mcp_client.mcp_tool_call import shutdown_all_mcp_sessions
from plugin import GlobalPluginManager
init_root_logger("ragflow_server")
@ -66,6 +67,7 @@ def update_progress():
def signal_handler(sig, frame):
logging.info("Received interrupt signal, shutting down...")
shutdown_all_mcp_sessions()
stop_event.set()
time.sleep(1)
sys.exit(0)

View File

@ -14,28 +14,28 @@
# limitations under the License.
#
import base64
import ipaddress
import json
import re
import socket
from urllib.parse import urlparse
import ipaddress
import json
import base64
from selenium import webdriver
from selenium.common.exceptions import TimeoutException
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.chrome.service import Service
from selenium.common.exceptions import TimeoutException
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support.expected_conditions import staleness_of
from webdriver_manager.chrome import ChromeDriverManager
from selenium.webdriver.common.by import By
from selenium.webdriver.support.expected_conditions import staleness_of
from selenium.webdriver.support.ui import WebDriverWait
from webdriver_manager.chrome import ChromeDriverManager
def html2pdf(
source: str,
timeout: int = 2,
install_driver: bool = True,
print_options: dict = {},
source: str,
timeout: int = 2,
install_driver: bool = True,
print_options: dict = {},
):
result = __get_pdf_from_html(source, timeout, install_driver, print_options)
return result
@ -53,12 +53,7 @@ def __send_devtools(driver, cmd, params={}):
return response.get("value")
def __get_pdf_from_html(
path: str,
timeout: int,
install_driver: bool,
print_options: dict
):
def __get_pdf_from_html(path: str, timeout: int, install_driver: bool, print_options: dict):
webdriver_options = Options()
webdriver_prefs = {}
webdriver_options.add_argument("--headless")
@ -78,9 +73,7 @@ def __get_pdf_from_html(
driver.get(path)
try:
WebDriverWait(driver, timeout).until(
staleness_of(driver.find_element(by=By.TAG_NAME, value="html"))
)
WebDriverWait(driver, timeout).until(staleness_of(driver.find_element(by=By.TAG_NAME, value="html")))
except TimeoutException:
calculated_print_options = {
"landscape": False,
@ -89,8 +82,7 @@ def __get_pdf_from_html(
"preferCSSPageSize": True,
}
calculated_print_options.update(print_options)
result = __send_devtools(
driver, "Page.printToPDF", calculated_print_options)
result = __send_devtools(driver, "Page.printToPDF", calculated_print_options)
driver.quit()
return base64.b64decode(result["data"])
@ -102,6 +94,7 @@ def is_private_ip(ip: str) -> bool:
except ValueError:
return False
def is_valid_url(url: str) -> bool:
if not re.match(r"(https?)://[-A-Za-z0-9+&@#/%?=~_|!:,.;]+[-A-Za-z0-9+&@#/%=~_|]", url):
return False
@ -127,3 +120,10 @@ def safe_json_parse(data: str | dict) -> dict:
except (json.JSONDecodeError, TypeError):
return {}
def get_float(req: dict, key: str, default: float | int = 10.0) -> float:
try:
parsed = float(req.get(key, default))
return parsed if parsed > 0 else default
except (TypeError, ValueError):
return default

View File

@ -1,45 +1,43 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
import logging
import threading
import weakref
from concurrent.futures import ThreadPoolExecutor
from string import Template
from typing import Any, Literal
from typing_extensions import override
from api.db import MCPServerType
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool
from api.db import MCPServerType
from rag.llm.chat_model import ToolCallSession
MCPTaskType = Literal["list_tools", "tool_call", "stop"]
MCPTaskType = Literal["list_tools", "tool_call"]
MCPTask = tuple[MCPTaskType, dict[str, Any], asyncio.Queue[Any]]
class MCPToolCallSession(ToolCallSession):
_EVENT_LOOP = asyncio.new_event_loop()
_THREAD_POOL = ThreadPoolExecutor(max_workers=1)
_mcp_server: Any
_server_variables: dict[str, Any]
_queue: asyncio.Queue[MCPTask]
_stop = False
@classmethod
def _init_thread_pool(cls) -> None:
cls._THREAD_POOL.submit(cls._EVENT_LOOP.run_forever)
_ALL_INSTANCES: weakref.WeakSet["MCPToolCallSession"] = weakref.WeakSet()
def __init__(self, mcp_server: Any, server_variables: dict[str, Any] | None = None) -> None:
self.__class__._ALL_INSTANCES.add(self)
self._mcp_server = mcp_server
self._server_variables = server_variables or {}
self._queue = asyncio.Queue()
self._close = False
asyncio.run_coroutine_threadsafe(self._mcp_server_loop(), MCPToolCallSession._EVENT_LOOP)
self._event_loop = asyncio.new_event_loop()
self._thread_pool = ThreadPoolExecutor(max_workers=1)
self._thread_pool.submit(self._event_loop.run_forever)
asyncio.run_coroutine_threadsafe(self._mcp_server_loop(), self._event_loop)
async def _mcp_server_loop(self) -> None:
url = self._mcp_server.url
url = self._mcp_server.url.strip()
raw_headers: dict[str, str] = self._mcp_server.headers or {}
headers: dict[str, str] = {}
@ -48,45 +46,62 @@ class MCPToolCallSession(ToolCallSession):
nv = Template(v).safe_substitute(self._server_variables)
headers[nh] = nv
_streams_source: Any
if self._mcp_server.server_type == MCPServerType.SSE:
_streams_source = sse_client(url, headers)
elif self._mcp_server.server_type == MCPServerType.StreamableHttp:
_streams_source = streamablehttp_client(url, headers)
# SSE transport
async with sse_client(url, headers) as stream:
async with ClientSession(*stream) as client_session:
try:
await asyncio.wait_for(client_session.initialize(), timeout=5)
logging.info("client_session initialized successfully")
except asyncio.TimeoutError:
logging.error(f"Timeout initializing client_session for server {self._mcp_server.id}")
return
await self._process_mcp_tasks(client_session)
elif self._mcp_server.server_type == MCPServerType.STREAMABLE_HTTP:
# Streamable HTTP transport
async with streamablehttp_client(url, headers) as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as client_session:
try:
await asyncio.wait_for(client_session.initialize(), timeout=5)
logging.info("client_session initialized successfully")
except asyncio.TimeoutError:
logging.error(f"Timeout initializing client_session for server {self._mcp_server.id}")
return
await asyncio.wait_for(client_session.initialize(), timeout=5)
await self._process_mcp_tasks(client_session)
else:
raise ValueError(f"Unsupported MCP server type {self._mcp_server.server_type} id {self._mcp_server.id}")
async with _streams_source as streams:
async with ClientSession(*streams) as client_session:
await client_session.initialize()
async def _process_mcp_tasks(self, client_session: ClientSession) -> None:
while not self._close:
try:
mcp_task, arguments, result_queue = await asyncio.wait_for(self._queue.get(), timeout=1)
except asyncio.TimeoutError:
continue
while not self._stop:
mcp_task, arguments, result_queue = await self._queue.get()
logging.debug(f"Got MCP task {mcp_task} arguments {arguments}")
logging.debug(f"Got MCP task {mcp_task} arguments {arguments}")
r: Any
r: Any = None
try:
if mcp_task == "list_tools":
r = await client_session.list_tools()
elif mcp_task == "tool_call":
r = await client_session.call_tool(**arguments)
else:
r = ValueError(f"Unknown MCP task {mcp_task}")
except Exception as e:
r = e
try:
if mcp_task == "list_tools":
r = await client_session.list_tools()
elif mcp_task == "tool_call":
r = await client_session.call_tool(**arguments)
elif mcp_task == "stop":
logging.debug(f"Shutting down MCPToolCallSession for server {self._mcp_server.id}")
self._stop = True
continue
else:
r = ValueError(f"MCPToolCallSession for server {self._mcp_server.id} received an unknown task {mcp_task}")
except Exception as e:
r = e
await result_queue.put(r)
await result_queue.put(r)
async def _call_mcp_server(self, task_type: MCPTaskType, **kwargs) -> Any:
async def _call_mcp_server(self, task_type: MCPTaskType, timeout: float = 8, **kwargs) -> Any:
results = asyncio.Queue()
await self._queue.put((task_type, kwargs, results))
result: CallToolResult | Exception = await results.get()
try:
result: CallToolResult | Exception = await asyncio.wait_for(results.get(), timeout=timeout)
except asyncio.TimeoutError:
raise TimeoutError(f"MCP task '{task_type}' timeout after {timeout}s")
if isinstance(result, Exception):
raise result
@ -106,32 +121,84 @@ class MCPToolCallSession(ToolCallSession):
return f"Unsupported content type {type(result.content)}"
async def _get_tools_from_mcp_server(self) -> list[Tool]:
# For now we only fetch the first page of tools
result: ListToolsResult = await self._call_mcp_server("list_tools")
return result.tools
def get_tools(self) -> list[Tool]:
return asyncio.run_coroutine_threadsafe(self._get_tools_from_mcp_server(), MCPToolCallSession._EVENT_LOOP).result()
def get_tools(self, timeout: float = 10) -> list[Tool]:
future = asyncio.run_coroutine_threadsafe(self._get_tools_from_mcp_server(), self._event_loop)
try:
return future.result(timeout=timeout)
except TimeoutError:
logging.error(f"Timeout when fetching tools from MCP server: {self._mcp_server.id}")
return []
except Exception:
logging.exception(f"Error fetching tools from MCP server: {self._mcp_server.id}")
return []
@override
def tool_call(self, name: str, arguments: dict[str, Any]) -> str:
return asyncio.run_coroutine_threadsafe(self._call_mcp_tool(name, arguments), MCPToolCallSession._EVENT_LOOP).result()
def tool_call(self, name: str, arguments: dict[str, Any], timeout: float = 10) -> str:
future = asyncio.run_coroutine_threadsafe(self._call_mcp_tool(name, arguments), self._event_loop)
try:
return future.result(timeout=timeout)
except TimeoutError as te:
logging.error(f"Timeout calling tool '{name}' on MCP server: {self._mcp_server.id}")
return f"Timeout calling tool '{name}': {te}."
except Exception as e:
logging.exception(f"Error calling tool '{name}' on MCP server: {self._mcp_server.id}")
return f"Error calling tool '{name}': {e}."
async def close(self) -> None:
await self._call_mcp_server("stop")
if self._close:
return
def close_sync(self) -> None:
asyncio.run_coroutine_threadsafe(self.close(), MCPToolCallSession._EVENT_LOOP).result()
self._close = True
self._event_loop.call_soon_threadsafe(self._event_loop.stop)
self._thread_pool.shutdown(wait=True)
self.__class__._ALL_INSTANCES.discard(self)
def close_sync(self, timeout: float = 5) -> None:
if not self._event_loop.is_running():
logging.warning(f"Event loop already stopped for {self._mcp_server.id}")
return
MCPToolCallSession._init_thread_pool()
future = asyncio.run_coroutine_threadsafe(self.close(), self._event_loop)
try:
future.result(timeout=timeout)
except TimeoutError:
logging.error(f"Timeout while closing session for server {self._mcp_server.id}")
except Exception:
logging.exception(f"Unexpected error during close_sync for {self._mcp_server.id}")
def close_multiple_mcp_toolcall_sessions(sessions: list[MCPToolCallSession]) -> None:
async def _gather() -> None:
await asyncio.gather(*[s.close() for s in sessions], return_exceptions=True)
logging.info(f"Want to clean up {len(sessions)} MCP sessions")
asyncio.run_coroutine_threadsafe(_gather(), MCPToolCallSession._EVENT_LOOP).result()
async def _gather_and_stop() -> None:
try:
await asyncio.gather(*[s.close() for s in sessions if s is not None], return_exceptions=True)
finally:
loop.call_soon_threadsafe(loop.stop)
loop = asyncio.new_event_loop()
thread = threading.Thread(target=loop.run_forever, daemon=True)
thread.start()
asyncio.run_coroutine_threadsafe(_gather_and_stop(), loop).result()
thread.join()
logging.info(f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.")
def shutdown_all_mcp_sessions():
"""Gracefully shutdown all active MCPToolCallSession instances."""
sessions = list(MCPToolCallSession._ALL_INSTANCES)
if not sessions:
logging.info("No MCPToolCallSession instances to close.")
return
logging.info(f"Shutting down {len(sessions)} MCPToolCallSession instances...")
close_multiple_mcp_toolcall_sessions(sessions)
logging.info("All MCPToolCallSession instances have been closed.")
def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool) -> dict[str, Any]: