Revert "[WIP] MCP Core Items Improvements (#21598)" (#21614)

This reverts commit 0b3bf4ac0d3a7ac74e39552ad49896d37e469516.
This commit is contained in:
Pere Miquel Brull 2025-06-06 07:32:20 +02:00 committed by GitHub
parent fc78dfd574
commit 635382dd1d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 306 additions and 1197 deletions

View File

@ -26,10 +26,8 @@ import static org.openmetadata.schema.type.Include.NON_DELETED;
import static org.openmetadata.service.Entity.DATABASE_SCHEMA;
import static org.openmetadata.service.Entity.FIELD_OWNERS;
import static org.openmetadata.service.Entity.FIELD_TAGS;
import static org.openmetadata.service.Entity.QUERY;
import static org.openmetadata.service.Entity.TABLE;
import static org.openmetadata.service.Entity.TEST_SUITE;
import static org.openmetadata.service.Entity.getEntities;
import static org.openmetadata.service.Entity.populateEntityFieldTags;
import static org.openmetadata.service.util.EntityUtil.getLocalColumnName;
import static org.openmetadata.service.util.FullyQualifiedName.getColumnName;
@ -64,7 +62,6 @@ import org.openmetadata.schema.EntityInterface;
import org.openmetadata.schema.api.data.CreateTableProfile;
import org.openmetadata.schema.api.feed.ResolveTask;
import org.openmetadata.schema.entity.data.DatabaseSchema;
import org.openmetadata.schema.entity.data.Query;
import org.openmetadata.schema.entity.data.Table;
import org.openmetadata.schema.entity.feed.Suggestion;
import org.openmetadata.schema.tests.CustomMetric;
@ -182,15 +179,6 @@ public class TableRepository extends EntityRepository<Table> {
column.setCustomMetrics(getCustomMetrics(table, column.getName()));
}
}
if (fields.contains("tableQueries")) {
List<Query> queriesEntity =
getEntities(
listOrEmpty(findTo(table.getId(), TABLE, Relationship.MENTIONED_IN, QUERY)),
"id",
ALL);
List<String> queries = listOrEmpty(queriesEntity).stream().map(Query::getQuery).toList();
table.setTableQueries(queries);
}
}
@Override

View File

@ -1,6 +1,8 @@
package org.openmetadata.service.mcp;
/*
HttpServletSseServerTransportProvider - Jakarta servlet-based MCP server transport
*/
import static org.openmetadata.service.mcp.McpUtils.getJsonRpcMessageWithAuthorizationParam;
package org.openmetadata.service.mcp;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
@ -9,7 +11,6 @@ import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.spec.McpServerTransport;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
import io.modelcontextprotocol.util.Assert;
import jakarta.servlet.AsyncContext;
import jakarta.servlet.ServletException;
import jakarta.servlet.annotation.WebServlet;
@ -22,10 +23,9 @@ import java.io.PrintWriter;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.openmetadata.service.security.JwtFilter;
import org.openmetadata.service.util.JsonUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
@ -42,220 +42,171 @@ public class HttpServletSseServerTransportProvider extends HttpServlet
public static final String DEFAULT_SSE_ENDPOINT = "/sse";
public static final String MESSAGE_EVENT_TYPE = "message";
public static final String ENDPOINT_EVENT_TYPE = "endpoint";
public static final String DEFAULT_BASE_URL = "";
private final ObjectMapper objectMapper;
private final String baseUrl;
private final String messageEndpoint;
private final String sseEndpoint;
private final Map<String, McpServerSession> sessions = new ConcurrentHashMap<>();
private final AtomicBoolean isClosing = new AtomicBoolean(false);
private final Map<String, McpServerSession> sessions;
private final AtomicBoolean isClosing;
private McpServerSession.Factory sessionFactory;
private ExecutorService executorService =
Executors.newCachedThreadPool(
r -> {
Thread t = new Thread(r, "MCP-Worker-SSE");
t.setDaemon(true);
return t;
});
public HttpServletSseServerTransportProvider(
ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) {
this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint);
}
public HttpServletSseServerTransportProvider(
ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) {
this.objectMapper = objectMapper;
this.baseUrl = baseUrl;
public HttpServletSseServerTransportProvider(String messageEndpoint, String sseEndpoint) {
this.sessions = new ConcurrentHashMap<>();
this.isClosing = new AtomicBoolean(false);
this.objectMapper = JsonUtils.getObjectMapper();
this.messageEndpoint = messageEndpoint;
this.sseEndpoint = sseEndpoint;
}
public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) {
this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT);
public HttpServletSseServerTransportProvider(String messageEndpoint) {
this(messageEndpoint, "/sse");
}
@Override
public void setSessionFactory(McpServerSession.Factory sessionFactory) {
this.sessionFactory = sessionFactory;
}
@Override
public Mono<Void> notifyClients(String method, Map<String, Object> params) {
if (sessions.isEmpty()) {
if (this.sessions.isEmpty()) {
logger.debug("No active sessions to broadcast message to");
return Mono.empty();
} else {
logger.debug("Attempting to broadcast message to {} active sessions", this.sessions.size());
return Flux.fromIterable(this.sessions.values())
.flatMap(
(session) ->
session
.sendNotification(method, params)
.doOnError(
(e) ->
logger.error(
"Failed to send message to session {}: {}",
session.getId(),
e.getMessage()))
.onErrorComplete())
.then();
}
logger.debug("Attempting to broadcast message to {} active sessions", sessions.size());
return Flux.fromIterable(sessions.values())
.flatMap(
session ->
session
.sendNotification(method, params)
.doOnError(
e ->
logger.error(
"Failed to send message to session {}: {}",
session.getId(),
e.getMessage()))
.onErrorComplete())
.then();
}
@Override
protected void doGet(HttpServletRequest request, HttpServletResponse response)
throws IOException {
throws ServletException, IOException {
handleSseEvent(request, response);
}
private void handleSseEvent(HttpServletRequest request, HttpServletResponse response)
throws IOException {
String requestURI = request.getRequestURI();
if (!requestURI.endsWith(sseEndpoint)) {
response.sendError(HttpServletResponse.SC_NOT_FOUND);
return;
throws ServletException, IOException {
String pathInfo = request.getPathInfo();
if (!this.sseEndpoint.contains(pathInfo)) {
response.sendError(404);
} else if (this.isClosing.get()) {
response.sendError(503, "Server is shutting down");
} else {
response.setContentType("text/event-stream");
response.setCharacterEncoding("UTF-8");
response.setHeader("Cache-Control", "no-cache");
response.setHeader("Connection", "keep-alive");
response.setHeader("Access-Control-Allow-Origin", "*");
String sessionId = UUID.randomUUID().toString();
AsyncContext asyncContext = request.startAsync();
asyncContext.setTimeout(0L);
PrintWriter writer = response.getWriter();
HttpServletMcpSessionTransport sessionTransport =
new HttpServletMcpSessionTransport(sessionId, asyncContext, writer);
McpServerSession session = this.sessionFactory.create(sessionTransport);
this.sessions.put(sessionId, session);
this.sendEvent(writer, "endpoint", this.messageEndpoint + "?sessionId=" + sessionId);
}
if (isClosing.get()) {
response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down");
return;
}
response.setContentType("text/event-stream");
response.setCharacterEncoding(UTF_8);
response.setHeader("Cache-Control", "no-cache");
response.setHeader("Connection", "keep-alive");
response.setHeader("Access-Control-Allow-Origin", "*");
String sessionId = UUID.randomUUID().toString();
AsyncContext asyncContext = request.startAsync();
asyncContext.setTimeout(0);
PrintWriter writer = response.getWriter();
// Create a new session transport
HttpServletMcpSessionTransport sessionTransport =
new HttpServletMcpSessionTransport(sessionId, asyncContext, writer);
// Create a new session using the session factory
McpServerSession session = sessionFactory.create(sessionTransport);
this.sessions.put(sessionId, session);
executorService.submit(
() -> {
// TODO: Handle session lifecycle and keepalive
try {
while (sessions.containsKey(sessionId)) {
// Send keepalive every 30 seconds
Thread.sleep(30000);
writer.write(": keep-alive\n\n");
writer.flush();
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} catch (Exception e) {
log("SSE error", e);
} finally {
try {
session.closeGracefully();
asyncContext.complete();
} catch (Exception e) {
log("Error closing long-lived SSE connection", e);
}
}
});
// Send initial endpoint event
this.sendEvent(
writer,
ENDPOINT_EVENT_TYPE,
this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId);
}
@Override
protected void doPost(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException {
if (this.isClosing.get()) {
response.sendError(503, "Server is shutting down");
} else {
String requestURI = request.getRequestURI();
if (!requestURI.endsWith(this.messageEndpoint)) {
response.sendError(404);
} else {
String sessionId = request.getParameter("sessionId");
if (sessionId == null) {
response.setContentType("application/json");
response.setCharacterEncoding("UTF-8");
response.setStatus(400);
String jsonError =
this.objectMapper.writeValueAsString(
new McpError("Session ID missing in message endpoint"));
PrintWriter writer = response.getWriter();
writer.write(jsonError);
writer.flush();
} else {
McpServerSession session = (McpServerSession) this.sessions.get(sessionId);
if (session == null) {
response.setContentType("application/json");
response.setCharacterEncoding("UTF-8");
response.setStatus(404);
String jsonError =
this.objectMapper.writeValueAsString(
new McpError("Session not found: " + sessionId));
PrintWriter writer = response.getWriter();
writer.write(jsonError);
writer.flush();
} else {
try {
BufferedReader reader = request.getReader();
StringBuilder body = new StringBuilder();
if (isClosing.get()) {
response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down");
return;
}
String line;
while ((line = reader.readLine()) != null) {
body.append(line);
}
String requestURI = request.getRequestURI();
if (!requestURI.endsWith(messageEndpoint)) {
response.sendError(HttpServletResponse.SC_NOT_FOUND);
return;
}
McpSchema.JSONRPCMessage message =
getJsonRpcMessageWithAuthorizationParam(request, body.toString());
session.handle(message).block();
response.setStatus(200);
} catch (Exception var11) {
Exception e = var11;
logger.error("Error processing message: {}", var11.getMessage());
// Get the session ID from the request parameter
String sessionId = request.getParameter("sessionId");
if (sessionId == null) {
response.setContentType(APPLICATION_JSON);
response.setCharacterEncoding(UTF_8);
response.setStatus(HttpServletResponse.SC_BAD_REQUEST);
String jsonError =
objectMapper.writeValueAsString(new McpError("Session ID missing in message endpoint"));
PrintWriter writer = response.getWriter();
writer.write(jsonError);
writer.flush();
return;
}
// Get the session from the sessions map
McpServerSession session = sessions.get(sessionId);
if (session == null) {
response.setContentType(APPLICATION_JSON);
response.setCharacterEncoding(UTF_8);
response.setStatus(HttpServletResponse.SC_NOT_FOUND);
String jsonError =
objectMapper.writeValueAsString(new McpError("Session not found: " + sessionId));
PrintWriter writer = response.getWriter();
writer.write(jsonError);
writer.flush();
return;
}
try {
BufferedReader reader = request.getReader();
StringBuilder body = new StringBuilder();
String line;
while ((line = reader.readLine()) != null) {
body.append(line);
}
McpSchema.JSONRPCMessage message =
getJsonRpcMessageWithAuthorizationParam(this.objectMapper, request, body.toString());
// Process the message through the session's handle method
session.handle(message).block(); // Block for Servlet compatibility
response.setStatus(HttpServletResponse.SC_OK);
} catch (Exception e) {
logger.error("Error processing message: {}", e.getMessage());
try {
McpError mcpError = new McpError(e.getMessage());
response.setContentType(APPLICATION_JSON);
response.setCharacterEncoding(UTF_8);
response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
String jsonError = objectMapper.writeValueAsString(mcpError);
PrintWriter writer = response.getWriter();
writer.write(jsonError);
writer.flush();
} catch (IOException ex) {
logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex.getMessage());
response.sendError(
HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Error processing message");
try {
McpError mcpError = new McpError(e.getMessage());
response.setContentType("application/json");
response.setCharacterEncoding("UTF-8");
response.setStatus(500);
String jsonError = this.objectMapper.writeValueAsString(mcpError);
PrintWriter writer = response.getWriter();
writer.write(jsonError);
writer.flush();
} catch (IOException ex) {
logger.error("Failed to send error response: {}", ex.getMessage());
response.sendError(500, "Error processing message");
}
}
}
}
}
}
}
@Override
public Mono<Void> closeGracefully() {
isClosing.set(true);
logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size());
private McpSchema.JSONRPCMessage getJsonRpcMessageWithAuthorizationParam(
HttpServletRequest request, String body) throws IOException {
Map<String, Object> requestMessage = JsonUtils.getMap(JsonUtils.readTree(body));
Map<String, Object> params = (Map<String, Object>) requestMessage.get("params");
if (params != null) {
Map<String, Object> arguments = (Map<String, Object>) params.get("arguments");
if (arguments != null) {
arguments.put("Authorization", JwtFilter.extractToken(request.getHeader("Authorization")));
}
}
return McpSchema.deserializeJsonRpcMessage(
this.objectMapper, JsonUtils.pojoToJson(requestMessage));
}
return Flux.fromIterable(sessions.values()).flatMap(McpServerSession::closeGracefully).then();
public Mono<Void> closeGracefully() {
this.isClosing.set(true);
logger.debug("Initiating graceful shutdown with {} active sessions", this.sessions.size());
return Flux.fromIterable(this.sessions.values())
.flatMap(McpServerSession::closeGracefully)
.then();
}
private void sendEvent(PrintWriter writer, String eventType, String data) throws IOException {
@ -267,19 +218,8 @@ public class HttpServletSseServerTransportProvider extends HttpServlet
}
}
@Override
public void destroy() {
closeGracefully().block();
if (executorService != null) {
executorService.shutdown();
try {
if (!executorService.awaitTermination(5, TimeUnit.SECONDS)) {
executorService.shutdownNow();
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
this.closeGracefully().block();
super.destroy();
}
@ -293,106 +233,65 @@ public class HttpServletSseServerTransportProvider extends HttpServlet
this.sessionId = sessionId;
this.asyncContext = asyncContext;
this.writer = writer;
logger.debug("Session transport {} initialized with SSE writer", sessionId);
HttpServletSseServerTransportProvider.logger.debug(
"Session transport {} initialized with SSE writer", sessionId);
}
@Override
public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
return Mono.fromRunnable(
() -> {
try {
String jsonText = objectMapper.writeValueAsString(message);
sendEvent(writer, MESSAGE_EVENT_TYPE, jsonText);
logger.debug("Message sent to session {}", sessionId);
String jsonText =
HttpServletSseServerTransportProvider.this.objectMapper.writeValueAsString(
message);
HttpServletSseServerTransportProvider.this.sendEvent(
this.writer, "message", jsonText);
HttpServletSseServerTransportProvider.logger.debug(
"Message sent to session {}", this.sessionId);
} catch (Exception e) {
logger.error("Failed to send message to session {}: {}", sessionId, e.getMessage());
sessions.remove(sessionId);
asyncContext.complete();
HttpServletSseServerTransportProvider.logger.error(
"Failed to send message to session {}: {}", this.sessionId, e.getMessage());
HttpServletSseServerTransportProvider.this.sessions.remove(this.sessionId);
this.asyncContext.complete();
}
});
}
@Override
public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
return objectMapper.convertValue(data, typeRef);
return (T)
HttpServletSseServerTransportProvider.this.objectMapper.convertValue(data, typeRef);
}
@Override
public Mono<Void> closeGracefully() {
return Mono.fromRunnable(
() -> {
logger.debug("Closing session transport: {}", sessionId);
HttpServletSseServerTransportProvider.logger.debug(
"Closing session transport: {}", this.sessionId);
try {
sessions.remove(sessionId);
asyncContext.complete();
logger.debug("Successfully completed async context for session {}", sessionId);
HttpServletSseServerTransportProvider.this.sessions.remove(this.sessionId);
this.asyncContext.complete();
HttpServletSseServerTransportProvider.logger.debug(
"Successfully completed async context for session {}", this.sessionId);
} catch (Exception e) {
logger.warn(
"Failed to complete async context for session {}: {}", sessionId, e.getMessage());
HttpServletSseServerTransportProvider.logger.warn(
"Failed to complete async context for session {}: {}",
this.sessionId,
e.getMessage());
}
});
}
@Override
public void close() {
try {
sessions.remove(sessionId);
asyncContext.complete();
logger.debug("Successfully completed async context for session {}", sessionId);
HttpServletSseServerTransportProvider.this.sessions.remove(this.sessionId);
this.asyncContext.complete();
HttpServletSseServerTransportProvider.logger.debug(
"Successfully completed async context for session {}", this.sessionId);
} catch (Exception e) {
logger.warn(
"Failed to complete async context for session {}: {}", sessionId, e.getMessage());
HttpServletSseServerTransportProvider.logger.warn(
"Failed to complete async context for session {}: {}", this.sessionId, e.getMessage());
}
}
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private ObjectMapper objectMapper = new ObjectMapper();
private String baseUrl = DEFAULT_BASE_URL;
private String messageEndpoint;
private String sseEndpoint = DEFAULT_SSE_ENDPOINT;
public Builder objectMapper(ObjectMapper objectMapper) {
Assert.notNull(objectMapper, "ObjectMapper must not be null");
this.objectMapper = objectMapper;
return this;
}
public Builder baseUrl(String baseUrl) {
Assert.notNull(baseUrl, "Base URL must not be null");
this.baseUrl = baseUrl;
return this;
}
public Builder messageEndpoint(String messageEndpoint) {
Assert.hasText(messageEndpoint, "Message endpoint must not be empty");
this.messageEndpoint = messageEndpoint;
return this;
}
public Builder sseEndpoint(String sseEndpoint) {
Assert.hasText(sseEndpoint, "SSE endpoint must not be empty");
this.sseEndpoint = sseEndpoint;
return this;
}
public HttpServletSseServerTransportProvider build() {
if (objectMapper == null) {
throw new IllegalStateException("ObjectMapper must be set");
}
if (messageEndpoint == null) {
throw new IllegalStateException("MessageEndpoint must be set");
}
return new HttpServletSseServerTransportProvider(
objectMapper, baseUrl, messageEndpoint, sseEndpoint);
}
}
}

View File

@ -1,701 +0,0 @@
package org.openmetadata.service.mcp;
import static org.openmetadata.service.mcp.McpUtils.callTool;
import static org.openmetadata.service.mcp.McpUtils.getJsonRpcMessageWithAuthorizationParam;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
import jakarta.servlet.AsyncContext;
import jakarta.servlet.ServletException;
import jakarta.servlet.annotation.WebServlet;
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.openmetadata.service.limits.Limits;
import org.openmetadata.service.security.Authorizer;
import org.openmetadata.service.security.JwtFilter;
import org.openmetadata.service.util.JsonUtils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
/**
* MCP (Model Context Protocol) Streamable HTTP Servlet
* This servlet implements the Streamable HTTP transport specification for MCP.
*/
@WebServlet(value = "/mcp", asyncSupported = true)
@Slf4j
public class MCPStreamableHttpServlet extends HttpServlet implements McpServerTransportProvider {
private static final String SESSION_HEADER = "Mcp-Session-Id";
private static final String CONTENT_TYPE_JSON = "application/json";
private static final String CONTENT_TYPE_SSE = "text/event-stream";
private ObjectMapper objectMapper = new ObjectMapper();
private Map<String, MCPSession> sessions = new ConcurrentHashMap<>();
private Map<String, SSEConnection> sseConnections = new ConcurrentHashMap<>();
private SecureRandom secureRandom = new SecureRandom();
private ExecutorService executorService =
Executors.newCachedThreadPool(
r -> {
Thread t = new Thread(r, "MCP-Worker-Streamable");
t.setDaemon(true);
return t;
});
private McpServerSession.Factory sessionFactory;
private final JwtFilter jwtFilter;
private final Authorizer authorizer;
private final List<McpSchema.Tool> tools = new ArrayList<>();
private final Limits limits;
public MCPStreamableHttpServlet(
JwtFilter jwtFilter, Authorizer authorizer, Limits limits, List<McpSchema.Tool> tools) {
this.jwtFilter = jwtFilter;
this.authorizer = authorizer;
this.limits = limits;
this.tools.addAll(tools);
}
@Override
public void init() throws ServletException {
super.init();
log("MCP Streamable HTTP Servlet initialized");
}
@Override
public void destroy() {
if (executorService != null) {
executorService.shutdown();
try {
if (!executorService.awaitTermination(5, TimeUnit.SECONDS)) {
executorService.shutdownNow();
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
// Close all SSE connections
for (SSEConnection connection : sseConnections.values()) {
try {
connection.close();
} catch (IOException e) {
log("Error closing SSE connection", e);
}
}
super.destroy();
}
@Override
protected void doPost(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException {
// Security: Validate Origin header
String origin = request.getHeader("Origin");
if (origin != null && !isValidOrigin(origin)) {
sendError(response, HttpServletResponse.SC_FORBIDDEN, "Invalid origin");
return;
}
// Validate Accept header - MUST include both application/json and text/event-stream
String acceptHeader = request.getHeader("Accept");
if (acceptHeader == null
|| !acceptHeader.contains(CONTENT_TYPE_JSON)
|| !acceptHeader.contains(CONTENT_TYPE_SSE)) {
sendError(
response,
HttpServletResponse.SC_BAD_REQUEST,
"Accept header must include both application/json and text/event-stream");
return;
}
try {
String requestBody = readRequestBody(request);
String sessionId = request.getHeader(SESSION_HEADER);
// Parse JSON-RPC message(s)
McpSchema.JSONRPCMessage message =
getJsonRpcMessageWithAuthorizationParam(this.objectMapper, request, requestBody);
JsonNode jsonNode = objectMapper.valueToTree(message);
// TODO: here we need to see how to handle the batch request from the Spec
if (jsonNode.isArray()) {
handleBatchRequest(request, response, jsonNode, sessionId);
} else {
handleSingleRequest(request, response, jsonNode, sessionId);
}
} catch (Exception e) {
log("Error handling POST request", e);
sendError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Internal server error");
}
}
@Override
protected void doGet(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException {
String sessionId = request.getHeader(SESSION_HEADER);
String acceptHeader = request.getHeader("Accept");
if (acceptHeader == null || !acceptHeader.contains(CONTENT_TYPE_SSE)) {
sendError(
response,
HttpServletResponse.SC_BAD_REQUEST,
"Accept header must include text/event-stream");
return;
}
if (sessionId != null && !sessions.containsKey(sessionId)) {
sendError(response, HttpServletResponse.SC_NOT_FOUND, "Session not found");
return;
}
startSSEStreamForGet(request, response, sessionId);
}
@Override
protected void doDelete(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException {
String sessionId = request.getHeader(SESSION_HEADER);
if (sessionId == null) {
sendError(response, HttpServletResponse.SC_BAD_REQUEST, "Session ID required");
return;
}
// Terminate session
MCPSession session = sessions.remove(sessionId);
if (session != null) {
log("Session terminated: " + sessionId);
}
response.setStatus(HttpServletResponse.SC_OK);
}
private void handleSingleRequest(
HttpServletRequest request,
HttpServletResponse response,
JsonNode jsonRequest,
String sessionId)
throws IOException {
String method = jsonRequest.has("method") ? jsonRequest.get("method").asText() : null;
boolean hasId = jsonRequest.has("id");
// Handle initialization
if ("initialize".equals(method)) {
handleInitialize(response, jsonRequest);
return;
}
// Validate session for non-initialization requests
if (sessionId != null && !sessions.containsKey(sessionId)) {
sendError(response, HttpServletResponse.SC_NOT_FOUND, "Session not found");
return;
}
// Handle different message types
if (!hasId) {
// Notification - return 202 Accepted
processNotification(jsonRequest, sessionId);
response.setStatus(HttpServletResponse.SC_ACCEPTED);
} else {
// Request - may return JSON or start SSE stream
String acceptHeader = request.getHeader("Accept");
boolean supportsSSE = acceptHeader != null && acceptHeader.contains(CONTENT_TYPE_SSE);
if (supportsSSE && shouldUseSSE()) {
startSSEStream(request, response, jsonRequest, sessionId);
} else {
sendJSONResponse(response, jsonRequest, sessionId);
}
}
}
private void handleInitialize(HttpServletResponse response, JsonNode request) throws IOException {
// Create new session
String sessionId = generateSessionId();
MCPSession session = new MCPSession(this.objectMapper, sessionId, response.getWriter());
sessions.put(sessionId, session);
// Create initialize response
Map<String, Object> jsonResponse = new HashMap<>();
jsonResponse.put("jsonrpc", "2.0");
jsonResponse.put("id", request.get("id"));
Map<String, Object> result = new HashMap<>();
result.put("protocolVersion", "2024-11-05");
result.put("capabilities", getServerCapabilities());
result.put("serverInfo", getServerInfo());
jsonResponse.put("result", result);
// Send response with session ID
String responseJson = objectMapper.writeValueAsString(jsonResponse);
response.setContentType(CONTENT_TYPE_JSON);
response.setHeader(SESSION_HEADER, sessionId);
response.setStatus(HttpServletResponse.SC_OK);
response.getWriter().write(responseJson);
log("New session initialized: " + sessionId);
}
private void startSSEStream(
HttpServletRequest request,
HttpServletResponse response,
JsonNode jsonRequest,
String sessionId)
throws IOException {
// Set up SSE response headers
response.setContentType(CONTENT_TYPE_SSE);
response.setHeader("Cache-Control", "no-cache");
response.setHeader("Connection", "keep-alive");
response.setHeader("Access-Control-Allow-Origin", "*");
response.setStatus(HttpServletResponse.SC_OK);
// Start async processing
AsyncContext asyncContext = request.startAsync();
asyncContext.setTimeout(0); // 5 minutes timeout
SSEConnection connection = new SSEConnection(response.getWriter(), sessionId);
String connectionId = UUID.randomUUID().toString();
sseConnections.put(connectionId, connection);
// Process request asynchronously
executorService.submit(
() -> {
try {
// Send any server-initiated messages first (if needed)
// sendServerInitiatedMessages(connection);
// Process the actual request
Map<String, Object> jsonResponse = processRequest(jsonRequest, sessionId);
connection.sendEvent(objectMapper.writeValueAsString(jsonResponse));
// Close the stream after sending response
connection.close();
asyncContext.complete();
} catch (Exception e) {
log("Error in SSE stream processing", e);
try {
connection.close();
asyncContext.complete();
} catch (Exception ex) {
log("Error closing SSE connection", ex);
}
} finally {
sseConnections.remove(connectionId);
}
});
}
private void startSSEStreamForGet(
HttpServletRequest request, HttpServletResponse response, String sessionId)
throws IOException {
response.setContentType(CONTENT_TYPE_SSE);
response.setHeader("Cache-Control", "no-cache");
response.setHeader("Connection", "keep-alive");
response.setHeader("Access-Control-Allow-Origin", "*");
response.setStatus(HttpServletResponse.SC_OK);
AsyncContext asyncContext = request.startAsync();
asyncContext.setTimeout(0); // No timeout for long-lived connections
SSEConnection connection = new SSEConnection(response.getWriter(), sessionId);
String connectionId = UUID.randomUUID().toString();
sseConnections.put(connectionId, connection);
// Keep connection alive and handle server-initiated messages
executorService.submit(
() -> {
try {
while (!connection.isClosed()) {
// Send keepalive every 30 seconds
Thread.sleep(30000);
connection.sendComment("keepalive");
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} catch (IOException e) {
LOG.error("SSE connection error for connection ID: {}", connectionId, e);
} finally {
try {
connection.close();
asyncContext.complete();
} catch (Exception e) {
log("Error closing long-lived SSE connection", e);
}
sseConnections.remove(connectionId);
}
});
}
@Override
public void setSessionFactory(McpServerSession.Factory sessionFactory) {
this.sessionFactory = sessionFactory;
}
@Override
public Mono<Void> notifyClients(String method, Map<String, Object> params) {
if (sessions.isEmpty()) {
LOG.debug("No active sessions to broadcast message to");
return Mono.empty();
}
LOG.debug("Attempting to broadcast message to {} active sessions", sessions.size());
return Flux.fromIterable(sessions.values())
.flatMap(
session ->
session
.sendNotification(method, params)
.doOnError(
e ->
LOG.error(
"Failed to send message to session {}: {}",
session.getSessionId(),
e.getMessage()))
.onErrorComplete())
.then();
}
@Override
public Mono<Void> closeGracefully() {
LOG.debug("Initiating graceful shutdown with {} active sessions", sessions.size());
return Flux.fromIterable(sessions.values()).flatMap(MCPSession::closeGracefully).then();
}
private static class SSEConnection {
private final PrintWriter writer;
private final String sessionId;
private volatile boolean closed = false;
private int eventId = 0;
public SSEConnection(PrintWriter writer, String sessionId) {
this.writer = writer;
this.sessionId = sessionId;
}
public synchronized void sendEvent(String data) throws IOException {
if (closed) return;
writer.printf("id: %d%n", ++eventId);
writer.printf("data: %s%n%n", data);
writer.flush();
if (writer.checkError()) {
throw new IOException("SSE write error");
}
}
private void sendEvent(String eventType, String data) throws IOException {
writer.write("event: " + eventType + "\n");
writer.write("data: " + data + "\n\n");
writer.flush();
if (writer.checkError()) {
throw new IOException("Client disconnected");
}
}
public synchronized void sendComment(String comment) throws IOException {
if (closed) return;
writer.printf(": %s%n%n", comment);
writer.flush();
if (writer.checkError()) {
throw new IOException("SSE write error");
}
}
public void close() throws IOException {
closed = true;
if (writer != null) {
writer.close();
}
}
public boolean isClosed() {
return closed;
}
public String getSessionId() {
return sessionId;
}
}
private static class MCPSession {
@Getter private final String sessionId;
private final long createdAt;
private final Map<String, Object> state;
private final PrintWriter outputStream;
private final ObjectMapper objectMapper;
public MCPSession(ObjectMapper objectMapper, String sessionId, PrintWriter outputStream) {
this.sessionId = sessionId;
this.createdAt = System.currentTimeMillis();
this.state = new ConcurrentHashMap<>();
this.objectMapper = objectMapper;
this.outputStream = outputStream;
}
public String getSessionId() {
return sessionId;
}
public long getCreatedAt() {
return createdAt;
}
public Map<String, Object> getState() {
return state;
}
public Mono<Void> sendNotification(String method, Map<String, Object> params) {
McpSchema.JSONRPCNotification jsonrpcNotification =
new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, method, params);
return Mono.fromRunnable(
() -> {
try {
String json = objectMapper.writeValueAsString(jsonrpcNotification);
outputStream.write(json);
outputStream.write('\n');
outputStream.flush();
} catch (IOException e) {
LOG.error("Failed to send message", e);
}
});
}
public Mono<Void> closeGracefully() {
return Mono.fromRunnable(
() -> {
outputStream.flush();
outputStream.close();
});
}
}
private Map<String, Object> processRequest(JsonNode request, String sessionId) {
Map<String, Object> response = new HashMap<>();
response.put("jsonrpc", "2.0");
response.put("id", request.get("id"));
String method = request.get("method").asText();
try {
switch (method) {
case "ping":
response.put("result", "pong");
break;
case "resources/list":
// TODO: Implement resource reading logic
response.put("result", getResourcesList());
break;
case "resources/read":
// TODO: Implement resource reading logic
break;
case "tools/list":
response.put("result", new McpSchema.ListToolsResult(tools, null));
break;
case "tools/call":
JsonNode toolParams = request.get("params");
if (toolParams != null && toolParams.has("name")) {
String toolName = toolParams.get("name").asText();
JsonNode arguments = toolParams.get("arguments");
McpSchema.Content content =
new McpSchema.TextContent(
JsonUtils.pojoToJson(
callTool(
authorizer, jwtFilter, limits, toolName, JsonUtils.getMap(arguments))));
response.put("result", new McpSchema.CallToolResult(List.of(content), false));
} else {
response.put("error", createError(-32602, "Invalid params"));
}
break;
default:
response.put("error", createError(-32601, "Method not found: " + method));
}
} catch (Exception e) {
log("Error processing request: " + method, e);
response.put("error", createError(-32603, "Internal error: " + e.getMessage()));
}
return response;
}
private void processNotification(JsonNode notification, String sessionId) {
String method = notification.get("method").asText();
log("Received notification: " + method + " (session: " + sessionId + ")");
// Handle specific notifications
switch (method) {
case "notifications/initialized":
LOG.info("Client initialized for session: {}", sessionId);
break;
case "notifications/cancelled":
LOG.info("Client sent a cancellation request for session: {}", sessionId);
// Handle cancellation
break;
default:
log("Unknown notification: " + method);
}
}
// Utility methods
private String generateSessionId() {
byte[] bytes = new byte[32];
secureRandom.nextBytes(bytes);
return Base64.getUrlEncoder().withoutPadding().encodeToString(bytes);
}
private boolean isValidOrigin(String origin) {
// Implement your origin validation logic
return origin.startsWith("http://localhost")
|| origin.startsWith("https://localhost")
|| origin.startsWith("https://yourdomain.com");
}
private boolean shouldUseSSE() {
// TODO: Decide when to use SSE vs direct JSON response
return false;
}
private String readRequestBody(HttpServletRequest request) throws IOException {
StringBuilder body = new StringBuilder();
try (BufferedReader reader = request.getReader()) {
String line;
while ((line = reader.readLine()) != null) {
body.append(line);
}
}
return body.toString();
}
private void sendError(HttpServletResponse response, int statusCode, String message)
throws IOException {
Map<String, Object> error = new HashMap<>();
error.put("error", message);
String errorJson = objectMapper.writeValueAsString(error);
response.setContentType(CONTENT_TYPE_JSON);
response.setStatus(statusCode);
response.getWriter().write(errorJson);
}
private void sendJSONResponse(HttpServletResponse response, JsonNode request, String sessionId)
throws IOException {
Map<String, Object> jsonResponse = processRequest(request, sessionId);
String responseJson = objectMapper.writeValueAsString(jsonResponse);
response.setContentType(CONTENT_TYPE_JSON);
response.setStatus(HttpServletResponse.SC_OK);
response.getWriter().write(responseJson);
response.getWriter().flush();
}
private void handleBatchRequest(
HttpServletRequest request,
HttpServletResponse response,
JsonNode batchRequest,
String sessionId)
throws IOException {
// TODO: Handle this
sendError(response, HttpServletResponse.SC_NOT_IMPLEMENTED, "Batch requests not implemented");
}
private Map<String, Object> createError(int code, String message) {
Map<String, Object> error = new HashMap<>();
error.put("code", code);
error.put("message", message);
return error;
}
private Map<String, Object> getServerCapabilities() {
Map<String, Object> capabilities = new HashMap<>();
// Resources capability
Map<String, Object> resources = new HashMap<>();
resources.put("subscribe", true);
resources.put("listChanged", true);
capabilities.put("resources", resources);
// Tools
Map<String, Object> tools = new HashMap<>();
tools.put("listChanged", true);
capabilities.put("tools", tools);
return capabilities;
}
private Map<String, Object> getServerInfo() {
Map<String, Object> serverInfo = new HashMap<>();
serverInfo.put("name", "OpenMetadata MCP Server - Streamable");
serverInfo.put("version", "1.0.0");
return serverInfo;
}
private Map<String, Object> getResourcesList() {
return new HashMap<>();
}
private Map<String, Object> createResource(
String uri, String name, String mimeType, String description) {
Map<String, Object> resource = new HashMap<>();
resource.put("uri", uri);
resource.put("name", name);
resource.put("mimeType", mimeType);
resource.put("description", description);
return resource;
}
/**
* Send a server-initiated notification to a specific session
*/
public void sendNotificationToSession(
String sessionId, String method, Map<String, Object> params) {
// Find SSE connections for this session
for (SSEConnection connection : sseConnections.values()) {
if (sessionId.equals(connection.getSessionId())) {
try {
Map<String, Object> notification = new HashMap<>();
notification.put("jsonrpc", "2.0");
notification.put("method", method);
if (params != null) {
notification.put("params", params);
}
connection.sendEvent(objectMapper.writeValueAsString(notification));
} catch (IOException e) {
log("Error sending notification to session: " + sessionId, e);
}
}
}
}
}

View File

@ -1,24 +1,33 @@
package org.openmetadata.service.mcp;
import static org.openmetadata.service.mcp.McpUtils.callTool;
import static org.openmetadata.service.mcp.McpUtils.getToolProperties;
import static org.openmetadata.service.search.SearchUtil.searchMetadata;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.JsonNode;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.jetty.MutableServletContextHandler;
import io.modelcontextprotocol.server.McpServerFeatures;
import io.modelcontextprotocol.server.McpSyncServer;
import io.modelcontextprotocol.spec.McpSchema;
import jakarta.servlet.DispatcherType;
import jakarta.servlet.Filter;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import org.eclipse.jetty.servlet.FilterHolder;
import org.eclipse.jetty.servlet.ServletHolder;
import org.openmetadata.common.utils.CommonUtil;
import org.openmetadata.service.OpenMetadataApplicationConfig;
import org.openmetadata.service.limits.Limits;
import org.openmetadata.service.mcp.tools.GlossaryTermTool;
import org.openmetadata.service.mcp.tools.GlossaryTool;
import org.openmetadata.service.mcp.tools.PatchEntityTool;
import org.openmetadata.service.security.AuthorizationException;
import org.openmetadata.service.security.Authorizer;
import org.openmetadata.service.security.JwtFilter;
import org.openmetadata.service.security.auth.CatalogSecurityContext;
import org.openmetadata.service.util.EntityUtil;
import org.openmetadata.service.util.JsonUtils;
@Slf4j
@ -38,20 +47,6 @@ public class McpServer {
new JwtFilter(config.getAuthenticationConfiguration(), config.getAuthorizerConfiguration());
this.authorizer = authorizer;
this.limits = limits;
MutableServletContextHandler contextHandler = environment.getApplicationContext();
McpAuthFilter authFilter =
new McpAuthFilter(
new JwtFilter(
config.getAuthenticationConfiguration(), config.getAuthorizerConfiguration()));
List<McpSchema.Tool> tools = loadToolsDefinitionsFromJson();
addSSETransport(contextHandler, authFilter, tools);
addStreamableHttpServlet(contextHandler, authFilter, tools);
}
private void addSSETransport(
MutableServletContextHandler contextHandler,
McpAuthFilter authFilter,
List<McpSchema.Tool> tools) {
McpSchema.ServerCapabilities serverCapabilities =
McpSchema.ServerCapabilities.builder()
.tools(true)
@ -59,57 +54,152 @@ public class McpServer {
.resources(true, true)
.build();
HttpServletSseServerTransportProvider sseTransport =
new HttpServletSseServerTransportProvider(new ObjectMapper(), "/mcp/messages", "/mcp/sse");
HttpServletSseServerTransportProvider transport =
new HttpServletSseServerTransportProvider("/mcp/messages", "/mcp/sse");
McpSyncServer server =
io.modelcontextprotocol.server.McpServer.sync(sseTransport)
.serverInfo("openmetadata-mcp-sse", "0.1.0")
io.modelcontextprotocol.server.McpServer.sync(transport)
.serverInfo("openmetadata-mcp", "0.1.0")
.capabilities(serverCapabilities)
.build();
addToolsToServer(server, tools);
// SSE transport for MCP
ServletHolder servletHolderSSE = new ServletHolder(sseTransport);
contextHandler.addServlet(servletHolderSSE, "/mcp/*");
// Add resources, prompts, and tools to the MCP server
addTools(server);
MutableServletContextHandler contextHandler = environment.getApplicationContext();
ServletHolder servletHolder = new ServletHolder(transport);
contextHandler.addServlet(servletHolder, "/mcp/*");
McpAuthFilter authFilter =
new McpAuthFilter(
new JwtFilter(
config.getAuthenticationConfiguration(), config.getAuthorizerConfiguration()));
contextHandler.addFilter(
new FilterHolder(authFilter), "/mcp/*", EnumSet.of(DispatcherType.REQUEST));
new FilterHolder((Filter) authFilter), "/mcp/*", EnumSet.of(DispatcherType.REQUEST));
}
private void addStreamableHttpServlet(
MutableServletContextHandler contextHandler,
McpAuthFilter authFilter,
List<McpSchema.Tool> tools) {
// Streamable HTTP servlet for MCP
MCPStreamableHttpServlet streamableHttpServlet =
new MCPStreamableHttpServlet(jwtFilter, authorizer, limits, tools);
ServletHolder servletHolderStreamableHttp = new ServletHolder(streamableHttpServlet);
contextHandler.addServlet(servletHolderStreamableHttp, "/mcp");
public void addTools(McpSyncServer server) {
try {
LOG.info("Loading tool definitions...");
List<Map<String, Object>> cachedTools = loadToolsDefinitionsFromJson();
if (cachedTools == null || cachedTools.isEmpty()) {
LOG.error("No tool definitions were loaded!");
throw new RuntimeException("Failed to load tool definitions");
}
LOG.info("Successfully loaded {} tool definitions", cachedTools.size());
contextHandler.addFilter(
new FilterHolder(authFilter), "/mcp/*", EnumSet.of(DispatcherType.REQUEST));
}
public void addToolsToServer(McpSyncServer server, List<McpSchema.Tool> tools) {
for (McpSchema.Tool tool : tools) {
server.addTool(getTool(tool));
for (Map<String, Object> toolDef : cachedTools) {
try {
String name = (String) toolDef.get("name");
String description = (String) toolDef.get("description");
Map<String, Object> schema = JsonUtils.getMap(toolDef.get("parameters"));
server.addTool(getTool(JsonUtils.pojoToJson(schema), name, description));
} catch (Exception e) {
LOG.error("Error processing tool definition: {}", toolDef, e);
}
}
LOG.info("Initializing request handlers...");
} catch (Exception e) {
LOG.error("Error during server startup", e);
throw new RuntimeException("Failed to start MCP server", e);
}
}
protected List<McpSchema.Tool> loadToolsDefinitionsFromJson() {
return getToolProperties("json/data/mcp/tools.json");
protected List<Map<String, Object>> loadToolsDefinitionsFromJson() {
String json = getJsonFromFile("json/data/mcp/tools.json");
return loadToolDefinitionsFromJson(json);
}
private McpServerFeatures.SyncToolSpecification getTool(McpSchema.Tool tool) {
protected static String getJsonFromFile(String path) {
try {
return CommonUtil.getResourceAsStream(McpServer.class.getClassLoader(), path);
} catch (Exception ex) {
LOG.error("Error loading JSON file: {}", path, ex);
return null;
}
}
@SuppressWarnings("unchecked")
public List<Map<String, Object>> loadToolDefinitionsFromJson(String json) {
try {
LOG.info("Loaded tool definitions, content length: {}", json.length());
LOG.info("Raw tools.json content: {}", json);
JsonNode toolsJson = JsonUtils.readTree(json);
JsonNode toolsArray = toolsJson.get("tools");
if (toolsArray == null || !toolsArray.isArray()) {
LOG.error("Invalid MCP tools file format. Expected 'tools' array.");
return new ArrayList<>();
}
List<Map<String, Object>> tools = new ArrayList<>();
for (JsonNode toolNode : toolsArray) {
String name = toolNode.get("name").asText();
Map<String, Object> toolDef = JsonUtils.convertValue(toolNode, Map.class);
tools.add(toolDef);
LOG.info("Tool found: {} with definition: {}", name, toolDef);
}
LOG.info("Found {} tool definitions", tools.size());
return tools;
} catch (Exception e) {
LOG.error("Error loading tool definitions: {}", e.getMessage(), e);
throw e;
}
}
private McpServerFeatures.SyncToolSpecification getTool(
String schema, String toolName, String description) {
McpSchema.Tool tool = new McpSchema.Tool(toolName, description, schema);
return new McpServerFeatures.SyncToolSpecification(
tool,
(exchange, arguments) -> {
McpSchema.Content content =
new McpSchema.TextContent(
JsonUtils.pojoToJson(
callTool(authorizer, jwtFilter, limits, tool.name(), arguments)));
new McpSchema.TextContent(JsonUtils.pojoToJson(runMethod(toolName, arguments)));
return new McpSchema.CallToolResult(List.of(content), false);
});
}
protected Object runMethod(String toolName, Map<String, Object> params) {
CatalogSecurityContext securityContext =
jwtFilter.getCatalogSecurityContext((String) params.get("Authorization"));
LOG.info(
"Catalog Principal: {} is trying to call the tool: {}",
securityContext.getUserPrincipal().getName(),
toolName);
Object result;
try {
switch (toolName) {
case "search_metadata":
result = searchMetadata(params);
break;
case "get_entity_details":
result = EntityUtil.getEntityDetails(params);
break;
case "create_glossary":
result = new GlossaryTool().execute(authorizer, limits, securityContext, params);
break;
case "create_glossary_term":
result = new GlossaryTermTool().execute(authorizer, limits, securityContext, params);
break;
case "patch_entity":
result = new PatchEntityTool().execute(authorizer, securityContext, params);
break;
default:
result = Map.of("error", "Unknown function: " + toolName);
break;
}
return result;
} catch (AuthorizationException ex) {
LOG.error("Authorization error: {}", ex.getMessage());
return Map.of(
"error", String.format("Authorization error: %s", ex.getMessage()), "statusCode", 403);
} catch (Exception ex) {
LOG.error("Error executing tool: {}", ex.getMessage());
return Map.of(
"error", String.format("Error executing tool: %s", ex.getMessage()), "statusCode", 500);
}
}
}

View File

@ -1,151 +0,0 @@
package org.openmetadata.service.mcp;
import static org.openmetadata.service.search.SearchUtil.searchMetadata;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.spec.McpSchema;
import jakarta.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import org.openmetadata.common.utils.CommonUtil;
import org.openmetadata.service.limits.Limits;
import org.openmetadata.service.mcp.tools.GlossaryTermTool;
import org.openmetadata.service.mcp.tools.GlossaryTool;
import org.openmetadata.service.mcp.tools.PatchEntityTool;
import org.openmetadata.service.security.AuthorizationException;
import org.openmetadata.service.security.Authorizer;
import org.openmetadata.service.security.JwtFilter;
import org.openmetadata.service.security.auth.CatalogSecurityContext;
import org.openmetadata.service.util.EntityUtil;
import org.openmetadata.service.util.JsonUtils;
@Slf4j
public class McpUtils {
public static McpSchema.JSONRPCMessage getJsonRpcMessageWithAuthorizationParam(
ObjectMapper objectMapper, HttpServletRequest request, String body) throws IOException {
Map<String, Object> requestMessage = JsonUtils.getMap(JsonUtils.readTree(body));
Map<String, Object> params = (Map<String, Object>) requestMessage.get("params");
if (params != null) {
Map<String, Object> arguments = (Map<String, Object>) params.get("arguments");
if (arguments != null) {
arguments.put("Authorization", JwtFilter.extractToken(request.getHeader("Authorization")));
}
}
return McpSchema.deserializeJsonRpcMessage(objectMapper, JsonUtils.pojoToJson(requestMessage));
}
@SuppressWarnings("unchecked")
public static List<Map<String, Object>> loadToolDefinitionsFromJson(String json) {
try {
LOG.info("Loaded tool definitions, content length: {}", json.length());
LOG.info("Raw tools.json content: {}", json);
JsonNode toolsJson = JsonUtils.readTree(json);
JsonNode toolsArray = toolsJson.get("tools");
if (toolsArray == null || !toolsArray.isArray()) {
LOG.error("Invalid MCP tools file format. Expected 'tools' array.");
return new ArrayList<>();
}
List<Map<String, Object>> tools = new ArrayList<>();
for (JsonNode toolNode : toolsArray) {
String name = toolNode.get("name").asText();
Map<String, Object> toolDef = JsonUtils.convertValue(toolNode, Map.class);
tools.add(toolDef);
LOG.info("Tool found: {} with definition: {}", name, toolDef);
}
LOG.info("Found {} tool definitions", tools.size());
return tools;
} catch (Exception e) {
LOG.error("Error loading tool definitions: {}", e.getMessage(), e);
throw e;
}
}
public static String getJsonFromFile(String path) {
try {
return CommonUtil.getResourceAsStream(McpServer.class.getClassLoader(), path);
} catch (Exception ex) {
LOG.error("Error loading JSON file: {}", path, ex);
return null;
}
}
public static Object callTool(
Authorizer authorizer,
JwtFilter jwtFilter,
Limits limits,
String toolName,
Map<String, Object> params) {
CatalogSecurityContext securityContext =
jwtFilter.getCatalogSecurityContext((String) params.get("Authorization"));
LOG.info(
"Catalog Principal: {} is trying to call the tool: {}",
securityContext.getUserPrincipal().getName(),
toolName);
Object result;
try {
switch (toolName) {
case "search_metadata":
result = searchMetadata(params);
break;
case "get_entity_details":
result = EntityUtil.getEntityDetails(params);
break;
case "create_glossary":
result = new GlossaryTool().execute(authorizer, limits, securityContext, params);
break;
case "create_glossary_term":
result = new GlossaryTermTool().execute(authorizer, limits, securityContext, params);
break;
case "patch_entity":
result = new PatchEntityTool().execute(authorizer, limits, securityContext, params);
break;
default:
result = Map.of("error", "Unknown function: " + toolName);
break;
}
return result;
} catch (AuthorizationException ex) {
LOG.error("Authorization error: {}", ex.getMessage());
return Map.of(
"error", String.format("Authorization error: %s", ex.getMessage()), "statusCode", 403);
} catch (Exception ex) {
LOG.error("Error executing tool: {}", ex.getMessage());
return Map.of(
"error", String.format("Error executing tool: %s", ex.getMessage()), "statusCode", 500);
}
}
public static List<McpSchema.Tool> getToolProperties(String jsonFilePath) {
try {
List<McpSchema.Tool> result = new ArrayList<>();
String json = getJsonFromFile(jsonFilePath);
List<Map<String, Object>> cachedTools = loadToolDefinitionsFromJson(json);
if (cachedTools == null || cachedTools.isEmpty()) {
LOG.error("No tool definitions were loaded!");
throw new RuntimeException("Failed to load tool definitions");
}
LOG.debug("Successfully loaded {} tool definitions", cachedTools.size());
for (int i = 0; i < cachedTools.size(); i++) {
Map<String, Object> toolDef = cachedTools.get(i);
String name = (String) toolDef.get("name");
String description = (String) toolDef.get("description");
Map<String, Object> schema = JsonUtils.getMap(toolDef.get("parameters"));
result.add(new McpSchema.Tool(name, description, JsonUtils.pojoToJson(schema)));
}
return result;
} catch (Exception e) {
LOG.error("Error during server startup", e);
throw new RuntimeException("Failed to start MCP server", e);
}
}
}

View File

@ -141,10 +141,7 @@ public class SearchUtil {
case "glossary_search_index", Entity.GLOSSARY -> Entity.GLOSSARY;
case "domain_search_index", Entity.DOMAIN -> Entity.DOMAIN;
case "data_product_search_index", Entity.DATA_PRODUCT -> Entity.DATA_PRODUCT;
case "team_search_index", Entity.TEAM -> Entity.TEAM;
case "user_Search_index", Entity.USER -> Entity.USER;
case "dataAsset" -> "dataAsset";
default -> "dataAsset";
default -> "default";
};
}

View File

@ -108,7 +108,6 @@ public record TableIndex(Table table) implements ColumnIndex {
doc.put("processedLineage", table.getProcessedLineage());
doc.put("entityRelationship", SearchIndex.populateEntityRelationshipData(table));
doc.put("databaseSchema", getEntityWithDisplayName(table.getDatabaseSchema()));
doc.put("tableQueries", table.getTableQueries());
doc.put(
"changeSummary",
Optional.ofNullable(table.getChangeDescription())

View File

@ -1162,14 +1162,6 @@
"description": "Processed lineage for the table",
"type": "boolean",
"default": false
},
"tableQueries": {
"description": "List of queries that are used to create this table.",
"type": "array",
"items": {
"$ref": "../../type/basic.json#/definitions/sqlQuery"
},
"default": null
}
},
"required": [

View File

@ -161,11 +161,7 @@ export interface Table {
* Table Profiler Config to include or exclude columns from profiling.
*/
tableProfilerConfig?: TableProfilerConfig;
/**
* List of queries that are used to create this table.
*/
tableQueries?: string[];
tableType?: TableType;
tableType?: TableType;
/**
* Tags for this table.
*/