From fd7ac176057bcea7e7caaa22d4317c01cb75ea45 Mon Sep 17 00:00:00 2001 From: Song Fuchang Date: Mon, 23 Jun 2025 17:45:35 +0800 Subject: [PATCH] Feat: Scratch MCP tool calling support. (#8263) ### What problem does this PR solve? This is a cherry-pick from #7781 as requested. ### Type of change - [x] New Feature (non-breaking change which adds functionality) Co-authored-by: Kevin Hu --- Dockerfile | 1 + Dockerfile.scratch.oc9 | 1 + api/apps/mcp_server_app.py | 107 ++++++++++++++++ api/db/__init__.py | 5 + api/db/db_models.py | 18 +++ api/db/services/mcp_server_service.py | 61 +++++++++ mcp/server/simple_tools_server.py | 23 ++++ mcp_client/__init__.py | 2 + mcp_client/mcp_tool_call.py | 145 ++++++++++++++++++++++ rag/llm/chat_model.py | 19 ++- uv.lock | 2 +- web/src/interfaces/database/mcp-server.ts | 19 +++ web/src/services/mcp-server-service.ts | 41 ++++++ web/src/utils/api.ts | 8 ++ 14 files changed, 445 insertions(+), 7 deletions(-) create mode 100644 api/apps/mcp_server_app.py create mode 100644 api/db/services/mcp_server_service.py create mode 100644 mcp/server/simple_tools_server.py create mode 100644 mcp_client/__init__.py create mode 100644 mcp_client/mcp_tool_call.py create mode 100644 web/src/interfaces/database/mcp-server.ts create mode 100644 web/src/services/mcp-server-service.ts diff --git a/Dockerfile b/Dockerfile index 67fd26456..0f0727b63 100644 --- a/Dockerfile +++ b/Dockerfile @@ -200,6 +200,7 @@ COPY graphrag graphrag COPY agentic_reasoning agentic_reasoning COPY pyproject.toml uv.lock ./ COPY mcp mcp +COPY mcp_client mcp_client COPY plugin plugin COPY docker/service_conf.yaml.template ./conf/service_conf.yaml.template diff --git a/Dockerfile.scratch.oc9 b/Dockerfile.scratch.oc9 index 64424735e..2403eae16 100644 --- a/Dockerfile.scratch.oc9 +++ b/Dockerfile.scratch.oc9 @@ -33,6 +33,7 @@ ADD ./rag ./rag ADD ./requirements.txt ./requirements.txt ADD ./agent ./agent ADD ./graphrag ./graphrag +ADD ./mcp_client ./mcp_client ADD ./plugin ./plugin RUN dnf install -y openmpi openmpi-devel python3-openmpi diff --git a/api/apps/mcp_server_app.py b/api/apps/mcp_server_app.py new file mode 100644 index 000000000..188756167 --- /dev/null +++ b/api/apps/mcp_server_app.py @@ -0,0 +1,107 @@ +from flask import Response, request +from flask_login import current_user, login_required +from api.db.db_models import MCPServer +from api.db.services.mcp_server_service import MCPServerService +from api.db.services.user_service import TenantService +from api.settings import RetCode +from api.utils import get_uuid +from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request + + +@manager.route("/list", methods=["GET"]) # noqa: F821 +@login_required +def get_list() -> Response: + try: + return get_json_result(data=MCPServerService.get_servers(current_user.id) or []) + except Exception as e: + return server_error_response(e) + + +@manager.route("/get_multiple", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("id_list") +def get_multiple() -> Response: + req = request.json + + try: + return get_json_result(data=MCPServerService.get_servers(current_user.id, id_list=req["id_list"]) or []) + except Exception as e: + return server_error_response(e) + + +@manager.route("/get/", methods=["GET"]) # noqa: F821 +@login_required +def get(ms_id: str) -> Response: + try: + mcp_server = MCPServerService.get_or_none(id=ms_id, tenant_id=current_user.id) + + if mcp_server is None: + return get_json_result(code=RetCode.NOT_FOUND, data=None) + + return get_json_result(data=mcp_server.to_dict()) + except Exception as e: + return server_error_response(e) + + +@manager.route("/create", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("name", "url", "server_type") +def create() -> Response: + req = request.json + + try: + req["id"] = get_uuid() + req["tenant_id"] = current_user.id + + e, _ = TenantService.get_by_id(current_user.id) + + if not e: + return get_data_error_result(message="Tenant not found.") + + if not req.get("headers"): + req["headers"] = {} + + if not MCPServerService.insert(**req): + return get_data_error_result() + + return get_json_result(data={"id": req["id"]}) + except Exception as e: + return server_error_response(e) + + +@manager.route("/update", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("id", "name", "url", "server_type") +def update() -> Response: + req = request.json + + if not req.get("headers"): + req["headers"] = {} + + try: + req["tenant_id"] = current_user.id + + if not MCPServerService.filter_update([MCPServer.id == req["id"], MCPServer.tenant_id == req["tenant_id"]], req): + return get_data_error_result() + + return get_json_result(data={"id": req["id"]}) + except Exception as e: + return server_error_response(e) + + +@manager.route("/rm", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("id") +def rm() -> Response: + req = request.json + ms_id = req["id"] + + try: + req["tenant_id"] = current_user.id + + if not MCPServerService.filter_delete([MCPServer.id == ms_id, MCPServer.tenant_id == req["tenant_id"]]): + return get_data_error_result() + + return get_json_result(data={"id": req["id"]}) + except Exception as e: + return server_error_response(e) diff --git a/api/db/__init__.py b/api/db/__init__.py index a8c85ef4c..54cc98533 100644 --- a/api/db/__init__.py +++ b/api/db/__init__.py @@ -104,4 +104,9 @@ class CanvasType(StrEnum): ChatBot = "chatbot" DocBot = "docbot" + +class MCPServerType(StrEnum): + SSE = "sse" + StreamableHttp = "streamable-http" + KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase" diff --git a/api/db/db_models.py b/api/db/db_models.py index 3ccfbdba3..839203a5f 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -799,6 +799,20 @@ class UserCanvasVersion(DataBaseModel): db_table = "user_canvas_version" +class MCPServer(DataBaseModel): + id = CharField(max_length=32, primary_key=True) + name = CharField(max_length=255, null=False, help_text="MCP Server name") + tenant_id = CharField(max_length=32, null=False, index=True) + url = CharField(max_length=2048, null=False, help_text="MCP Server URL") + server_type = CharField(max_length=32, null=False, help_text="MCP Server type") + description = TextField(null=True, help_text="MCP Server description") + variables = JSONField(null=True, default=[], help_text="MCP Server variables") + headers = JSONField(null=True, default={}, help_text="MCP Server additional request headers") + + class Meta: + db_table = "mcp_server" + + class Search(DataBaseModel): id = CharField(max_length=32, primary_key=True) avatar = TextField(null=True, help_text="avatar base64 string") @@ -934,3 +948,7 @@ def migrate_db(): migrate(migrator.add_column("llm", "is_tools", BooleanField(null=False, help_text="support tools", default=False))) except Exception: pass + try: + migrate(migrator.add_column("mcp_server", "variables", JSONField(null=True, help_text="MCP Server variables", default=[]))) + except Exception: + pass diff --git a/api/db/services/mcp_server_service.py b/api/db/services/mcp_server_service.py new file mode 100644 index 000000000..43bc75f6c --- /dev/null +++ b/api/db/services/mcp_server_service.py @@ -0,0 +1,61 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from api.db.db_models import DB, MCPServer +from api.db.services.common_service import CommonService + + +class MCPServerService(CommonService): + """Service class for managing MCP server related database operations. + + This class extends CommonService to provide specialized functionality for MCP server management, + including MCP server creation, updates, and deletions. + + Attributes: + model: The MCPServer model class for database operations. + """ + + model = MCPServer + + @classmethod + @DB.connection_context() + def get_servers(cls, tenant_id: str, id_list: list[str] | None = None): + """Retrieve all MCP servers associated with a tenant. + + This method fetches all MCP servers for a given tenant, ordered by creation time. + It only includes fields for list display. + + Args: + tenant_id (str): The unique identifier of the tenant. + id_list (list[str]): Get servers by ID list. Will ignore this condition if None. + + Returns: + list[dict]: List of MCP server dictionaries containing MCP server details. + Returns None if no MCP servers are found. + """ + fields = [ + cls.model.id, cls.model.name, cls.model.server_type, cls.model.url, cls.model.description, + cls.model.variables, cls.model.update_date + ] + + servers = cls.model.select(*fields).order_by(cls.model.create_time.desc()).where(cls.model.tenant_id == tenant_id) + + if id_list is not None: + servers = servers.where(cls.model.id.in_(id_list)) + + servers = list(servers.dicts()) + if not servers: + return None + return servers diff --git a/mcp/server/simple_tools_server.py b/mcp/server/simple_tools_server.py new file mode 100644 index 000000000..f5f9a5257 --- /dev/null +++ b/mcp/server/simple_tools_server.py @@ -0,0 +1,23 @@ +from mcp.server import FastMCP + + +app = FastMCP("simple-tools", port=8080) + + +@app.tool() +async def bad_calculator(a: int, b: int) -> str: + """ + A calculator to sum up two numbers (will give wrong answer) + + Args: + a: The first number + b: The second number + + Returns: + Sum of a and b + """ + return str(a + b + 200) + + +if __name__ == "__main__": + app.run(transport="sse") diff --git a/mcp_client/__init__.py b/mcp_client/__init__.py new file mode 100644 index 000000000..a07125853 --- /dev/null +++ b/mcp_client/__init__.py @@ -0,0 +1,2 @@ +# ruff: noqa: F401 +from .mcp_tool_call import MCPToolCallSession, mcp_tool_metadata_to_openai_tool, close_multiple_mcp_toolcall_sessions diff --git a/mcp_client/mcp_tool_call.py b/mcp_client/mcp_tool_call.py new file mode 100644 index 000000000..e0a7cf192 --- /dev/null +++ b/mcp_client/mcp_tool_call.py @@ -0,0 +1,145 @@ +import asyncio +from concurrent.futures import ThreadPoolExecutor +import logging +from string import Template +from typing import Any, Literal +from typing_extensions import override + +from mcp.client.session import ClientSession +from mcp.client.sse import sse_client +from mcp.client.streamable_http import streamablehttp_client +from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool + +from api.db import MCPServerType +from rag.llm.chat_model import ToolCallSession + + +MCPTaskType = Literal["list_tools", "tool_call", "stop"] +MCPTask = tuple[MCPTaskType, dict[str, Any], asyncio.Queue[Any]] + + +class MCPToolCallSession(ToolCallSession): + _EVENT_LOOP = asyncio.new_event_loop() + _THREAD_POOL = ThreadPoolExecutor(max_workers=1) + + _mcp_server: Any + _server_variables: dict[str, Any] + _queue: asyncio.Queue[MCPTask] + _stop = False + + @classmethod + def _init_thread_pool(cls) -> None: + cls._THREAD_POOL.submit(cls._EVENT_LOOP.run_forever) + + def __init__(self, mcp_server: Any, server_variables: dict[str, Any] | None = None) -> None: + self._mcp_server = mcp_server + self._server_variables = server_variables or {} + self._queue = asyncio.Queue() + + asyncio.run_coroutine_threadsafe(self._mcp_server_loop(), MCPToolCallSession._EVENT_LOOP) + + async def _mcp_server_loop(self) -> None: + url = self._mcp_server.url + raw_headers: dict[str, str] = self._mcp_server.headers or {} + headers: dict[str, str] = {} + + for h, v in raw_headers.items(): + nh = Template(h).safe_substitute(self._server_variables) + nv = Template(v).safe_substitute(self._server_variables) + headers[nh] = nv + + _streams_source: Any + + if self._mcp_server.server_type == MCPServerType.SSE: + _streams_source = sse_client(url, headers) + elif self._mcp_server.server_type == MCPServerType.StreamableHttp: + _streams_source = streamablehttp_client(url, headers) + else: + raise ValueError(f"Unsupported MCP server type {self._mcp_server.server_type} id {self._mcp_server.id}") + + async with _streams_source as streams: + async with ClientSession(*streams) as client_session: + await client_session.initialize() + + while not self._stop: + mcp_task, arguments, result_queue = await self._queue.get() + logging.debug(f"Got MCP task {mcp_task} arguments {arguments}") + + r: Any + + try: + if mcp_task == "list_tools": + r = await client_session.list_tools() + elif mcp_task == "tool_call": + r = await client_session.call_tool(**arguments) + elif mcp_task == "stop": + logging.debug(f"Shutting down MCPToolCallSession for server {self._mcp_server.id}") + self._stop = True + continue + else: + r = ValueError(f"MCPToolCallSession for server {self._mcp_server.id} received an unknown task {mcp_task}") + except Exception as e: + r = e + + await result_queue.put(r) + + async def _call_mcp_server(self, task_type: MCPTaskType, **kwargs) -> Any: + results = asyncio.Queue() + await self._queue.put((task_type, kwargs, results)) + result: CallToolResult | Exception = await results.get() + + if isinstance(result, Exception): + raise result + + return result + + async def _call_mcp_tool(self, name: str, arguments: dict[str, Any]) -> str: + result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments) + + if result.isError: + return f"MCP server error: {result.content}" + + # For now we only support text content + if isinstance(result.content[0], TextContent): + return result.content[0].text + else: + return f"Unsupported content type {type(result.content)}" + + async def _get_tools_from_mcp_server(self) -> list[Tool]: + # For now we only fetch the first page of tools + result: ListToolsResult = await self._call_mcp_server("list_tools") + return result.tools + + def get_tools(self) -> list[Tool]: + return asyncio.run_coroutine_threadsafe(self._get_tools_from_mcp_server(), MCPToolCallSession._EVENT_LOOP).result() + + @override + def tool_call(self, name: str, arguments: dict[str, Any]) -> str: + return asyncio.run_coroutine_threadsafe(self._call_mcp_tool(name, arguments), MCPToolCallSession._EVENT_LOOP).result() + + async def close(self) -> None: + await self._call_mcp_server("stop") + + def close_sync(self) -> None: + asyncio.run_coroutine_threadsafe(self.close(), MCPToolCallSession._EVENT_LOOP).result() + + +MCPToolCallSession._init_thread_pool() + + +def close_multiple_mcp_toolcall_sessions(sessions: list[MCPToolCallSession]) -> None: + async def _gather() -> None: + await asyncio.gather(*[s.close() for s in sessions], return_exceptions=True) + + asyncio.run_coroutine_threadsafe(_gather(), MCPToolCallSession._EVENT_LOOP).result() + + +def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool) -> dict[str, Any]: + return { + "type": "function", + "function": { + "name": mcp_tool.name, + "description": mcp_tool.description, + "parameters": mcp_tool.inputSchema, + }, + } diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index b1235ba62..3c6ea4d09 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -61,6 +61,9 @@ class ToolCallSession(Protocol): class Base(ABC): + tools: list[Any] + toolcall_sessions: dict[str, ToolCallSession] + def __init__(self, key, model_name, base_url, **kwargs): timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600)) self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout) @@ -70,6 +73,8 @@ class Base(ABC): self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0))) self.max_rounds = kwargs.get("max_rounds", 5) self.is_tools = False + self.tools = [] + self.toolcall_sessions = {} def _get_delay(self): """Calculate retry delay time""" @@ -145,8 +150,10 @@ class Base(ABC): if not (toolcall_session and tools): return self.is_tools = True - self.toolcall_session = toolcall_session - self.tools = tools + + for tool in tools: + self.toolcall_sessions[tool["function"]["name"]] = toolcall_session + self.tools.append(tool) def chat_with_tools(self, system: str, history: list, gen_conf: dict): gen_conf = self._clean_conf() @@ -180,7 +187,7 @@ class Base(ABC): name = tool_call.function.name try: args = json_repair.loads(tool_call.function.arguments) - tool_response = self.toolcall_session.tool_call(name, args) + tool_response = self.toolcall_sessions[name].tool_call(name, args) history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)}) except Exception as e: history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) @@ -286,7 +293,7 @@ class Base(ABC): name = tool_call.function.name try: args = json_repair.loads(tool_call.function.arguments) - tool_response = self.toolcall_session.tool_call(name, args) + tool_response = self.toolcall_sessions[name].tool_call(name, args) history.append( { "role": "assistant", @@ -585,7 +592,7 @@ class QWenChat(Base): tool_name = assistant_output.tool_calls[0]["function"]["name"] if tool_name: arguments = json.loads(assistant_output.tool_calls[0]["function"]["arguments"]) - tool_info["content"] = self.toolcall_session.tool_call(name=tool_name, arguments=arguments) + tool_info["content"] = self.toolcall_sessions[tool_name].tool_call(name=tool_name, arguments=arguments) history.append(tool_info) response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, **gen_conf) @@ -708,7 +715,7 @@ class QWenChat(Base): tool_name = toolcall_message.tool_calls[0]["function"]["name"] history.append(toolcall_message) - tool_info["content"] = self.toolcall_session.tool_call(name=tool_name, arguments=tool_arguments) + tool_info["content"] = self.toolcall_sessions[tool_name].tool_call(name=tool_name, arguments=tool_arguments) history.append(tool_info) tool_info = {"content": "", "role": "tool"} tool_name = "" diff --git a/uv.lock b/uv.lock index cd5efce0e..7b9cf6b88 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 1 +revision = 2 requires-python = ">=3.10, <3.13" resolution-markers = [ "python_full_version >= '3.12' and sys_platform == 'darwin'", diff --git a/web/src/interfaces/database/mcp-server.ts b/web/src/interfaces/database/mcp-server.ts new file mode 100644 index 000000000..34ed7e4b9 --- /dev/null +++ b/web/src/interfaces/database/mcp-server.ts @@ -0,0 +1,19 @@ +export enum McpServerType { + Sse = 'sse', + StreamableHttp = 'streamable-http', +} + +export interface IMcpServerVariable { + key: string; + name: string; +} + +export interface IMcpServerInfo { + id: string; + name: string; + url: string; + server_type: McpServerType; + description?: string; + variables?: IMcpServerVariable[]; + headers: Map; +} diff --git a/web/src/services/mcp-server-service.ts b/web/src/services/mcp-server-service.ts new file mode 100644 index 000000000..a90fa0961 --- /dev/null +++ b/web/src/services/mcp-server-service.ts @@ -0,0 +1,41 @@ +import api from '@/utils/api'; +import registerServer from '@/utils/register-server'; +import request from '@/utils/request'; + +const { + getMcpServerList, + getMultipleMcpServers, + createMcpServer, + updateMcpServer, + deleteMcpServer, +} = api; + +const methods = { + get_list: { + url: getMcpServerList, + method: 'get', + }, + get_multiple: { + url: getMultipleMcpServers, + method: 'post', + }, + add: { + url: createMcpServer, + method: 'post' + }, + update: { + url: updateMcpServer, + method: 'post' + }, + rm: { + url: deleteMcpServer, + method: 'post' + }, +} as const; + +const mcpServerService = registerServer(methods, request); + +export const getMcpServer = (serverId: string) => + request.get(api.getMcpServer(serverId)); + +export default mcpServerService; diff --git a/web/src/utils/api.ts b/web/src/utils/api.ts index b0c32b123..d0369d1e8 100644 --- a/web/src/utils/api.ts +++ b/web/src/utils/api.ts @@ -143,4 +143,12 @@ export default { testDbConnect: `${api_host}/canvas/test_db_connect`, getInputElements: `${api_host}/canvas/input_elements`, debug: `${api_host}/canvas/debug`, + + // mcp server + getMcpServerList: `${api_host}/mcp_server/list`, + getMultipleMcpServers: `${api_host}/mcp_server/get_multiple`, + getMcpServer: (serverId: string) => `${api_host}/mcp_server/get/${serverId}`, + createMcpServer: `${api_host}/mcp_server/create`, + updateMcpServer: `${api_host}/mcp_server/update`, + deleteMcpServer: `${api_host}/mcp_server/rm`, };