From dc25350ea26f8f0be6af2c6c07a95109cdaf395b Mon Sep 17 00:00:00 2001 From: Mohit Yadav <105265192+mohityadav766@users.noreply.github.com> Date: Tue, 10 Jun 2025 09:42:24 +0530 Subject: [PATCH] MCP Core Items Improvements (#21643) * Search Util fix and added tableQueries * some json input fix * Add team and user * WIP : Add Streamable HTTP * - Add proper tools/list schema and tools/call * - auth filter exact match * - Add Tools Class to dynamically build tools * Add Origin Validation Mandate * Refactor MCP Stream * comment * Cleanups * Typo * Typo --- conf/openmetadata.yaml | 4 +- .../service/OpenMetadataApplication.java | 3 +- .../service/config/MCPConfiguration.java | 3 + .../service/jdbi3/TableRepository.java | 12 + ...HttpServletSseServerTransportProvider.java | 437 +++++--- .../service/mcp/MCPStreamableHttpServlet.java | 971 ++++++++++++++++++ .../openmetadata/service/mcp/McpServer.java | 200 +--- .../openmetadata/service/mcp/McpUtils.java | 94 ++ .../service/mcp/tools/DefaultToolContext.java | 75 ++ .../service/mcp/tools/GetEntityTool.java | 42 + .../service/mcp/tools/GlossaryTermTool.java | 33 +- .../service/mcp/tools/GlossaryTool.java | 25 +- .../service/mcp/tools/McpTool.java | 7 +- .../service/mcp/tools/SearchMetadataTool.java | 154 +++ .../service/search/SearchUtil.java | 137 +-- .../service/search/indexes/TableIndex.java | 1 + .../openmetadata/service/util/EntityUtil.java | 16 - .../main/resources/json/data/mcp/tools.json | 2 +- .../json/schema/entity/data/table.json | 8 + .../ui/src/generated/entity/data/table.ts | 6 +- 20 files changed, 1728 insertions(+), 502 deletions(-) create mode 100644 openmetadata-service/src/main/java/org/openmetadata/service/mcp/MCPStreamableHttpServlet.java create mode 100644 openmetadata-service/src/main/java/org/openmetadata/service/mcp/McpUtils.java create mode 100644 openmetadata-service/src/main/java/org/openmetadata/service/mcp/tools/DefaultToolContext.java create mode 100644 openmetadata-service/src/main/java/org/openmetadata/service/mcp/tools/GetEntityTool.java create mode 100644 openmetadata-service/src/main/java/org/openmetadata/service/mcp/tools/SearchMetadataTool.java diff --git a/conf/openmetadata.yaml b/conf/openmetadata.yaml index dff1f326091..83226a73e40 100644 --- a/conf/openmetadata.yaml +++ b/conf/openmetadata.yaml @@ -443,6 +443,8 @@ operationalConfig: mcpConfiguration: enabled: ${MCP_ENABLED:-true} path: ${MCP_PATH:-"/mcp"} - mcpServerVersion: ${MCP_SERVER_VERSION:-"1.0.0"} + mcpServerVersion: ${MCP_SERVER_VERSION:-"1.0.0"} + originHeaderUri: ${OPENMETADATA_SERVER_URL:-"http://localhost"} + diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/OpenMetadataApplication.java b/openmetadata-service/src/main/java/org/openmetadata/service/OpenMetadataApplication.java index fe50a58de71..ac8a8e0ca78 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/OpenMetadataApplication.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/OpenMetadataApplication.java @@ -91,6 +91,7 @@ import org.openmetadata.service.jobs.JobHandlerRegistry; import org.openmetadata.service.limits.DefaultLimits; import org.openmetadata.service.limits.Limits; import org.openmetadata.service.mcp.McpServer; +import org.openmetadata.service.mcp.tools.DefaultToolContext; import org.openmetadata.service.migration.Migration; import org.openmetadata.service.migration.MigrationValidationClient; import org.openmetadata.service.migration.api.MigrationWorkflow; @@ -289,7 +290,7 @@ public class OpenMetadataApplication extends Application { column.setCustomMetrics(getCustomMetrics(table, column.getName())); } } + if (fields.contains("tableQueries")) { + List queriesEntity = + getEntities( + listOrEmpty(findTo(table.getId(), TABLE, Relationship.MENTIONED_IN, QUERY)), + "id", + ALL); + List queries = listOrEmpty(queriesEntity).stream().map(Query::getQuery).toList(); + table.setTableQueries(queries); + } } @Override diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/mcp/HttpServletSseServerTransportProvider.java b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/HttpServletSseServerTransportProvider.java index 3d931f7415b..3a8e22932a2 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/mcp/HttpServletSseServerTransportProvider.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/HttpServletSseServerTransportProvider.java @@ -1,9 +1,7 @@ -/* -HttpServletSseServerTransportProvider - Jakarta servlet-based MCP server transport -*/ - package org.openmetadata.service.mcp; +import static org.openmetadata.service.mcp.McpUtils.getJsonRpcMessageWithAuthorizationParam; + import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpError; @@ -11,6 +9,7 @@ import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.util.Assert; import jakarta.servlet.AsyncContext; import jakarta.servlet.ServletException; import jakarta.servlet.annotation.WebServlet; @@ -23,9 +22,10 @@ import java.io.PrintWriter; import java.util.Map; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import org.openmetadata.service.security.JwtFilter; -import org.openmetadata.service.util.JsonUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; @@ -42,171 +42,220 @@ public class HttpServletSseServerTransportProvider extends HttpServlet public static final String DEFAULT_SSE_ENDPOINT = "/sse"; public static final String MESSAGE_EVENT_TYPE = "message"; public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + public static final String DEFAULT_BASE_URL = ""; private final ObjectMapper objectMapper; + private final String baseUrl; private final String messageEndpoint; private final String sseEndpoint; - private final Map sessions; - private final AtomicBoolean isClosing; + private final Map sessions = new ConcurrentHashMap<>(); + private final AtomicBoolean isClosing = new AtomicBoolean(false); private McpServerSession.Factory sessionFactory; + private ExecutorService executorService = + Executors.newCachedThreadPool( + r -> { + Thread t = new Thread(r, "MCP-Worker-SSE"); + t.setDaemon(true); + return t; + }); - public HttpServletSseServerTransportProvider(String messageEndpoint, String sseEndpoint) { - this.sessions = new ConcurrentHashMap<>(); - this.isClosing = new AtomicBoolean(false); - this.objectMapper = JsonUtils.getObjectMapper(); + public HttpServletSseServerTransportProvider( + ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { + this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint); + } + + public HttpServletSseServerTransportProvider( + ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) { + this.objectMapper = objectMapper; + this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; } - public HttpServletSseServerTransportProvider(String messageEndpoint) { - this(messageEndpoint, "/sse"); + public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { + this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); } + @Override public void setSessionFactory(McpServerSession.Factory sessionFactory) { this.sessionFactory = sessionFactory; } + @Override public Mono notifyClients(String method, Map params) { - if (this.sessions.isEmpty()) { + if (sessions.isEmpty()) { logger.debug("No active sessions to broadcast message to"); return Mono.empty(); - } else { - logger.debug("Attempting to broadcast message to {} active sessions", this.sessions.size()); - return Flux.fromIterable(this.sessions.values()) - .flatMap( - (session) -> - session - .sendNotification(method, params) - .doOnError( - (e) -> - logger.error( - "Failed to send message to session {}: {}", - session.getId(), - e.getMessage())) - .onErrorComplete()) - .then(); } + + logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); + return Flux.fromIterable(sessions.values()) + .flatMap( + session -> + session + .sendNotification(method, params) + .doOnError( + e -> + logger.error( + "Failed to send message to session {}: {}", + session.getId(), + e.getMessage())) + .onErrorComplete()) + .then(); } + @Override protected void doGet(HttpServletRequest request, HttpServletResponse response) - throws ServletException, IOException { + throws IOException { handleSseEvent(request, response); } private void handleSseEvent(HttpServletRequest request, HttpServletResponse response) - throws ServletException, IOException { - String pathInfo = request.getPathInfo(); - if (!this.sseEndpoint.contains(pathInfo)) { - response.sendError(404); - } else if (this.isClosing.get()) { - response.sendError(503, "Server is shutting down"); - } else { - response.setContentType("text/event-stream"); - response.setCharacterEncoding("UTF-8"); - response.setHeader("Cache-Control", "no-cache"); - response.setHeader("Connection", "keep-alive"); - response.setHeader("Access-Control-Allow-Origin", "*"); - String sessionId = UUID.randomUUID().toString(); - AsyncContext asyncContext = request.startAsync(); - asyncContext.setTimeout(0L); - PrintWriter writer = response.getWriter(); - HttpServletMcpSessionTransport sessionTransport = - new HttpServletMcpSessionTransport(sessionId, asyncContext, writer); - McpServerSession session = this.sessionFactory.create(sessionTransport); - this.sessions.put(sessionId, session); - this.sendEvent(writer, "endpoint", this.messageEndpoint + "?sessionId=" + sessionId); + throws IOException { + String requestURI = request.getRequestURI(); + if (!requestURI.endsWith(sseEndpoint)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; } - } - protected void doPost(HttpServletRequest request, HttpServletResponse response) - throws ServletException, IOException { - if (this.isClosing.get()) { - response.sendError(503, "Server is shutting down"); - } else { - String requestURI = request.getRequestURI(); - if (!requestURI.endsWith(this.messageEndpoint)) { - response.sendError(404); - } else { - String sessionId = request.getParameter("sessionId"); - if (sessionId == null) { - response.setContentType("application/json"); - response.setCharacterEncoding("UTF-8"); - response.setStatus(400); - String jsonError = - this.objectMapper.writeValueAsString( - new McpError("Session ID missing in message endpoint")); - PrintWriter writer = response.getWriter(); - writer.write(jsonError); - writer.flush(); - } else { - McpServerSession session = (McpServerSession) this.sessions.get(sessionId); - if (session == null) { - response.setContentType("application/json"); - response.setCharacterEncoding("UTF-8"); - response.setStatus(404); - String jsonError = - this.objectMapper.writeValueAsString( - new McpError("Session not found: " + sessionId)); - PrintWriter writer = response.getWriter(); - writer.write(jsonError); - writer.flush(); - } else { + if (isClosing.get()) { + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); + return; + } + + response.setContentType("text/event-stream"); + response.setCharacterEncoding(UTF_8); + response.setHeader("Cache-Control", "no-cache"); + response.setHeader("Connection", "keep-alive"); + response.setHeader("Access-Control-Allow-Origin", "*"); + + String sessionId = UUID.randomUUID().toString(); + AsyncContext asyncContext = request.startAsync(); + asyncContext.setTimeout(0); + + PrintWriter writer = response.getWriter(); + + // Create a new session transport + HttpServletMcpSessionTransport sessionTransport = + new HttpServletMcpSessionTransport(sessionId, asyncContext, writer); + + // Create a new session using the session factory + McpServerSession session = sessionFactory.create(sessionTransport); + this.sessions.put(sessionId, session); + + executorService.submit( + () -> { + // TODO: Handle session lifecycle and keepalive + try { + while (sessions.containsKey(sessionId)) { + // Send keepalive every 30 seconds + Thread.sleep(30000); + writer.write(": keep-alive\n\n"); + writer.flush(); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } catch (Exception e) { + log("SSE error", e); + } finally { try { - BufferedReader reader = request.getReader(); - StringBuilder body = new StringBuilder(); - - String line; - while ((line = reader.readLine()) != null) { - body.append(line); - } - - McpSchema.JSONRPCMessage message = - getJsonRpcMessageWithAuthorizationParam(request, body.toString()); - session.handle(message).block(); - response.setStatus(200); - } catch (Exception var11) { - Exception e = var11; - logger.error("Error processing message: {}", var11.getMessage()); - - try { - McpError mcpError = new McpError(e.getMessage()); - response.setContentType("application/json"); - response.setCharacterEncoding("UTF-8"); - response.setStatus(500); - String jsonError = this.objectMapper.writeValueAsString(mcpError); - PrintWriter writer = response.getWriter(); - writer.write(jsonError); - writer.flush(); - } catch (IOException ex) { - logger.error("Failed to send error response: {}", ex.getMessage()); - response.sendError(500, "Error processing message"); - } + session.closeGracefully(); + asyncContext.complete(); + } catch (Exception e) { + log("Error closing long-lived SSE connection", e); } } - } + }); + + // Send initial endpoint event + this.sendEvent( + writer, + ENDPOINT_EVENT_TYPE, + this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId); + } + + @Override + protected void doPost(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + if (isClosing.get()) { + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); + return; + } + + String requestURI = request.getRequestURI(); + if (!requestURI.endsWith(messageEndpoint)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + // Get the session ID from the request parameter + String sessionId = request.getParameter("sessionId"); + if (sessionId == null) { + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(HttpServletResponse.SC_BAD_REQUEST); + String jsonError = + objectMapper.writeValueAsString(new McpError("Session ID missing in message endpoint")); + PrintWriter writer = response.getWriter(); + writer.write(jsonError); + writer.flush(); + return; + } + + // Get the session from the sessions map + McpServerSession session = sessions.get(sessionId); + if (session == null) { + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(HttpServletResponse.SC_NOT_FOUND); + String jsonError = + objectMapper.writeValueAsString(new McpError("Session not found: " + sessionId)); + PrintWriter writer = response.getWriter(); + writer.write(jsonError); + writer.flush(); + return; + } + + try { + BufferedReader reader = request.getReader(); + StringBuilder body = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + body.append(line); + } + + McpSchema.JSONRPCMessage message = + getJsonRpcMessageWithAuthorizationParam(this.objectMapper, request, body.toString()); + + // Process the message through the session's handle method + session.handle(message).block(); // Block for Servlet compatibility + + response.setStatus(HttpServletResponse.SC_OK); + } catch (Exception e) { + logger.error("Error processing message: {}", e.getMessage()); + try { + McpError mcpError = new McpError(e.getMessage()); + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); + String jsonError = objectMapper.writeValueAsString(mcpError); + PrintWriter writer = response.getWriter(); + writer.write(jsonError); + writer.flush(); + } catch (IOException ex) { + logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex.getMessage()); + response.sendError( + HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Error processing message"); } } } - private McpSchema.JSONRPCMessage getJsonRpcMessageWithAuthorizationParam( - HttpServletRequest request, String body) throws IOException { - Map requestMessage = JsonUtils.getMap(JsonUtils.readTree(body)); - Map params = (Map) requestMessage.get("params"); - if (params != null) { - Map arguments = (Map) params.get("arguments"); - if (arguments != null) { - arguments.put("Authorization", JwtFilter.extractToken(request.getHeader("Authorization"))); - } - } - return McpSchema.deserializeJsonRpcMessage( - this.objectMapper, JsonUtils.pojoToJson(requestMessage)); - } - + @Override public Mono closeGracefully() { - this.isClosing.set(true); - logger.debug("Initiating graceful shutdown with {} active sessions", this.sessions.size()); - return Flux.fromIterable(this.sessions.values()) - .flatMap(McpServerSession::closeGracefully) - .then(); + isClosing.set(true); + logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); + + return Flux.fromIterable(sessions.values()).flatMap(McpServerSession::closeGracefully).then(); } private void sendEvent(PrintWriter writer, String eventType, String data) throws IOException { @@ -218,8 +267,19 @@ public class HttpServletSseServerTransportProvider extends HttpServlet } } + @Override public void destroy() { - this.closeGracefully().block(); + closeGracefully().block(); + if (executorService != null) { + executorService.shutdown(); + try { + if (!executorService.awaitTermination(5, TimeUnit.SECONDS)) { + executorService.shutdownNow(); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } super.destroy(); } @@ -233,65 +293,106 @@ public class HttpServletSseServerTransportProvider extends HttpServlet this.sessionId = sessionId; this.asyncContext = asyncContext; this.writer = writer; - HttpServletSseServerTransportProvider.logger.debug( - "Session transport {} initialized with SSE writer", sessionId); + logger.debug("Session transport {} initialized with SSE writer", sessionId); } + @Override public Mono sendMessage(McpSchema.JSONRPCMessage message) { return Mono.fromRunnable( () -> { try { - String jsonText = - HttpServletSseServerTransportProvider.this.objectMapper.writeValueAsString( - message); - HttpServletSseServerTransportProvider.this.sendEvent( - this.writer, "message", jsonText); - HttpServletSseServerTransportProvider.logger.debug( - "Message sent to session {}", this.sessionId); + String jsonText = objectMapper.writeValueAsString(message); + sendEvent(writer, MESSAGE_EVENT_TYPE, jsonText); + logger.debug("Message sent to session {}", sessionId); } catch (Exception e) { - HttpServletSseServerTransportProvider.logger.error( - "Failed to send message to session {}: {}", this.sessionId, e.getMessage()); - HttpServletSseServerTransportProvider.this.sessions.remove(this.sessionId); - this.asyncContext.complete(); + logger.error("Failed to send message to session {}: {}", sessionId, e.getMessage()); + sessions.remove(sessionId); + asyncContext.complete(); } }); } + @Override public T unmarshalFrom(Object data, TypeReference typeRef) { - return (T) - HttpServletSseServerTransportProvider.this.objectMapper.convertValue(data, typeRef); + return objectMapper.convertValue(data, typeRef); } + @Override public Mono closeGracefully() { return Mono.fromRunnable( () -> { - HttpServletSseServerTransportProvider.logger.debug( - "Closing session transport: {}", this.sessionId); - + logger.debug("Closing session transport: {}", sessionId); try { - HttpServletSseServerTransportProvider.this.sessions.remove(this.sessionId); - this.asyncContext.complete(); - HttpServletSseServerTransportProvider.logger.debug( - "Successfully completed async context for session {}", this.sessionId); + sessions.remove(sessionId); + asyncContext.complete(); + logger.debug("Successfully completed async context for session {}", sessionId); } catch (Exception e) { - HttpServletSseServerTransportProvider.logger.warn( - "Failed to complete async context for session {}: {}", - this.sessionId, - e.getMessage()); + logger.warn( + "Failed to complete async context for session {}: {}", sessionId, e.getMessage()); } }); } + @Override public void close() { try { - HttpServletSseServerTransportProvider.this.sessions.remove(this.sessionId); - this.asyncContext.complete(); - HttpServletSseServerTransportProvider.logger.debug( - "Successfully completed async context for session {}", this.sessionId); + sessions.remove(sessionId); + asyncContext.complete(); + logger.debug("Successfully completed async context for session {}", sessionId); } catch (Exception e) { - HttpServletSseServerTransportProvider.logger.warn( - "Failed to complete async context for session {}: {}", this.sessionId, e.getMessage()); + logger.warn( + "Failed to complete async context for session {}: {}", sessionId, e.getMessage()); } } } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private ObjectMapper objectMapper = new ObjectMapper(); + + private String baseUrl = DEFAULT_BASE_URL; + + private String messageEndpoint; + + private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + public Builder baseUrl(String baseUrl) { + Assert.notNull(baseUrl, "Base URL must not be null"); + this.baseUrl = baseUrl; + return this; + } + + public Builder messageEndpoint(String messageEndpoint) { + Assert.hasText(messageEndpoint, "Message endpoint must not be empty"); + this.messageEndpoint = messageEndpoint; + return this; + } + + public Builder sseEndpoint(String sseEndpoint) { + Assert.hasText(sseEndpoint, "SSE endpoint must not be empty"); + this.sseEndpoint = sseEndpoint; + return this; + } + + public HttpServletSseServerTransportProvider build() { + if (objectMapper == null) { + throw new IllegalStateException("ObjectMapper must be set"); + } + if (messageEndpoint == null) { + throw new IllegalStateException("MessageEndpoint must be set"); + } + return new HttpServletSseServerTransportProvider( + objectMapper, baseUrl, messageEndpoint, sseEndpoint); + } + } } diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/mcp/MCPStreamableHttpServlet.java b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/MCPStreamableHttpServlet.java new file mode 100644 index 00000000000..f96d69f91b1 --- /dev/null +++ b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/MCPStreamableHttpServlet.java @@ -0,0 +1,971 @@ +package org.openmetadata.service.mcp; + +import static org.openmetadata.service.mcp.McpUtils.getJsonRpcMessageWithAuthorizationParam; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import jakarta.servlet.AsyncContext; +import jakarta.servlet.ServletException; +import jakarta.servlet.annotation.WebServlet; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.PrintWriter; +import java.security.SecureRandom; +import java.util.ArrayList; +import java.util.Base64; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; +import org.openmetadata.service.config.MCPConfiguration; +import org.openmetadata.service.exception.BadRequestException; +import org.openmetadata.service.limits.Limits; +import org.openmetadata.service.mcp.tools.DefaultToolContext; +import org.openmetadata.service.security.Authorizer; +import org.openmetadata.service.security.JwtFilter; +import org.openmetadata.service.util.JsonUtils; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * MCP (Model Context Protocol) Streamable HTTP Servlet + * This servlet implements the Streamable HTTP transport specification for MCP. + */ +@WebServlet(value = "/mcp", asyncSupported = true) +@Slf4j +public class MCPStreamableHttpServlet extends HttpServlet implements McpServerTransportProvider { + + private static final String SESSION_HEADER = "Mcp-Session-Id"; + private static final String CONTENT_TYPE_JSON = "application/json"; + private static final String CONTENT_TYPE_SSE = "text/event-stream"; + + private final transient ObjectMapper objectMapper = new ObjectMapper(); + private final transient Map sessions = new ConcurrentHashMap<>(); + private final transient Map sseConnections = new ConcurrentHashMap<>(); + private final transient SecureRandom secureRandom = new SecureRandom(); + private final transient ExecutorService executorService = + Executors.newCachedThreadPool( + r -> { + Thread t = new Thread(r, "MCP-Worker-Streamable"); + t.setDaemon(true); + return t; + }); + private transient McpServerSession.Factory sessionFactory; + private final transient JwtFilter jwtFilter; + private final transient Authorizer authorizer; + private final transient List tools = new ArrayList<>(); + private final transient Limits limits; + private final transient DefaultToolContext toolContext; + private final transient MCPConfiguration mcpConfiguration; + + public MCPStreamableHttpServlet( + MCPConfiguration mcpConfiguration, + JwtFilter jwtFilter, + Authorizer authorizer, + Limits limits, + DefaultToolContext toolContext, + List tools) { + this.mcpConfiguration = mcpConfiguration; + this.jwtFilter = jwtFilter; + this.authorizer = authorizer; + this.limits = limits; + this.toolContext = toolContext; + this.tools.addAll(tools); + } + + @Override + public void init() throws ServletException { + super.init(); + log("MCP Streamable HTTP Servlet initialized"); + } + + @Override + public void destroy() { + if (executorService != null) { + executorService.shutdown(); + try { + if (!executorService.awaitTermination(5, TimeUnit.SECONDS)) { + executorService.shutdownNow(); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + // Close all SSE connections + for (SSEConnection connection : sseConnections.values()) { + try { + connection.close(); + } catch (IOException e) { + log("Error closing SSE connection", e); + } + } + + super.destroy(); + } + + @Override + protected void doPost(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + String origin = request.getHeader("Origin"); + if (origin != null && !isValidOrigin(origin)) { + sendError(response, HttpServletResponse.SC_FORBIDDEN, "Invalid origin"); + return; + } + + String acceptHeader = request.getHeader("Accept"); + if (acceptHeader == null + || !acceptHeader.contains(CONTENT_TYPE_JSON) + || !acceptHeader.contains(CONTENT_TYPE_SSE)) { + sendError( + response, + HttpServletResponse.SC_BAD_REQUEST, + "Accept header must include both application/json and text/event-stream"); + return; + } + + try { + String requestBody = readRequestBody(request); + String sessionId = request.getHeader(SESSION_HEADER); + + // Parse JSON-RPC message(s) + McpSchema.JSONRPCMessage message = + getJsonRpcMessageWithAuthorizationParam(this.objectMapper, request, requestBody); + JsonNode jsonNode = objectMapper.valueToTree(message); + + if (jsonNode.isArray()) { + handleBatchRequest(request, response, jsonNode, sessionId); + } else { + handleSingleMessage(request, response, jsonNode, sessionId); + } + } catch (Exception e) { + log("Error handling POST request", e); + sendError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Internal server error"); + } + } + + /** + * Handle a single JSON-RPC message according to MCP specification + */ + private void handleSingleMessage( + HttpServletRequest request, + HttpServletResponse response, + JsonNode jsonMessage, + String sessionId) + throws IOException { + + String method = jsonMessage.has("method") ? jsonMessage.get("method").asText() : null; + boolean hasId = jsonMessage.has("id"); + boolean isResponse = jsonMessage.has("result") || jsonMessage.has("error"); + + // Handle initialization specially + if ("initialize".equals(method)) { + handleInitialize(response, jsonMessage); + return; + } + + // Validate session for non-initialization requests (except for responses/notifications) + if (sessionId != null && !sessions.containsKey(sessionId) && !isResponse) { + sendError(response, HttpServletResponse.SC_NOT_FOUND, "Session not found"); + return; + } + + if (isResponse || (!hasId && method != null)) { + // This is either a response or a notification + handleResponseOrNotification(response, jsonMessage, sessionId, isResponse); + } else if (hasId && method != null) { + // This is a request - may return JSON or start SSE stream + handleRequest(request, response, jsonMessage, sessionId); + } else { + sendError(response, HttpServletResponse.SC_BAD_REQUEST, "Invalid JSON-RPC message format"); + } + } + + /** + * Handle responses and notifications from client + */ + private void handleResponseOrNotification( + HttpServletResponse response, JsonNode message, String sessionId, boolean isResponse) + throws IOException { + + try { + if (isResponse) { + processClientResponse(message, sessionId); + log("Processed client response for session: " + sessionId); + } else { + processNotification(message, sessionId); + log("Processed client notification for session: " + sessionId); + } + + response.setStatus(HttpServletResponse.SC_ACCEPTED); + response.setContentLength(0); + + } catch (Exception e) { + log("Error processing response/notification", e); + sendError(response, HttpServletResponse.SC_BAD_REQUEST, "Failed to process message"); + } + } + + /** + * Handle JSON-RPC requests from client + */ + private void handleRequest( + HttpServletRequest request, + HttpServletResponse response, + JsonNode jsonRequest, + String sessionId) + throws IOException { + + String acceptHeader = request.getHeader("Accept"); + boolean supportsSSE = acceptHeader != null && acceptHeader.contains(CONTENT_TYPE_SSE); + + // Determine whether to use SSE stream or direct JSON response + boolean useSSE = supportsSSE && shouldUseSSEForRequest(jsonRequest, sessionId); + + if (useSSE) { + // Initiate SSE stream and process request asynchronously + startSSEStreamForRequest(request, response, jsonRequest, sessionId); + } else { + // Send direct JSON response + sendDirectJSONResponse(response, jsonRequest, sessionId); + } + } + + /** + * Send direct JSON response for requests + */ + private void sendDirectJSONResponse( + HttpServletResponse response, JsonNode request, String sessionId) throws IOException { + + Map jsonResponse = processRequest(request, sessionId); + String responseJson = objectMapper.writeValueAsString(jsonResponse); + + response.setContentType(CONTENT_TYPE_JSON); + response.setStatus(HttpServletResponse.SC_OK); + + // Include session ID in response if present + if (sessionId != null) { + response.setHeader(SESSION_HEADER, sessionId); + } + + response.getWriter().write(responseJson); + response.getWriter().flush(); + } + + /** + * Start SSE stream for processing requests + */ + private void startSSEStreamForRequest( + HttpServletRequest request, + HttpServletResponse response, + JsonNode jsonRequest, + String sessionId) + throws IOException { + + // Set up SSE response headers + response.setContentType(CONTENT_TYPE_SSE); + response.setHeader("Cache-Control", "no-cache"); + response.setHeader("Connection", "keep-alive"); + response.setHeader("Access-Control-Allow-Origin", "*"); + + // Include session ID in response if present + if (sessionId != null) { + response.setHeader(SESSION_HEADER, sessionId); + } + + response.setStatus(HttpServletResponse.SC_OK); + + // Start async processing + AsyncContext asyncContext = request.startAsync(); + asyncContext.setTimeout(300000); // 5 minutes timeout + + SSEConnection connection = new SSEConnection(response.getWriter(), sessionId); + String connectionId = UUID.randomUUID().toString(); + sseConnections.put(connectionId, connection); + + // Process request asynchronously + executorService.submit( + () -> { + try { + // Send server-initiated messages first + // sendServerInitiatedMessages(connection); + + // Process the actual request and send response + Map jsonResponse = processRequest(jsonRequest, sessionId); + String eventId = generateEventId(); + connection.sendEventWithId(eventId, objectMapper.writeValueAsString(jsonResponse)); + + // Close the stream after sending response + connection.close(); + + } catch (Exception e) { + log("Error in SSE stream processing for request", e); + try { + // Send error response before closing + Map errorResponse = + createErrorResponse( + jsonRequest.get("id"), -32603, "Internal error: " + e.getMessage()); + connection.sendEvent(objectMapper.writeValueAsString(errorResponse)); + connection.close(); + } catch (Exception ex) { + log("Error sending error response in SSE stream", ex); + } + } finally { + try { + asyncContext.complete(); + } catch (Exception e) { + log("Error completing async context", e); + } + sseConnections.remove(connectionId); + } + }); + } + + /** + * Handle batch requests according to MCP specification + */ + private void handleBatchRequest( + HttpServletRequest request, + HttpServletResponse response, + JsonNode batchRequest, + String sessionId) + throws IOException { + + if (!batchRequest.isArray() || batchRequest.size() == 0) { + sendError(response, HttpServletResponse.SC_BAD_REQUEST, "Invalid batch request format"); + return; + } + + // Analyze the batch to determine message types + boolean hasRequests = false; + boolean hasResponsesOrNotifications = false; + + for (JsonNode message : batchRequest) { + boolean hasId = message.has("id"); + boolean isResponse = message.has("result") || message.has("error"); + boolean hasMethod = message.has("method"); + + if (hasId && hasMethod && !isResponse) { + hasRequests = true; + } else if (isResponse || (!hasId && hasMethod)) { + hasResponsesOrNotifications = true; + } + } + + if (hasRequests && hasResponsesOrNotifications) { + sendError( + response, + HttpServletResponse.SC_BAD_REQUEST, + "Batch cannot mix requests with responses/notifications"); + return; + } + + if (hasResponsesOrNotifications) { + // Process responses and notifications, return 202 Accepted + processBatchResponsesOrNotifications(batchRequest, sessionId); + response.setStatus(HttpServletResponse.SC_ACCEPTED); + response.setContentLength(0); + } else if (hasRequests) { + // Process batch requests - determine if SSE or direct JSON + String acceptHeader = request.getHeader("Accept"); + boolean supportsSSE = acceptHeader != null && acceptHeader.contains(CONTENT_TYPE_SSE); + + if (supportsSSE && shouldUseSSEForBatch(batchRequest, sessionId)) { + startSSEStreamForBatchRequests(request, response, batchRequest, sessionId); + } else { + sendBatchJSONResponse(response, batchRequest, sessionId); + } + } else { + sendError(response, HttpServletResponse.SC_BAD_REQUEST, "Invalid batch content"); + } + } + + /** + * Process client responses (for server-initiated requests) + */ + private void processClientResponse(JsonNode response, String sessionId) { + // Handle responses to server-initiated requests + Object id = response.has("id") ? response.get("id") : null; + LOG.info("Received client response for request ID: " + id + " (session: " + sessionId + ")"); + + // TODO: Match with pending server requests and complete them + // This would involve maintaining a map of pending requests by ID + } + + /** + * Determine if SSE should be used for a specific request + */ + private boolean shouldUseSSEForRequest(JsonNode request, String sessionId) { + // TODO: This is good for now, but we can enhance this logic later, like tools/call can be long + // running for our use case should be fine + // Use SSE for requests that are streaming operations or have specific methods + MCPSession session = sessions.get(sessionId); + return session != null; + } + + /** + * Determine if SSE should be used for batch requests + */ + private boolean shouldUseSSEForBatch(JsonNode batchRequest, String sessionId) { + // Use SSE for batches containing streaming operations + for (JsonNode request : batchRequest) { + if (shouldUseSSEForRequest(request, sessionId)) { + return true; + } + } + return false; + } + + /** + * Process batch responses and notifications + */ + private void processBatchResponsesOrNotifications(JsonNode batch, String sessionId) { + for (JsonNode message : batch) { + boolean isResponse = message.has("result") || message.has("error"); + if (isResponse) { + processClientResponse(message, sessionId); + } else { + processNotification(message, sessionId); + } + } + } + + /** + * Send batch JSON response + */ + private void sendBatchJSONResponse( + HttpServletResponse response, JsonNode batchRequest, String sessionId) throws IOException { + + List> responses = new ArrayList<>(); + + for (JsonNode request : batchRequest) { + Map jsonResponse = processRequest(request, sessionId); + responses.add(jsonResponse); + } + + String responseJson = objectMapper.writeValueAsString(responses); + + response.setContentType(CONTENT_TYPE_JSON); + response.setStatus(HttpServletResponse.SC_OK); + + if (sessionId != null) { + response.setHeader(SESSION_HEADER, sessionId); + } + + response.getWriter().write(responseJson); + response.getWriter().flush(); + } + + /** + * Start SSE stream for batch requests + */ + private void startSSEStreamForBatchRequests( + HttpServletRequest request, + HttpServletResponse response, + JsonNode batchRequest, + String sessionId) + throws IOException { + + // Set up SSE response headers + response.setContentType(CONTENT_TYPE_SSE); + response.setHeader("Cache-Control", "no-cache"); + response.setHeader("Connection", "keep-alive"); + response.setHeader("Access-Control-Allow-Origin", "*"); + + if (sessionId != null) { + response.setHeader(SESSION_HEADER, sessionId); + } + + response.setStatus(HttpServletResponse.SC_OK); + + AsyncContext asyncContext = request.startAsync(); + asyncContext.setTimeout(300000); // 5 minutes timeout + + SSEConnection connection = new SSEConnection(response.getWriter(), sessionId); + String connectionId = UUID.randomUUID().toString(); + sseConnections.put(connectionId, connection); + + // Process batch asynchronously + executorService.submit( + () -> { + try { + // Send server-initiated messages first + // sendServerInitiatedMessages(connection); + + // Process each request in the batch + List> responses = new ArrayList<>(); + for (JsonNode req : batchRequest) { + Map jsonResponse = processRequest(req, sessionId); + responses.add(jsonResponse); + } + + // Send batch response + String eventId = generateEventId(); + connection.sendEventWithId(eventId, objectMapper.writeValueAsString(responses)); + + connection.close(); + + } catch (Exception e) { + log("Error in SSE stream processing for batch", e); + try { + Map errorResponse = + createErrorResponse(null, -32603, "Internal error: " + e.getMessage()); + connection.sendEvent(objectMapper.writeValueAsString(errorResponse)); + connection.close(); + } catch (Exception ex) { + log("Error sending error response in batch SSE stream", ex); + } + } finally { + try { + asyncContext.complete(); + } catch (Exception e) { + log("Error completing async context for batch", e); + } + sseConnections.remove(connectionId); + } + }); + } + + /** + * Create error response object + */ + private Map createErrorResponse(JsonNode id, int code, String message) { + Map response = new HashMap<>(); + response.put("jsonrpc", "2.0"); + response.put("id", id); + response.put("error", createError(code, message)); + return response; + } + + /** + * Generate unique event ID for SSE + */ + private String generateEventId() { + return UUID.randomUUID().toString(); + } + + /** + * Enhanced SSE connection class with event ID support + */ + public static class SSEConnection { + private final PrintWriter writer; + private final String sessionId; + private volatile boolean closed = false; + private int eventCounter = 0; + + public SSEConnection(PrintWriter writer, String sessionId) { + this.writer = writer; + this.sessionId = sessionId; + } + + public synchronized void sendEvent(String data) throws IOException { + if (closed) return; + + writer.printf("id: %d%n", ++eventCounter); + writer.printf("data: %s%n%n", data); + writer.flush(); + + if (writer.checkError()) { + throw new IOException("SSE write error"); + } + } + + public synchronized void sendEventWithId(String id, String data) throws IOException { + if (closed) return; + + writer.printf("id: %s%n", id); + writer.printf("data: %s%n%n", data); + writer.flush(); + + if (writer.checkError()) { + throw new IOException("SSE write error"); + } + } + + public synchronized void sendComment(String comment) throws IOException { + if (closed) return; + + writer.printf(": %s%n%n", comment); + writer.flush(); + + if (writer.checkError()) { + throw new IOException("SSE write error"); + } + } + + public void close() throws IOException { + closed = true; + if (writer != null) { + writer.close(); + } + } + + public boolean isClosed() { + return closed; + } + + public String getSessionId() { + return sessionId; + } + } + + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String sessionId = request.getHeader(SESSION_HEADER); + String acceptHeader = request.getHeader("Accept"); + + if (acceptHeader == null || !acceptHeader.contains(CONTENT_TYPE_SSE)) { + sendError( + response, + HttpServletResponse.SC_BAD_REQUEST, + "Accept header must include text/event-stream"); + return; + } + + if (sessionId != null && !sessions.containsKey(sessionId)) { + sendError(response, HttpServletResponse.SC_NOT_FOUND, "Session not found"); + return; + } + + startSSEStreamForGet(request, response, sessionId); + } + + @Override + protected void doDelete(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + String sessionId = request.getHeader(SESSION_HEADER); + removeMCPSession(sessionId); + response.setStatus(HttpServletResponse.SC_OK); + } + + private void removeMCPSession(String sessionId) { + if (sessionId == null) { + throw BadRequestException.of("Session ID required"); + } + + // Terminate session + MCPSession session = sessions.remove(sessionId); + if (session != null) { + LOG.info("Session terminated: " + sessionId); + } + } + + private void handleInitialize(HttpServletResponse response, JsonNode request) throws IOException { + // Create new session + String sessionId = generateSessionId(); + MCPSession session = new MCPSession(this.objectMapper, sessionId, response.getWriter()); + sessions.put(sessionId, session); + + // Create initialize response + Map jsonResponse = new HashMap<>(); + jsonResponse.put("jsonrpc", "2.0"); + jsonResponse.put("id", request.get("id")); + + Map result = new HashMap<>(); + result.put("protocolVersion", "2024-11-05"); + result.put("capabilities", getServerCapabilities()); + result.put("serverInfo", getServerInfo()); + jsonResponse.put("result", result); + + // Send response with session ID + String responseJson = objectMapper.writeValueAsString(jsonResponse); + response.setContentType(CONTENT_TYPE_JSON); + response.setHeader(SESSION_HEADER, sessionId); + response.setStatus(HttpServletResponse.SC_OK); + response.getWriter().write(responseJson); + + log("New session initialized: " + sessionId); + } + + private void startSSEStreamForGet( + HttpServletRequest request, HttpServletResponse response, String sessionId) + throws IOException { + + response.setContentType(CONTENT_TYPE_SSE); + response.setHeader("Cache-Control", "no-cache"); + response.setHeader("Connection", "keep-alive"); + response.setHeader("Access-Control-Allow-Origin", "*"); + response.setStatus(HttpServletResponse.SC_OK); + + AsyncContext asyncContext = request.startAsync(); + asyncContext.setTimeout(0); // No timeout for long-lived connections + + SSEConnection connection = new SSEConnection(response.getWriter(), sessionId); + String connectionId = UUID.randomUUID().toString(); + sseConnections.put(connectionId, connection); + + // Keep connection alive and handle server-initiated messages + executorService.submit( + () -> { + try { + while (!connection.isClosed()) { + // Send keepalive every 30 seconds + Thread.sleep(30000); + connection.sendComment("keepalive"); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } catch (IOException e) { + LOG.error("SSE connection error for connection ID: {}", connectionId, e); + } finally { + try { + connection.close(); + asyncContext.complete(); + } catch (Exception e) { + log("Error closing long-lived SSE connection", e); + } + sseConnections.remove(connectionId); + } + }); + } + + @Override + public void setSessionFactory(McpServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + @Override + public Mono notifyClients(String method, Map params) { + if (sessions.isEmpty()) { + LOG.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + LOG.debug("Attempting to broadcast message to {} active sessions", sessions.size()); + return Flux.fromIterable(sessions.values()) + .flatMap( + session -> + session + .sendNotification(method, params) + .doOnError( + e -> + LOG.error( + "Failed to send message to session {}: {}", + session.getSessionId(), + e.getMessage())) + .onErrorComplete()) + .then(); + } + + @Override + public Mono closeGracefully() { + LOG.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); + return Flux.fromIterable(sessions.values()).flatMap(MCPSession::closeGracefully).then(); + } + + private static class MCPSession { + @Getter private final String sessionId; + private final long createdAt; + private final Map state; + private final PrintWriter outputStream; + private final ObjectMapper objectMapper; + + public MCPSession(ObjectMapper objectMapper, String sessionId, PrintWriter outputStream) { + this.sessionId = sessionId; + this.createdAt = System.currentTimeMillis(); + this.state = new ConcurrentHashMap<>(); + this.objectMapper = objectMapper; + this.outputStream = outputStream; + } + + public String getSessionId() { + return sessionId; + } + + public long getCreatedAt() { + return createdAt; + } + + public Map getState() { + return state; + } + + public Mono sendNotification(String method, Map params) { + McpSchema.JSONRPCNotification jsonrpcNotification = + new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, method, params); + return Mono.fromRunnable( + () -> { + try { + String json = objectMapper.writeValueAsString(jsonrpcNotification); + outputStream.write(json); + outputStream.write('\n'); + outputStream.flush(); + } catch (IOException e) { + LOG.error("Failed to send message", e); + } + }); + } + + public Mono closeGracefully() { + return Mono.fromRunnable( + () -> { + outputStream.flush(); + outputStream.close(); + }); + } + } + + private Map processRequest(JsonNode request, String sessionId) { + Map response = new HashMap<>(); + response.put("jsonrpc", "2.0"); + response.put("id", request.get("id")); + + String method = request.get("method").asText(); + + try { + switch (method) { + case "ping": + response.put("result", "pong"); + break; + + case "resources/list": + response.put("result", new McpSchema.ListResourcesResult(new ArrayList<>(), null)); + break; + + case "resources/read": + // TODO: Implement resource reading logic + break; + + case "tools/list": + response.put("result", new McpSchema.ListToolsResult(tools, null)); + break; + + case "tools/call": + JsonNode toolParams = request.get("params"); + if (toolParams != null && toolParams.has("name")) { + String toolName = toolParams.get("name").asText(); + JsonNode arguments = toolParams.get("arguments"); + McpSchema.Content content = + new McpSchema.TextContent( + JsonUtils.pojoToJson( + toolContext.callTool( + authorizer, jwtFilter, limits, toolName, JsonUtils.getMap(arguments)))); + response.put("result", new McpSchema.CallToolResult(List.of(content), false)); + } else { + response.put("error", createError(-32602, "Invalid params")); + } + break; + + default: + response.put("error", createError(-32601, "Method not found: " + method)); + } + } catch (Exception e) { + log("Error processing request: " + method, e); + response.put("error", createError(-32603, "Internal error: " + e.getMessage())); + } + + return response; + } + + private void processNotification(JsonNode notification, String sessionId) { + String method = notification.get("method").asText(); + log("Received notification: " + method + " (session: " + sessionId + ")"); + + // Handle specific notifications + switch (method) { + case "notifications/initialized": + LOG.info("Client initialized for session: {}", sessionId); + break; + case "notifications/cancelled": + // TODO: Check this + LOG.info("Client sent a cancellation request for session: {}", sessionId); + removeMCPSession(sessionId); + break; + default: + log("Unknown notification: " + method); + } + } + + // Utility methods + private String generateSessionId() { + byte[] bytes = new byte[32]; + secureRandom.nextBytes(bytes); + return Base64.getUrlEncoder().withoutPadding().encodeToString(bytes); + } + + private boolean isValidOrigin(String origin) { + if (null == origin || origin.isEmpty()) { + return false; + } + return origin.startsWith(mcpConfiguration.getOriginHeaderUri()); + } + + private String readRequestBody(HttpServletRequest request) throws IOException { + StringBuilder body = new StringBuilder(); + try (BufferedReader reader = request.getReader()) { + String line; + while ((line = reader.readLine()) != null) { + body.append(line); + } + } + return body.toString(); + } + + private void sendError(HttpServletResponse response, int statusCode, String message) + throws IOException { + Map error = new HashMap<>(); + error.put("error", message); + + String errorJson = objectMapper.writeValueAsString(error); + response.setContentType(CONTENT_TYPE_JSON); + response.setStatus(statusCode); + response.getWriter().write(errorJson); + } + + private Map createError(int code, String message) { + Map error = new HashMap<>(); + error.put("code", code); + error.put("message", message); + return error; + } + + private McpSchema.ServerCapabilities getServerCapabilities() { + return McpSchema.ServerCapabilities.builder() + .tools(true) + .prompts(true) + .resources(true, true) + .build(); + } + + private McpSchema.Implementation getServerInfo() { + return new McpSchema.Implementation("OpenMetadata MCP Server - Streamable", "1.0.0"); + } + + /** + * Send a server-initiated notification to a specific session + */ + public void sendNotificationToSession( + String sessionId, String method, Map params) { + // Find SSE connections for this session + for (SSEConnection connection : sseConnections.values()) { + if (sessionId.equals(connection.getSessionId())) { + try { + Map notification = new HashMap<>(); + notification.put("jsonrpc", "2.0"); + notification.put("method", method); + if (params != null) { + notification.put("params", params); + } + + connection.sendEvent(objectMapper.writeValueAsString(notification)); + } catch (IOException e) { + log("Error sending notification to session: " + sessionId, e); + } + } + } + } +} diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/mcp/McpServer.java b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/McpServer.java index 3decc827e5c..25d476b559c 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/mcp/McpServer.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/McpServer.java @@ -1,33 +1,23 @@ package org.openmetadata.service.mcp; -import static org.openmetadata.service.search.SearchUtil.searchMetadata; - -import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; import io.dropwizard.core.setup.Environment; import io.dropwizard.jetty.MutableServletContextHandler; import io.modelcontextprotocol.server.McpServerFeatures; import io.modelcontextprotocol.server.McpSyncServer; import io.modelcontextprotocol.spec.McpSchema; import jakarta.servlet.DispatcherType; -import jakarta.servlet.Filter; -import java.util.ArrayList; import java.util.EnumSet; import java.util.List; -import java.util.Map; import lombok.extern.slf4j.Slf4j; import org.eclipse.jetty.servlet.FilterHolder; import org.eclipse.jetty.servlet.ServletHolder; -import org.openmetadata.common.utils.CommonUtil; import org.openmetadata.service.OpenMetadataApplicationConfig; +import org.openmetadata.service.config.MCPConfiguration; import org.openmetadata.service.limits.Limits; -import org.openmetadata.service.mcp.tools.GlossaryTermTool; -import org.openmetadata.service.mcp.tools.GlossaryTool; -import org.openmetadata.service.mcp.tools.PatchEntityTool; -import org.openmetadata.service.security.AuthorizationException; +import org.openmetadata.service.mcp.tools.DefaultToolContext; import org.openmetadata.service.security.Authorizer; import org.openmetadata.service.security.JwtFilter; -import org.openmetadata.service.security.auth.CatalogSecurityContext; -import org.openmetadata.service.util.EntityUtil; import org.openmetadata.service.util.JsonUtils; @Slf4j @@ -35,8 +25,11 @@ public class McpServer { private JwtFilter jwtFilter; private Authorizer authorizer; private Limits limits; + protected DefaultToolContext toolContext; - public McpServer() {} + public McpServer(DefaultToolContext toolContext) { + this.toolContext = toolContext; + } public void initializeMcpServer( Environment environment, @@ -47,6 +40,24 @@ public class McpServer { new JwtFilter(config.getAuthenticationConfiguration(), config.getAuthorizerConfiguration()); this.authorizer = authorizer; this.limits = limits; + MutableServletContextHandler contextHandler = environment.getApplicationContext(); + McpAuthFilter authFilter = + new McpAuthFilter( + new JwtFilter( + config.getAuthenticationConfiguration(), config.getAuthorizerConfiguration())); + List tools = getTools(); + addSSETransport(contextHandler, authFilter, tools); + addStreamableHttpServlet(config.getMcpConfiguration(), contextHandler, authFilter, tools); + } + + protected List getTools() { + return toolContext.loadToolsDefinitionsFromJson("json/data/mcp/tools.json"); + } + + private void addSSETransport( + MutableServletContextHandler contextHandler, + McpAuthFilter authFilter, + List tools) { McpSchema.ServerCapabilities serverCapabilities = McpSchema.ServerCapabilities.builder() .tools(true) @@ -54,152 +65,55 @@ public class McpServer { .resources(true, true) .build(); - HttpServletSseServerTransportProvider transport = - new HttpServletSseServerTransportProvider("/mcp/messages", "/mcp/sse"); + HttpServletSseServerTransportProvider sseTransport = + new HttpServletSseServerTransportProvider(new ObjectMapper(), "/mcp/messages", "/mcp/sse"); + McpSyncServer server = - io.modelcontextprotocol.server.McpServer.sync(transport) - .serverInfo("openmetadata-mcp", "0.1.0") + io.modelcontextprotocol.server.McpServer.sync(sseTransport) + .serverInfo("openmetadata-mcp-sse", "0.1.0") .capabilities(serverCapabilities) .build(); + addToolsToServer(server, tools); - // Add resources, prompts, and tools to the MCP server - addTools(server); + // SSE transport for MCP + ServletHolder servletHolderSSE = new ServletHolder(sseTransport); + contextHandler.addServlet(servletHolderSSE, "/mcp/*"); - MutableServletContextHandler contextHandler = environment.getApplicationContext(); - ServletHolder servletHolder = new ServletHolder(transport); - contextHandler.addServlet(servletHolder, "/mcp/*"); - - McpAuthFilter authFilter = - new McpAuthFilter( - new JwtFilter( - config.getAuthenticationConfiguration(), config.getAuthorizerConfiguration())); contextHandler.addFilter( - new FilterHolder((Filter) authFilter), "/mcp/*", EnumSet.of(DispatcherType.REQUEST)); + new FilterHolder(authFilter), "/mcp/*", EnumSet.of(DispatcherType.REQUEST)); } - public void addTools(McpSyncServer server) { - try { - LOG.info("Loading tool definitions..."); - List> cachedTools = loadToolsDefinitionsFromJson(); - if (cachedTools == null || cachedTools.isEmpty()) { - LOG.error("No tool definitions were loaded!"); - throw new RuntimeException("Failed to load tool definitions"); - } - LOG.info("Successfully loaded {} tool definitions", cachedTools.size()); + private void addStreamableHttpServlet( + MCPConfiguration configuration, + MutableServletContextHandler contextHandler, + McpAuthFilter authFilter, + List tools) { + // Streamable HTTP servlet for MCP + MCPStreamableHttpServlet streamableHttpServlet = + new MCPStreamableHttpServlet( + configuration, jwtFilter, authorizer, limits, toolContext, tools); + ServletHolder servletHolderStreamableHttp = new ServletHolder(streamableHttpServlet); + contextHandler.addServlet(servletHolderStreamableHttp, "/mcp"); - for (Map toolDef : cachedTools) { - try { - String name = (String) toolDef.get("name"); - String description = (String) toolDef.get("description"); - Map schema = JsonUtils.getMap(toolDef.get("parameters")); - server.addTool(getTool(JsonUtils.pojoToJson(schema), name, description)); - } catch (Exception e) { - LOG.error("Error processing tool definition: {}", toolDef, e); - } - } - LOG.info("Initializing request handlers..."); - } catch (Exception e) { - LOG.error("Error during server startup", e); - throw new RuntimeException("Failed to start MCP server", e); + contextHandler.addFilter( + new FilterHolder(authFilter), "/mcp", EnumSet.of(DispatcherType.REQUEST)); + } + + public void addToolsToServer(McpSyncServer server, List tools) { + for (McpSchema.Tool tool : tools) { + server.addTool(getTool(tool)); } } - protected List> loadToolsDefinitionsFromJson() { - String json = getJsonFromFile("json/data/mcp/tools.json"); - return loadToolDefinitionsFromJson(json); - } - - protected static String getJsonFromFile(String path) { - try { - return CommonUtil.getResourceAsStream(McpServer.class.getClassLoader(), path); - } catch (Exception ex) { - LOG.error("Error loading JSON file: {}", path, ex); - return null; - } - } - - @SuppressWarnings("unchecked") - public List> loadToolDefinitionsFromJson(String json) { - try { - LOG.info("Loaded tool definitions, content length: {}", json.length()); - LOG.info("Raw tools.json content: {}", json); - - JsonNode toolsJson = JsonUtils.readTree(json); - JsonNode toolsArray = toolsJson.get("tools"); - - if (toolsArray == null || !toolsArray.isArray()) { - LOG.error("Invalid MCP tools file format. Expected 'tools' array."); - return new ArrayList<>(); - } - - List> tools = new ArrayList<>(); - for (JsonNode toolNode : toolsArray) { - String name = toolNode.get("name").asText(); - Map toolDef = JsonUtils.convertValue(toolNode, Map.class); - tools.add(toolDef); - LOG.info("Tool found: {} with definition: {}", name, toolDef); - } - - LOG.info("Found {} tool definitions", tools.size()); - return tools; - } catch (Exception e) { - LOG.error("Error loading tool definitions: {}", e.getMessage(), e); - throw e; - } - } - - private McpServerFeatures.SyncToolSpecification getTool( - String schema, String toolName, String description) { - McpSchema.Tool tool = new McpSchema.Tool(toolName, description, schema); - + private McpServerFeatures.SyncToolSpecification getTool(McpSchema.Tool tool) { return new McpServerFeatures.SyncToolSpecification( tool, (exchange, arguments) -> { McpSchema.Content content = - new McpSchema.TextContent(JsonUtils.pojoToJson(runMethod(toolName, arguments))); + new McpSchema.TextContent( + JsonUtils.pojoToJson( + toolContext.callTool(authorizer, jwtFilter, limits, tool.name(), arguments))); return new McpSchema.CallToolResult(List.of(content), false); }); } - - protected Object runMethod(String toolName, Map params) { - CatalogSecurityContext securityContext = - jwtFilter.getCatalogSecurityContext((String) params.get("Authorization")); - LOG.info( - "Catalog Principal: {} is trying to call the tool: {}", - securityContext.getUserPrincipal().getName(), - toolName); - Object result; - try { - switch (toolName) { - case "search_metadata": - result = searchMetadata(params); - break; - case "get_entity_details": - result = EntityUtil.getEntityDetails(params); - break; - case "create_glossary": - result = new GlossaryTool().execute(authorizer, limits, securityContext, params); - break; - case "create_glossary_term": - result = new GlossaryTermTool().execute(authorizer, limits, securityContext, params); - break; - case "patch_entity": - result = new PatchEntityTool().execute(authorizer, securityContext, params); - break; - default: - result = Map.of("error", "Unknown function: " + toolName); - break; - } - - return result; - } catch (AuthorizationException ex) { - LOG.error("Authorization error: {}", ex.getMessage()); - return Map.of( - "error", String.format("Authorization error: %s", ex.getMessage()), "statusCode", 403); - } catch (Exception ex) { - LOG.error("Error executing tool: {}", ex.getMessage()); - return Map.of( - "error", String.format("Error executing tool: %s", ex.getMessage()), "statusCode", 500); - } - } } diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/mcp/McpUtils.java b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/McpUtils.java new file mode 100644 index 00000000000..3547ad92323 --- /dev/null +++ b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/McpUtils.java @@ -0,0 +1,94 @@ +package org.openmetadata.service.mcp; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpSchema; +import jakarta.servlet.http.HttpServletRequest; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import lombok.extern.slf4j.Slf4j; +import org.openmetadata.common.utils.CommonUtil; +import org.openmetadata.service.security.JwtFilter; +import org.openmetadata.service.util.JsonUtils; + +@Slf4j +public class McpUtils { + + public static McpSchema.JSONRPCMessage getJsonRpcMessageWithAuthorizationParam( + ObjectMapper objectMapper, HttpServletRequest request, String body) throws IOException { + Map requestMessage = JsonUtils.getMap(JsonUtils.readTree(body)); + Map params = (Map) requestMessage.get("params"); + if (params != null) { + Map arguments = (Map) params.get("arguments"); + if (arguments != null) { + arguments.put("Authorization", JwtFilter.extractToken(request.getHeader("Authorization"))); + } + } + return McpSchema.deserializeJsonRpcMessage(objectMapper, JsonUtils.pojoToJson(requestMessage)); + } + + @SuppressWarnings("unchecked") + public static List> loadToolDefinitionsFromJson(String json) { + try { + LOG.info("Loaded tool definitions, content length: {}", json.length()); + LOG.info("Raw tools.json content: {}", json); + + JsonNode toolsJson = JsonUtils.readTree(json); + JsonNode toolsArray = toolsJson.get("tools"); + + if (toolsArray == null || !toolsArray.isArray()) { + LOG.error("Invalid MCP tools file format. Expected 'tools' array."); + return new ArrayList<>(); + } + + List> tools = new ArrayList<>(); + for (JsonNode toolNode : toolsArray) { + String name = toolNode.get("name").asText(); + Map toolDef = JsonUtils.convertValue(toolNode, Map.class); + tools.add(toolDef); + LOG.info("Tool found: {} with definition: {}", name, toolDef); + } + + LOG.info("Found {} tool definitions", tools.size()); + return tools; + } catch (Exception e) { + LOG.error("Error loading tool definitions: {}", e.getMessage(), e); + throw e; + } + } + + public static String getJsonFromFile(String path) { + try { + return CommonUtil.getResourceAsStream(McpServer.class.getClassLoader(), path); + } catch (Exception ex) { + LOG.error("Error loading JSON file: {}", path, ex); + return null; + } + } + + public static List getToolProperties(String jsonFilePath) { + try { + List result = new ArrayList<>(); + String json = getJsonFromFile(jsonFilePath); + List> cachedTools = loadToolDefinitionsFromJson(json); + if (cachedTools == null || cachedTools.isEmpty()) { + LOG.error("No tool definitions were loaded!"); + throw new RuntimeException("Failed to load tool definitions"); + } + LOG.debug("Successfully loaded {} tool definitions", cachedTools.size()); + for (int i = 0; i < cachedTools.size(); i++) { + Map toolDef = cachedTools.get(i); + String name = (String) toolDef.get("name"); + String description = (String) toolDef.get("description"); + Map schema = JsonUtils.getMap(toolDef.get("parameters")); + result.add(new McpSchema.Tool(name, description, JsonUtils.pojoToJson(schema))); + } + return result; + } catch (Exception e) { + LOG.error("Error during server startup", e); + throw new RuntimeException("Failed to start MCP server", e); + } + } +} diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/mcp/tools/DefaultToolContext.java b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/tools/DefaultToolContext.java new file mode 100644 index 00000000000..1af07246787 --- /dev/null +++ b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/tools/DefaultToolContext.java @@ -0,0 +1,75 @@ +package org.openmetadata.service.mcp.tools; + +import static org.openmetadata.service.mcp.McpUtils.getToolProperties; + +import io.modelcontextprotocol.spec.McpSchema; +import java.util.List; +import java.util.Map; +import lombok.extern.slf4j.Slf4j; +import org.openmetadata.service.limits.Limits; +import org.openmetadata.service.security.AuthorizationException; +import org.openmetadata.service.security.Authorizer; +import org.openmetadata.service.security.JwtFilter; +import org.openmetadata.service.security.auth.CatalogSecurityContext; + +@Slf4j +public class DefaultToolContext { + public DefaultToolContext() {} + + /** + * Loads tool definitions from a JSON file located at the specified path. + * The JSON file should contain an array of tool definitions under the "tools" key. + * + * @return List of McpSchema.Tool objects loaded from the JSON file. + */ + public List loadToolsDefinitionsFromJson(String toolFilePath) { + return getToolProperties(toolFilePath); + } + + public Object callTool( + Authorizer authorizer, + JwtFilter jwtFilter, + Limits limits, + String toolName, + Map params) { + CatalogSecurityContext securityContext = + jwtFilter.getCatalogSecurityContext((String) params.get("Authorization")); + LOG.info( + "Catalog Principal: {} is trying to call the tool: {}", + securityContext.getUserPrincipal().getName(), + toolName); + Object result; + try { + switch (toolName) { + case "search_metadata": + result = new SearchMetadataTool().execute(authorizer, securityContext, params); + break; + case "get_entity_details": + result = new GetEntityTool().execute(authorizer, securityContext, params); + break; + case "create_glossary": + result = new GlossaryTool().execute(authorizer, limits, securityContext, params); + break; + case "create_glossary_term": + result = new GlossaryTermTool().execute(authorizer, limits, securityContext, params); + break; + case "patch_entity": + result = new PatchEntityTool().execute(authorizer, limits, securityContext, params); + break; + default: + result = Map.of("error", "Unknown function: " + toolName); + break; + } + + return result; + } catch (AuthorizationException ex) { + LOG.error("Authorization error: {}", ex.getMessage()); + return Map.of( + "error", String.format("Authorization error: %s", ex.getMessage()), "statusCode", 403); + } catch (Exception ex) { + LOG.error("Error executing tool: {}", ex.getMessage()); + return Map.of( + "error", String.format("Error executing tool: %s", ex.getMessage()), "statusCode", 500); + } + } +} diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/mcp/tools/GetEntityTool.java b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/tools/GetEntityTool.java new file mode 100644 index 00000000000..6580115b5c3 --- /dev/null +++ b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/tools/GetEntityTool.java @@ -0,0 +1,42 @@ +package org.openmetadata.service.mcp.tools; + +import static org.openmetadata.schema.type.MetadataOperation.VIEW_ALL; + +import java.io.IOException; +import java.util.Map; +import lombok.extern.slf4j.Slf4j; +import org.openmetadata.service.Entity; +import org.openmetadata.service.limits.Limits; +import org.openmetadata.service.security.Authorizer; +import org.openmetadata.service.security.auth.CatalogSecurityContext; +import org.openmetadata.service.security.policyevaluator.OperationContext; +import org.openmetadata.service.security.policyevaluator.ResourceContext; +import org.openmetadata.service.util.JsonUtils; + +@Slf4j +public class GetEntityTool implements McpTool { + @Override + public Map execute( + Authorizer authorizer, CatalogSecurityContext securityContext, Map params) + throws IOException { + String entityType = (String) params.get("entity_type"); + String fqn = (String) params.get("fqn"); + authorizer.authorize( + securityContext, + new OperationContext(entityType, VIEW_ALL), + new ResourceContext<>(entityType)); + LOG.info("Getting details for entity type: {}, FQN: {}", entityType, fqn); + String fields = "*"; + return JsonUtils.getMap(Entity.getEntityByName(entityType, fqn, fields, null)); + } + + @Override + public Map execute( + Authorizer authorizer, + Limits limits, + CatalogSecurityContext securityContext, + Map params) + throws IOException { + throw new UnsupportedOperationException("GetEntityTool does not requires limit validation."); + } +} diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/mcp/tools/GlossaryTermTool.java b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/tools/GlossaryTermTool.java index 4676e37826f..e8b8d437221 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/mcp/tools/GlossaryTermTool.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/tools/GlossaryTermTool.java @@ -2,7 +2,6 @@ package org.openmetadata.service.mcp.tools; import static org.openmetadata.common.utils.CommonUtil.nullOrEmpty; -import java.util.HashMap; import java.util.Map; import lombok.extern.slf4j.Slf4j; import org.openmetadata.schema.entity.data.Glossary; @@ -59,25 +58,19 @@ public class GlossaryTermTool implements McpTool { limits.enforceLimits(securityContext, createResourceContext, operationContext); authorizer.authorize(securityContext, operationContext, createResourceContext); - try { - GlossaryRepository glossaryRepository = - (GlossaryRepository) Entity.getEntityRepository(Entity.GLOSSARY); - Glossary glossary = - glossaryRepository.findByNameOrNull(createGlossaryTerm.getGlossary(), Include.ALL); + GlossaryRepository glossaryRepository = + (GlossaryRepository) Entity.getEntityRepository(Entity.GLOSSARY); + Glossary glossary = + glossaryRepository.findByNameOrNull(createGlossaryTerm.getGlossary(), Include.ALL); - GlossaryTermRepository glossaryTermRepository = - (GlossaryTermRepository) Entity.getEntityRepository(Entity.GLOSSARY_TERM); - // TODO: Get the updatedBy from the tool request. - glossaryTermRepository.prepare(glossaryTerm, nullOrEmpty(glossary)); - glossaryTermRepository.setFullyQualifiedName(glossaryTerm); - RestUtil.PutResponse response = - glossaryTermRepository.createOrUpdate( - null, glossaryTerm, securityContext.getUserPrincipal().getName()); - return JsonUtils.convertValue(response.getEntity(), Map.class); - } catch (Exception e) { - Map error = new HashMap<>(); - error.put("error", e.getMessage()); - return error; - } + GlossaryTermRepository glossaryTermRepository = + (GlossaryTermRepository) Entity.getEntityRepository(Entity.GLOSSARY_TERM); + // TODO: Get the updatedBy from the tool request. + glossaryTermRepository.prepare(glossaryTerm, nullOrEmpty(glossary)); + glossaryTermRepository.setFullyQualifiedName(glossaryTerm); + RestUtil.PutResponse response = + glossaryTermRepository.createOrUpdate( + null, glossaryTerm, securityContext.getUserPrincipal().getName()); + return JsonUtils.getMap(response.getEntity()); } } diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/mcp/tools/GlossaryTool.java b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/tools/GlossaryTool.java index bce61770d02..8b7275f5d5a 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/mcp/tools/GlossaryTool.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/tools/GlossaryTool.java @@ -1,6 +1,5 @@ package org.openmetadata.service.mcp.tools; -import java.util.HashMap; import java.util.List; import java.util.Map; import lombok.extern.slf4j.Slf4j; @@ -26,7 +25,7 @@ public class GlossaryTool implements McpTool { @Override public Map execute( Authorizer authorizer, CatalogSecurityContext securityContext, Map params) { - throw new UnsupportedOperationException("GlossaryTermTool requires limit validation."); + throw new UnsupportedOperationException("GlossaryTool requires limit validation."); } @Override @@ -56,21 +55,15 @@ public class GlossaryTool implements McpTool { limits.enforceLimits(securityContext, createResourceContext, operationContext); authorizer.authorize(securityContext, operationContext, createResourceContext); - try { - GlossaryRepository glossaryRepository = - (GlossaryRepository) Entity.getEntityRepository(Entity.GLOSSARY); + GlossaryRepository glossaryRepository = + (GlossaryRepository) Entity.getEntityRepository(Entity.GLOSSARY); - glossaryRepository.prepare(glossary, true); - glossaryRepository.setFullyQualifiedName(glossary); - RestUtil.PutResponse response = - glossaryRepository.createOrUpdate( - null, glossary, securityContext.getUserPrincipal().getName()); - return JsonUtils.convertValue(response.getEntity(), Map.class); - } catch (Exception e) { - Map error = new HashMap<>(); - error.put("error", e.getMessage()); - return error; - } + glossaryRepository.prepare(glossary, true); + glossaryRepository.setFullyQualifiedName(glossary); + RestUtil.PutResponse response = + glossaryRepository.createOrUpdate( + null, glossary, securityContext.getUserPrincipal().getName()); + return JsonUtils.convertValue(response.getEntity(), Map.class); } public static void setReviewers(CreateGlossary entity, Map params) { diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/mcp/tools/McpTool.java b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/tools/McpTool.java index 7fdb82ad890..41ee16cda2d 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/mcp/tools/McpTool.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/tools/McpTool.java @@ -1,5 +1,6 @@ package org.openmetadata.service.mcp.tools; +import java.io.IOException; import java.util.Map; import org.openmetadata.service.limits.Limits; import org.openmetadata.service.security.Authorizer; @@ -7,11 +8,13 @@ import org.openmetadata.service.security.auth.CatalogSecurityContext; public interface McpTool { Map execute( - Authorizer authorizer, CatalogSecurityContext securityContext, Map params); + Authorizer authorizer, CatalogSecurityContext securityContext, Map params) + throws IOException; Map execute( Authorizer authorizer, Limits limits, CatalogSecurityContext securityContext, - Map params); + Map params) + throws IOException; } diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/mcp/tools/SearchMetadataTool.java b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/tools/SearchMetadataTool.java new file mode 100644 index 00000000000..8df71a1cdae --- /dev/null +++ b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/tools/SearchMetadataTool.java @@ -0,0 +1,154 @@ +package org.openmetadata.service.mcp.tools; + +import static org.openmetadata.service.search.SearchUtil.mapEntityTypesToIndexNames; +import static org.openmetadata.service.security.DefaultAuthorizer.getSubjectContext; + +import com.fasterxml.jackson.databind.JsonNode; +import jakarta.ws.rs.core.Response; +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import lombok.extern.slf4j.Slf4j; +import org.openmetadata.schema.search.SearchRequest; +import org.openmetadata.service.Entity; +import org.openmetadata.service.limits.Limits; +import org.openmetadata.service.security.Authorizer; +import org.openmetadata.service.security.auth.CatalogSecurityContext; +import org.openmetadata.service.security.policyevaluator.SubjectContext; +import org.openmetadata.service.util.JsonUtils; + +@Slf4j +public class SearchMetadataTool implements McpTool { + + private static final List IGNORE_SEARCH_KEYS = + List.of( + "id", + "version", + "updatedAt", + "updatedBy", + "usageSummary", + "followers", + "deleted", + "votes", + "lifeCycle", + "sourceHash", + "processedLineage", + "totalVotes", + "fqnParts", + "service_suggest", + "column_suggest", + "schema_suggest", + "database_suggest", + "upstreamLineage", + "entityRelationship", + "changeSummary", + "fqnHash"); + + @Override + public Map execute( + Authorizer authorizer, CatalogSecurityContext securityContext, Map params) + throws IOException { + LOG.info("Executing searchMetadata with params: {}", params); + String query = params.containsKey("query") ? (String) params.get("query") : "*"; + int limit = 10; + if (params.containsKey("limit")) { + Object limitObj = params.get("limit"); + if (limitObj instanceof Number) { + limit = ((Number) limitObj).intValue(); + } else if (limitObj instanceof String) { + limit = Integer.parseInt((String) limitObj); + } + } + + boolean includeDeleted = false; + if (params.containsKey("include_deleted")) { + Object deletedObj = params.get("include_deleted"); + if (deletedObj instanceof Boolean) { + includeDeleted = (Boolean) deletedObj; + } else if (deletedObj instanceof String) { + includeDeleted = "true".equals(deletedObj); + } + } + + String entityType = + params.containsKey("entity_type") ? (String) params.get("entity_type") : null; + String index = + (entityType != null && !entityType.isEmpty()) + ? mapEntityTypesToIndexNames(entityType) + : Entity.TABLE; + + LOG.info( + "Search query: {}, index: {}, limit: {}, includeDeleted: {}", + query, + index, + limit, + includeDeleted); + + SearchRequest searchRequest = + new SearchRequest() + .withQuery(query) + .withIndex(index) + .withSize(limit) + .withFrom(0) + .withFetchSource(true) + .withDeleted(includeDeleted); + + SubjectContext subjectContext = getSubjectContext(securityContext); + Response response = Entity.getSearchRepository().search(searchRequest, subjectContext); + + Map searchResponse; + if (response.getEntity() instanceof String responseStr) { + LOG.info("Search returned string response"); + JsonNode jsonNode = JsonUtils.readTree(responseStr); + searchResponse = JsonUtils.convertValue(jsonNode, Map.class); + } else { + LOG.info("Search returned object response: {}", response.getEntity().getClass().getName()); + searchResponse = JsonUtils.convertValue(response.getEntity(), Map.class); + } + return SearchMetadataTool.cleanSearchResponse(searchResponse); + } + + @Override + public Map execute( + Authorizer authorizer, + Limits limits, + CatalogSecurityContext securityContext, + Map params) { + throw new UnsupportedOperationException( + "SearchMetadataTool does not support limits enforcement."); + } + + private static Map cleanSearchResponse(Map searchResponse) { + if (searchResponse == null) return Collections.emptyMap(); + + Map topHits = safeGetMap(searchResponse.get("hits")); + if (topHits == null) return Collections.emptyMap(); + + List hits = safeGetList(topHits.get("hits")); + if (hits == null || hits.isEmpty()) return Collections.emptyMap(); + + for (Object hitObj : hits) { + Map hit = safeGetMap(hitObj); + if (hit == null) continue; + + Map source = safeGetMap(hit.get("_source")); + if (source == null) continue; + + IGNORE_SEARCH_KEYS.forEach(source::remove); + return source; // Return the first valid, cleaned _source + } + + return Collections.emptyMap(); + } + + @SuppressWarnings("unchecked") + private static Map safeGetMap(Object obj) { + return (obj instanceof Map) ? (Map) obj : null; + } + + @SuppressWarnings("unchecked") + private static List safeGetList(Object obj) { + return (obj instanceof List) ? (List) obj : null; + } +} diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/search/SearchUtil.java b/openmetadata-service/src/main/java/org/openmetadata/service/search/SearchUtil.java index 47dd5fd21ac..40848550b41 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/search/SearchUtil.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/search/SearchUtil.java @@ -1,44 +1,10 @@ package org.openmetadata.service.search; -import com.fasterxml.jackson.databind.JsonNode; -import jakarta.ws.rs.core.Response; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; -import org.openmetadata.schema.search.SearchRequest; import org.openmetadata.service.Entity; -import org.openmetadata.service.util.JsonUtils; @Slf4j public class SearchUtil { - - private static final List IGNORE_SEARCH_KEYS = - List.of( - "id", - "version", - "updatedAt", - "updatedBy", - "usageSummary", - "followers", - "deleted", - "votes", - "lifeCycle", - "sourceHash", - "processedLineage", - "totalVotes", - "fqnParts", - "service_suggest", - "column_suggest", - "schema_suggest", - "database_suggest", - "upstreamLineage", - "entityRelationship", - "changeSummary", - "fqnHash"); - /** * Check if the index is a data asset index * @param indexName name of the index to check @@ -141,105 +107,10 @@ public class SearchUtil { case "glossary_search_index", Entity.GLOSSARY -> Entity.GLOSSARY; case "domain_search_index", Entity.DOMAIN -> Entity.DOMAIN; case "data_product_search_index", Entity.DATA_PRODUCT -> Entity.DATA_PRODUCT; - default -> "default"; + case "team_search_index", Entity.TEAM -> Entity.TEAM; + case "user_Search_index", Entity.USER -> Entity.USER; + case "dataAsset" -> "dataAsset"; + default -> "dataAsset"; }; } - - public static List searchMetadata(Map params) { - try { - LOG.info("Executing searchMetadata with params: {}", params); - String query = params.containsKey("query") ? (String) params.get("query") : "*"; - int limit = 10; - if (params.containsKey("limit")) { - Object limitObj = params.get("limit"); - if (limitObj instanceof Number) { - limit = ((Number) limitObj).intValue(); - } else if (limitObj instanceof String) { - limit = Integer.parseInt((String) limitObj); - } - } - - boolean includeDeleted = false; - if (params.containsKey("include_deleted")) { - Object deletedObj = params.get("include_deleted"); - if (deletedObj instanceof Boolean) { - includeDeleted = (Boolean) deletedObj; - } else if (deletedObj instanceof String) { - includeDeleted = "true".equals(deletedObj); - } - } - - String entityType = - params.containsKey("entity_type") ? (String) params.get("entity_type") : null; - String index = - (entityType != null && !entityType.isEmpty()) - ? mapEntityTypesToIndexNames(entityType) - : Entity.TABLE; - - LOG.info( - "Search query: {}, index: {}, limit: {}, includeDeleted: {}", - query, - index, - limit, - includeDeleted); - - SearchRequest searchRequest = - new SearchRequest() - .withQuery(query) - .withIndex(index) - .withSize(limit) - .withFrom(0) - .withFetchSource(true) - .withDeleted(includeDeleted); - - Response response = Entity.getSearchRepository().search(searchRequest, null); - - Map searchResponse; - if (response.getEntity() instanceof String responseStr) { - LOG.info("Search returned string response"); - JsonNode jsonNode = JsonUtils.readTree(responseStr); - searchResponse = JsonUtils.convertValue(jsonNode, Map.class); - } else { - LOG.info("Search returned object response: {}", response.getEntity().getClass().getName()); - searchResponse = JsonUtils.convertValue(response.getEntity(), Map.class); - } - return cleanSearchResponse(searchResponse); - } catch (Exception e) { - LOG.error("Error in searchMetadata", e); - return Collections.emptyList(); - } - } - - public static List cleanSearchResponse(Map searchResponse) { - if (searchResponse == null) return Collections.emptyList(); - - Map topHits = safeGetMap(searchResponse.get("hits")); - if (topHits == null) return Collections.emptyList(); - - List hits = safeGetList(topHits.get("hits")); - if (hits == null) return Collections.emptyList(); - - return hits.stream() - .map(SearchUtil::safeGetMap) - .filter(Objects::nonNull) - .map( - hit -> { - Map source = safeGetMap(hit.get("_source")); - if (source == null) return null; - IGNORE_SEARCH_KEYS.forEach(source::remove); - return source; - }) - .filter(Objects::nonNull) - .collect(Collectors.toList()); - } - - @SuppressWarnings("unchecked") - private static Map safeGetMap(Object obj) { - return (obj instanceof Map) ? (Map) obj : null; - } - - @SuppressWarnings("unchecked") - private static List safeGetList(Object obj) { - return (obj instanceof List) ? (List) obj : null; - } } diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/search/indexes/TableIndex.java b/openmetadata-service/src/main/java/org/openmetadata/service/search/indexes/TableIndex.java index ab3441035c4..a750975bc7b 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/search/indexes/TableIndex.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/search/indexes/TableIndex.java @@ -73,6 +73,7 @@ public record TableIndex(Table table) implements ColumnIndex { doc.put("processedLineage", table.getProcessedLineage()); doc.put("entityRelationship", SearchIndex.populateEntityRelationshipData(table)); doc.put("databaseSchema", getEntityWithDisplayName(table.getDatabaseSchema())); + doc.put("tableQueries", table.getTableQueries()); doc.put( "changeSummary", Optional.ofNullable(table.getChangeDescription()) diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/util/EntityUtil.java b/openmetadata-service/src/main/java/org/openmetadata/service/util/EntityUtil.java index 591605ca522..22a69cf0507 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/util/EntityUtil.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/util/EntityUtil.java @@ -37,7 +37,6 @@ import java.util.Comparator; import java.util.HashSet; import java.util.Iterator; import java.util.List; -import java.util.Map; import java.util.Set; import java.util.TreeSet; import java.util.UUID; @@ -799,19 +798,4 @@ public final class EntityUtil { && changeDescription.getFieldsUpdated().isEmpty() && changeDescription.getFieldsDeleted().isEmpty(); } - - public static Object getEntityDetails(Map params) { - try { - String entityType = (String) params.get("entity_type"); - String fqn = (String) params.get("fqn"); - - LOG.info("Getting details for entity type: {}, FQN: {}", entityType, fqn); - String fields = "*"; - Object entity = Entity.getEntityByName(entityType, fqn, fields, null); - return entity; - } catch (Exception e) { - LOG.error("Error getting entity details", e); - return Map.of("error", e.getMessage()); - } - } } diff --git a/openmetadata-service/src/main/resources/json/data/mcp/tools.json b/openmetadata-service/src/main/resources/json/data/mcp/tools.json index 80135739855..030ff19f885 100644 --- a/openmetadata-service/src/main/resources/json/data/mcp/tools.json +++ b/openmetadata-service/src/main/resources/json/data/mcp/tools.json @@ -2,7 +2,7 @@ "tools": [ { "name": "search_metadata", - "description": "Find your data and business terms in OpenMetadata. For example if the user asks to 'find tables that contain customers information', then 'customers' should be the query, and the entity_type should be 'table'.", + "description": "Find your data and business terms in OpenMetadata. For example if the user asks to 'find tables that contain customers information', then 'customers' should be the query, and the entity_type should be 'table'. Here make sure to use 'Href' is available in result to create a hyperlink to the entity in OpenMetadata.", "parameters": { "description": "The search query to find metadata in the OpenMetadata catalog, entity type could be table, topic etc. Limit can be used to paginate on the data.", "type": "object", diff --git a/openmetadata-spec/src/main/resources/json/schema/entity/data/table.json b/openmetadata-spec/src/main/resources/json/schema/entity/data/table.json index c0559cf884b..3fe1d444fb1 100644 --- a/openmetadata-spec/src/main/resources/json/schema/entity/data/table.json +++ b/openmetadata-spec/src/main/resources/json/schema/entity/data/table.json @@ -1162,6 +1162,14 @@ "description": "Processed lineage for the table", "type": "boolean", "default": false + }, + "tableQueries": { + "description": "List of queries that are used to create this table.", + "type": "array", + "items": { + "$ref": "../../type/basic.json#/definitions/sqlQuery" + }, + "default": null } }, "required": [ diff --git a/openmetadata-ui/src/main/resources/ui/src/generated/entity/data/table.ts b/openmetadata-ui/src/main/resources/ui/src/generated/entity/data/table.ts index cbefaa21e5f..56e8bd4f52a 100644 --- a/openmetadata-ui/src/main/resources/ui/src/generated/entity/data/table.ts +++ b/openmetadata-ui/src/main/resources/ui/src/generated/entity/data/table.ts @@ -161,7 +161,11 @@ export interface Table { * Table Profiler Config to include or exclude columns from profiling. */ tableProfilerConfig?: TableProfilerConfig; - tableType?: TableType; + /** + * List of queries that are used to create this table. + */ + tableQueries?: string[]; + tableType?: TableType; /** * Tags for this table. */