diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/jdbi3/TableRepository.java b/openmetadata-service/src/main/java/org/openmetadata/service/jdbi3/TableRepository.java
index 756f35732de..d52b2cb556e 100644
--- a/openmetadata-service/src/main/java/org/openmetadata/service/jdbi3/TableRepository.java
+++ b/openmetadata-service/src/main/java/org/openmetadata/service/jdbi3/TableRepository.java
@@ -26,10 +26,8 @@ import static org.openmetadata.schema.type.Include.NON_DELETED;
import static org.openmetadata.service.Entity.DATABASE_SCHEMA;
import static org.openmetadata.service.Entity.FIELD_OWNERS;
import static org.openmetadata.service.Entity.FIELD_TAGS;
-import static org.openmetadata.service.Entity.QUERY;
import static org.openmetadata.service.Entity.TABLE;
import static org.openmetadata.service.Entity.TEST_SUITE;
-import static org.openmetadata.service.Entity.getEntities;
import static org.openmetadata.service.Entity.populateEntityFieldTags;
import static org.openmetadata.service.util.EntityUtil.getLocalColumnName;
import static org.openmetadata.service.util.FullyQualifiedName.getColumnName;
@@ -64,7 +62,6 @@ import org.openmetadata.schema.EntityInterface;
import org.openmetadata.schema.api.data.CreateTableProfile;
import org.openmetadata.schema.api.feed.ResolveTask;
import org.openmetadata.schema.entity.data.DatabaseSchema;
-import org.openmetadata.schema.entity.data.Query;
import org.openmetadata.schema.entity.data.Table;
import org.openmetadata.schema.entity.feed.Suggestion;
import org.openmetadata.schema.tests.CustomMetric;
@@ -182,15 +179,6 @@ public class TableRepository extends EntityRepository
{
column.setCustomMetrics(getCustomMetrics(table, column.getName()));
}
}
- if (fields.contains("tableQueries")) {
- List queriesEntity =
- getEntities(
- listOrEmpty(findTo(table.getId(), TABLE, Relationship.MENTIONED_IN, QUERY)),
- "id",
- ALL);
- List queries = listOrEmpty(queriesEntity).stream().map(Query::getQuery).toList();
- table.setTableQueries(queries);
- }
}
@Override
diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/mcp/HttpServletSseServerTransportProvider.java b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/HttpServletSseServerTransportProvider.java
index 3a8e22932a2..3d931f7415b 100644
--- a/openmetadata-service/src/main/java/org/openmetadata/service/mcp/HttpServletSseServerTransportProvider.java
+++ b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/HttpServletSseServerTransportProvider.java
@@ -1,6 +1,8 @@
-package org.openmetadata.service.mcp;
+/*
+HttpServletSseServerTransportProvider - Jakarta servlet-based MCP server transport
+*/
-import static org.openmetadata.service.mcp.McpUtils.getJsonRpcMessageWithAuthorizationParam;
+package org.openmetadata.service.mcp;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
@@ -9,7 +11,6 @@ import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.spec.McpServerTransport;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
-import io.modelcontextprotocol.util.Assert;
import jakarta.servlet.AsyncContext;
import jakarta.servlet.ServletException;
import jakarta.servlet.annotation.WebServlet;
@@ -22,10 +23,9 @@ import java.io.PrintWriter;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
-import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
+import org.openmetadata.service.security.JwtFilter;
+import org.openmetadata.service.util.JsonUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
@@ -42,220 +42,171 @@ public class HttpServletSseServerTransportProvider extends HttpServlet
public static final String DEFAULT_SSE_ENDPOINT = "/sse";
public static final String MESSAGE_EVENT_TYPE = "message";
public static final String ENDPOINT_EVENT_TYPE = "endpoint";
- public static final String DEFAULT_BASE_URL = "";
private final ObjectMapper objectMapper;
- private final String baseUrl;
private final String messageEndpoint;
private final String sseEndpoint;
- private final Map sessions = new ConcurrentHashMap<>();
- private final AtomicBoolean isClosing = new AtomicBoolean(false);
+ private final Map sessions;
+ private final AtomicBoolean isClosing;
private McpServerSession.Factory sessionFactory;
- private ExecutorService executorService =
- Executors.newCachedThreadPool(
- r -> {
- Thread t = new Thread(r, "MCP-Worker-SSE");
- t.setDaemon(true);
- return t;
- });
- public HttpServletSseServerTransportProvider(
- ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) {
- this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint);
- }
-
- public HttpServletSseServerTransportProvider(
- ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) {
- this.objectMapper = objectMapper;
- this.baseUrl = baseUrl;
+ public HttpServletSseServerTransportProvider(String messageEndpoint, String sseEndpoint) {
+ this.sessions = new ConcurrentHashMap<>();
+ this.isClosing = new AtomicBoolean(false);
+ this.objectMapper = JsonUtils.getObjectMapper();
this.messageEndpoint = messageEndpoint;
this.sseEndpoint = sseEndpoint;
}
- public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) {
- this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT);
+ public HttpServletSseServerTransportProvider(String messageEndpoint) {
+ this(messageEndpoint, "/sse");
}
- @Override
public void setSessionFactory(McpServerSession.Factory sessionFactory) {
this.sessionFactory = sessionFactory;
}
- @Override
public Mono notifyClients(String method, Map params) {
- if (sessions.isEmpty()) {
+ if (this.sessions.isEmpty()) {
logger.debug("No active sessions to broadcast message to");
return Mono.empty();
+ } else {
+ logger.debug("Attempting to broadcast message to {} active sessions", this.sessions.size());
+ return Flux.fromIterable(this.sessions.values())
+ .flatMap(
+ (session) ->
+ session
+ .sendNotification(method, params)
+ .doOnError(
+ (e) ->
+ logger.error(
+ "Failed to send message to session {}: {}",
+ session.getId(),
+ e.getMessage()))
+ .onErrorComplete())
+ .then();
}
-
- logger.debug("Attempting to broadcast message to {} active sessions", sessions.size());
- return Flux.fromIterable(sessions.values())
- .flatMap(
- session ->
- session
- .sendNotification(method, params)
- .doOnError(
- e ->
- logger.error(
- "Failed to send message to session {}: {}",
- session.getId(),
- e.getMessage()))
- .onErrorComplete())
- .then();
}
- @Override
protected void doGet(HttpServletRequest request, HttpServletResponse response)
- throws IOException {
+ throws ServletException, IOException {
handleSseEvent(request, response);
}
private void handleSseEvent(HttpServletRequest request, HttpServletResponse response)
- throws IOException {
- String requestURI = request.getRequestURI();
- if (!requestURI.endsWith(sseEndpoint)) {
- response.sendError(HttpServletResponse.SC_NOT_FOUND);
- return;
+ throws ServletException, IOException {
+ String pathInfo = request.getPathInfo();
+ if (!this.sseEndpoint.contains(pathInfo)) {
+ response.sendError(404);
+ } else if (this.isClosing.get()) {
+ response.sendError(503, "Server is shutting down");
+ } else {
+ response.setContentType("text/event-stream");
+ response.setCharacterEncoding("UTF-8");
+ response.setHeader("Cache-Control", "no-cache");
+ response.setHeader("Connection", "keep-alive");
+ response.setHeader("Access-Control-Allow-Origin", "*");
+ String sessionId = UUID.randomUUID().toString();
+ AsyncContext asyncContext = request.startAsync();
+ asyncContext.setTimeout(0L);
+ PrintWriter writer = response.getWriter();
+ HttpServletMcpSessionTransport sessionTransport =
+ new HttpServletMcpSessionTransport(sessionId, asyncContext, writer);
+ McpServerSession session = this.sessionFactory.create(sessionTransport);
+ this.sessions.put(sessionId, session);
+ this.sendEvent(writer, "endpoint", this.messageEndpoint + "?sessionId=" + sessionId);
}
-
- if (isClosing.get()) {
- response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down");
- return;
- }
-
- response.setContentType("text/event-stream");
- response.setCharacterEncoding(UTF_8);
- response.setHeader("Cache-Control", "no-cache");
- response.setHeader("Connection", "keep-alive");
- response.setHeader("Access-Control-Allow-Origin", "*");
-
- String sessionId = UUID.randomUUID().toString();
- AsyncContext asyncContext = request.startAsync();
- asyncContext.setTimeout(0);
-
- PrintWriter writer = response.getWriter();
-
- // Create a new session transport
- HttpServletMcpSessionTransport sessionTransport =
- new HttpServletMcpSessionTransport(sessionId, asyncContext, writer);
-
- // Create a new session using the session factory
- McpServerSession session = sessionFactory.create(sessionTransport);
- this.sessions.put(sessionId, session);
-
- executorService.submit(
- () -> {
- // TODO: Handle session lifecycle and keepalive
- try {
- while (sessions.containsKey(sessionId)) {
- // Send keepalive every 30 seconds
- Thread.sleep(30000);
- writer.write(": keep-alive\n\n");
- writer.flush();
- }
- } catch (InterruptedException e) {
- Thread.currentThread().interrupt();
- } catch (Exception e) {
- log("SSE error", e);
- } finally {
- try {
- session.closeGracefully();
- asyncContext.complete();
- } catch (Exception e) {
- log("Error closing long-lived SSE connection", e);
- }
- }
- });
-
- // Send initial endpoint event
- this.sendEvent(
- writer,
- ENDPOINT_EVENT_TYPE,
- this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId);
}
- @Override
protected void doPost(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException {
+ if (this.isClosing.get()) {
+ response.sendError(503, "Server is shutting down");
+ } else {
+ String requestURI = request.getRequestURI();
+ if (!requestURI.endsWith(this.messageEndpoint)) {
+ response.sendError(404);
+ } else {
+ String sessionId = request.getParameter("sessionId");
+ if (sessionId == null) {
+ response.setContentType("application/json");
+ response.setCharacterEncoding("UTF-8");
+ response.setStatus(400);
+ String jsonError =
+ this.objectMapper.writeValueAsString(
+ new McpError("Session ID missing in message endpoint"));
+ PrintWriter writer = response.getWriter();
+ writer.write(jsonError);
+ writer.flush();
+ } else {
+ McpServerSession session = (McpServerSession) this.sessions.get(sessionId);
+ if (session == null) {
+ response.setContentType("application/json");
+ response.setCharacterEncoding("UTF-8");
+ response.setStatus(404);
+ String jsonError =
+ this.objectMapper.writeValueAsString(
+ new McpError("Session not found: " + sessionId));
+ PrintWriter writer = response.getWriter();
+ writer.write(jsonError);
+ writer.flush();
+ } else {
+ try {
+ BufferedReader reader = request.getReader();
+ StringBuilder body = new StringBuilder();
- if (isClosing.get()) {
- response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down");
- return;
- }
+ String line;
+ while ((line = reader.readLine()) != null) {
+ body.append(line);
+ }
- String requestURI = request.getRequestURI();
- if (!requestURI.endsWith(messageEndpoint)) {
- response.sendError(HttpServletResponse.SC_NOT_FOUND);
- return;
- }
+ McpSchema.JSONRPCMessage message =
+ getJsonRpcMessageWithAuthorizationParam(request, body.toString());
+ session.handle(message).block();
+ response.setStatus(200);
+ } catch (Exception var11) {
+ Exception e = var11;
+ logger.error("Error processing message: {}", var11.getMessage());
- // Get the session ID from the request parameter
- String sessionId = request.getParameter("sessionId");
- if (sessionId == null) {
- response.setContentType(APPLICATION_JSON);
- response.setCharacterEncoding(UTF_8);
- response.setStatus(HttpServletResponse.SC_BAD_REQUEST);
- String jsonError =
- objectMapper.writeValueAsString(new McpError("Session ID missing in message endpoint"));
- PrintWriter writer = response.getWriter();
- writer.write(jsonError);
- writer.flush();
- return;
- }
-
- // Get the session from the sessions map
- McpServerSession session = sessions.get(sessionId);
- if (session == null) {
- response.setContentType(APPLICATION_JSON);
- response.setCharacterEncoding(UTF_8);
- response.setStatus(HttpServletResponse.SC_NOT_FOUND);
- String jsonError =
- objectMapper.writeValueAsString(new McpError("Session not found: " + sessionId));
- PrintWriter writer = response.getWriter();
- writer.write(jsonError);
- writer.flush();
- return;
- }
-
- try {
- BufferedReader reader = request.getReader();
- StringBuilder body = new StringBuilder();
- String line;
- while ((line = reader.readLine()) != null) {
- body.append(line);
- }
-
- McpSchema.JSONRPCMessage message =
- getJsonRpcMessageWithAuthorizationParam(this.objectMapper, request, body.toString());
-
- // Process the message through the session's handle method
- session.handle(message).block(); // Block for Servlet compatibility
-
- response.setStatus(HttpServletResponse.SC_OK);
- } catch (Exception e) {
- logger.error("Error processing message: {}", e.getMessage());
- try {
- McpError mcpError = new McpError(e.getMessage());
- response.setContentType(APPLICATION_JSON);
- response.setCharacterEncoding(UTF_8);
- response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
- String jsonError = objectMapper.writeValueAsString(mcpError);
- PrintWriter writer = response.getWriter();
- writer.write(jsonError);
- writer.flush();
- } catch (IOException ex) {
- logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex.getMessage());
- response.sendError(
- HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Error processing message");
+ try {
+ McpError mcpError = new McpError(e.getMessage());
+ response.setContentType("application/json");
+ response.setCharacterEncoding("UTF-8");
+ response.setStatus(500);
+ String jsonError = this.objectMapper.writeValueAsString(mcpError);
+ PrintWriter writer = response.getWriter();
+ writer.write(jsonError);
+ writer.flush();
+ } catch (IOException ex) {
+ logger.error("Failed to send error response: {}", ex.getMessage());
+ response.sendError(500, "Error processing message");
+ }
+ }
+ }
+ }
}
}
}
- @Override
- public Mono closeGracefully() {
- isClosing.set(true);
- logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size());
+ private McpSchema.JSONRPCMessage getJsonRpcMessageWithAuthorizationParam(
+ HttpServletRequest request, String body) throws IOException {
+ Map requestMessage = JsonUtils.getMap(JsonUtils.readTree(body));
+ Map params = (Map) requestMessage.get("params");
+ if (params != null) {
+ Map arguments = (Map) params.get("arguments");
+ if (arguments != null) {
+ arguments.put("Authorization", JwtFilter.extractToken(request.getHeader("Authorization")));
+ }
+ }
+ return McpSchema.deserializeJsonRpcMessage(
+ this.objectMapper, JsonUtils.pojoToJson(requestMessage));
+ }
- return Flux.fromIterable(sessions.values()).flatMap(McpServerSession::closeGracefully).then();
+ public Mono closeGracefully() {
+ this.isClosing.set(true);
+ logger.debug("Initiating graceful shutdown with {} active sessions", this.sessions.size());
+ return Flux.fromIterable(this.sessions.values())
+ .flatMap(McpServerSession::closeGracefully)
+ .then();
}
private void sendEvent(PrintWriter writer, String eventType, String data) throws IOException {
@@ -267,19 +218,8 @@ public class HttpServletSseServerTransportProvider extends HttpServlet
}
}
- @Override
public void destroy() {
- closeGracefully().block();
- if (executorService != null) {
- executorService.shutdown();
- try {
- if (!executorService.awaitTermination(5, TimeUnit.SECONDS)) {
- executorService.shutdownNow();
- }
- } catch (InterruptedException e) {
- Thread.currentThread().interrupt();
- }
- }
+ this.closeGracefully().block();
super.destroy();
}
@@ -293,106 +233,65 @@ public class HttpServletSseServerTransportProvider extends HttpServlet
this.sessionId = sessionId;
this.asyncContext = asyncContext;
this.writer = writer;
- logger.debug("Session transport {} initialized with SSE writer", sessionId);
+ HttpServletSseServerTransportProvider.logger.debug(
+ "Session transport {} initialized with SSE writer", sessionId);
}
- @Override
public Mono sendMessage(McpSchema.JSONRPCMessage message) {
return Mono.fromRunnable(
() -> {
try {
- String jsonText = objectMapper.writeValueAsString(message);
- sendEvent(writer, MESSAGE_EVENT_TYPE, jsonText);
- logger.debug("Message sent to session {}", sessionId);
+ String jsonText =
+ HttpServletSseServerTransportProvider.this.objectMapper.writeValueAsString(
+ message);
+ HttpServletSseServerTransportProvider.this.sendEvent(
+ this.writer, "message", jsonText);
+ HttpServletSseServerTransportProvider.logger.debug(
+ "Message sent to session {}", this.sessionId);
} catch (Exception e) {
- logger.error("Failed to send message to session {}: {}", sessionId, e.getMessage());
- sessions.remove(sessionId);
- asyncContext.complete();
+ HttpServletSseServerTransportProvider.logger.error(
+ "Failed to send message to session {}: {}", this.sessionId, e.getMessage());
+ HttpServletSseServerTransportProvider.this.sessions.remove(this.sessionId);
+ this.asyncContext.complete();
}
});
}
- @Override
public T unmarshalFrom(Object data, TypeReference typeRef) {
- return objectMapper.convertValue(data, typeRef);
+ return (T)
+ HttpServletSseServerTransportProvider.this.objectMapper.convertValue(data, typeRef);
}
- @Override
public Mono closeGracefully() {
return Mono.fromRunnable(
() -> {
- logger.debug("Closing session transport: {}", sessionId);
+ HttpServletSseServerTransportProvider.logger.debug(
+ "Closing session transport: {}", this.sessionId);
+
try {
- sessions.remove(sessionId);
- asyncContext.complete();
- logger.debug("Successfully completed async context for session {}", sessionId);
+ HttpServletSseServerTransportProvider.this.sessions.remove(this.sessionId);
+ this.asyncContext.complete();
+ HttpServletSseServerTransportProvider.logger.debug(
+ "Successfully completed async context for session {}", this.sessionId);
} catch (Exception e) {
- logger.warn(
- "Failed to complete async context for session {}: {}", sessionId, e.getMessage());
+ HttpServletSseServerTransportProvider.logger.warn(
+ "Failed to complete async context for session {}: {}",
+ this.sessionId,
+ e.getMessage());
}
});
}
- @Override
public void close() {
try {
- sessions.remove(sessionId);
- asyncContext.complete();
- logger.debug("Successfully completed async context for session {}", sessionId);
+ HttpServletSseServerTransportProvider.this.sessions.remove(this.sessionId);
+ this.asyncContext.complete();
+ HttpServletSseServerTransportProvider.logger.debug(
+ "Successfully completed async context for session {}", this.sessionId);
} catch (Exception e) {
- logger.warn(
- "Failed to complete async context for session {}: {}", sessionId, e.getMessage());
+ HttpServletSseServerTransportProvider.logger.warn(
+ "Failed to complete async context for session {}: {}", this.sessionId, e.getMessage());
}
}
}
-
- public static Builder builder() {
- return new Builder();
- }
-
- public static class Builder {
-
- private ObjectMapper objectMapper = new ObjectMapper();
-
- private String baseUrl = DEFAULT_BASE_URL;
-
- private String messageEndpoint;
-
- private String sseEndpoint = DEFAULT_SSE_ENDPOINT;
-
- public Builder objectMapper(ObjectMapper objectMapper) {
- Assert.notNull(objectMapper, "ObjectMapper must not be null");
- this.objectMapper = objectMapper;
- return this;
- }
-
- public Builder baseUrl(String baseUrl) {
- Assert.notNull(baseUrl, "Base URL must not be null");
- this.baseUrl = baseUrl;
- return this;
- }
-
- public Builder messageEndpoint(String messageEndpoint) {
- Assert.hasText(messageEndpoint, "Message endpoint must not be empty");
- this.messageEndpoint = messageEndpoint;
- return this;
- }
-
- public Builder sseEndpoint(String sseEndpoint) {
- Assert.hasText(sseEndpoint, "SSE endpoint must not be empty");
- this.sseEndpoint = sseEndpoint;
- return this;
- }
-
- public HttpServletSseServerTransportProvider build() {
- if (objectMapper == null) {
- throw new IllegalStateException("ObjectMapper must be set");
- }
- if (messageEndpoint == null) {
- throw new IllegalStateException("MessageEndpoint must be set");
- }
- return new HttpServletSseServerTransportProvider(
- objectMapper, baseUrl, messageEndpoint, sseEndpoint);
- }
- }
}
diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/mcp/MCPStreamableHttpServlet.java b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/MCPStreamableHttpServlet.java
deleted file mode 100644
index 8e35120122d..00000000000
--- a/openmetadata-service/src/main/java/org/openmetadata/service/mcp/MCPStreamableHttpServlet.java
+++ /dev/null
@@ -1,701 +0,0 @@
-package org.openmetadata.service.mcp;
-
-import static org.openmetadata.service.mcp.McpUtils.callTool;
-import static org.openmetadata.service.mcp.McpUtils.getJsonRpcMessageWithAuthorizationParam;
-
-import com.fasterxml.jackson.databind.JsonNode;
-import com.fasterxml.jackson.databind.ObjectMapper;
-import io.modelcontextprotocol.spec.McpSchema;
-import io.modelcontextprotocol.spec.McpServerSession;
-import io.modelcontextprotocol.spec.McpServerTransportProvider;
-import jakarta.servlet.AsyncContext;
-import jakarta.servlet.ServletException;
-import jakarta.servlet.annotation.WebServlet;
-import jakarta.servlet.http.HttpServlet;
-import jakarta.servlet.http.HttpServletRequest;
-import jakarta.servlet.http.HttpServletResponse;
-import java.io.BufferedReader;
-import java.io.IOException;
-import java.io.PrintWriter;
-import java.security.SecureRandom;
-import java.util.ArrayList;
-import java.util.Base64;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.UUID;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
-import java.util.concurrent.TimeUnit;
-import lombok.Getter;
-import lombok.extern.slf4j.Slf4j;
-import org.openmetadata.service.limits.Limits;
-import org.openmetadata.service.security.Authorizer;
-import org.openmetadata.service.security.JwtFilter;
-import org.openmetadata.service.util.JsonUtils;
-import reactor.core.publisher.Flux;
-import reactor.core.publisher.Mono;
-
-/**
- * MCP (Model Context Protocol) Streamable HTTP Servlet
- * This servlet implements the Streamable HTTP transport specification for MCP.
- */
-@WebServlet(value = "/mcp", asyncSupported = true)
-@Slf4j
-public class MCPStreamableHttpServlet extends HttpServlet implements McpServerTransportProvider {
-
- private static final String SESSION_HEADER = "Mcp-Session-Id";
- private static final String CONTENT_TYPE_JSON = "application/json";
- private static final String CONTENT_TYPE_SSE = "text/event-stream";
-
- private ObjectMapper objectMapper = new ObjectMapper();
- private Map sessions = new ConcurrentHashMap<>();
- private Map sseConnections = new ConcurrentHashMap<>();
- private SecureRandom secureRandom = new SecureRandom();
- private ExecutorService executorService =
- Executors.newCachedThreadPool(
- r -> {
- Thread t = new Thread(r, "MCP-Worker-Streamable");
- t.setDaemon(true);
- return t;
- });
- private McpServerSession.Factory sessionFactory;
- private final JwtFilter jwtFilter;
- private final Authorizer authorizer;
- private final List tools = new ArrayList<>();
- private final Limits limits;
-
- public MCPStreamableHttpServlet(
- JwtFilter jwtFilter, Authorizer authorizer, Limits limits, List tools) {
- this.jwtFilter = jwtFilter;
- this.authorizer = authorizer;
- this.limits = limits;
- this.tools.addAll(tools);
- }
-
- @Override
- public void init() throws ServletException {
- super.init();
- log("MCP Streamable HTTP Servlet initialized");
- }
-
- @Override
- public void destroy() {
- if (executorService != null) {
- executorService.shutdown();
- try {
- if (!executorService.awaitTermination(5, TimeUnit.SECONDS)) {
- executorService.shutdownNow();
- }
- } catch (InterruptedException e) {
- Thread.currentThread().interrupt();
- }
- }
-
- // Close all SSE connections
- for (SSEConnection connection : sseConnections.values()) {
- try {
- connection.close();
- } catch (IOException e) {
- log("Error closing SSE connection", e);
- }
- }
-
- super.destroy();
- }
-
- @Override
- protected void doPost(HttpServletRequest request, HttpServletResponse response)
- throws ServletException, IOException {
-
- // Security: Validate Origin header
- String origin = request.getHeader("Origin");
- if (origin != null && !isValidOrigin(origin)) {
- sendError(response, HttpServletResponse.SC_FORBIDDEN, "Invalid origin");
- return;
- }
-
- // Validate Accept header - MUST include both application/json and text/event-stream
- String acceptHeader = request.getHeader("Accept");
- if (acceptHeader == null
- || !acceptHeader.contains(CONTENT_TYPE_JSON)
- || !acceptHeader.contains(CONTENT_TYPE_SSE)) {
- sendError(
- response,
- HttpServletResponse.SC_BAD_REQUEST,
- "Accept header must include both application/json and text/event-stream");
- return;
- }
-
- try {
- String requestBody = readRequestBody(request);
- String sessionId = request.getHeader(SESSION_HEADER);
-
- // Parse JSON-RPC message(s)
- McpSchema.JSONRPCMessage message =
- getJsonRpcMessageWithAuthorizationParam(this.objectMapper, request, requestBody);
- JsonNode jsonNode = objectMapper.valueToTree(message);
-
- // TODO: here we need to see how to handle the batch request from the Spec
- if (jsonNode.isArray()) {
- handleBatchRequest(request, response, jsonNode, sessionId);
- } else {
- handleSingleRequest(request, response, jsonNode, sessionId);
- }
- } catch (Exception e) {
- log("Error handling POST request", e);
- sendError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Internal server error");
- }
- }
-
- @Override
- protected void doGet(HttpServletRequest request, HttpServletResponse response)
- throws ServletException, IOException {
-
- String sessionId = request.getHeader(SESSION_HEADER);
- String acceptHeader = request.getHeader("Accept");
-
- if (acceptHeader == null || !acceptHeader.contains(CONTENT_TYPE_SSE)) {
- sendError(
- response,
- HttpServletResponse.SC_BAD_REQUEST,
- "Accept header must include text/event-stream");
- return;
- }
-
- if (sessionId != null && !sessions.containsKey(sessionId)) {
- sendError(response, HttpServletResponse.SC_NOT_FOUND, "Session not found");
- return;
- }
-
- startSSEStreamForGet(request, response, sessionId);
- }
-
- @Override
- protected void doDelete(HttpServletRequest request, HttpServletResponse response)
- throws ServletException, IOException {
-
- String sessionId = request.getHeader(SESSION_HEADER);
-
- if (sessionId == null) {
- sendError(response, HttpServletResponse.SC_BAD_REQUEST, "Session ID required");
- return;
- }
-
- // Terminate session
- MCPSession session = sessions.remove(sessionId);
- if (session != null) {
- log("Session terminated: " + sessionId);
- }
-
- response.setStatus(HttpServletResponse.SC_OK);
- }
-
- private void handleSingleRequest(
- HttpServletRequest request,
- HttpServletResponse response,
- JsonNode jsonRequest,
- String sessionId)
- throws IOException {
-
- String method = jsonRequest.has("method") ? jsonRequest.get("method").asText() : null;
- boolean hasId = jsonRequest.has("id");
-
- // Handle initialization
- if ("initialize".equals(method)) {
- handleInitialize(response, jsonRequest);
- return;
- }
-
- // Validate session for non-initialization requests
- if (sessionId != null && !sessions.containsKey(sessionId)) {
- sendError(response, HttpServletResponse.SC_NOT_FOUND, "Session not found");
- return;
- }
-
- // Handle different message types
- if (!hasId) {
- // Notification - return 202 Accepted
- processNotification(jsonRequest, sessionId);
- response.setStatus(HttpServletResponse.SC_ACCEPTED);
- } else {
- // Request - may return JSON or start SSE stream
- String acceptHeader = request.getHeader("Accept");
- boolean supportsSSE = acceptHeader != null && acceptHeader.contains(CONTENT_TYPE_SSE);
-
- if (supportsSSE && shouldUseSSE()) {
- startSSEStream(request, response, jsonRequest, sessionId);
- } else {
- sendJSONResponse(response, jsonRequest, sessionId);
- }
- }
- }
-
- private void handleInitialize(HttpServletResponse response, JsonNode request) throws IOException {
- // Create new session
- String sessionId = generateSessionId();
- MCPSession session = new MCPSession(this.objectMapper, sessionId, response.getWriter());
- sessions.put(sessionId, session);
-
- // Create initialize response
- Map jsonResponse = new HashMap<>();
- jsonResponse.put("jsonrpc", "2.0");
- jsonResponse.put("id", request.get("id"));
-
- Map result = new HashMap<>();
- result.put("protocolVersion", "2024-11-05");
- result.put("capabilities", getServerCapabilities());
- result.put("serverInfo", getServerInfo());
- jsonResponse.put("result", result);
-
- // Send response with session ID
- String responseJson = objectMapper.writeValueAsString(jsonResponse);
- response.setContentType(CONTENT_TYPE_JSON);
- response.setHeader(SESSION_HEADER, sessionId);
- response.setStatus(HttpServletResponse.SC_OK);
- response.getWriter().write(responseJson);
-
- log("New session initialized: " + sessionId);
- }
-
- private void startSSEStream(
- HttpServletRequest request,
- HttpServletResponse response,
- JsonNode jsonRequest,
- String sessionId)
- throws IOException {
-
- // Set up SSE response headers
- response.setContentType(CONTENT_TYPE_SSE);
- response.setHeader("Cache-Control", "no-cache");
- response.setHeader("Connection", "keep-alive");
- response.setHeader("Access-Control-Allow-Origin", "*");
- response.setStatus(HttpServletResponse.SC_OK);
-
- // Start async processing
- AsyncContext asyncContext = request.startAsync();
- asyncContext.setTimeout(0); // 5 minutes timeout
-
- SSEConnection connection = new SSEConnection(response.getWriter(), sessionId);
- String connectionId = UUID.randomUUID().toString();
- sseConnections.put(connectionId, connection);
-
- // Process request asynchronously
- executorService.submit(
- () -> {
- try {
- // Send any server-initiated messages first (if needed)
- // sendServerInitiatedMessages(connection);
-
- // Process the actual request
- Map jsonResponse = processRequest(jsonRequest, sessionId);
- connection.sendEvent(objectMapper.writeValueAsString(jsonResponse));
-
- // Close the stream after sending response
- connection.close();
- asyncContext.complete();
-
- } catch (Exception e) {
- log("Error in SSE stream processing", e);
- try {
- connection.close();
- asyncContext.complete();
- } catch (Exception ex) {
- log("Error closing SSE connection", ex);
- }
- } finally {
- sseConnections.remove(connectionId);
- }
- });
- }
-
- private void startSSEStreamForGet(
- HttpServletRequest request, HttpServletResponse response, String sessionId)
- throws IOException {
-
- response.setContentType(CONTENT_TYPE_SSE);
- response.setHeader("Cache-Control", "no-cache");
- response.setHeader("Connection", "keep-alive");
- response.setHeader("Access-Control-Allow-Origin", "*");
- response.setStatus(HttpServletResponse.SC_OK);
-
- AsyncContext asyncContext = request.startAsync();
- asyncContext.setTimeout(0); // No timeout for long-lived connections
-
- SSEConnection connection = new SSEConnection(response.getWriter(), sessionId);
- String connectionId = UUID.randomUUID().toString();
- sseConnections.put(connectionId, connection);
-
- // Keep connection alive and handle server-initiated messages
- executorService.submit(
- () -> {
- try {
- while (!connection.isClosed()) {
- // Send keepalive every 30 seconds
- Thread.sleep(30000);
- connection.sendComment("keepalive");
- }
- } catch (InterruptedException e) {
- Thread.currentThread().interrupt();
- } catch (IOException e) {
- LOG.error("SSE connection error for connection ID: {}", connectionId, e);
- } finally {
- try {
- connection.close();
- asyncContext.complete();
- } catch (Exception e) {
- log("Error closing long-lived SSE connection", e);
- }
- sseConnections.remove(connectionId);
- }
- });
- }
-
- @Override
- public void setSessionFactory(McpServerSession.Factory sessionFactory) {
- this.sessionFactory = sessionFactory;
- }
-
- @Override
- public Mono notifyClients(String method, Map params) {
- if (sessions.isEmpty()) {
- LOG.debug("No active sessions to broadcast message to");
- return Mono.empty();
- }
-
- LOG.debug("Attempting to broadcast message to {} active sessions", sessions.size());
- return Flux.fromIterable(sessions.values())
- .flatMap(
- session ->
- session
- .sendNotification(method, params)
- .doOnError(
- e ->
- LOG.error(
- "Failed to send message to session {}: {}",
- session.getSessionId(),
- e.getMessage()))
- .onErrorComplete())
- .then();
- }
-
- @Override
- public Mono closeGracefully() {
- LOG.debug("Initiating graceful shutdown with {} active sessions", sessions.size());
- return Flux.fromIterable(sessions.values()).flatMap(MCPSession::closeGracefully).then();
- }
-
- private static class SSEConnection {
- private final PrintWriter writer;
- private final String sessionId;
- private volatile boolean closed = false;
- private int eventId = 0;
-
- public SSEConnection(PrintWriter writer, String sessionId) {
- this.writer = writer;
- this.sessionId = sessionId;
- }
-
- public synchronized void sendEvent(String data) throws IOException {
- if (closed) return;
-
- writer.printf("id: %d%n", ++eventId);
- writer.printf("data: %s%n%n", data);
- writer.flush();
-
- if (writer.checkError()) {
- throw new IOException("SSE write error");
- }
- }
-
- private void sendEvent(String eventType, String data) throws IOException {
- writer.write("event: " + eventType + "\n");
- writer.write("data: " + data + "\n\n");
- writer.flush();
- if (writer.checkError()) {
- throw new IOException("Client disconnected");
- }
- }
-
- public synchronized void sendComment(String comment) throws IOException {
- if (closed) return;
-
- writer.printf(": %s%n%n", comment);
- writer.flush();
-
- if (writer.checkError()) {
- throw new IOException("SSE write error");
- }
- }
-
- public void close() throws IOException {
- closed = true;
- if (writer != null) {
- writer.close();
- }
- }
-
- public boolean isClosed() {
- return closed;
- }
-
- public String getSessionId() {
- return sessionId;
- }
- }
-
- private static class MCPSession {
- @Getter private final String sessionId;
- private final long createdAt;
- private final Map state;
- private final PrintWriter outputStream;
- private final ObjectMapper objectMapper;
-
- public MCPSession(ObjectMapper objectMapper, String sessionId, PrintWriter outputStream) {
- this.sessionId = sessionId;
- this.createdAt = System.currentTimeMillis();
- this.state = new ConcurrentHashMap<>();
- this.objectMapper = objectMapper;
- this.outputStream = outputStream;
- }
-
- public String getSessionId() {
- return sessionId;
- }
-
- public long getCreatedAt() {
- return createdAt;
- }
-
- public Map getState() {
- return state;
- }
-
- public Mono sendNotification(String method, Map params) {
- McpSchema.JSONRPCNotification jsonrpcNotification =
- new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, method, params);
- return Mono.fromRunnable(
- () -> {
- try {
- String json = objectMapper.writeValueAsString(jsonrpcNotification);
- outputStream.write(json);
- outputStream.write('\n');
- outputStream.flush();
- } catch (IOException e) {
- LOG.error("Failed to send message", e);
- }
- });
- }
-
- public Mono closeGracefully() {
- return Mono.fromRunnable(
- () -> {
- outputStream.flush();
- outputStream.close();
- });
- }
- }
-
- private Map processRequest(JsonNode request, String sessionId) {
- Map response = new HashMap<>();
- response.put("jsonrpc", "2.0");
- response.put("id", request.get("id"));
-
- String method = request.get("method").asText();
-
- try {
- switch (method) {
- case "ping":
- response.put("result", "pong");
- break;
-
- case "resources/list":
- // TODO: Implement resource reading logic
- response.put("result", getResourcesList());
- break;
-
- case "resources/read":
- // TODO: Implement resource reading logic
- break;
-
- case "tools/list":
- response.put("result", new McpSchema.ListToolsResult(tools, null));
- break;
-
- case "tools/call":
- JsonNode toolParams = request.get("params");
- if (toolParams != null && toolParams.has("name")) {
- String toolName = toolParams.get("name").asText();
- JsonNode arguments = toolParams.get("arguments");
- McpSchema.Content content =
- new McpSchema.TextContent(
- JsonUtils.pojoToJson(
- callTool(
- authorizer, jwtFilter, limits, toolName, JsonUtils.getMap(arguments))));
- response.put("result", new McpSchema.CallToolResult(List.of(content), false));
- } else {
- response.put("error", createError(-32602, "Invalid params"));
- }
- break;
-
- default:
- response.put("error", createError(-32601, "Method not found: " + method));
- }
- } catch (Exception e) {
- log("Error processing request: " + method, e);
- response.put("error", createError(-32603, "Internal error: " + e.getMessage()));
- }
-
- return response;
- }
-
- private void processNotification(JsonNode notification, String sessionId) {
- String method = notification.get("method").asText();
- log("Received notification: " + method + " (session: " + sessionId + ")");
-
- // Handle specific notifications
- switch (method) {
- case "notifications/initialized":
- LOG.info("Client initialized for session: {}", sessionId);
- break;
- case "notifications/cancelled":
- LOG.info("Client sent a cancellation request for session: {}", sessionId);
- // Handle cancellation
- break;
- default:
- log("Unknown notification: " + method);
- }
- }
-
- // Utility methods
- private String generateSessionId() {
- byte[] bytes = new byte[32];
- secureRandom.nextBytes(bytes);
- return Base64.getUrlEncoder().withoutPadding().encodeToString(bytes);
- }
-
- private boolean isValidOrigin(String origin) {
- // Implement your origin validation logic
- return origin.startsWith("http://localhost")
- || origin.startsWith("https://localhost")
- || origin.startsWith("https://yourdomain.com");
- }
-
- private boolean shouldUseSSE() {
- // TODO: Decide when to use SSE vs direct JSON response
- return false;
- }
-
- private String readRequestBody(HttpServletRequest request) throws IOException {
- StringBuilder body = new StringBuilder();
- try (BufferedReader reader = request.getReader()) {
- String line;
- while ((line = reader.readLine()) != null) {
- body.append(line);
- }
- }
- return body.toString();
- }
-
- private void sendError(HttpServletResponse response, int statusCode, String message)
- throws IOException {
- Map error = new HashMap<>();
- error.put("error", message);
-
- String errorJson = objectMapper.writeValueAsString(error);
- response.setContentType(CONTENT_TYPE_JSON);
- response.setStatus(statusCode);
- response.getWriter().write(errorJson);
- }
-
- private void sendJSONResponse(HttpServletResponse response, JsonNode request, String sessionId)
- throws IOException {
- Map jsonResponse = processRequest(request, sessionId);
- String responseJson = objectMapper.writeValueAsString(jsonResponse);
-
- response.setContentType(CONTENT_TYPE_JSON);
- response.setStatus(HttpServletResponse.SC_OK);
- response.getWriter().write(responseJson);
- response.getWriter().flush();
- }
-
- private void handleBatchRequest(
- HttpServletRequest request,
- HttpServletResponse response,
- JsonNode batchRequest,
- String sessionId)
- throws IOException {
- // TODO: Handle this
- sendError(response, HttpServletResponse.SC_NOT_IMPLEMENTED, "Batch requests not implemented");
- }
-
- private Map createError(int code, String message) {
- Map error = new HashMap<>();
- error.put("code", code);
- error.put("message", message);
- return error;
- }
-
- private Map getServerCapabilities() {
- Map capabilities = new HashMap<>();
-
- // Resources capability
- Map resources = new HashMap<>();
- resources.put("subscribe", true);
- resources.put("listChanged", true);
- capabilities.put("resources", resources);
-
- // Tools
- Map tools = new HashMap<>();
- tools.put("listChanged", true);
- capabilities.put("tools", tools);
-
- return capabilities;
- }
-
- private Map getServerInfo() {
- Map serverInfo = new HashMap<>();
- serverInfo.put("name", "OpenMetadata MCP Server - Streamable");
- serverInfo.put("version", "1.0.0");
- return serverInfo;
- }
-
- private Map getResourcesList() {
- return new HashMap<>();
- }
-
- private Map createResource(
- String uri, String name, String mimeType, String description) {
- Map resource = new HashMap<>();
- resource.put("uri", uri);
- resource.put("name", name);
- resource.put("mimeType", mimeType);
- resource.put("description", description);
- return resource;
- }
-
- /**
- * Send a server-initiated notification to a specific session
- */
- public void sendNotificationToSession(
- String sessionId, String method, Map params) {
- // Find SSE connections for this session
- for (SSEConnection connection : sseConnections.values()) {
- if (sessionId.equals(connection.getSessionId())) {
- try {
- Map notification = new HashMap<>();
- notification.put("jsonrpc", "2.0");
- notification.put("method", method);
- if (params != null) {
- notification.put("params", params);
- }
-
- connection.sendEvent(objectMapper.writeValueAsString(notification));
- } catch (IOException e) {
- log("Error sending notification to session: " + sessionId, e);
- }
- }
- }
- }
-}
diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/mcp/McpServer.java b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/McpServer.java
index c059e480984..3decc827e5c 100644
--- a/openmetadata-service/src/main/java/org/openmetadata/service/mcp/McpServer.java
+++ b/openmetadata-service/src/main/java/org/openmetadata/service/mcp/McpServer.java
@@ -1,24 +1,33 @@
package org.openmetadata.service.mcp;
-import static org.openmetadata.service.mcp.McpUtils.callTool;
-import static org.openmetadata.service.mcp.McpUtils.getToolProperties;
+import static org.openmetadata.service.search.SearchUtil.searchMetadata;
-import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.JsonNode;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.jetty.MutableServletContextHandler;
import io.modelcontextprotocol.server.McpServerFeatures;
import io.modelcontextprotocol.server.McpSyncServer;
import io.modelcontextprotocol.spec.McpSchema;
import jakarta.servlet.DispatcherType;
+import jakarta.servlet.Filter;
+import java.util.ArrayList;
import java.util.EnumSet;
import java.util.List;
+import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import org.eclipse.jetty.servlet.FilterHolder;
import org.eclipse.jetty.servlet.ServletHolder;
+import org.openmetadata.common.utils.CommonUtil;
import org.openmetadata.service.OpenMetadataApplicationConfig;
import org.openmetadata.service.limits.Limits;
+import org.openmetadata.service.mcp.tools.GlossaryTermTool;
+import org.openmetadata.service.mcp.tools.GlossaryTool;
+import org.openmetadata.service.mcp.tools.PatchEntityTool;
+import org.openmetadata.service.security.AuthorizationException;
import org.openmetadata.service.security.Authorizer;
import org.openmetadata.service.security.JwtFilter;
+import org.openmetadata.service.security.auth.CatalogSecurityContext;
+import org.openmetadata.service.util.EntityUtil;
import org.openmetadata.service.util.JsonUtils;
@Slf4j
@@ -38,20 +47,6 @@ public class McpServer {
new JwtFilter(config.getAuthenticationConfiguration(), config.getAuthorizerConfiguration());
this.authorizer = authorizer;
this.limits = limits;
- MutableServletContextHandler contextHandler = environment.getApplicationContext();
- McpAuthFilter authFilter =
- new McpAuthFilter(
- new JwtFilter(
- config.getAuthenticationConfiguration(), config.getAuthorizerConfiguration()));
- List tools = loadToolsDefinitionsFromJson();
- addSSETransport(contextHandler, authFilter, tools);
- addStreamableHttpServlet(contextHandler, authFilter, tools);
- }
-
- private void addSSETransport(
- MutableServletContextHandler contextHandler,
- McpAuthFilter authFilter,
- List tools) {
McpSchema.ServerCapabilities serverCapabilities =
McpSchema.ServerCapabilities.builder()
.tools(true)
@@ -59,57 +54,152 @@ public class McpServer {
.resources(true, true)
.build();
- HttpServletSseServerTransportProvider sseTransport =
- new HttpServletSseServerTransportProvider(new ObjectMapper(), "/mcp/messages", "/mcp/sse");
-
+ HttpServletSseServerTransportProvider transport =
+ new HttpServletSseServerTransportProvider("/mcp/messages", "/mcp/sse");
McpSyncServer server =
- io.modelcontextprotocol.server.McpServer.sync(sseTransport)
- .serverInfo("openmetadata-mcp-sse", "0.1.0")
+ io.modelcontextprotocol.server.McpServer.sync(transport)
+ .serverInfo("openmetadata-mcp", "0.1.0")
.capabilities(serverCapabilities)
.build();
- addToolsToServer(server, tools);
- // SSE transport for MCP
- ServletHolder servletHolderSSE = new ServletHolder(sseTransport);
- contextHandler.addServlet(servletHolderSSE, "/mcp/*");
+ // Add resources, prompts, and tools to the MCP server
+ addTools(server);
+ MutableServletContextHandler contextHandler = environment.getApplicationContext();
+ ServletHolder servletHolder = new ServletHolder(transport);
+ contextHandler.addServlet(servletHolder, "/mcp/*");
+
+ McpAuthFilter authFilter =
+ new McpAuthFilter(
+ new JwtFilter(
+ config.getAuthenticationConfiguration(), config.getAuthorizerConfiguration()));
contextHandler.addFilter(
- new FilterHolder(authFilter), "/mcp/*", EnumSet.of(DispatcherType.REQUEST));
+ new FilterHolder((Filter) authFilter), "/mcp/*", EnumSet.of(DispatcherType.REQUEST));
}
- private void addStreamableHttpServlet(
- MutableServletContextHandler contextHandler,
- McpAuthFilter authFilter,
- List tools) {
- // Streamable HTTP servlet for MCP
- MCPStreamableHttpServlet streamableHttpServlet =
- new MCPStreamableHttpServlet(jwtFilter, authorizer, limits, tools);
- ServletHolder servletHolderStreamableHttp = new ServletHolder(streamableHttpServlet);
- contextHandler.addServlet(servletHolderStreamableHttp, "/mcp");
+ public void addTools(McpSyncServer server) {
+ try {
+ LOG.info("Loading tool definitions...");
+ List