diff --git a/api/apps/mcp_server_app.py b/api/apps/mcp_server_app.py index 185788c81..3d29debdb 100644 --- a/api/apps/mcp_server_app.py +++ b/api/apps/mcp_server_app.py @@ -8,7 +8,8 @@ from api.db.services.user_service import TenantService from api.settings import RetCode from api.utils import get_uuid from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request -from api.utils.web_utils import safe_json_parse +from api.utils.web_utils import get_float, safe_json_parse +from mcp_client.mcp_tool_call import MCPToolCallSession, close_multiple_mcp_toolcall_sessions @manager.route("/list", methods=["POST"]) # noqa: F821 @@ -95,8 +96,13 @@ def update() -> Response: if server_name and len(server_name.encode("utf-8")) > 255: return get_data_error_result(message=f"Invaild MCP name or length is {len(server_name)} which is large than 255.") - req["headers"] = safe_json_parse(req.get("headers", {})) - req["variables"] = safe_json_parse(req.get("variables", {})) + mcp_id = req.get("id", "") + e, mcp_server = MCPServerService.get_by_id(mcp_id) + if not e or mcp_server.tenant_id != current_user.id: + return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}") + + req["headers"] = safe_json_parse(req.get("headers", mcp_server.headers)) + req["variables"] = safe_json_parse(req.get("variables", mcp_server.variables)) try: req["tenant_id"] = current_user.id @@ -212,3 +218,69 @@ def export_multiple() -> Response: return get_json_result(data={"mcpServers": exported_servers}) except Exception as e: return server_error_response(e) + + +@manager.route("/list_tools", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("mcp_ids") +def list_tools() -> Response: + req = request.get_json() + mcp_ids = req.get("mcp_ids", []) + if not mcp_ids: + return get_data_error_result(message="No MCP server IDs provided.") + + timeout = get_float(req, "timeout", 10) + + results = {} + tool_call_sessions = [] + try: + for mcp_id in mcp_ids: + e, mcp_server = MCPServerService.get_by_id(mcp_id) + + if e and mcp_server.tenant_id == current_user.id: + server_key = mcp_server.id + + tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables) + tool_call_sessions.append(tool_call_session) + tools = tool_call_session.get_tools(timeout) + + results[server_key] = [tool.model_dump() for tool in tools] + + # PERF: blocking call to close sessions — consider moving to background thread or task queue + close_multiple_mcp_toolcall_sessions(tool_call_sessions) + return get_json_result(data=results) + except Exception as e: + return server_error_response(e) + + +@manager.route("/test_tool", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("mcp_id", "tool_name", "arguments") +def test_tool() -> Response: + req = request.get_json() + mcp_id = req.get("mcp_id", "") + if not mcp_id: + return get_data_error_result(message="No MCP server ID provided.") + + timeout = get_float(req, "timeout", 10) + + tool_name = req.get("tool_name", "") + arguments = req.get("arguments", {}) + if not all([tool_name, arguments]): + return get_data_error_result(message="Require provide tool name and arguments.") + + tool_call_sessions = [] + try: + e, mcp_server = MCPServerService.get_by_id(mcp_id) + if not e or mcp_server.tenant_id != current_user.id: + return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}") + + tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables) + tool_call_sessions.append(tool_call_session) + result = tool_call_session.tool_call(tool_name, arguments, timeout) + + # PERF: blocking call to close sessions — consider moving to background thread or task queue + close_multiple_mcp_toolcall_sessions(tool_call_sessions) + return get_json_result(data=result) + except Exception as e: + return server_error_response(e) diff --git a/api/ragflow_server.py b/api/ragflow_server.py index 75bc8916c..288ae7fd3 100644 --- a/api/ragflow_server.py +++ b/api/ragflow_server.py @@ -19,6 +19,7 @@ # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code from api.utils.log_utils import init_root_logger +from mcp_client.mcp_tool_call import shutdown_all_mcp_sessions from plugin import GlobalPluginManager init_root_logger("ragflow_server") @@ -66,6 +67,7 @@ def update_progress(): def signal_handler(sig, frame): logging.info("Received interrupt signal, shutting down...") + shutdown_all_mcp_sessions() stop_event.set() time.sleep(1) sys.exit(0) diff --git a/api/utils/web_utils.py b/api/utils/web_utils.py index de3d692dd..5b89248d7 100644 --- a/api/utils/web_utils.py +++ b/api/utils/web_utils.py @@ -14,28 +14,28 @@ # limitations under the License. # +import base64 +import ipaddress +import json import re import socket from urllib.parse import urlparse -import ipaddress -import json -import base64 from selenium import webdriver +from selenium.common.exceptions import TimeoutException from selenium.webdriver.chrome.options import Options from selenium.webdriver.chrome.service import Service -from selenium.common.exceptions import TimeoutException -from selenium.webdriver.support.ui import WebDriverWait -from selenium.webdriver.support.expected_conditions import staleness_of -from webdriver_manager.chrome import ChromeDriverManager from selenium.webdriver.common.by import By +from selenium.webdriver.support.expected_conditions import staleness_of +from selenium.webdriver.support.ui import WebDriverWait +from webdriver_manager.chrome import ChromeDriverManager def html2pdf( - source: str, - timeout: int = 2, - install_driver: bool = True, - print_options: dict = {}, + source: str, + timeout: int = 2, + install_driver: bool = True, + print_options: dict = {}, ): result = __get_pdf_from_html(source, timeout, install_driver, print_options) return result @@ -53,12 +53,7 @@ def __send_devtools(driver, cmd, params={}): return response.get("value") -def __get_pdf_from_html( - path: str, - timeout: int, - install_driver: bool, - print_options: dict -): +def __get_pdf_from_html(path: str, timeout: int, install_driver: bool, print_options: dict): webdriver_options = Options() webdriver_prefs = {} webdriver_options.add_argument("--headless") @@ -78,9 +73,7 @@ def __get_pdf_from_html( driver.get(path) try: - WebDriverWait(driver, timeout).until( - staleness_of(driver.find_element(by=By.TAG_NAME, value="html")) - ) + WebDriverWait(driver, timeout).until(staleness_of(driver.find_element(by=By.TAG_NAME, value="html"))) except TimeoutException: calculated_print_options = { "landscape": False, @@ -89,8 +82,7 @@ def __get_pdf_from_html( "preferCSSPageSize": True, } calculated_print_options.update(print_options) - result = __send_devtools( - driver, "Page.printToPDF", calculated_print_options) + result = __send_devtools(driver, "Page.printToPDF", calculated_print_options) driver.quit() return base64.b64decode(result["data"]) @@ -102,6 +94,7 @@ def is_private_ip(ip: str) -> bool: except ValueError: return False + def is_valid_url(url: str) -> bool: if not re.match(r"(https?)://[-A-Za-z0-9+&@#/%?=~_|!:,.;]+[-A-Za-z0-9+&@#/%=~_|]", url): return False @@ -127,3 +120,10 @@ def safe_json_parse(data: str | dict) -> dict: except (json.JSONDecodeError, TypeError): return {} + +def get_float(req: dict, key: str, default: float | int = 10.0) -> float: + try: + parsed = float(req.get(key, default)) + return parsed if parsed > 0 else default + except (TypeError, ValueError): + return default diff --git a/mcp_client/mcp_tool_call.py b/mcp_client/mcp_tool_call.py index e0a7cf192..22fe5d20b 100644 --- a/mcp_client/mcp_tool_call.py +++ b/mcp_client/mcp_tool_call.py @@ -1,45 +1,43 @@ import asyncio -from concurrent.futures import ThreadPoolExecutor import logging +import threading +import weakref +from concurrent.futures import ThreadPoolExecutor from string import Template from typing import Any, Literal + from typing_extensions import override +from api.db import MCPServerType from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool - -from api.db import MCPServerType from rag.llm.chat_model import ToolCallSession - -MCPTaskType = Literal["list_tools", "tool_call", "stop"] +MCPTaskType = Literal["list_tools", "tool_call"] MCPTask = tuple[MCPTaskType, dict[str, Any], asyncio.Queue[Any]] class MCPToolCallSession(ToolCallSession): - _EVENT_LOOP = asyncio.new_event_loop() - _THREAD_POOL = ThreadPoolExecutor(max_workers=1) - - _mcp_server: Any - _server_variables: dict[str, Any] - _queue: asyncio.Queue[MCPTask] - _stop = False - - @classmethod - def _init_thread_pool(cls) -> None: - cls._THREAD_POOL.submit(cls._EVENT_LOOP.run_forever) + _ALL_INSTANCES: weakref.WeakSet["MCPToolCallSession"] = weakref.WeakSet() def __init__(self, mcp_server: Any, server_variables: dict[str, Any] | None = None) -> None: + self.__class__._ALL_INSTANCES.add(self) + self._mcp_server = mcp_server self._server_variables = server_variables or {} self._queue = asyncio.Queue() + self._close = False - asyncio.run_coroutine_threadsafe(self._mcp_server_loop(), MCPToolCallSession._EVENT_LOOP) + self._event_loop = asyncio.new_event_loop() + self._thread_pool = ThreadPoolExecutor(max_workers=1) + self._thread_pool.submit(self._event_loop.run_forever) + + asyncio.run_coroutine_threadsafe(self._mcp_server_loop(), self._event_loop) async def _mcp_server_loop(self) -> None: - url = self._mcp_server.url + url = self._mcp_server.url.strip() raw_headers: dict[str, str] = self._mcp_server.headers or {} headers: dict[str, str] = {} @@ -48,45 +46,62 @@ class MCPToolCallSession(ToolCallSession): nv = Template(v).safe_substitute(self._server_variables) headers[nh] = nv - _streams_source: Any - if self._mcp_server.server_type == MCPServerType.SSE: - _streams_source = sse_client(url, headers) - elif self._mcp_server.server_type == MCPServerType.StreamableHttp: - _streams_source = streamablehttp_client(url, headers) + # SSE transport + async with sse_client(url, headers) as stream: + async with ClientSession(*stream) as client_session: + try: + await asyncio.wait_for(client_session.initialize(), timeout=5) + logging.info("client_session initialized successfully") + except asyncio.TimeoutError: + logging.error(f"Timeout initializing client_session for server {self._mcp_server.id}") + return + await self._process_mcp_tasks(client_session) + + elif self._mcp_server.server_type == MCPServerType.STREAMABLE_HTTP: + # Streamable HTTP transport + async with streamablehttp_client(url, headers) as (read_stream, write_stream, _): + async with ClientSession(read_stream, write_stream) as client_session: + try: + await asyncio.wait_for(client_session.initialize(), timeout=5) + logging.info("client_session initialized successfully") + except asyncio.TimeoutError: + logging.error(f"Timeout initializing client_session for server {self._mcp_server.id}") + return + await asyncio.wait_for(client_session.initialize(), timeout=5) + await self._process_mcp_tasks(client_session) else: raise ValueError(f"Unsupported MCP server type {self._mcp_server.server_type} id {self._mcp_server.id}") - async with _streams_source as streams: - async with ClientSession(*streams) as client_session: - await client_session.initialize() + async def _process_mcp_tasks(self, client_session: ClientSession) -> None: + while not self._close: + try: + mcp_task, arguments, result_queue = await asyncio.wait_for(self._queue.get(), timeout=1) + except asyncio.TimeoutError: + continue - while not self._stop: - mcp_task, arguments, result_queue = await self._queue.get() - logging.debug(f"Got MCP task {mcp_task} arguments {arguments}") + logging.debug(f"Got MCP task {mcp_task} arguments {arguments}") - r: Any + r: Any = None + try: + if mcp_task == "list_tools": + r = await client_session.list_tools() + elif mcp_task == "tool_call": + r = await client_session.call_tool(**arguments) + else: + r = ValueError(f"Unknown MCP task {mcp_task}") + except Exception as e: + r = e - try: - if mcp_task == "list_tools": - r = await client_session.list_tools() - elif mcp_task == "tool_call": - r = await client_session.call_tool(**arguments) - elif mcp_task == "stop": - logging.debug(f"Shutting down MCPToolCallSession for server {self._mcp_server.id}") - self._stop = True - continue - else: - r = ValueError(f"MCPToolCallSession for server {self._mcp_server.id} received an unknown task {mcp_task}") - except Exception as e: - r = e + await result_queue.put(r) - await result_queue.put(r) - - async def _call_mcp_server(self, task_type: MCPTaskType, **kwargs) -> Any: + async def _call_mcp_server(self, task_type: MCPTaskType, timeout: float = 8, **kwargs) -> Any: results = asyncio.Queue() await self._queue.put((task_type, kwargs, results)) - result: CallToolResult | Exception = await results.get() + try: + result: CallToolResult | Exception = await asyncio.wait_for(results.get(), timeout=timeout) + except asyncio.TimeoutError: + raise TimeoutError(f"MCP task '{task_type}' timeout after {timeout}s") if isinstance(result, Exception): raise result @@ -106,32 +121,84 @@ class MCPToolCallSession(ToolCallSession): return f"Unsupported content type {type(result.content)}" async def _get_tools_from_mcp_server(self) -> list[Tool]: - # For now we only fetch the first page of tools result: ListToolsResult = await self._call_mcp_server("list_tools") return result.tools - def get_tools(self) -> list[Tool]: - return asyncio.run_coroutine_threadsafe(self._get_tools_from_mcp_server(), MCPToolCallSession._EVENT_LOOP).result() + def get_tools(self, timeout: float = 10) -> list[Tool]: + future = asyncio.run_coroutine_threadsafe(self._get_tools_from_mcp_server(), self._event_loop) + try: + return future.result(timeout=timeout) + except TimeoutError: + logging.error(f"Timeout when fetching tools from MCP server: {self._mcp_server.id}") + return [] + except Exception: + logging.exception(f"Error fetching tools from MCP server: {self._mcp_server.id}") + return [] @override - def tool_call(self, name: str, arguments: dict[str, Any]) -> str: - return asyncio.run_coroutine_threadsafe(self._call_mcp_tool(name, arguments), MCPToolCallSession._EVENT_LOOP).result() + def tool_call(self, name: str, arguments: dict[str, Any], timeout: float = 10) -> str: + future = asyncio.run_coroutine_threadsafe(self._call_mcp_tool(name, arguments), self._event_loop) + try: + return future.result(timeout=timeout) + except TimeoutError as te: + logging.error(f"Timeout calling tool '{name}' on MCP server: {self._mcp_server.id}") + return f"Timeout calling tool '{name}': {te}." + except Exception as e: + logging.exception(f"Error calling tool '{name}' on MCP server: {self._mcp_server.id}") + return f"Error calling tool '{name}': {e}." async def close(self) -> None: - await self._call_mcp_server("stop") + if self._close: + return - def close_sync(self) -> None: - asyncio.run_coroutine_threadsafe(self.close(), MCPToolCallSession._EVENT_LOOP).result() + self._close = True + self._event_loop.call_soon_threadsafe(self._event_loop.stop) + self._thread_pool.shutdown(wait=True) + self.__class__._ALL_INSTANCES.discard(self) + def close_sync(self, timeout: float = 5) -> None: + if not self._event_loop.is_running(): + logging.warning(f"Event loop already stopped for {self._mcp_server.id}") + return -MCPToolCallSession._init_thread_pool() + future = asyncio.run_coroutine_threadsafe(self.close(), self._event_loop) + try: + future.result(timeout=timeout) + except TimeoutError: + logging.error(f"Timeout while closing session for server {self._mcp_server.id}") + except Exception: + logging.exception(f"Unexpected error during close_sync for {self._mcp_server.id}") def close_multiple_mcp_toolcall_sessions(sessions: list[MCPToolCallSession]) -> None: - async def _gather() -> None: - await asyncio.gather(*[s.close() for s in sessions], return_exceptions=True) + logging.info(f"Want to clean up {len(sessions)} MCP sessions") - asyncio.run_coroutine_threadsafe(_gather(), MCPToolCallSession._EVENT_LOOP).result() + async def _gather_and_stop() -> None: + try: + await asyncio.gather(*[s.close() for s in sessions if s is not None], return_exceptions=True) + finally: + loop.call_soon_threadsafe(loop.stop) + + loop = asyncio.new_event_loop() + thread = threading.Thread(target=loop.run_forever, daemon=True) + thread.start() + + asyncio.run_coroutine_threadsafe(_gather_and_stop(), loop).result() + + thread.join() + logging.info(f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.") + + +def shutdown_all_mcp_sessions(): + """Gracefully shutdown all active MCPToolCallSession instances.""" + sessions = list(MCPToolCallSession._ALL_INSTANCES) + if not sessions: + logging.info("No MCPToolCallSession instances to close.") + return + + logging.info(f"Shutting down {len(sessions)} MCPToolCallSession instances...") + close_multiple_mcp_toolcall_sessions(sessions) + logging.info("All MCPToolCallSession instances have been closed.") def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool) -> dict[str, Any]: