From 635382dd1dac45172f8190b2d4a13ce947cdff00 Mon Sep 17 00:00:00 2001 From: Pere Miquel Brull Date: Fri, 6 Jun 2025 07:32:20 +0200 Subject: [PATCH] Revert "[WIP] MCP Core Items Improvements (#21598)" (#21614) This reverts commit 0b3bf4ac0d3a7ac74e39552ad49896d37e469516. --- .../service/jdbi3/TableRepository.java | 12 - ...HttpServletSseServerTransportProvider.java | 431 +++++------ .../service/mcp/MCPStreamableHttpServlet.java | 701 ------------------ .../openmetadata/service/mcp/McpServer.java | 188 +++-- .../openmetadata/service/mcp/McpUtils.java | 151 ---- .../service/search/SearchUtil.java | 5 +- .../service/search/indexes/TableIndex.java | 1 - .../json/schema/entity/data/table.json | 8 - .../ui/src/generated/entity/data/table.ts | 6 +- 9 files changed, 306 insertions(+), 1197 deletions(-) delete mode 100644 openmetadata-service/src/main/java/org/openmetadata/service/mcp/MCPStreamableHttpServlet.java delete mode 100644 openmetadata-service/src/main/java/org/openmetadata/service/mcp/McpUtils.java diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/jdbi3/TableRepository.java b/openmetadata-service/src/main/java/org/openmetadata/service/jdbi3/TableRepository.java index 756f35732de..d52b2cb556e 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/jdbi3/TableRepository.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/jdbi3/TableRepository.java @@ -26,10 +26,8 @@ import static org.openmetadata.schema.type.Include.NON_DELETED; import static org.openmetadata.service.Entity.DATABASE_SCHEMA; import static org.openmetadata.service.Entity.FIELD_OWNERS; import static org.openmetadata.service.Entity.FIELD_TAGS; -import static org.openmetadata.service.Entity.QUERY; import static org.openmetadata.service.Entity.TABLE; import static org.openmetadata.service.Entity.TEST_SUITE; -import static org.openmetadata.service.Entity.getEntities; import static org.openmetadata.service.Entity.populateEntityFieldTags; import static org.openmetadata.service.util.EntityUtil.getLocalColumnName; import static org.openmetadata.service.util.FullyQualifiedName.getColumnName; @@ -64,7 +62,6 @@ import org.openmetadata.schema.EntityInterface; import org.openmetadata.schema.api.data.CreateTableProfile; import org.openmetadata.schema.api.feed.ResolveTask; import org.openmetadata.schema.entity.data.DatabaseSchema; -import org.openmetadata.schema.entity.data.Query; import org.openmetadata.schema.entity.data.Table; import org.openmetadata.schema.entity.feed.Suggestion; import org.openmetadata.schema.tests.CustomMetric; @@ -182,15 +179,6 @@ public class TableRepository extends EntityRepository { column.setCustomMetrics(getCustomMetrics(table, column.getName())); } } - if (fields.contains("tableQueries")) { - List queriesEntity = - getEntities( - listOrEmpty(findTo(table.getId(), TABLE, Relationship.MENTIONED_IN, QUERY)), - "id", - ALL); - List queries = listOrEmpty(queriesEntity).stream().map(Query::getQuery).toList(); - table.setTableQueries(queries); - } } @Override diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/mcp/HttpServletSseServerTransportProvider.java b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/HttpServletSseServerTransportProvider.java index 3a8e22932a2..3d931f7415b 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/mcp/HttpServletSseServerTransportProvider.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/HttpServletSseServerTransportProvider.java @@ -1,6 +1,8 @@ -package org.openmetadata.service.mcp; +/* +HttpServletSseServerTransportProvider - Jakarta servlet-based MCP server transport +*/ -import static org.openmetadata.service.mcp.McpUtils.getJsonRpcMessageWithAuthorizationParam; +package org.openmetadata.service.mcp; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; @@ -9,7 +11,6 @@ import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; -import io.modelcontextprotocol.util.Assert; import jakarta.servlet.AsyncContext; import jakarta.servlet.ServletException; import jakarta.servlet.annotation.WebServlet; @@ -22,10 +23,9 @@ import java.io.PrintWriter; import java.util.Map; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import org.openmetadata.service.security.JwtFilter; +import org.openmetadata.service.util.JsonUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; @@ -42,220 +42,171 @@ public class HttpServletSseServerTransportProvider extends HttpServlet public static final String DEFAULT_SSE_ENDPOINT = "/sse"; public static final String MESSAGE_EVENT_TYPE = "message"; public static final String ENDPOINT_EVENT_TYPE = "endpoint"; - public static final String DEFAULT_BASE_URL = ""; private final ObjectMapper objectMapper; - private final String baseUrl; private final String messageEndpoint; private final String sseEndpoint; - private final Map sessions = new ConcurrentHashMap<>(); - private final AtomicBoolean isClosing = new AtomicBoolean(false); + private final Map sessions; + private final AtomicBoolean isClosing; private McpServerSession.Factory sessionFactory; - private ExecutorService executorService = - Executors.newCachedThreadPool( - r -> { - Thread t = new Thread(r, "MCP-Worker-SSE"); - t.setDaemon(true); - return t; - }); - public HttpServletSseServerTransportProvider( - ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { - this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint); - } - - public HttpServletSseServerTransportProvider( - ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) { - this.objectMapper = objectMapper; - this.baseUrl = baseUrl; + public HttpServletSseServerTransportProvider(String messageEndpoint, String sseEndpoint) { + this.sessions = new ConcurrentHashMap<>(); + this.isClosing = new AtomicBoolean(false); + this.objectMapper = JsonUtils.getObjectMapper(); this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; } - public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { - this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); + public HttpServletSseServerTransportProvider(String messageEndpoint) { + this(messageEndpoint, "/sse"); } - @Override public void setSessionFactory(McpServerSession.Factory sessionFactory) { this.sessionFactory = sessionFactory; } - @Override public Mono notifyClients(String method, Map params) { - if (sessions.isEmpty()) { + if (this.sessions.isEmpty()) { logger.debug("No active sessions to broadcast message to"); return Mono.empty(); + } else { + logger.debug("Attempting to broadcast message to {} active sessions", this.sessions.size()); + return Flux.fromIterable(this.sessions.values()) + .flatMap( + (session) -> + session + .sendNotification(method, params) + .doOnError( + (e) -> + logger.error( + "Failed to send message to session {}: {}", + session.getId(), + e.getMessage())) + .onErrorComplete()) + .then(); } - - logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); - return Flux.fromIterable(sessions.values()) - .flatMap( - session -> - session - .sendNotification(method, params) - .doOnError( - e -> - logger.error( - "Failed to send message to session {}: {}", - session.getId(), - e.getMessage())) - .onErrorComplete()) - .then(); } - @Override protected void doGet(HttpServletRequest request, HttpServletResponse response) - throws IOException { + throws ServletException, IOException { handleSseEvent(request, response); } private void handleSseEvent(HttpServletRequest request, HttpServletResponse response) - throws IOException { - String requestURI = request.getRequestURI(); - if (!requestURI.endsWith(sseEndpoint)) { - response.sendError(HttpServletResponse.SC_NOT_FOUND); - return; + throws ServletException, IOException { + String pathInfo = request.getPathInfo(); + if (!this.sseEndpoint.contains(pathInfo)) { + response.sendError(404); + } else if (this.isClosing.get()) { + response.sendError(503, "Server is shutting down"); + } else { + response.setContentType("text/event-stream"); + response.setCharacterEncoding("UTF-8"); + response.setHeader("Cache-Control", "no-cache"); + response.setHeader("Connection", "keep-alive"); + response.setHeader("Access-Control-Allow-Origin", "*"); + String sessionId = UUID.randomUUID().toString(); + AsyncContext asyncContext = request.startAsync(); + asyncContext.setTimeout(0L); + PrintWriter writer = response.getWriter(); + HttpServletMcpSessionTransport sessionTransport = + new HttpServletMcpSessionTransport(sessionId, asyncContext, writer); + McpServerSession session = this.sessionFactory.create(sessionTransport); + this.sessions.put(sessionId, session); + this.sendEvent(writer, "endpoint", this.messageEndpoint + "?sessionId=" + sessionId); } - - if (isClosing.get()) { - response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); - return; - } - - response.setContentType("text/event-stream"); - response.setCharacterEncoding(UTF_8); - response.setHeader("Cache-Control", "no-cache"); - response.setHeader("Connection", "keep-alive"); - response.setHeader("Access-Control-Allow-Origin", "*"); - - String sessionId = UUID.randomUUID().toString(); - AsyncContext asyncContext = request.startAsync(); - asyncContext.setTimeout(0); - - PrintWriter writer = response.getWriter(); - - // Create a new session transport - HttpServletMcpSessionTransport sessionTransport = - new HttpServletMcpSessionTransport(sessionId, asyncContext, writer); - - // Create a new session using the session factory - McpServerSession session = sessionFactory.create(sessionTransport); - this.sessions.put(sessionId, session); - - executorService.submit( - () -> { - // TODO: Handle session lifecycle and keepalive - try { - while (sessions.containsKey(sessionId)) { - // Send keepalive every 30 seconds - Thread.sleep(30000); - writer.write(": keep-alive\n\n"); - writer.flush(); - } - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } catch (Exception e) { - log("SSE error", e); - } finally { - try { - session.closeGracefully(); - asyncContext.complete(); - } catch (Exception e) { - log("Error closing long-lived SSE connection", e); - } - } - }); - - // Send initial endpoint event - this.sendEvent( - writer, - ENDPOINT_EVENT_TYPE, - this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId); } - @Override protected void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { + if (this.isClosing.get()) { + response.sendError(503, "Server is shutting down"); + } else { + String requestURI = request.getRequestURI(); + if (!requestURI.endsWith(this.messageEndpoint)) { + response.sendError(404); + } else { + String sessionId = request.getParameter("sessionId"); + if (sessionId == null) { + response.setContentType("application/json"); + response.setCharacterEncoding("UTF-8"); + response.setStatus(400); + String jsonError = + this.objectMapper.writeValueAsString( + new McpError("Session ID missing in message endpoint")); + PrintWriter writer = response.getWriter(); + writer.write(jsonError); + writer.flush(); + } else { + McpServerSession session = (McpServerSession) this.sessions.get(sessionId); + if (session == null) { + response.setContentType("application/json"); + response.setCharacterEncoding("UTF-8"); + response.setStatus(404); + String jsonError = + this.objectMapper.writeValueAsString( + new McpError("Session not found: " + sessionId)); + PrintWriter writer = response.getWriter(); + writer.write(jsonError); + writer.flush(); + } else { + try { + BufferedReader reader = request.getReader(); + StringBuilder body = new StringBuilder(); - if (isClosing.get()) { - response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); - return; - } + String line; + while ((line = reader.readLine()) != null) { + body.append(line); + } - String requestURI = request.getRequestURI(); - if (!requestURI.endsWith(messageEndpoint)) { - response.sendError(HttpServletResponse.SC_NOT_FOUND); - return; - } + McpSchema.JSONRPCMessage message = + getJsonRpcMessageWithAuthorizationParam(request, body.toString()); + session.handle(message).block(); + response.setStatus(200); + } catch (Exception var11) { + Exception e = var11; + logger.error("Error processing message: {}", var11.getMessage()); - // Get the session ID from the request parameter - String sessionId = request.getParameter("sessionId"); - if (sessionId == null) { - response.setContentType(APPLICATION_JSON); - response.setCharacterEncoding(UTF_8); - response.setStatus(HttpServletResponse.SC_BAD_REQUEST); - String jsonError = - objectMapper.writeValueAsString(new McpError("Session ID missing in message endpoint")); - PrintWriter writer = response.getWriter(); - writer.write(jsonError); - writer.flush(); - return; - } - - // Get the session from the sessions map - McpServerSession session = sessions.get(sessionId); - if (session == null) { - response.setContentType(APPLICATION_JSON); - response.setCharacterEncoding(UTF_8); - response.setStatus(HttpServletResponse.SC_NOT_FOUND); - String jsonError = - objectMapper.writeValueAsString(new McpError("Session not found: " + sessionId)); - PrintWriter writer = response.getWriter(); - writer.write(jsonError); - writer.flush(); - return; - } - - try { - BufferedReader reader = request.getReader(); - StringBuilder body = new StringBuilder(); - String line; - while ((line = reader.readLine()) != null) { - body.append(line); - } - - McpSchema.JSONRPCMessage message = - getJsonRpcMessageWithAuthorizationParam(this.objectMapper, request, body.toString()); - - // Process the message through the session's handle method - session.handle(message).block(); // Block for Servlet compatibility - - response.setStatus(HttpServletResponse.SC_OK); - } catch (Exception e) { - logger.error("Error processing message: {}", e.getMessage()); - try { - McpError mcpError = new McpError(e.getMessage()); - response.setContentType(APPLICATION_JSON); - response.setCharacterEncoding(UTF_8); - response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); - String jsonError = objectMapper.writeValueAsString(mcpError); - PrintWriter writer = response.getWriter(); - writer.write(jsonError); - writer.flush(); - } catch (IOException ex) { - logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex.getMessage()); - response.sendError( - HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Error processing message"); + try { + McpError mcpError = new McpError(e.getMessage()); + response.setContentType("application/json"); + response.setCharacterEncoding("UTF-8"); + response.setStatus(500); + String jsonError = this.objectMapper.writeValueAsString(mcpError); + PrintWriter writer = response.getWriter(); + writer.write(jsonError); + writer.flush(); + } catch (IOException ex) { + logger.error("Failed to send error response: {}", ex.getMessage()); + response.sendError(500, "Error processing message"); + } + } + } + } } } } - @Override - public Mono closeGracefully() { - isClosing.set(true); - logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); + private McpSchema.JSONRPCMessage getJsonRpcMessageWithAuthorizationParam( + HttpServletRequest request, String body) throws IOException { + Map requestMessage = JsonUtils.getMap(JsonUtils.readTree(body)); + Map params = (Map) requestMessage.get("params"); + if (params != null) { + Map arguments = (Map) params.get("arguments"); + if (arguments != null) { + arguments.put("Authorization", JwtFilter.extractToken(request.getHeader("Authorization"))); + } + } + return McpSchema.deserializeJsonRpcMessage( + this.objectMapper, JsonUtils.pojoToJson(requestMessage)); + } - return Flux.fromIterable(sessions.values()).flatMap(McpServerSession::closeGracefully).then(); + public Mono closeGracefully() { + this.isClosing.set(true); + logger.debug("Initiating graceful shutdown with {} active sessions", this.sessions.size()); + return Flux.fromIterable(this.sessions.values()) + .flatMap(McpServerSession::closeGracefully) + .then(); } private void sendEvent(PrintWriter writer, String eventType, String data) throws IOException { @@ -267,19 +218,8 @@ public class HttpServletSseServerTransportProvider extends HttpServlet } } - @Override public void destroy() { - closeGracefully().block(); - if (executorService != null) { - executorService.shutdown(); - try { - if (!executorService.awaitTermination(5, TimeUnit.SECONDS)) { - executorService.shutdownNow(); - } - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } - } + this.closeGracefully().block(); super.destroy(); } @@ -293,106 +233,65 @@ public class HttpServletSseServerTransportProvider extends HttpServlet this.sessionId = sessionId; this.asyncContext = asyncContext; this.writer = writer; - logger.debug("Session transport {} initialized with SSE writer", sessionId); + HttpServletSseServerTransportProvider.logger.debug( + "Session transport {} initialized with SSE writer", sessionId); } - @Override public Mono sendMessage(McpSchema.JSONRPCMessage message) { return Mono.fromRunnable( () -> { try { - String jsonText = objectMapper.writeValueAsString(message); - sendEvent(writer, MESSAGE_EVENT_TYPE, jsonText); - logger.debug("Message sent to session {}", sessionId); + String jsonText = + HttpServletSseServerTransportProvider.this.objectMapper.writeValueAsString( + message); + HttpServletSseServerTransportProvider.this.sendEvent( + this.writer, "message", jsonText); + HttpServletSseServerTransportProvider.logger.debug( + "Message sent to session {}", this.sessionId); } catch (Exception e) { - logger.error("Failed to send message to session {}: {}", sessionId, e.getMessage()); - sessions.remove(sessionId); - asyncContext.complete(); + HttpServletSseServerTransportProvider.logger.error( + "Failed to send message to session {}: {}", this.sessionId, e.getMessage()); + HttpServletSseServerTransportProvider.this.sessions.remove(this.sessionId); + this.asyncContext.complete(); } }); } - @Override public T unmarshalFrom(Object data, TypeReference typeRef) { - return objectMapper.convertValue(data, typeRef); + return (T) + HttpServletSseServerTransportProvider.this.objectMapper.convertValue(data, typeRef); } - @Override public Mono closeGracefully() { return Mono.fromRunnable( () -> { - logger.debug("Closing session transport: {}", sessionId); + HttpServletSseServerTransportProvider.logger.debug( + "Closing session transport: {}", this.sessionId); + try { - sessions.remove(sessionId); - asyncContext.complete(); - logger.debug("Successfully completed async context for session {}", sessionId); + HttpServletSseServerTransportProvider.this.sessions.remove(this.sessionId); + this.asyncContext.complete(); + HttpServletSseServerTransportProvider.logger.debug( + "Successfully completed async context for session {}", this.sessionId); } catch (Exception e) { - logger.warn( - "Failed to complete async context for session {}: {}", sessionId, e.getMessage()); + HttpServletSseServerTransportProvider.logger.warn( + "Failed to complete async context for session {}: {}", + this.sessionId, + e.getMessage()); } }); } - @Override public void close() { try { - sessions.remove(sessionId); - asyncContext.complete(); - logger.debug("Successfully completed async context for session {}", sessionId); + HttpServletSseServerTransportProvider.this.sessions.remove(this.sessionId); + this.asyncContext.complete(); + HttpServletSseServerTransportProvider.logger.debug( + "Successfully completed async context for session {}", this.sessionId); } catch (Exception e) { - logger.warn( - "Failed to complete async context for session {}: {}", sessionId, e.getMessage()); + HttpServletSseServerTransportProvider.logger.warn( + "Failed to complete async context for session {}: {}", this.sessionId, e.getMessage()); } } } - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - - private ObjectMapper objectMapper = new ObjectMapper(); - - private String baseUrl = DEFAULT_BASE_URL; - - private String messageEndpoint; - - private String sseEndpoint = DEFAULT_SSE_ENDPOINT; - - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; - return this; - } - - public Builder baseUrl(String baseUrl) { - Assert.notNull(baseUrl, "Base URL must not be null"); - this.baseUrl = baseUrl; - return this; - } - - public Builder messageEndpoint(String messageEndpoint) { - Assert.hasText(messageEndpoint, "Message endpoint must not be empty"); - this.messageEndpoint = messageEndpoint; - return this; - } - - public Builder sseEndpoint(String sseEndpoint) { - Assert.hasText(sseEndpoint, "SSE endpoint must not be empty"); - this.sseEndpoint = sseEndpoint; - return this; - } - - public HttpServletSseServerTransportProvider build() { - if (objectMapper == null) { - throw new IllegalStateException("ObjectMapper must be set"); - } - if (messageEndpoint == null) { - throw new IllegalStateException("MessageEndpoint must be set"); - } - return new HttpServletSseServerTransportProvider( - objectMapper, baseUrl, messageEndpoint, sseEndpoint); - } - } } diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/mcp/MCPStreamableHttpServlet.java b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/MCPStreamableHttpServlet.java deleted file mode 100644 index 8e35120122d..00000000000 --- a/openmetadata-service/src/main/java/org/openmetadata/service/mcp/MCPStreamableHttpServlet.java +++ /dev/null @@ -1,701 +0,0 @@ -package org.openmetadata.service.mcp; - -import static org.openmetadata.service.mcp.McpUtils.callTool; -import static org.openmetadata.service.mcp.McpUtils.getJsonRpcMessageWithAuthorizationParam; - -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpServerSession; -import io.modelcontextprotocol.spec.McpServerTransportProvider; -import jakarta.servlet.AsyncContext; -import jakarta.servlet.ServletException; -import jakarta.servlet.annotation.WebServlet; -import jakarta.servlet.http.HttpServlet; -import jakarta.servlet.http.HttpServletRequest; -import jakarta.servlet.http.HttpServletResponse; -import java.io.BufferedReader; -import java.io.IOException; -import java.io.PrintWriter; -import java.security.SecureRandom; -import java.util.ArrayList; -import java.util.Base64; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; -import lombok.Getter; -import lombok.extern.slf4j.Slf4j; -import org.openmetadata.service.limits.Limits; -import org.openmetadata.service.security.Authorizer; -import org.openmetadata.service.security.JwtFilter; -import org.openmetadata.service.util.JsonUtils; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -/** - * MCP (Model Context Protocol) Streamable HTTP Servlet - * This servlet implements the Streamable HTTP transport specification for MCP. - */ -@WebServlet(value = "/mcp", asyncSupported = true) -@Slf4j -public class MCPStreamableHttpServlet extends HttpServlet implements McpServerTransportProvider { - - private static final String SESSION_HEADER = "Mcp-Session-Id"; - private static final String CONTENT_TYPE_JSON = "application/json"; - private static final String CONTENT_TYPE_SSE = "text/event-stream"; - - private ObjectMapper objectMapper = new ObjectMapper(); - private Map sessions = new ConcurrentHashMap<>(); - private Map sseConnections = new ConcurrentHashMap<>(); - private SecureRandom secureRandom = new SecureRandom(); - private ExecutorService executorService = - Executors.newCachedThreadPool( - r -> { - Thread t = new Thread(r, "MCP-Worker-Streamable"); - t.setDaemon(true); - return t; - }); - private McpServerSession.Factory sessionFactory; - private final JwtFilter jwtFilter; - private final Authorizer authorizer; - private final List tools = new ArrayList<>(); - private final Limits limits; - - public MCPStreamableHttpServlet( - JwtFilter jwtFilter, Authorizer authorizer, Limits limits, List tools) { - this.jwtFilter = jwtFilter; - this.authorizer = authorizer; - this.limits = limits; - this.tools.addAll(tools); - } - - @Override - public void init() throws ServletException { - super.init(); - log("MCP Streamable HTTP Servlet initialized"); - } - - @Override - public void destroy() { - if (executorService != null) { - executorService.shutdown(); - try { - if (!executorService.awaitTermination(5, TimeUnit.SECONDS)) { - executorService.shutdownNow(); - } - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } - } - - // Close all SSE connections - for (SSEConnection connection : sseConnections.values()) { - try { - connection.close(); - } catch (IOException e) { - log("Error closing SSE connection", e); - } - } - - super.destroy(); - } - - @Override - protected void doPost(HttpServletRequest request, HttpServletResponse response) - throws ServletException, IOException { - - // Security: Validate Origin header - String origin = request.getHeader("Origin"); - if (origin != null && !isValidOrigin(origin)) { - sendError(response, HttpServletResponse.SC_FORBIDDEN, "Invalid origin"); - return; - } - - // Validate Accept header - MUST include both application/json and text/event-stream - String acceptHeader = request.getHeader("Accept"); - if (acceptHeader == null - || !acceptHeader.contains(CONTENT_TYPE_JSON) - || !acceptHeader.contains(CONTENT_TYPE_SSE)) { - sendError( - response, - HttpServletResponse.SC_BAD_REQUEST, - "Accept header must include both application/json and text/event-stream"); - return; - } - - try { - String requestBody = readRequestBody(request); - String sessionId = request.getHeader(SESSION_HEADER); - - // Parse JSON-RPC message(s) - McpSchema.JSONRPCMessage message = - getJsonRpcMessageWithAuthorizationParam(this.objectMapper, request, requestBody); - JsonNode jsonNode = objectMapper.valueToTree(message); - - // TODO: here we need to see how to handle the batch request from the Spec - if (jsonNode.isArray()) { - handleBatchRequest(request, response, jsonNode, sessionId); - } else { - handleSingleRequest(request, response, jsonNode, sessionId); - } - } catch (Exception e) { - log("Error handling POST request", e); - sendError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Internal server error"); - } - } - - @Override - protected void doGet(HttpServletRequest request, HttpServletResponse response) - throws ServletException, IOException { - - String sessionId = request.getHeader(SESSION_HEADER); - String acceptHeader = request.getHeader("Accept"); - - if (acceptHeader == null || !acceptHeader.contains(CONTENT_TYPE_SSE)) { - sendError( - response, - HttpServletResponse.SC_BAD_REQUEST, - "Accept header must include text/event-stream"); - return; - } - - if (sessionId != null && !sessions.containsKey(sessionId)) { - sendError(response, HttpServletResponse.SC_NOT_FOUND, "Session not found"); - return; - } - - startSSEStreamForGet(request, response, sessionId); - } - - @Override - protected void doDelete(HttpServletRequest request, HttpServletResponse response) - throws ServletException, IOException { - - String sessionId = request.getHeader(SESSION_HEADER); - - if (sessionId == null) { - sendError(response, HttpServletResponse.SC_BAD_REQUEST, "Session ID required"); - return; - } - - // Terminate session - MCPSession session = sessions.remove(sessionId); - if (session != null) { - log("Session terminated: " + sessionId); - } - - response.setStatus(HttpServletResponse.SC_OK); - } - - private void handleSingleRequest( - HttpServletRequest request, - HttpServletResponse response, - JsonNode jsonRequest, - String sessionId) - throws IOException { - - String method = jsonRequest.has("method") ? jsonRequest.get("method").asText() : null; - boolean hasId = jsonRequest.has("id"); - - // Handle initialization - if ("initialize".equals(method)) { - handleInitialize(response, jsonRequest); - return; - } - - // Validate session for non-initialization requests - if (sessionId != null && !sessions.containsKey(sessionId)) { - sendError(response, HttpServletResponse.SC_NOT_FOUND, "Session not found"); - return; - } - - // Handle different message types - if (!hasId) { - // Notification - return 202 Accepted - processNotification(jsonRequest, sessionId); - response.setStatus(HttpServletResponse.SC_ACCEPTED); - } else { - // Request - may return JSON or start SSE stream - String acceptHeader = request.getHeader("Accept"); - boolean supportsSSE = acceptHeader != null && acceptHeader.contains(CONTENT_TYPE_SSE); - - if (supportsSSE && shouldUseSSE()) { - startSSEStream(request, response, jsonRequest, sessionId); - } else { - sendJSONResponse(response, jsonRequest, sessionId); - } - } - } - - private void handleInitialize(HttpServletResponse response, JsonNode request) throws IOException { - // Create new session - String sessionId = generateSessionId(); - MCPSession session = new MCPSession(this.objectMapper, sessionId, response.getWriter()); - sessions.put(sessionId, session); - - // Create initialize response - Map jsonResponse = new HashMap<>(); - jsonResponse.put("jsonrpc", "2.0"); - jsonResponse.put("id", request.get("id")); - - Map result = new HashMap<>(); - result.put("protocolVersion", "2024-11-05"); - result.put("capabilities", getServerCapabilities()); - result.put("serverInfo", getServerInfo()); - jsonResponse.put("result", result); - - // Send response with session ID - String responseJson = objectMapper.writeValueAsString(jsonResponse); - response.setContentType(CONTENT_TYPE_JSON); - response.setHeader(SESSION_HEADER, sessionId); - response.setStatus(HttpServletResponse.SC_OK); - response.getWriter().write(responseJson); - - log("New session initialized: " + sessionId); - } - - private void startSSEStream( - HttpServletRequest request, - HttpServletResponse response, - JsonNode jsonRequest, - String sessionId) - throws IOException { - - // Set up SSE response headers - response.setContentType(CONTENT_TYPE_SSE); - response.setHeader("Cache-Control", "no-cache"); - response.setHeader("Connection", "keep-alive"); - response.setHeader("Access-Control-Allow-Origin", "*"); - response.setStatus(HttpServletResponse.SC_OK); - - // Start async processing - AsyncContext asyncContext = request.startAsync(); - asyncContext.setTimeout(0); // 5 minutes timeout - - SSEConnection connection = new SSEConnection(response.getWriter(), sessionId); - String connectionId = UUID.randomUUID().toString(); - sseConnections.put(connectionId, connection); - - // Process request asynchronously - executorService.submit( - () -> { - try { - // Send any server-initiated messages first (if needed) - // sendServerInitiatedMessages(connection); - - // Process the actual request - Map jsonResponse = processRequest(jsonRequest, sessionId); - connection.sendEvent(objectMapper.writeValueAsString(jsonResponse)); - - // Close the stream after sending response - connection.close(); - asyncContext.complete(); - - } catch (Exception e) { - log("Error in SSE stream processing", e); - try { - connection.close(); - asyncContext.complete(); - } catch (Exception ex) { - log("Error closing SSE connection", ex); - } - } finally { - sseConnections.remove(connectionId); - } - }); - } - - private void startSSEStreamForGet( - HttpServletRequest request, HttpServletResponse response, String sessionId) - throws IOException { - - response.setContentType(CONTENT_TYPE_SSE); - response.setHeader("Cache-Control", "no-cache"); - response.setHeader("Connection", "keep-alive"); - response.setHeader("Access-Control-Allow-Origin", "*"); - response.setStatus(HttpServletResponse.SC_OK); - - AsyncContext asyncContext = request.startAsync(); - asyncContext.setTimeout(0); // No timeout for long-lived connections - - SSEConnection connection = new SSEConnection(response.getWriter(), sessionId); - String connectionId = UUID.randomUUID().toString(); - sseConnections.put(connectionId, connection); - - // Keep connection alive and handle server-initiated messages - executorService.submit( - () -> { - try { - while (!connection.isClosed()) { - // Send keepalive every 30 seconds - Thread.sleep(30000); - connection.sendComment("keepalive"); - } - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } catch (IOException e) { - LOG.error("SSE connection error for connection ID: {}", connectionId, e); - } finally { - try { - connection.close(); - asyncContext.complete(); - } catch (Exception e) { - log("Error closing long-lived SSE connection", e); - } - sseConnections.remove(connectionId); - } - }); - } - - @Override - public void setSessionFactory(McpServerSession.Factory sessionFactory) { - this.sessionFactory = sessionFactory; - } - - @Override - public Mono notifyClients(String method, Map params) { - if (sessions.isEmpty()) { - LOG.debug("No active sessions to broadcast message to"); - return Mono.empty(); - } - - LOG.debug("Attempting to broadcast message to {} active sessions", sessions.size()); - return Flux.fromIterable(sessions.values()) - .flatMap( - session -> - session - .sendNotification(method, params) - .doOnError( - e -> - LOG.error( - "Failed to send message to session {}: {}", - session.getSessionId(), - e.getMessage())) - .onErrorComplete()) - .then(); - } - - @Override - public Mono closeGracefully() { - LOG.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); - return Flux.fromIterable(sessions.values()).flatMap(MCPSession::closeGracefully).then(); - } - - private static class SSEConnection { - private final PrintWriter writer; - private final String sessionId; - private volatile boolean closed = false; - private int eventId = 0; - - public SSEConnection(PrintWriter writer, String sessionId) { - this.writer = writer; - this.sessionId = sessionId; - } - - public synchronized void sendEvent(String data) throws IOException { - if (closed) return; - - writer.printf("id: %d%n", ++eventId); - writer.printf("data: %s%n%n", data); - writer.flush(); - - if (writer.checkError()) { - throw new IOException("SSE write error"); - } - } - - private void sendEvent(String eventType, String data) throws IOException { - writer.write("event: " + eventType + "\n"); - writer.write("data: " + data + "\n\n"); - writer.flush(); - if (writer.checkError()) { - throw new IOException("Client disconnected"); - } - } - - public synchronized void sendComment(String comment) throws IOException { - if (closed) return; - - writer.printf(": %s%n%n", comment); - writer.flush(); - - if (writer.checkError()) { - throw new IOException("SSE write error"); - } - } - - public void close() throws IOException { - closed = true; - if (writer != null) { - writer.close(); - } - } - - public boolean isClosed() { - return closed; - } - - public String getSessionId() { - return sessionId; - } - } - - private static class MCPSession { - @Getter private final String sessionId; - private final long createdAt; - private final Map state; - private final PrintWriter outputStream; - private final ObjectMapper objectMapper; - - public MCPSession(ObjectMapper objectMapper, String sessionId, PrintWriter outputStream) { - this.sessionId = sessionId; - this.createdAt = System.currentTimeMillis(); - this.state = new ConcurrentHashMap<>(); - this.objectMapper = objectMapper; - this.outputStream = outputStream; - } - - public String getSessionId() { - return sessionId; - } - - public long getCreatedAt() { - return createdAt; - } - - public Map getState() { - return state; - } - - public Mono sendNotification(String method, Map params) { - McpSchema.JSONRPCNotification jsonrpcNotification = - new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, method, params); - return Mono.fromRunnable( - () -> { - try { - String json = objectMapper.writeValueAsString(jsonrpcNotification); - outputStream.write(json); - outputStream.write('\n'); - outputStream.flush(); - } catch (IOException e) { - LOG.error("Failed to send message", e); - } - }); - } - - public Mono closeGracefully() { - return Mono.fromRunnable( - () -> { - outputStream.flush(); - outputStream.close(); - }); - } - } - - private Map processRequest(JsonNode request, String sessionId) { - Map response = new HashMap<>(); - response.put("jsonrpc", "2.0"); - response.put("id", request.get("id")); - - String method = request.get("method").asText(); - - try { - switch (method) { - case "ping": - response.put("result", "pong"); - break; - - case "resources/list": - // TODO: Implement resource reading logic - response.put("result", getResourcesList()); - break; - - case "resources/read": - // TODO: Implement resource reading logic - break; - - case "tools/list": - response.put("result", new McpSchema.ListToolsResult(tools, null)); - break; - - case "tools/call": - JsonNode toolParams = request.get("params"); - if (toolParams != null && toolParams.has("name")) { - String toolName = toolParams.get("name").asText(); - JsonNode arguments = toolParams.get("arguments"); - McpSchema.Content content = - new McpSchema.TextContent( - JsonUtils.pojoToJson( - callTool( - authorizer, jwtFilter, limits, toolName, JsonUtils.getMap(arguments)))); - response.put("result", new McpSchema.CallToolResult(List.of(content), false)); - } else { - response.put("error", createError(-32602, "Invalid params")); - } - break; - - default: - response.put("error", createError(-32601, "Method not found: " + method)); - } - } catch (Exception e) { - log("Error processing request: " + method, e); - response.put("error", createError(-32603, "Internal error: " + e.getMessage())); - } - - return response; - } - - private void processNotification(JsonNode notification, String sessionId) { - String method = notification.get("method").asText(); - log("Received notification: " + method + " (session: " + sessionId + ")"); - - // Handle specific notifications - switch (method) { - case "notifications/initialized": - LOG.info("Client initialized for session: {}", sessionId); - break; - case "notifications/cancelled": - LOG.info("Client sent a cancellation request for session: {}", sessionId); - // Handle cancellation - break; - default: - log("Unknown notification: " + method); - } - } - - // Utility methods - private String generateSessionId() { - byte[] bytes = new byte[32]; - secureRandom.nextBytes(bytes); - return Base64.getUrlEncoder().withoutPadding().encodeToString(bytes); - } - - private boolean isValidOrigin(String origin) { - // Implement your origin validation logic - return origin.startsWith("http://localhost") - || origin.startsWith("https://localhost") - || origin.startsWith("https://yourdomain.com"); - } - - private boolean shouldUseSSE() { - // TODO: Decide when to use SSE vs direct JSON response - return false; - } - - private String readRequestBody(HttpServletRequest request) throws IOException { - StringBuilder body = new StringBuilder(); - try (BufferedReader reader = request.getReader()) { - String line; - while ((line = reader.readLine()) != null) { - body.append(line); - } - } - return body.toString(); - } - - private void sendError(HttpServletResponse response, int statusCode, String message) - throws IOException { - Map error = new HashMap<>(); - error.put("error", message); - - String errorJson = objectMapper.writeValueAsString(error); - response.setContentType(CONTENT_TYPE_JSON); - response.setStatus(statusCode); - response.getWriter().write(errorJson); - } - - private void sendJSONResponse(HttpServletResponse response, JsonNode request, String sessionId) - throws IOException { - Map jsonResponse = processRequest(request, sessionId); - String responseJson = objectMapper.writeValueAsString(jsonResponse); - - response.setContentType(CONTENT_TYPE_JSON); - response.setStatus(HttpServletResponse.SC_OK); - response.getWriter().write(responseJson); - response.getWriter().flush(); - } - - private void handleBatchRequest( - HttpServletRequest request, - HttpServletResponse response, - JsonNode batchRequest, - String sessionId) - throws IOException { - // TODO: Handle this - sendError(response, HttpServletResponse.SC_NOT_IMPLEMENTED, "Batch requests not implemented"); - } - - private Map createError(int code, String message) { - Map error = new HashMap<>(); - error.put("code", code); - error.put("message", message); - return error; - } - - private Map getServerCapabilities() { - Map capabilities = new HashMap<>(); - - // Resources capability - Map resources = new HashMap<>(); - resources.put("subscribe", true); - resources.put("listChanged", true); - capabilities.put("resources", resources); - - // Tools - Map tools = new HashMap<>(); - tools.put("listChanged", true); - capabilities.put("tools", tools); - - return capabilities; - } - - private Map getServerInfo() { - Map serverInfo = new HashMap<>(); - serverInfo.put("name", "OpenMetadata MCP Server - Streamable"); - serverInfo.put("version", "1.0.0"); - return serverInfo; - } - - private Map getResourcesList() { - return new HashMap<>(); - } - - private Map createResource( - String uri, String name, String mimeType, String description) { - Map resource = new HashMap<>(); - resource.put("uri", uri); - resource.put("name", name); - resource.put("mimeType", mimeType); - resource.put("description", description); - return resource; - } - - /** - * Send a server-initiated notification to a specific session - */ - public void sendNotificationToSession( - String sessionId, String method, Map params) { - // Find SSE connections for this session - for (SSEConnection connection : sseConnections.values()) { - if (sessionId.equals(connection.getSessionId())) { - try { - Map notification = new HashMap<>(); - notification.put("jsonrpc", "2.0"); - notification.put("method", method); - if (params != null) { - notification.put("params", params); - } - - connection.sendEvent(objectMapper.writeValueAsString(notification)); - } catch (IOException e) { - log("Error sending notification to session: " + sessionId, e); - } - } - } - } -} diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/mcp/McpServer.java b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/McpServer.java index c059e480984..3decc827e5c 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/mcp/McpServer.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/McpServer.java @@ -1,24 +1,33 @@ package org.openmetadata.service.mcp; -import static org.openmetadata.service.mcp.McpUtils.callTool; -import static org.openmetadata.service.mcp.McpUtils.getToolProperties; +import static org.openmetadata.service.search.SearchUtil.searchMetadata; -import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.JsonNode; import io.dropwizard.core.setup.Environment; import io.dropwizard.jetty.MutableServletContextHandler; import io.modelcontextprotocol.server.McpServerFeatures; import io.modelcontextprotocol.server.McpSyncServer; import io.modelcontextprotocol.spec.McpSchema; import jakarta.servlet.DispatcherType; +import jakarta.servlet.Filter; +import java.util.ArrayList; import java.util.EnumSet; import java.util.List; +import java.util.Map; import lombok.extern.slf4j.Slf4j; import org.eclipse.jetty.servlet.FilterHolder; import org.eclipse.jetty.servlet.ServletHolder; +import org.openmetadata.common.utils.CommonUtil; import org.openmetadata.service.OpenMetadataApplicationConfig; import org.openmetadata.service.limits.Limits; +import org.openmetadata.service.mcp.tools.GlossaryTermTool; +import org.openmetadata.service.mcp.tools.GlossaryTool; +import org.openmetadata.service.mcp.tools.PatchEntityTool; +import org.openmetadata.service.security.AuthorizationException; import org.openmetadata.service.security.Authorizer; import org.openmetadata.service.security.JwtFilter; +import org.openmetadata.service.security.auth.CatalogSecurityContext; +import org.openmetadata.service.util.EntityUtil; import org.openmetadata.service.util.JsonUtils; @Slf4j @@ -38,20 +47,6 @@ public class McpServer { new JwtFilter(config.getAuthenticationConfiguration(), config.getAuthorizerConfiguration()); this.authorizer = authorizer; this.limits = limits; - MutableServletContextHandler contextHandler = environment.getApplicationContext(); - McpAuthFilter authFilter = - new McpAuthFilter( - new JwtFilter( - config.getAuthenticationConfiguration(), config.getAuthorizerConfiguration())); - List tools = loadToolsDefinitionsFromJson(); - addSSETransport(contextHandler, authFilter, tools); - addStreamableHttpServlet(contextHandler, authFilter, tools); - } - - private void addSSETransport( - MutableServletContextHandler contextHandler, - McpAuthFilter authFilter, - List tools) { McpSchema.ServerCapabilities serverCapabilities = McpSchema.ServerCapabilities.builder() .tools(true) @@ -59,57 +54,152 @@ public class McpServer { .resources(true, true) .build(); - HttpServletSseServerTransportProvider sseTransport = - new HttpServletSseServerTransportProvider(new ObjectMapper(), "/mcp/messages", "/mcp/sse"); - + HttpServletSseServerTransportProvider transport = + new HttpServletSseServerTransportProvider("/mcp/messages", "/mcp/sse"); McpSyncServer server = - io.modelcontextprotocol.server.McpServer.sync(sseTransport) - .serverInfo("openmetadata-mcp-sse", "0.1.0") + io.modelcontextprotocol.server.McpServer.sync(transport) + .serverInfo("openmetadata-mcp", "0.1.0") .capabilities(serverCapabilities) .build(); - addToolsToServer(server, tools); - // SSE transport for MCP - ServletHolder servletHolderSSE = new ServletHolder(sseTransport); - contextHandler.addServlet(servletHolderSSE, "/mcp/*"); + // Add resources, prompts, and tools to the MCP server + addTools(server); + MutableServletContextHandler contextHandler = environment.getApplicationContext(); + ServletHolder servletHolder = new ServletHolder(transport); + contextHandler.addServlet(servletHolder, "/mcp/*"); + + McpAuthFilter authFilter = + new McpAuthFilter( + new JwtFilter( + config.getAuthenticationConfiguration(), config.getAuthorizerConfiguration())); contextHandler.addFilter( - new FilterHolder(authFilter), "/mcp/*", EnumSet.of(DispatcherType.REQUEST)); + new FilterHolder((Filter) authFilter), "/mcp/*", EnumSet.of(DispatcherType.REQUEST)); } - private void addStreamableHttpServlet( - MutableServletContextHandler contextHandler, - McpAuthFilter authFilter, - List tools) { - // Streamable HTTP servlet for MCP - MCPStreamableHttpServlet streamableHttpServlet = - new MCPStreamableHttpServlet(jwtFilter, authorizer, limits, tools); - ServletHolder servletHolderStreamableHttp = new ServletHolder(streamableHttpServlet); - contextHandler.addServlet(servletHolderStreamableHttp, "/mcp"); + public void addTools(McpSyncServer server) { + try { + LOG.info("Loading tool definitions..."); + List> cachedTools = loadToolsDefinitionsFromJson(); + if (cachedTools == null || cachedTools.isEmpty()) { + LOG.error("No tool definitions were loaded!"); + throw new RuntimeException("Failed to load tool definitions"); + } + LOG.info("Successfully loaded {} tool definitions", cachedTools.size()); - contextHandler.addFilter( - new FilterHolder(authFilter), "/mcp/*", EnumSet.of(DispatcherType.REQUEST)); - } - - public void addToolsToServer(McpSyncServer server, List tools) { - for (McpSchema.Tool tool : tools) { - server.addTool(getTool(tool)); + for (Map toolDef : cachedTools) { + try { + String name = (String) toolDef.get("name"); + String description = (String) toolDef.get("description"); + Map schema = JsonUtils.getMap(toolDef.get("parameters")); + server.addTool(getTool(JsonUtils.pojoToJson(schema), name, description)); + } catch (Exception e) { + LOG.error("Error processing tool definition: {}", toolDef, e); + } + } + LOG.info("Initializing request handlers..."); + } catch (Exception e) { + LOG.error("Error during server startup", e); + throw new RuntimeException("Failed to start MCP server", e); } } - protected List loadToolsDefinitionsFromJson() { - return getToolProperties("json/data/mcp/tools.json"); + protected List> loadToolsDefinitionsFromJson() { + String json = getJsonFromFile("json/data/mcp/tools.json"); + return loadToolDefinitionsFromJson(json); } - private McpServerFeatures.SyncToolSpecification getTool(McpSchema.Tool tool) { + protected static String getJsonFromFile(String path) { + try { + return CommonUtil.getResourceAsStream(McpServer.class.getClassLoader(), path); + } catch (Exception ex) { + LOG.error("Error loading JSON file: {}", path, ex); + return null; + } + } + + @SuppressWarnings("unchecked") + public List> loadToolDefinitionsFromJson(String json) { + try { + LOG.info("Loaded tool definitions, content length: {}", json.length()); + LOG.info("Raw tools.json content: {}", json); + + JsonNode toolsJson = JsonUtils.readTree(json); + JsonNode toolsArray = toolsJson.get("tools"); + + if (toolsArray == null || !toolsArray.isArray()) { + LOG.error("Invalid MCP tools file format. Expected 'tools' array."); + return new ArrayList<>(); + } + + List> tools = new ArrayList<>(); + for (JsonNode toolNode : toolsArray) { + String name = toolNode.get("name").asText(); + Map toolDef = JsonUtils.convertValue(toolNode, Map.class); + tools.add(toolDef); + LOG.info("Tool found: {} with definition: {}", name, toolDef); + } + + LOG.info("Found {} tool definitions", tools.size()); + return tools; + } catch (Exception e) { + LOG.error("Error loading tool definitions: {}", e.getMessage(), e); + throw e; + } + } + + private McpServerFeatures.SyncToolSpecification getTool( + String schema, String toolName, String description) { + McpSchema.Tool tool = new McpSchema.Tool(toolName, description, schema); + return new McpServerFeatures.SyncToolSpecification( tool, (exchange, arguments) -> { McpSchema.Content content = - new McpSchema.TextContent( - JsonUtils.pojoToJson( - callTool(authorizer, jwtFilter, limits, tool.name(), arguments))); + new McpSchema.TextContent(JsonUtils.pojoToJson(runMethod(toolName, arguments))); return new McpSchema.CallToolResult(List.of(content), false); }); } + + protected Object runMethod(String toolName, Map params) { + CatalogSecurityContext securityContext = + jwtFilter.getCatalogSecurityContext((String) params.get("Authorization")); + LOG.info( + "Catalog Principal: {} is trying to call the tool: {}", + securityContext.getUserPrincipal().getName(), + toolName); + Object result; + try { + switch (toolName) { + case "search_metadata": + result = searchMetadata(params); + break; + case "get_entity_details": + result = EntityUtil.getEntityDetails(params); + break; + case "create_glossary": + result = new GlossaryTool().execute(authorizer, limits, securityContext, params); + break; + case "create_glossary_term": + result = new GlossaryTermTool().execute(authorizer, limits, securityContext, params); + break; + case "patch_entity": + result = new PatchEntityTool().execute(authorizer, securityContext, params); + break; + default: + result = Map.of("error", "Unknown function: " + toolName); + break; + } + + return result; + } catch (AuthorizationException ex) { + LOG.error("Authorization error: {}", ex.getMessage()); + return Map.of( + "error", String.format("Authorization error: %s", ex.getMessage()), "statusCode", 403); + } catch (Exception ex) { + LOG.error("Error executing tool: {}", ex.getMessage()); + return Map.of( + "error", String.format("Error executing tool: %s", ex.getMessage()), "statusCode", 500); + } + } } diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/mcp/McpUtils.java b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/McpUtils.java deleted file mode 100644 index be43a4e2981..00000000000 --- a/openmetadata-service/src/main/java/org/openmetadata/service/mcp/McpUtils.java +++ /dev/null @@ -1,151 +0,0 @@ -package org.openmetadata.service.mcp; - -import static org.openmetadata.service.search.SearchUtil.searchMetadata; - -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpSchema; -import jakarta.servlet.http.HttpServletRequest; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import lombok.extern.slf4j.Slf4j; -import org.openmetadata.common.utils.CommonUtil; -import org.openmetadata.service.limits.Limits; -import org.openmetadata.service.mcp.tools.GlossaryTermTool; -import org.openmetadata.service.mcp.tools.GlossaryTool; -import org.openmetadata.service.mcp.tools.PatchEntityTool; -import org.openmetadata.service.security.AuthorizationException; -import org.openmetadata.service.security.Authorizer; -import org.openmetadata.service.security.JwtFilter; -import org.openmetadata.service.security.auth.CatalogSecurityContext; -import org.openmetadata.service.util.EntityUtil; -import org.openmetadata.service.util.JsonUtils; - -@Slf4j -public class McpUtils { - - public static McpSchema.JSONRPCMessage getJsonRpcMessageWithAuthorizationParam( - ObjectMapper objectMapper, HttpServletRequest request, String body) throws IOException { - Map requestMessage = JsonUtils.getMap(JsonUtils.readTree(body)); - Map params = (Map) requestMessage.get("params"); - if (params != null) { - Map arguments = (Map) params.get("arguments"); - if (arguments != null) { - arguments.put("Authorization", JwtFilter.extractToken(request.getHeader("Authorization"))); - } - } - return McpSchema.deserializeJsonRpcMessage(objectMapper, JsonUtils.pojoToJson(requestMessage)); - } - - @SuppressWarnings("unchecked") - public static List> loadToolDefinitionsFromJson(String json) { - try { - LOG.info("Loaded tool definitions, content length: {}", json.length()); - LOG.info("Raw tools.json content: {}", json); - - JsonNode toolsJson = JsonUtils.readTree(json); - JsonNode toolsArray = toolsJson.get("tools"); - - if (toolsArray == null || !toolsArray.isArray()) { - LOG.error("Invalid MCP tools file format. Expected 'tools' array."); - return new ArrayList<>(); - } - - List> tools = new ArrayList<>(); - for (JsonNode toolNode : toolsArray) { - String name = toolNode.get("name").asText(); - Map toolDef = JsonUtils.convertValue(toolNode, Map.class); - tools.add(toolDef); - LOG.info("Tool found: {} with definition: {}", name, toolDef); - } - - LOG.info("Found {} tool definitions", tools.size()); - return tools; - } catch (Exception e) { - LOG.error("Error loading tool definitions: {}", e.getMessage(), e); - throw e; - } - } - - public static String getJsonFromFile(String path) { - try { - return CommonUtil.getResourceAsStream(McpServer.class.getClassLoader(), path); - } catch (Exception ex) { - LOG.error("Error loading JSON file: {}", path, ex); - return null; - } - } - - public static Object callTool( - Authorizer authorizer, - JwtFilter jwtFilter, - Limits limits, - String toolName, - Map params) { - CatalogSecurityContext securityContext = - jwtFilter.getCatalogSecurityContext((String) params.get("Authorization")); - LOG.info( - "Catalog Principal: {} is trying to call the tool: {}", - securityContext.getUserPrincipal().getName(), - toolName); - Object result; - try { - switch (toolName) { - case "search_metadata": - result = searchMetadata(params); - break; - case "get_entity_details": - result = EntityUtil.getEntityDetails(params); - break; - case "create_glossary": - result = new GlossaryTool().execute(authorizer, limits, securityContext, params); - break; - case "create_glossary_term": - result = new GlossaryTermTool().execute(authorizer, limits, securityContext, params); - break; - case "patch_entity": - result = new PatchEntityTool().execute(authorizer, limits, securityContext, params); - break; - default: - result = Map.of("error", "Unknown function: " + toolName); - break; - } - - return result; - } catch (AuthorizationException ex) { - LOG.error("Authorization error: {}", ex.getMessage()); - return Map.of( - "error", String.format("Authorization error: %s", ex.getMessage()), "statusCode", 403); - } catch (Exception ex) { - LOG.error("Error executing tool: {}", ex.getMessage()); - return Map.of( - "error", String.format("Error executing tool: %s", ex.getMessage()), "statusCode", 500); - } - } - - public static List getToolProperties(String jsonFilePath) { - try { - List result = new ArrayList<>(); - String json = getJsonFromFile(jsonFilePath); - List> cachedTools = loadToolDefinitionsFromJson(json); - if (cachedTools == null || cachedTools.isEmpty()) { - LOG.error("No tool definitions were loaded!"); - throw new RuntimeException("Failed to load tool definitions"); - } - LOG.debug("Successfully loaded {} tool definitions", cachedTools.size()); - for (int i = 0; i < cachedTools.size(); i++) { - Map toolDef = cachedTools.get(i); - String name = (String) toolDef.get("name"); - String description = (String) toolDef.get("description"); - Map schema = JsonUtils.getMap(toolDef.get("parameters")); - result.add(new McpSchema.Tool(name, description, JsonUtils.pojoToJson(schema))); - } - return result; - } catch (Exception e) { - LOG.error("Error during server startup", e); - throw new RuntimeException("Failed to start MCP server", e); - } - } -} diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/search/SearchUtil.java b/openmetadata-service/src/main/java/org/openmetadata/service/search/SearchUtil.java index ff7e5528b04..47dd5fd21ac 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/search/SearchUtil.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/search/SearchUtil.java @@ -141,10 +141,7 @@ public class SearchUtil { case "glossary_search_index", Entity.GLOSSARY -> Entity.GLOSSARY; case "domain_search_index", Entity.DOMAIN -> Entity.DOMAIN; case "data_product_search_index", Entity.DATA_PRODUCT -> Entity.DATA_PRODUCT; - case "team_search_index", Entity.TEAM -> Entity.TEAM; - case "user_Search_index", Entity.USER -> Entity.USER; - case "dataAsset" -> "dataAsset"; - default -> "dataAsset"; + default -> "default"; }; } diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/search/indexes/TableIndex.java b/openmetadata-service/src/main/java/org/openmetadata/service/search/indexes/TableIndex.java index 31afa0aad6e..1bb01266e14 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/search/indexes/TableIndex.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/search/indexes/TableIndex.java @@ -108,7 +108,6 @@ public record TableIndex(Table table) implements ColumnIndex { doc.put("processedLineage", table.getProcessedLineage()); doc.put("entityRelationship", SearchIndex.populateEntityRelationshipData(table)); doc.put("databaseSchema", getEntityWithDisplayName(table.getDatabaseSchema())); - doc.put("tableQueries", table.getTableQueries()); doc.put( "changeSummary", Optional.ofNullable(table.getChangeDescription()) diff --git a/openmetadata-spec/src/main/resources/json/schema/entity/data/table.json b/openmetadata-spec/src/main/resources/json/schema/entity/data/table.json index 3fe1d444fb1..c0559cf884b 100644 --- a/openmetadata-spec/src/main/resources/json/schema/entity/data/table.json +++ b/openmetadata-spec/src/main/resources/json/schema/entity/data/table.json @@ -1162,14 +1162,6 @@ "description": "Processed lineage for the table", "type": "boolean", "default": false - }, - "tableQueries": { - "description": "List of queries that are used to create this table.", - "type": "array", - "items": { - "$ref": "../../type/basic.json#/definitions/sqlQuery" - }, - "default": null } }, "required": [ diff --git a/openmetadata-ui/src/main/resources/ui/src/generated/entity/data/table.ts b/openmetadata-ui/src/main/resources/ui/src/generated/entity/data/table.ts index 56e8bd4f52a..cbefaa21e5f 100644 --- a/openmetadata-ui/src/main/resources/ui/src/generated/entity/data/table.ts +++ b/openmetadata-ui/src/main/resources/ui/src/generated/entity/data/table.ts @@ -161,11 +161,7 @@ export interface Table { * Table Profiler Config to include or exclude columns from profiling. */ tableProfilerConfig?: TableProfilerConfig; - /** - * List of queries that are used to create this table. - */ - tableQueries?: string[]; - tableType?: TableType; + tableType?: TableType; /** * Tags for this table. */