mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-06-26 22:19:57 +00:00
Feat: add MCP dashboard functionalities list_tools and test_tool (#8505)
### What problem does this PR solve? Add MCP dashboard functionalities list_tools and test_tool. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
6b1221d2f6
commit
0eb90e73a5
@ -8,7 +8,8 @@ from api.db.services.user_service import TenantService
|
|||||||
from api.settings import RetCode
|
from api.settings import RetCode
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
||||||
from api.utils.web_utils import safe_json_parse
|
from api.utils.web_utils import get_float, safe_json_parse
|
||||||
|
from mcp_client.mcp_tool_call import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/list", methods=["POST"]) # noqa: F821
|
@manager.route("/list", methods=["POST"]) # noqa: F821
|
||||||
@ -95,8 +96,13 @@ def update() -> Response:
|
|||||||
if server_name and len(server_name.encode("utf-8")) > 255:
|
if server_name and len(server_name.encode("utf-8")) > 255:
|
||||||
return get_data_error_result(message=f"Invaild MCP name or length is {len(server_name)} which is large than 255.")
|
return get_data_error_result(message=f"Invaild MCP name or length is {len(server_name)} which is large than 255.")
|
||||||
|
|
||||||
req["headers"] = safe_json_parse(req.get("headers", {}))
|
mcp_id = req.get("id", "")
|
||||||
req["variables"] = safe_json_parse(req.get("variables", {}))
|
e, mcp_server = MCPServerService.get_by_id(mcp_id)
|
||||||
|
if not e or mcp_server.tenant_id != current_user.id:
|
||||||
|
return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}")
|
||||||
|
|
||||||
|
req["headers"] = safe_json_parse(req.get("headers", mcp_server.headers))
|
||||||
|
req["variables"] = safe_json_parse(req.get("variables", mcp_server.variables))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
req["tenant_id"] = current_user.id
|
req["tenant_id"] = current_user.id
|
||||||
@ -212,3 +218,69 @@ def export_multiple() -> Response:
|
|||||||
return get_json_result(data={"mcpServers": exported_servers})
|
return get_json_result(data={"mcpServers": exported_servers})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/list_tools", methods=["POST"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
@validate_request("mcp_ids")
|
||||||
|
def list_tools() -> Response:
|
||||||
|
req = request.get_json()
|
||||||
|
mcp_ids = req.get("mcp_ids", [])
|
||||||
|
if not mcp_ids:
|
||||||
|
return get_data_error_result(message="No MCP server IDs provided.")
|
||||||
|
|
||||||
|
timeout = get_float(req, "timeout", 10)
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
tool_call_sessions = []
|
||||||
|
try:
|
||||||
|
for mcp_id in mcp_ids:
|
||||||
|
e, mcp_server = MCPServerService.get_by_id(mcp_id)
|
||||||
|
|
||||||
|
if e and mcp_server.tenant_id == current_user.id:
|
||||||
|
server_key = mcp_server.id
|
||||||
|
|
||||||
|
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
|
||||||
|
tool_call_sessions.append(tool_call_session)
|
||||||
|
tools = tool_call_session.get_tools(timeout)
|
||||||
|
|
||||||
|
results[server_key] = [tool.model_dump() for tool in tools]
|
||||||
|
|
||||||
|
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
||||||
|
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
|
||||||
|
return get_json_result(data=results)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/test_tool", methods=["POST"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
@validate_request("mcp_id", "tool_name", "arguments")
|
||||||
|
def test_tool() -> Response:
|
||||||
|
req = request.get_json()
|
||||||
|
mcp_id = req.get("mcp_id", "")
|
||||||
|
if not mcp_id:
|
||||||
|
return get_data_error_result(message="No MCP server ID provided.")
|
||||||
|
|
||||||
|
timeout = get_float(req, "timeout", 10)
|
||||||
|
|
||||||
|
tool_name = req.get("tool_name", "")
|
||||||
|
arguments = req.get("arguments", {})
|
||||||
|
if not all([tool_name, arguments]):
|
||||||
|
return get_data_error_result(message="Require provide tool name and arguments.")
|
||||||
|
|
||||||
|
tool_call_sessions = []
|
||||||
|
try:
|
||||||
|
e, mcp_server = MCPServerService.get_by_id(mcp_id)
|
||||||
|
if not e or mcp_server.tenant_id != current_user.id:
|
||||||
|
return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}")
|
||||||
|
|
||||||
|
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
|
||||||
|
tool_call_sessions.append(tool_call_session)
|
||||||
|
result = tool_call_session.tool_call(tool_name, arguments, timeout)
|
||||||
|
|
||||||
|
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
||||||
|
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
|
||||||
|
return get_json_result(data=result)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
@ -19,6 +19,7 @@
|
|||||||
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
|
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
|
||||||
|
|
||||||
from api.utils.log_utils import init_root_logger
|
from api.utils.log_utils import init_root_logger
|
||||||
|
from mcp_client.mcp_tool_call import shutdown_all_mcp_sessions
|
||||||
from plugin import GlobalPluginManager
|
from plugin import GlobalPluginManager
|
||||||
init_root_logger("ragflow_server")
|
init_root_logger("ragflow_server")
|
||||||
|
|
||||||
@ -66,6 +67,7 @@ def update_progress():
|
|||||||
|
|
||||||
def signal_handler(sig, frame):
|
def signal_handler(sig, frame):
|
||||||
logging.info("Received interrupt signal, shutting down...")
|
logging.info("Received interrupt signal, shutting down...")
|
||||||
|
shutdown_all_mcp_sessions()
|
||||||
stop_event.set()
|
stop_event.set()
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
@ -14,21 +14,21 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import ipaddress
|
||||||
|
import json
|
||||||
import re
|
import re
|
||||||
import socket
|
import socket
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
import ipaddress
|
|
||||||
import json
|
|
||||||
import base64
|
|
||||||
|
|
||||||
from selenium import webdriver
|
from selenium import webdriver
|
||||||
|
from selenium.common.exceptions import TimeoutException
|
||||||
from selenium.webdriver.chrome.options import Options
|
from selenium.webdriver.chrome.options import Options
|
||||||
from selenium.webdriver.chrome.service import Service
|
from selenium.webdriver.chrome.service import Service
|
||||||
from selenium.common.exceptions import TimeoutException
|
|
||||||
from selenium.webdriver.support.ui import WebDriverWait
|
|
||||||
from selenium.webdriver.support.expected_conditions import staleness_of
|
|
||||||
from webdriver_manager.chrome import ChromeDriverManager
|
|
||||||
from selenium.webdriver.common.by import By
|
from selenium.webdriver.common.by import By
|
||||||
|
from selenium.webdriver.support.expected_conditions import staleness_of
|
||||||
|
from selenium.webdriver.support.ui import WebDriverWait
|
||||||
|
from webdriver_manager.chrome import ChromeDriverManager
|
||||||
|
|
||||||
|
|
||||||
def html2pdf(
|
def html2pdf(
|
||||||
@ -53,12 +53,7 @@ def __send_devtools(driver, cmd, params={}):
|
|||||||
return response.get("value")
|
return response.get("value")
|
||||||
|
|
||||||
|
|
||||||
def __get_pdf_from_html(
|
def __get_pdf_from_html(path: str, timeout: int, install_driver: bool, print_options: dict):
|
||||||
path: str,
|
|
||||||
timeout: int,
|
|
||||||
install_driver: bool,
|
|
||||||
print_options: dict
|
|
||||||
):
|
|
||||||
webdriver_options = Options()
|
webdriver_options = Options()
|
||||||
webdriver_prefs = {}
|
webdriver_prefs = {}
|
||||||
webdriver_options.add_argument("--headless")
|
webdriver_options.add_argument("--headless")
|
||||||
@ -78,9 +73,7 @@ def __get_pdf_from_html(
|
|||||||
driver.get(path)
|
driver.get(path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
WebDriverWait(driver, timeout).until(
|
WebDriverWait(driver, timeout).until(staleness_of(driver.find_element(by=By.TAG_NAME, value="html")))
|
||||||
staleness_of(driver.find_element(by=By.TAG_NAME, value="html"))
|
|
||||||
)
|
|
||||||
except TimeoutException:
|
except TimeoutException:
|
||||||
calculated_print_options = {
|
calculated_print_options = {
|
||||||
"landscape": False,
|
"landscape": False,
|
||||||
@ -89,8 +82,7 @@ def __get_pdf_from_html(
|
|||||||
"preferCSSPageSize": True,
|
"preferCSSPageSize": True,
|
||||||
}
|
}
|
||||||
calculated_print_options.update(print_options)
|
calculated_print_options.update(print_options)
|
||||||
result = __send_devtools(
|
result = __send_devtools(driver, "Page.printToPDF", calculated_print_options)
|
||||||
driver, "Page.printToPDF", calculated_print_options)
|
|
||||||
driver.quit()
|
driver.quit()
|
||||||
return base64.b64decode(result["data"])
|
return base64.b64decode(result["data"])
|
||||||
|
|
||||||
@ -102,6 +94,7 @@ def is_private_ip(ip: str) -> bool:
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_valid_url(url: str) -> bool:
|
def is_valid_url(url: str) -> bool:
|
||||||
if not re.match(r"(https?)://[-A-Za-z0-9+&@#/%?=~_|!:,.;]+[-A-Za-z0-9+&@#/%=~_|]", url):
|
if not re.match(r"(https?)://[-A-Za-z0-9+&@#/%?=~_|!:,.;]+[-A-Za-z0-9+&@#/%=~_|]", url):
|
||||||
return False
|
return False
|
||||||
@ -127,3 +120,10 @@ def safe_json_parse(data: str | dict) -> dict:
|
|||||||
except (json.JSONDecodeError, TypeError):
|
except (json.JSONDecodeError, TypeError):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_float(req: dict, key: str, default: float | int = 10.0) -> float:
|
||||||
|
try:
|
||||||
|
parsed = float(req.get(key, default))
|
||||||
|
return parsed if parsed > 0 else default
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return default
|
||||||
|
@ -1,45 +1,43 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
import logging
|
import logging
|
||||||
|
import threading
|
||||||
|
import weakref
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from string import Template
|
from string import Template
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from api.db import MCPServerType
|
||||||
from mcp.client.session import ClientSession
|
from mcp.client.session import ClientSession
|
||||||
from mcp.client.sse import sse_client
|
from mcp.client.sse import sse_client
|
||||||
from mcp.client.streamable_http import streamablehttp_client
|
from mcp.client.streamable_http import streamablehttp_client
|
||||||
from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool
|
from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool
|
||||||
|
|
||||||
from api.db import MCPServerType
|
|
||||||
from rag.llm.chat_model import ToolCallSession
|
from rag.llm.chat_model import ToolCallSession
|
||||||
|
|
||||||
|
MCPTaskType = Literal["list_tools", "tool_call"]
|
||||||
MCPTaskType = Literal["list_tools", "tool_call", "stop"]
|
|
||||||
MCPTask = tuple[MCPTaskType, dict[str, Any], asyncio.Queue[Any]]
|
MCPTask = tuple[MCPTaskType, dict[str, Any], asyncio.Queue[Any]]
|
||||||
|
|
||||||
|
|
||||||
class MCPToolCallSession(ToolCallSession):
|
class MCPToolCallSession(ToolCallSession):
|
||||||
_EVENT_LOOP = asyncio.new_event_loop()
|
_ALL_INSTANCES: weakref.WeakSet["MCPToolCallSession"] = weakref.WeakSet()
|
||||||
_THREAD_POOL = ThreadPoolExecutor(max_workers=1)
|
|
||||||
|
|
||||||
_mcp_server: Any
|
|
||||||
_server_variables: dict[str, Any]
|
|
||||||
_queue: asyncio.Queue[MCPTask]
|
|
||||||
_stop = False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _init_thread_pool(cls) -> None:
|
|
||||||
cls._THREAD_POOL.submit(cls._EVENT_LOOP.run_forever)
|
|
||||||
|
|
||||||
def __init__(self, mcp_server: Any, server_variables: dict[str, Any] | None = None) -> None:
|
def __init__(self, mcp_server: Any, server_variables: dict[str, Any] | None = None) -> None:
|
||||||
|
self.__class__._ALL_INSTANCES.add(self)
|
||||||
|
|
||||||
self._mcp_server = mcp_server
|
self._mcp_server = mcp_server
|
||||||
self._server_variables = server_variables or {}
|
self._server_variables = server_variables or {}
|
||||||
self._queue = asyncio.Queue()
|
self._queue = asyncio.Queue()
|
||||||
|
self._close = False
|
||||||
|
|
||||||
asyncio.run_coroutine_threadsafe(self._mcp_server_loop(), MCPToolCallSession._EVENT_LOOP)
|
self._event_loop = asyncio.new_event_loop()
|
||||||
|
self._thread_pool = ThreadPoolExecutor(max_workers=1)
|
||||||
|
self._thread_pool.submit(self._event_loop.run_forever)
|
||||||
|
|
||||||
|
asyncio.run_coroutine_threadsafe(self._mcp_server_loop(), self._event_loop)
|
||||||
|
|
||||||
async def _mcp_server_loop(self) -> None:
|
async def _mcp_server_loop(self) -> None:
|
||||||
url = self._mcp_server.url
|
url = self._mcp_server.url.strip()
|
||||||
raw_headers: dict[str, str] = self._mcp_server.headers or {}
|
raw_headers: dict[str, str] = self._mcp_server.headers or {}
|
||||||
headers: dict[str, str] = {}
|
headers: dict[str, str] = {}
|
||||||
|
|
||||||
@ -48,45 +46,62 @@ class MCPToolCallSession(ToolCallSession):
|
|||||||
nv = Template(v).safe_substitute(self._server_variables)
|
nv = Template(v).safe_substitute(self._server_variables)
|
||||||
headers[nh] = nv
|
headers[nh] = nv
|
||||||
|
|
||||||
_streams_source: Any
|
|
||||||
|
|
||||||
if self._mcp_server.server_type == MCPServerType.SSE:
|
if self._mcp_server.server_type == MCPServerType.SSE:
|
||||||
_streams_source = sse_client(url, headers)
|
# SSE transport
|
||||||
elif self._mcp_server.server_type == MCPServerType.StreamableHttp:
|
async with sse_client(url, headers) as stream:
|
||||||
_streams_source = streamablehttp_client(url, headers)
|
async with ClientSession(*stream) as client_session:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(client_session.initialize(), timeout=5)
|
||||||
|
logging.info("client_session initialized successfully")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logging.error(f"Timeout initializing client_session for server {self._mcp_server.id}")
|
||||||
|
return
|
||||||
|
await self._process_mcp_tasks(client_session)
|
||||||
|
|
||||||
|
elif self._mcp_server.server_type == MCPServerType.STREAMABLE_HTTP:
|
||||||
|
# Streamable HTTP transport
|
||||||
|
async with streamablehttp_client(url, headers) as (read_stream, write_stream, _):
|
||||||
|
async with ClientSession(read_stream, write_stream) as client_session:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(client_session.initialize(), timeout=5)
|
||||||
|
logging.info("client_session initialized successfully")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logging.error(f"Timeout initializing client_session for server {self._mcp_server.id}")
|
||||||
|
return
|
||||||
|
await asyncio.wait_for(client_session.initialize(), timeout=5)
|
||||||
|
await self._process_mcp_tasks(client_session)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported MCP server type {self._mcp_server.server_type} id {self._mcp_server.id}")
|
raise ValueError(f"Unsupported MCP server type {self._mcp_server.server_type} id {self._mcp_server.id}")
|
||||||
|
|
||||||
async with _streams_source as streams:
|
async def _process_mcp_tasks(self, client_session: ClientSession) -> None:
|
||||||
async with ClientSession(*streams) as client_session:
|
while not self._close:
|
||||||
await client_session.initialize()
|
try:
|
||||||
|
mcp_task, arguments, result_queue = await asyncio.wait_for(self._queue.get(), timeout=1)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
|
||||||
while not self._stop:
|
|
||||||
mcp_task, arguments, result_queue = await self._queue.get()
|
|
||||||
logging.debug(f"Got MCP task {mcp_task} arguments {arguments}")
|
logging.debug(f"Got MCP task {mcp_task} arguments {arguments}")
|
||||||
|
|
||||||
r: Any
|
r: Any = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if mcp_task == "list_tools":
|
if mcp_task == "list_tools":
|
||||||
r = await client_session.list_tools()
|
r = await client_session.list_tools()
|
||||||
elif mcp_task == "tool_call":
|
elif mcp_task == "tool_call":
|
||||||
r = await client_session.call_tool(**arguments)
|
r = await client_session.call_tool(**arguments)
|
||||||
elif mcp_task == "stop":
|
|
||||||
logging.debug(f"Shutting down MCPToolCallSession for server {self._mcp_server.id}")
|
|
||||||
self._stop = True
|
|
||||||
continue
|
|
||||||
else:
|
else:
|
||||||
r = ValueError(f"MCPToolCallSession for server {self._mcp_server.id} received an unknown task {mcp_task}")
|
r = ValueError(f"Unknown MCP task {mcp_task}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
r = e
|
r = e
|
||||||
|
|
||||||
await result_queue.put(r)
|
await result_queue.put(r)
|
||||||
|
|
||||||
async def _call_mcp_server(self, task_type: MCPTaskType, **kwargs) -> Any:
|
async def _call_mcp_server(self, task_type: MCPTaskType, timeout: float = 8, **kwargs) -> Any:
|
||||||
results = asyncio.Queue()
|
results = asyncio.Queue()
|
||||||
await self._queue.put((task_type, kwargs, results))
|
await self._queue.put((task_type, kwargs, results))
|
||||||
result: CallToolResult | Exception = await results.get()
|
try:
|
||||||
|
result: CallToolResult | Exception = await asyncio.wait_for(results.get(), timeout=timeout)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
raise TimeoutError(f"MCP task '{task_type}' timeout after {timeout}s")
|
||||||
|
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
raise result
|
raise result
|
||||||
@ -106,32 +121,84 @@ class MCPToolCallSession(ToolCallSession):
|
|||||||
return f"Unsupported content type {type(result.content)}"
|
return f"Unsupported content type {type(result.content)}"
|
||||||
|
|
||||||
async def _get_tools_from_mcp_server(self) -> list[Tool]:
|
async def _get_tools_from_mcp_server(self) -> list[Tool]:
|
||||||
# For now we only fetch the first page of tools
|
|
||||||
result: ListToolsResult = await self._call_mcp_server("list_tools")
|
result: ListToolsResult = await self._call_mcp_server("list_tools")
|
||||||
return result.tools
|
return result.tools
|
||||||
|
|
||||||
def get_tools(self) -> list[Tool]:
|
def get_tools(self, timeout: float = 10) -> list[Tool]:
|
||||||
return asyncio.run_coroutine_threadsafe(self._get_tools_from_mcp_server(), MCPToolCallSession._EVENT_LOOP).result()
|
future = asyncio.run_coroutine_threadsafe(self._get_tools_from_mcp_server(), self._event_loop)
|
||||||
|
try:
|
||||||
|
return future.result(timeout=timeout)
|
||||||
|
except TimeoutError:
|
||||||
|
logging.error(f"Timeout when fetching tools from MCP server: {self._mcp_server.id}")
|
||||||
|
return []
|
||||||
|
except Exception:
|
||||||
|
logging.exception(f"Error fetching tools from MCP server: {self._mcp_server.id}")
|
||||||
|
return []
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def tool_call(self, name: str, arguments: dict[str, Any]) -> str:
|
def tool_call(self, name: str, arguments: dict[str, Any], timeout: float = 10) -> str:
|
||||||
return asyncio.run_coroutine_threadsafe(self._call_mcp_tool(name, arguments), MCPToolCallSession._EVENT_LOOP).result()
|
future = asyncio.run_coroutine_threadsafe(self._call_mcp_tool(name, arguments), self._event_loop)
|
||||||
|
try:
|
||||||
|
return future.result(timeout=timeout)
|
||||||
|
except TimeoutError as te:
|
||||||
|
logging.error(f"Timeout calling tool '{name}' on MCP server: {self._mcp_server.id}")
|
||||||
|
return f"Timeout calling tool '{name}': {te}."
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(f"Error calling tool '{name}' on MCP server: {self._mcp_server.id}")
|
||||||
|
return f"Error calling tool '{name}': {e}."
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
await self._call_mcp_server("stop")
|
if self._close:
|
||||||
|
return
|
||||||
|
|
||||||
def close_sync(self) -> None:
|
self._close = True
|
||||||
asyncio.run_coroutine_threadsafe(self.close(), MCPToolCallSession._EVENT_LOOP).result()
|
self._event_loop.call_soon_threadsafe(self._event_loop.stop)
|
||||||
|
self._thread_pool.shutdown(wait=True)
|
||||||
|
self.__class__._ALL_INSTANCES.discard(self)
|
||||||
|
|
||||||
|
def close_sync(self, timeout: float = 5) -> None:
|
||||||
|
if not self._event_loop.is_running():
|
||||||
|
logging.warning(f"Event loop already stopped for {self._mcp_server.id}")
|
||||||
|
return
|
||||||
|
|
||||||
MCPToolCallSession._init_thread_pool()
|
future = asyncio.run_coroutine_threadsafe(self.close(), self._event_loop)
|
||||||
|
try:
|
||||||
|
future.result(timeout=timeout)
|
||||||
|
except TimeoutError:
|
||||||
|
logging.error(f"Timeout while closing session for server {self._mcp_server.id}")
|
||||||
|
except Exception:
|
||||||
|
logging.exception(f"Unexpected error during close_sync for {self._mcp_server.id}")
|
||||||
|
|
||||||
|
|
||||||
def close_multiple_mcp_toolcall_sessions(sessions: list[MCPToolCallSession]) -> None:
|
def close_multiple_mcp_toolcall_sessions(sessions: list[MCPToolCallSession]) -> None:
|
||||||
async def _gather() -> None:
|
logging.info(f"Want to clean up {len(sessions)} MCP sessions")
|
||||||
await asyncio.gather(*[s.close() for s in sessions], return_exceptions=True)
|
|
||||||
|
|
||||||
asyncio.run_coroutine_threadsafe(_gather(), MCPToolCallSession._EVENT_LOOP).result()
|
async def _gather_and_stop() -> None:
|
||||||
|
try:
|
||||||
|
await asyncio.gather(*[s.close() for s in sessions if s is not None], return_exceptions=True)
|
||||||
|
finally:
|
||||||
|
loop.call_soon_threadsafe(loop.stop)
|
||||||
|
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
thread = threading.Thread(target=loop.run_forever, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
asyncio.run_coroutine_threadsafe(_gather_and_stop(), loop).result()
|
||||||
|
|
||||||
|
thread.join()
|
||||||
|
logging.info(f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.")
|
||||||
|
|
||||||
|
|
||||||
|
def shutdown_all_mcp_sessions():
|
||||||
|
"""Gracefully shutdown all active MCPToolCallSession instances."""
|
||||||
|
sessions = list(MCPToolCallSession._ALL_INSTANCES)
|
||||||
|
if not sessions:
|
||||||
|
logging.info("No MCPToolCallSession instances to close.")
|
||||||
|
return
|
||||||
|
|
||||||
|
logging.info(f"Shutting down {len(sessions)} MCPToolCallSession instances...")
|
||||||
|
close_multiple_mcp_toolcall_sessions(sessions)
|
||||||
|
logging.info("All MCPToolCallSession instances have been closed.")
|
||||||
|
|
||||||
|
|
||||||
def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool) -> dict[str, Any]:
|
def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool) -> dict[str, Any]:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user