MCP Core Items Improvements (#21643)

* Search Util fix and added tableQueries

* some json input fix

* Add team and user

* WIP : Add Streamable HTTP

* - Add proper tools/list schema and tools/call

* - auth filter exact match

* - Add Tools Class to dynamically build tools

* Add Origin Validation Mandate

* Refactor MCP Stream

* comment

* Cleanups

* Typo

* Typo
This commit is contained in:
Mohit Yadav 2025-06-10 09:42:24 +05:30 committed by GitHub
parent 40aba1d906
commit dc25350ea2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 1728 additions and 502 deletions

View File

@ -443,6 +443,8 @@ operationalConfig:
mcpConfiguration:
enabled: ${MCP_ENABLED:-true}
path: ${MCP_PATH:-"/mcp"}
mcpServerVersion: ${MCP_SERVER_VERSION:-"1.0.0"}
mcpServerVersion: ${MCP_SERVER_VERSION:-"1.0.0"}
originHeaderUri: ${OPENMETADATA_SERVER_URL:-"http://localhost"}

View File

@ -91,6 +91,7 @@ import org.openmetadata.service.jobs.JobHandlerRegistry;
import org.openmetadata.service.limits.DefaultLimits;
import org.openmetadata.service.limits.Limits;
import org.openmetadata.service.mcp.McpServer;
import org.openmetadata.service.mcp.tools.DefaultToolContext;
import org.openmetadata.service.migration.Migration;
import org.openmetadata.service.migration.MigrationValidationClient;
import org.openmetadata.service.migration.api.MigrationWorkflow;
@ -289,7 +290,7 @@ public class OpenMetadataApplication extends Application<OpenMetadataApplication
OpenMetadataApplicationConfig catalogConfig, Environment environment) {
if (catalogConfig.getMcpConfiguration() != null
&& catalogConfig.getMcpConfiguration().isEnabled()) {
McpServer mcpServer = new McpServer();
McpServer mcpServer = new McpServer(new DefaultToolContext());
mcpServer.initializeMcpServer(environment, authorizer, limits, catalogConfig);
}
}

View File

@ -18,4 +18,7 @@ public class MCPConfiguration {
@JsonProperty("path")
private String path = "/api/v1/mcp";
@JsonProperty("originHeaderUri")
private String originHeaderUri = "http://localhost";
}

View File

@ -26,8 +26,10 @@ import static org.openmetadata.schema.type.Include.NON_DELETED;
import static org.openmetadata.service.Entity.DATABASE_SCHEMA;
import static org.openmetadata.service.Entity.FIELD_OWNERS;
import static org.openmetadata.service.Entity.FIELD_TAGS;
import static org.openmetadata.service.Entity.QUERY;
import static org.openmetadata.service.Entity.TABLE;
import static org.openmetadata.service.Entity.TEST_SUITE;
import static org.openmetadata.service.Entity.getEntities;
import static org.openmetadata.service.Entity.populateEntityFieldTags;
import static org.openmetadata.service.util.EntityUtil.getLocalColumnName;
import static org.openmetadata.service.util.FullyQualifiedName.getColumnName;
@ -62,6 +64,7 @@ import org.openmetadata.schema.EntityInterface;
import org.openmetadata.schema.api.data.CreateTableProfile;
import org.openmetadata.schema.api.feed.ResolveTask;
import org.openmetadata.schema.entity.data.DatabaseSchema;
import org.openmetadata.schema.entity.data.Query;
import org.openmetadata.schema.entity.data.Table;
import org.openmetadata.schema.entity.feed.Suggestion;
import org.openmetadata.schema.tests.CustomMetric;
@ -179,6 +182,15 @@ public class TableRepository extends EntityRepository<Table> {
column.setCustomMetrics(getCustomMetrics(table, column.getName()));
}
}
if (fields.contains("tableQueries")) {
List<Query> queriesEntity =
getEntities(
listOrEmpty(findTo(table.getId(), TABLE, Relationship.MENTIONED_IN, QUERY)),
"id",
ALL);
List<String> queries = listOrEmpty(queriesEntity).stream().map(Query::getQuery).toList();
table.setTableQueries(queries);
}
}
@Override

View File

@ -1,9 +1,7 @@
/*
HttpServletSseServerTransportProvider - Jakarta servlet-based MCP server transport
*/
package org.openmetadata.service.mcp;
import static org.openmetadata.service.mcp.McpUtils.getJsonRpcMessageWithAuthorizationParam;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.spec.McpError;
@ -11,6 +9,7 @@ import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.spec.McpServerTransport;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
import io.modelcontextprotocol.util.Assert;
import jakarta.servlet.AsyncContext;
import jakarta.servlet.ServletException;
import jakarta.servlet.annotation.WebServlet;
@ -23,9 +22,10 @@ import java.io.PrintWriter;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.openmetadata.service.security.JwtFilter;
import org.openmetadata.service.util.JsonUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
@ -42,171 +42,220 @@ public class HttpServletSseServerTransportProvider extends HttpServlet
public static final String DEFAULT_SSE_ENDPOINT = "/sse";
public static final String MESSAGE_EVENT_TYPE = "message";
public static final String ENDPOINT_EVENT_TYPE = "endpoint";
public static final String DEFAULT_BASE_URL = "";
private final ObjectMapper objectMapper;
private final String baseUrl;
private final String messageEndpoint;
private final String sseEndpoint;
private final Map<String, McpServerSession> sessions;
private final AtomicBoolean isClosing;
private final Map<String, McpServerSession> sessions = new ConcurrentHashMap<>();
private final AtomicBoolean isClosing = new AtomicBoolean(false);
private McpServerSession.Factory sessionFactory;
private ExecutorService executorService =
Executors.newCachedThreadPool(
r -> {
Thread t = new Thread(r, "MCP-Worker-SSE");
t.setDaemon(true);
return t;
});
public HttpServletSseServerTransportProvider(String messageEndpoint, String sseEndpoint) {
this.sessions = new ConcurrentHashMap<>();
this.isClosing = new AtomicBoolean(false);
this.objectMapper = JsonUtils.getObjectMapper();
public HttpServletSseServerTransportProvider(
ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) {
this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint);
}
public HttpServletSseServerTransportProvider(
ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) {
this.objectMapper = objectMapper;
this.baseUrl = baseUrl;
this.messageEndpoint = messageEndpoint;
this.sseEndpoint = sseEndpoint;
}
public HttpServletSseServerTransportProvider(String messageEndpoint) {
this(messageEndpoint, "/sse");
public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) {
this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT);
}
@Override
public void setSessionFactory(McpServerSession.Factory sessionFactory) {
this.sessionFactory = sessionFactory;
}
@Override
public Mono<Void> notifyClients(String method, Map<String, Object> params) {
if (this.sessions.isEmpty()) {
if (sessions.isEmpty()) {
logger.debug("No active sessions to broadcast message to");
return Mono.empty();
} else {
logger.debug("Attempting to broadcast message to {} active sessions", this.sessions.size());
return Flux.fromIterable(this.sessions.values())
.flatMap(
(session) ->
session
.sendNotification(method, params)
.doOnError(
(e) ->
logger.error(
"Failed to send message to session {}: {}",
session.getId(),
e.getMessage()))
.onErrorComplete())
.then();
}
logger.debug("Attempting to broadcast message to {} active sessions", sessions.size());
return Flux.fromIterable(sessions.values())
.flatMap(
session ->
session
.sendNotification(method, params)
.doOnError(
e ->
logger.error(
"Failed to send message to session {}: {}",
session.getId(),
e.getMessage()))
.onErrorComplete())
.then();
}
@Override
protected void doGet(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException {
throws IOException {
handleSseEvent(request, response);
}
private void handleSseEvent(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException {
String pathInfo = request.getPathInfo();
if (!this.sseEndpoint.contains(pathInfo)) {
response.sendError(404);
} else if (this.isClosing.get()) {
response.sendError(503, "Server is shutting down");
} else {
response.setContentType("text/event-stream");
response.setCharacterEncoding("UTF-8");
response.setHeader("Cache-Control", "no-cache");
response.setHeader("Connection", "keep-alive");
response.setHeader("Access-Control-Allow-Origin", "*");
String sessionId = UUID.randomUUID().toString();
AsyncContext asyncContext = request.startAsync();
asyncContext.setTimeout(0L);
PrintWriter writer = response.getWriter();
HttpServletMcpSessionTransport sessionTransport =
new HttpServletMcpSessionTransport(sessionId, asyncContext, writer);
McpServerSession session = this.sessionFactory.create(sessionTransport);
this.sessions.put(sessionId, session);
this.sendEvent(writer, "endpoint", this.messageEndpoint + "?sessionId=" + sessionId);
throws IOException {
String requestURI = request.getRequestURI();
if (!requestURI.endsWith(sseEndpoint)) {
response.sendError(HttpServletResponse.SC_NOT_FOUND);
return;
}
}
protected void doPost(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException {
if (this.isClosing.get()) {
response.sendError(503, "Server is shutting down");
} else {
String requestURI = request.getRequestURI();
if (!requestURI.endsWith(this.messageEndpoint)) {
response.sendError(404);
} else {
String sessionId = request.getParameter("sessionId");
if (sessionId == null) {
response.setContentType("application/json");
response.setCharacterEncoding("UTF-8");
response.setStatus(400);
String jsonError =
this.objectMapper.writeValueAsString(
new McpError("Session ID missing in message endpoint"));
PrintWriter writer = response.getWriter();
writer.write(jsonError);
writer.flush();
} else {
McpServerSession session = (McpServerSession) this.sessions.get(sessionId);
if (session == null) {
response.setContentType("application/json");
response.setCharacterEncoding("UTF-8");
response.setStatus(404);
String jsonError =
this.objectMapper.writeValueAsString(
new McpError("Session not found: " + sessionId));
PrintWriter writer = response.getWriter();
writer.write(jsonError);
writer.flush();
} else {
if (isClosing.get()) {
response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down");
return;
}
response.setContentType("text/event-stream");
response.setCharacterEncoding(UTF_8);
response.setHeader("Cache-Control", "no-cache");
response.setHeader("Connection", "keep-alive");
response.setHeader("Access-Control-Allow-Origin", "*");
String sessionId = UUID.randomUUID().toString();
AsyncContext asyncContext = request.startAsync();
asyncContext.setTimeout(0);
PrintWriter writer = response.getWriter();
// Create a new session transport
HttpServletMcpSessionTransport sessionTransport =
new HttpServletMcpSessionTransport(sessionId, asyncContext, writer);
// Create a new session using the session factory
McpServerSession session = sessionFactory.create(sessionTransport);
this.sessions.put(sessionId, session);
executorService.submit(
() -> {
// TODO: Handle session lifecycle and keepalive
try {
while (sessions.containsKey(sessionId)) {
// Send keepalive every 30 seconds
Thread.sleep(30000);
writer.write(": keep-alive\n\n");
writer.flush();
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} catch (Exception e) {
log("SSE error", e);
} finally {
try {
BufferedReader reader = request.getReader();
StringBuilder body = new StringBuilder();
String line;
while ((line = reader.readLine()) != null) {
body.append(line);
}
McpSchema.JSONRPCMessage message =
getJsonRpcMessageWithAuthorizationParam(request, body.toString());
session.handle(message).block();
response.setStatus(200);
} catch (Exception var11) {
Exception e = var11;
logger.error("Error processing message: {}", var11.getMessage());
try {
McpError mcpError = new McpError(e.getMessage());
response.setContentType("application/json");
response.setCharacterEncoding("UTF-8");
response.setStatus(500);
String jsonError = this.objectMapper.writeValueAsString(mcpError);
PrintWriter writer = response.getWriter();
writer.write(jsonError);
writer.flush();
} catch (IOException ex) {
logger.error("Failed to send error response: {}", ex.getMessage());
response.sendError(500, "Error processing message");
}
session.closeGracefully();
asyncContext.complete();
} catch (Exception e) {
log("Error closing long-lived SSE connection", e);
}
}
}
});
// Send initial endpoint event
this.sendEvent(
writer,
ENDPOINT_EVENT_TYPE,
this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId);
}
@Override
protected void doPost(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException {
if (isClosing.get()) {
response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down");
return;
}
String requestURI = request.getRequestURI();
if (!requestURI.endsWith(messageEndpoint)) {
response.sendError(HttpServletResponse.SC_NOT_FOUND);
return;
}
// Get the session ID from the request parameter
String sessionId = request.getParameter("sessionId");
if (sessionId == null) {
response.setContentType(APPLICATION_JSON);
response.setCharacterEncoding(UTF_8);
response.setStatus(HttpServletResponse.SC_BAD_REQUEST);
String jsonError =
objectMapper.writeValueAsString(new McpError("Session ID missing in message endpoint"));
PrintWriter writer = response.getWriter();
writer.write(jsonError);
writer.flush();
return;
}
// Get the session from the sessions map
McpServerSession session = sessions.get(sessionId);
if (session == null) {
response.setContentType(APPLICATION_JSON);
response.setCharacterEncoding(UTF_8);
response.setStatus(HttpServletResponse.SC_NOT_FOUND);
String jsonError =
objectMapper.writeValueAsString(new McpError("Session not found: " + sessionId));
PrintWriter writer = response.getWriter();
writer.write(jsonError);
writer.flush();
return;
}
try {
BufferedReader reader = request.getReader();
StringBuilder body = new StringBuilder();
String line;
while ((line = reader.readLine()) != null) {
body.append(line);
}
McpSchema.JSONRPCMessage message =
getJsonRpcMessageWithAuthorizationParam(this.objectMapper, request, body.toString());
// Process the message through the session's handle method
session.handle(message).block(); // Block for Servlet compatibility
response.setStatus(HttpServletResponse.SC_OK);
} catch (Exception e) {
logger.error("Error processing message: {}", e.getMessage());
try {
McpError mcpError = new McpError(e.getMessage());
response.setContentType(APPLICATION_JSON);
response.setCharacterEncoding(UTF_8);
response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
String jsonError = objectMapper.writeValueAsString(mcpError);
PrintWriter writer = response.getWriter();
writer.write(jsonError);
writer.flush();
} catch (IOException ex) {
logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex.getMessage());
response.sendError(
HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Error processing message");
}
}
}
private McpSchema.JSONRPCMessage getJsonRpcMessageWithAuthorizationParam(
HttpServletRequest request, String body) throws IOException {
Map<String, Object> requestMessage = JsonUtils.getMap(JsonUtils.readTree(body));
Map<String, Object> params = (Map<String, Object>) requestMessage.get("params");
if (params != null) {
Map<String, Object> arguments = (Map<String, Object>) params.get("arguments");
if (arguments != null) {
arguments.put("Authorization", JwtFilter.extractToken(request.getHeader("Authorization")));
}
}
return McpSchema.deserializeJsonRpcMessage(
this.objectMapper, JsonUtils.pojoToJson(requestMessage));
}
@Override
public Mono<Void> closeGracefully() {
this.isClosing.set(true);
logger.debug("Initiating graceful shutdown with {} active sessions", this.sessions.size());
return Flux.fromIterable(this.sessions.values())
.flatMap(McpServerSession::closeGracefully)
.then();
isClosing.set(true);
logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size());
return Flux.fromIterable(sessions.values()).flatMap(McpServerSession::closeGracefully).then();
}
private void sendEvent(PrintWriter writer, String eventType, String data) throws IOException {
@ -218,8 +267,19 @@ public class HttpServletSseServerTransportProvider extends HttpServlet
}
}
@Override
public void destroy() {
this.closeGracefully().block();
closeGracefully().block();
if (executorService != null) {
executorService.shutdown();
try {
if (!executorService.awaitTermination(5, TimeUnit.SECONDS)) {
executorService.shutdownNow();
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
super.destroy();
}
@ -233,65 +293,106 @@ public class HttpServletSseServerTransportProvider extends HttpServlet
this.sessionId = sessionId;
this.asyncContext = asyncContext;
this.writer = writer;
HttpServletSseServerTransportProvider.logger.debug(
"Session transport {} initialized with SSE writer", sessionId);
logger.debug("Session transport {} initialized with SSE writer", sessionId);
}
@Override
public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
return Mono.fromRunnable(
() -> {
try {
String jsonText =
HttpServletSseServerTransportProvider.this.objectMapper.writeValueAsString(
message);
HttpServletSseServerTransportProvider.this.sendEvent(
this.writer, "message", jsonText);
HttpServletSseServerTransportProvider.logger.debug(
"Message sent to session {}", this.sessionId);
String jsonText = objectMapper.writeValueAsString(message);
sendEvent(writer, MESSAGE_EVENT_TYPE, jsonText);
logger.debug("Message sent to session {}", sessionId);
} catch (Exception e) {
HttpServletSseServerTransportProvider.logger.error(
"Failed to send message to session {}: {}", this.sessionId, e.getMessage());
HttpServletSseServerTransportProvider.this.sessions.remove(this.sessionId);
this.asyncContext.complete();
logger.error("Failed to send message to session {}: {}", sessionId, e.getMessage());
sessions.remove(sessionId);
asyncContext.complete();
}
});
}
@Override
public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
return (T)
HttpServletSseServerTransportProvider.this.objectMapper.convertValue(data, typeRef);
return objectMapper.convertValue(data, typeRef);
}
@Override
public Mono<Void> closeGracefully() {
return Mono.fromRunnable(
() -> {
HttpServletSseServerTransportProvider.logger.debug(
"Closing session transport: {}", this.sessionId);
logger.debug("Closing session transport: {}", sessionId);
try {
HttpServletSseServerTransportProvider.this.sessions.remove(this.sessionId);
this.asyncContext.complete();
HttpServletSseServerTransportProvider.logger.debug(
"Successfully completed async context for session {}", this.sessionId);
sessions.remove(sessionId);
asyncContext.complete();
logger.debug("Successfully completed async context for session {}", sessionId);
} catch (Exception e) {
HttpServletSseServerTransportProvider.logger.warn(
"Failed to complete async context for session {}: {}",
this.sessionId,
e.getMessage());
logger.warn(
"Failed to complete async context for session {}: {}", sessionId, e.getMessage());
}
});
}
@Override
public void close() {
try {
HttpServletSseServerTransportProvider.this.sessions.remove(this.sessionId);
this.asyncContext.complete();
HttpServletSseServerTransportProvider.logger.debug(
"Successfully completed async context for session {}", this.sessionId);
sessions.remove(sessionId);
asyncContext.complete();
logger.debug("Successfully completed async context for session {}", sessionId);
} catch (Exception e) {
HttpServletSseServerTransportProvider.logger.warn(
"Failed to complete async context for session {}: {}", this.sessionId, e.getMessage());
logger.warn(
"Failed to complete async context for session {}: {}", sessionId, e.getMessage());
}
}
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private ObjectMapper objectMapper = new ObjectMapper();
private String baseUrl = DEFAULT_BASE_URL;
private String messageEndpoint;
private String sseEndpoint = DEFAULT_SSE_ENDPOINT;
public Builder objectMapper(ObjectMapper objectMapper) {
Assert.notNull(objectMapper, "ObjectMapper must not be null");
this.objectMapper = objectMapper;
return this;
}
public Builder baseUrl(String baseUrl) {
Assert.notNull(baseUrl, "Base URL must not be null");
this.baseUrl = baseUrl;
return this;
}
public Builder messageEndpoint(String messageEndpoint) {
Assert.hasText(messageEndpoint, "Message endpoint must not be empty");
this.messageEndpoint = messageEndpoint;
return this;
}
public Builder sseEndpoint(String sseEndpoint) {
Assert.hasText(sseEndpoint, "SSE endpoint must not be empty");
this.sseEndpoint = sseEndpoint;
return this;
}
public HttpServletSseServerTransportProvider build() {
if (objectMapper == null) {
throw new IllegalStateException("ObjectMapper must be set");
}
if (messageEndpoint == null) {
throw new IllegalStateException("MessageEndpoint must be set");
}
return new HttpServletSseServerTransportProvider(
objectMapper, baseUrl, messageEndpoint, sseEndpoint);
}
}
}

View File

@ -0,0 +1,971 @@
package org.openmetadata.service.mcp;
import static org.openmetadata.service.mcp.McpUtils.getJsonRpcMessageWithAuthorizationParam;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
import jakarta.servlet.AsyncContext;
import jakarta.servlet.ServletException;
import jakarta.servlet.annotation.WebServlet;
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.openmetadata.service.config.MCPConfiguration;
import org.openmetadata.service.exception.BadRequestException;
import org.openmetadata.service.limits.Limits;
import org.openmetadata.service.mcp.tools.DefaultToolContext;
import org.openmetadata.service.security.Authorizer;
import org.openmetadata.service.security.JwtFilter;
import org.openmetadata.service.util.JsonUtils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
/**
* MCP (Model Context Protocol) Streamable HTTP Servlet
* This servlet implements the Streamable HTTP transport specification for MCP.
*/
@WebServlet(value = "/mcp", asyncSupported = true)
@Slf4j
public class MCPStreamableHttpServlet extends HttpServlet implements McpServerTransportProvider {
private static final String SESSION_HEADER = "Mcp-Session-Id";
private static final String CONTENT_TYPE_JSON = "application/json";
private static final String CONTENT_TYPE_SSE = "text/event-stream";
private final transient ObjectMapper objectMapper = new ObjectMapper();
private final transient Map<String, MCPSession> sessions = new ConcurrentHashMap<>();
private final transient Map<String, SSEConnection> sseConnections = new ConcurrentHashMap<>();
private final transient SecureRandom secureRandom = new SecureRandom();
private final transient ExecutorService executorService =
Executors.newCachedThreadPool(
r -> {
Thread t = new Thread(r, "MCP-Worker-Streamable");
t.setDaemon(true);
return t;
});
private transient McpServerSession.Factory sessionFactory;
private final transient JwtFilter jwtFilter;
private final transient Authorizer authorizer;
private final transient List<McpSchema.Tool> tools = new ArrayList<>();
private final transient Limits limits;
private final transient DefaultToolContext toolContext;
private final transient MCPConfiguration mcpConfiguration;
public MCPStreamableHttpServlet(
MCPConfiguration mcpConfiguration,
JwtFilter jwtFilter,
Authorizer authorizer,
Limits limits,
DefaultToolContext toolContext,
List<McpSchema.Tool> tools) {
this.mcpConfiguration = mcpConfiguration;
this.jwtFilter = jwtFilter;
this.authorizer = authorizer;
this.limits = limits;
this.toolContext = toolContext;
this.tools.addAll(tools);
}
@Override
public void init() throws ServletException {
super.init();
log("MCP Streamable HTTP Servlet initialized");
}
@Override
public void destroy() {
if (executorService != null) {
executorService.shutdown();
try {
if (!executorService.awaitTermination(5, TimeUnit.SECONDS)) {
executorService.shutdownNow();
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
// Close all SSE connections
for (SSEConnection connection : sseConnections.values()) {
try {
connection.close();
} catch (IOException e) {
log("Error closing SSE connection", e);
}
}
super.destroy();
}
@Override
protected void doPost(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException {
String origin = request.getHeader("Origin");
if (origin != null && !isValidOrigin(origin)) {
sendError(response, HttpServletResponse.SC_FORBIDDEN, "Invalid origin");
return;
}
String acceptHeader = request.getHeader("Accept");
if (acceptHeader == null
|| !acceptHeader.contains(CONTENT_TYPE_JSON)
|| !acceptHeader.contains(CONTENT_TYPE_SSE)) {
sendError(
response,
HttpServletResponse.SC_BAD_REQUEST,
"Accept header must include both application/json and text/event-stream");
return;
}
try {
String requestBody = readRequestBody(request);
String sessionId = request.getHeader(SESSION_HEADER);
// Parse JSON-RPC message(s)
McpSchema.JSONRPCMessage message =
getJsonRpcMessageWithAuthorizationParam(this.objectMapper, request, requestBody);
JsonNode jsonNode = objectMapper.valueToTree(message);
if (jsonNode.isArray()) {
handleBatchRequest(request, response, jsonNode, sessionId);
} else {
handleSingleMessage(request, response, jsonNode, sessionId);
}
} catch (Exception e) {
log("Error handling POST request", e);
sendError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Internal server error");
}
}
/**
* Handle a single JSON-RPC message according to MCP specification
*/
private void handleSingleMessage(
HttpServletRequest request,
HttpServletResponse response,
JsonNode jsonMessage,
String sessionId)
throws IOException {
String method = jsonMessage.has("method") ? jsonMessage.get("method").asText() : null;
boolean hasId = jsonMessage.has("id");
boolean isResponse = jsonMessage.has("result") || jsonMessage.has("error");
// Handle initialization specially
if ("initialize".equals(method)) {
handleInitialize(response, jsonMessage);
return;
}
// Validate session for non-initialization requests (except for responses/notifications)
if (sessionId != null && !sessions.containsKey(sessionId) && !isResponse) {
sendError(response, HttpServletResponse.SC_NOT_FOUND, "Session not found");
return;
}
if (isResponse || (!hasId && method != null)) {
// This is either a response or a notification
handleResponseOrNotification(response, jsonMessage, sessionId, isResponse);
} else if (hasId && method != null) {
// This is a request - may return JSON or start SSE stream
handleRequest(request, response, jsonMessage, sessionId);
} else {
sendError(response, HttpServletResponse.SC_BAD_REQUEST, "Invalid JSON-RPC message format");
}
}
/**
* Handle responses and notifications from client
*/
private void handleResponseOrNotification(
HttpServletResponse response, JsonNode message, String sessionId, boolean isResponse)
throws IOException {
try {
if (isResponse) {
processClientResponse(message, sessionId);
log("Processed client response for session: " + sessionId);
} else {
processNotification(message, sessionId);
log("Processed client notification for session: " + sessionId);
}
response.setStatus(HttpServletResponse.SC_ACCEPTED);
response.setContentLength(0);
} catch (Exception e) {
log("Error processing response/notification", e);
sendError(response, HttpServletResponse.SC_BAD_REQUEST, "Failed to process message");
}
}
/**
* Handle JSON-RPC requests from client
*/
private void handleRequest(
HttpServletRequest request,
HttpServletResponse response,
JsonNode jsonRequest,
String sessionId)
throws IOException {
String acceptHeader = request.getHeader("Accept");
boolean supportsSSE = acceptHeader != null && acceptHeader.contains(CONTENT_TYPE_SSE);
// Determine whether to use SSE stream or direct JSON response
boolean useSSE = supportsSSE && shouldUseSSEForRequest(jsonRequest, sessionId);
if (useSSE) {
// Initiate SSE stream and process request asynchronously
startSSEStreamForRequest(request, response, jsonRequest, sessionId);
} else {
// Send direct JSON response
sendDirectJSONResponse(response, jsonRequest, sessionId);
}
}
/**
* Send direct JSON response for requests
*/
private void sendDirectJSONResponse(
HttpServletResponse response, JsonNode request, String sessionId) throws IOException {
Map<String, Object> jsonResponse = processRequest(request, sessionId);
String responseJson = objectMapper.writeValueAsString(jsonResponse);
response.setContentType(CONTENT_TYPE_JSON);
response.setStatus(HttpServletResponse.SC_OK);
// Include session ID in response if present
if (sessionId != null) {
response.setHeader(SESSION_HEADER, sessionId);
}
response.getWriter().write(responseJson);
response.getWriter().flush();
}
/**
* Start SSE stream for processing requests
*/
private void startSSEStreamForRequest(
HttpServletRequest request,
HttpServletResponse response,
JsonNode jsonRequest,
String sessionId)
throws IOException {
// Set up SSE response headers
response.setContentType(CONTENT_TYPE_SSE);
response.setHeader("Cache-Control", "no-cache");
response.setHeader("Connection", "keep-alive");
response.setHeader("Access-Control-Allow-Origin", "*");
// Include session ID in response if present
if (sessionId != null) {
response.setHeader(SESSION_HEADER, sessionId);
}
response.setStatus(HttpServletResponse.SC_OK);
// Start async processing
AsyncContext asyncContext = request.startAsync();
asyncContext.setTimeout(300000); // 5 minutes timeout
SSEConnection connection = new SSEConnection(response.getWriter(), sessionId);
String connectionId = UUID.randomUUID().toString();
sseConnections.put(connectionId, connection);
// Process request asynchronously
executorService.submit(
() -> {
try {
// Send server-initiated messages first
// sendServerInitiatedMessages(connection);
// Process the actual request and send response
Map<String, Object> jsonResponse = processRequest(jsonRequest, sessionId);
String eventId = generateEventId();
connection.sendEventWithId(eventId, objectMapper.writeValueAsString(jsonResponse));
// Close the stream after sending response
connection.close();
} catch (Exception e) {
log("Error in SSE stream processing for request", e);
try {
// Send error response before closing
Map<String, Object> errorResponse =
createErrorResponse(
jsonRequest.get("id"), -32603, "Internal error: " + e.getMessage());
connection.sendEvent(objectMapper.writeValueAsString(errorResponse));
connection.close();
} catch (Exception ex) {
log("Error sending error response in SSE stream", ex);
}
} finally {
try {
asyncContext.complete();
} catch (Exception e) {
log("Error completing async context", e);
}
sseConnections.remove(connectionId);
}
});
}
/**
* Handle batch requests according to MCP specification
*/
private void handleBatchRequest(
HttpServletRequest request,
HttpServletResponse response,
JsonNode batchRequest,
String sessionId)
throws IOException {
if (!batchRequest.isArray() || batchRequest.size() == 0) {
sendError(response, HttpServletResponse.SC_BAD_REQUEST, "Invalid batch request format");
return;
}
// Analyze the batch to determine message types
boolean hasRequests = false;
boolean hasResponsesOrNotifications = false;
for (JsonNode message : batchRequest) {
boolean hasId = message.has("id");
boolean isResponse = message.has("result") || message.has("error");
boolean hasMethod = message.has("method");
if (hasId && hasMethod && !isResponse) {
hasRequests = true;
} else if (isResponse || (!hasId && hasMethod)) {
hasResponsesOrNotifications = true;
}
}
if (hasRequests && hasResponsesOrNotifications) {
sendError(
response,
HttpServletResponse.SC_BAD_REQUEST,
"Batch cannot mix requests with responses/notifications");
return;
}
if (hasResponsesOrNotifications) {
// Process responses and notifications, return 202 Accepted
processBatchResponsesOrNotifications(batchRequest, sessionId);
response.setStatus(HttpServletResponse.SC_ACCEPTED);
response.setContentLength(0);
} else if (hasRequests) {
// Process batch requests - determine if SSE or direct JSON
String acceptHeader = request.getHeader("Accept");
boolean supportsSSE = acceptHeader != null && acceptHeader.contains(CONTENT_TYPE_SSE);
if (supportsSSE && shouldUseSSEForBatch(batchRequest, sessionId)) {
startSSEStreamForBatchRequests(request, response, batchRequest, sessionId);
} else {
sendBatchJSONResponse(response, batchRequest, sessionId);
}
} else {
sendError(response, HttpServletResponse.SC_BAD_REQUEST, "Invalid batch content");
}
}
/**
* Process client responses (for server-initiated requests)
*/
private void processClientResponse(JsonNode response, String sessionId) {
// Handle responses to server-initiated requests
Object id = response.has("id") ? response.get("id") : null;
LOG.info("Received client response for request ID: " + id + " (session: " + sessionId + ")");
// TODO: Match with pending server requests and complete them
// This would involve maintaining a map of pending requests by ID
}
/**
* Determine if SSE should be used for a specific request
*/
private boolean shouldUseSSEForRequest(JsonNode request, String sessionId) {
// TODO: This is good for now, but we can enhance this logic later, like tools/call can be long
// running for our use case should be fine
// Use SSE for requests that are streaming operations or have specific methods
MCPSession session = sessions.get(sessionId);
return session != null;
}
/**
* Determine if SSE should be used for batch requests
*/
private boolean shouldUseSSEForBatch(JsonNode batchRequest, String sessionId) {
// Use SSE for batches containing streaming operations
for (JsonNode request : batchRequest) {
if (shouldUseSSEForRequest(request, sessionId)) {
return true;
}
}
return false;
}
/**
* Process batch responses and notifications
*/
private void processBatchResponsesOrNotifications(JsonNode batch, String sessionId) {
for (JsonNode message : batch) {
boolean isResponse = message.has("result") || message.has("error");
if (isResponse) {
processClientResponse(message, sessionId);
} else {
processNotification(message, sessionId);
}
}
}
/**
* Send batch JSON response
*/
private void sendBatchJSONResponse(
HttpServletResponse response, JsonNode batchRequest, String sessionId) throws IOException {
List<Map<String, Object>> responses = new ArrayList<>();
for (JsonNode request : batchRequest) {
Map<String, Object> jsonResponse = processRequest(request, sessionId);
responses.add(jsonResponse);
}
String responseJson = objectMapper.writeValueAsString(responses);
response.setContentType(CONTENT_TYPE_JSON);
response.setStatus(HttpServletResponse.SC_OK);
if (sessionId != null) {
response.setHeader(SESSION_HEADER, sessionId);
}
response.getWriter().write(responseJson);
response.getWriter().flush();
}
/**
* Start SSE stream for batch requests
*/
private void startSSEStreamForBatchRequests(
HttpServletRequest request,
HttpServletResponse response,
JsonNode batchRequest,
String sessionId)
throws IOException {
// Set up SSE response headers
response.setContentType(CONTENT_TYPE_SSE);
response.setHeader("Cache-Control", "no-cache");
response.setHeader("Connection", "keep-alive");
response.setHeader("Access-Control-Allow-Origin", "*");
if (sessionId != null) {
response.setHeader(SESSION_HEADER, sessionId);
}
response.setStatus(HttpServletResponse.SC_OK);
AsyncContext asyncContext = request.startAsync();
asyncContext.setTimeout(300000); // 5 minutes timeout
SSEConnection connection = new SSEConnection(response.getWriter(), sessionId);
String connectionId = UUID.randomUUID().toString();
sseConnections.put(connectionId, connection);
// Process batch asynchronously
executorService.submit(
() -> {
try {
// Send server-initiated messages first
// sendServerInitiatedMessages(connection);
// Process each request in the batch
List<Map<String, Object>> responses = new ArrayList<>();
for (JsonNode req : batchRequest) {
Map<String, Object> jsonResponse = processRequest(req, sessionId);
responses.add(jsonResponse);
}
// Send batch response
String eventId = generateEventId();
connection.sendEventWithId(eventId, objectMapper.writeValueAsString(responses));
connection.close();
} catch (Exception e) {
log("Error in SSE stream processing for batch", e);
try {
Map<String, Object> errorResponse =
createErrorResponse(null, -32603, "Internal error: " + e.getMessage());
connection.sendEvent(objectMapper.writeValueAsString(errorResponse));
connection.close();
} catch (Exception ex) {
log("Error sending error response in batch SSE stream", ex);
}
} finally {
try {
asyncContext.complete();
} catch (Exception e) {
log("Error completing async context for batch", e);
}
sseConnections.remove(connectionId);
}
});
}
/**
* Create error response object
*/
private Map<String, Object> createErrorResponse(JsonNode id, int code, String message) {
Map<String, Object> response = new HashMap<>();
response.put("jsonrpc", "2.0");
response.put("id", id);
response.put("error", createError(code, message));
return response;
}
/**
* Generate unique event ID for SSE
*/
private String generateEventId() {
return UUID.randomUUID().toString();
}
/**
* Enhanced SSE connection class with event ID support
*/
public static class SSEConnection {
private final PrintWriter writer;
private final String sessionId;
private volatile boolean closed = false;
private int eventCounter = 0;
public SSEConnection(PrintWriter writer, String sessionId) {
this.writer = writer;
this.sessionId = sessionId;
}
public synchronized void sendEvent(String data) throws IOException {
if (closed) return;
writer.printf("id: %d%n", ++eventCounter);
writer.printf("data: %s%n%n", data);
writer.flush();
if (writer.checkError()) {
throw new IOException("SSE write error");
}
}
public synchronized void sendEventWithId(String id, String data) throws IOException {
if (closed) return;
writer.printf("id: %s%n", id);
writer.printf("data: %s%n%n", data);
writer.flush();
if (writer.checkError()) {
throw new IOException("SSE write error");
}
}
public synchronized void sendComment(String comment) throws IOException {
if (closed) return;
writer.printf(": %s%n%n", comment);
writer.flush();
if (writer.checkError()) {
throw new IOException("SSE write error");
}
}
public void close() throws IOException {
closed = true;
if (writer != null) {
writer.close();
}
}
public boolean isClosed() {
return closed;
}
public String getSessionId() {
return sessionId;
}
}
@Override
protected void doGet(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException {
String sessionId = request.getHeader(SESSION_HEADER);
String acceptHeader = request.getHeader("Accept");
if (acceptHeader == null || !acceptHeader.contains(CONTENT_TYPE_SSE)) {
sendError(
response,
HttpServletResponse.SC_BAD_REQUEST,
"Accept header must include text/event-stream");
return;
}
if (sessionId != null && !sessions.containsKey(sessionId)) {
sendError(response, HttpServletResponse.SC_NOT_FOUND, "Session not found");
return;
}
startSSEStreamForGet(request, response, sessionId);
}
@Override
protected void doDelete(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException {
String sessionId = request.getHeader(SESSION_HEADER);
removeMCPSession(sessionId);
response.setStatus(HttpServletResponse.SC_OK);
}
private void removeMCPSession(String sessionId) {
if (sessionId == null) {
throw BadRequestException.of("Session ID required");
}
// Terminate session
MCPSession session = sessions.remove(sessionId);
if (session != null) {
LOG.info("Session terminated: " + sessionId);
}
}
private void handleInitialize(HttpServletResponse response, JsonNode request) throws IOException {
// Create new session
String sessionId = generateSessionId();
MCPSession session = new MCPSession(this.objectMapper, sessionId, response.getWriter());
sessions.put(sessionId, session);
// Create initialize response
Map<String, Object> jsonResponse = new HashMap<>();
jsonResponse.put("jsonrpc", "2.0");
jsonResponse.put("id", request.get("id"));
Map<String, Object> result = new HashMap<>();
result.put("protocolVersion", "2024-11-05");
result.put("capabilities", getServerCapabilities());
result.put("serverInfo", getServerInfo());
jsonResponse.put("result", result);
// Send response with session ID
String responseJson = objectMapper.writeValueAsString(jsonResponse);
response.setContentType(CONTENT_TYPE_JSON);
response.setHeader(SESSION_HEADER, sessionId);
response.setStatus(HttpServletResponse.SC_OK);
response.getWriter().write(responseJson);
log("New session initialized: " + sessionId);
}
private void startSSEStreamForGet(
HttpServletRequest request, HttpServletResponse response, String sessionId)
throws IOException {
response.setContentType(CONTENT_TYPE_SSE);
response.setHeader("Cache-Control", "no-cache");
response.setHeader("Connection", "keep-alive");
response.setHeader("Access-Control-Allow-Origin", "*");
response.setStatus(HttpServletResponse.SC_OK);
AsyncContext asyncContext = request.startAsync();
asyncContext.setTimeout(0); // No timeout for long-lived connections
SSEConnection connection = new SSEConnection(response.getWriter(), sessionId);
String connectionId = UUID.randomUUID().toString();
sseConnections.put(connectionId, connection);
// Keep connection alive and handle server-initiated messages
executorService.submit(
() -> {
try {
while (!connection.isClosed()) {
// Send keepalive every 30 seconds
Thread.sleep(30000);
connection.sendComment("keepalive");
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} catch (IOException e) {
LOG.error("SSE connection error for connection ID: {}", connectionId, e);
} finally {
try {
connection.close();
asyncContext.complete();
} catch (Exception e) {
log("Error closing long-lived SSE connection", e);
}
sseConnections.remove(connectionId);
}
});
}
@Override
public void setSessionFactory(McpServerSession.Factory sessionFactory) {
this.sessionFactory = sessionFactory;
}
@Override
public Mono<Void> notifyClients(String method, Map<String, Object> params) {
if (sessions.isEmpty()) {
LOG.debug("No active sessions to broadcast message to");
return Mono.empty();
}
LOG.debug("Attempting to broadcast message to {} active sessions", sessions.size());
return Flux.fromIterable(sessions.values())
.flatMap(
session ->
session
.sendNotification(method, params)
.doOnError(
e ->
LOG.error(
"Failed to send message to session {}: {}",
session.getSessionId(),
e.getMessage()))
.onErrorComplete())
.then();
}
@Override
public Mono<Void> closeGracefully() {
LOG.debug("Initiating graceful shutdown with {} active sessions", sessions.size());
return Flux.fromIterable(sessions.values()).flatMap(MCPSession::closeGracefully).then();
}
private static class MCPSession {
@Getter private final String sessionId;
private final long createdAt;
private final Map<String, Object> state;
private final PrintWriter outputStream;
private final ObjectMapper objectMapper;
public MCPSession(ObjectMapper objectMapper, String sessionId, PrintWriter outputStream) {
this.sessionId = sessionId;
this.createdAt = System.currentTimeMillis();
this.state = new ConcurrentHashMap<>();
this.objectMapper = objectMapper;
this.outputStream = outputStream;
}
public String getSessionId() {
return sessionId;
}
public long getCreatedAt() {
return createdAt;
}
public Map<String, Object> getState() {
return state;
}
public Mono<Void> sendNotification(String method, Map<String, Object> params) {
McpSchema.JSONRPCNotification jsonrpcNotification =
new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, method, params);
return Mono.fromRunnable(
() -> {
try {
String json = objectMapper.writeValueAsString(jsonrpcNotification);
outputStream.write(json);
outputStream.write('\n');
outputStream.flush();
} catch (IOException e) {
LOG.error("Failed to send message", e);
}
});
}
public Mono<Void> closeGracefully() {
return Mono.fromRunnable(
() -> {
outputStream.flush();
outputStream.close();
});
}
}
private Map<String, Object> processRequest(JsonNode request, String sessionId) {
Map<String, Object> response = new HashMap<>();
response.put("jsonrpc", "2.0");
response.put("id", request.get("id"));
String method = request.get("method").asText();
try {
switch (method) {
case "ping":
response.put("result", "pong");
break;
case "resources/list":
response.put("result", new McpSchema.ListResourcesResult(new ArrayList<>(), null));
break;
case "resources/read":
// TODO: Implement resource reading logic
break;
case "tools/list":
response.put("result", new McpSchema.ListToolsResult(tools, null));
break;
case "tools/call":
JsonNode toolParams = request.get("params");
if (toolParams != null && toolParams.has("name")) {
String toolName = toolParams.get("name").asText();
JsonNode arguments = toolParams.get("arguments");
McpSchema.Content content =
new McpSchema.TextContent(
JsonUtils.pojoToJson(
toolContext.callTool(
authorizer, jwtFilter, limits, toolName, JsonUtils.getMap(arguments))));
response.put("result", new McpSchema.CallToolResult(List.of(content), false));
} else {
response.put("error", createError(-32602, "Invalid params"));
}
break;
default:
response.put("error", createError(-32601, "Method not found: " + method));
}
} catch (Exception e) {
log("Error processing request: " + method, e);
response.put("error", createError(-32603, "Internal error: " + e.getMessage()));
}
return response;
}
private void processNotification(JsonNode notification, String sessionId) {
String method = notification.get("method").asText();
log("Received notification: " + method + " (session: " + sessionId + ")");
// Handle specific notifications
switch (method) {
case "notifications/initialized":
LOG.info("Client initialized for session: {}", sessionId);
break;
case "notifications/cancelled":
// TODO: Check this
LOG.info("Client sent a cancellation request for session: {}", sessionId);
removeMCPSession(sessionId);
break;
default:
log("Unknown notification: " + method);
}
}
// Utility methods
private String generateSessionId() {
byte[] bytes = new byte[32];
secureRandom.nextBytes(bytes);
return Base64.getUrlEncoder().withoutPadding().encodeToString(bytes);
}
private boolean isValidOrigin(String origin) {
if (null == origin || origin.isEmpty()) {
return false;
}
return origin.startsWith(mcpConfiguration.getOriginHeaderUri());
}
private String readRequestBody(HttpServletRequest request) throws IOException {
StringBuilder body = new StringBuilder();
try (BufferedReader reader = request.getReader()) {
String line;
while ((line = reader.readLine()) != null) {
body.append(line);
}
}
return body.toString();
}
private void sendError(HttpServletResponse response, int statusCode, String message)
throws IOException {
Map<String, Object> error = new HashMap<>();
error.put("error", message);
String errorJson = objectMapper.writeValueAsString(error);
response.setContentType(CONTENT_TYPE_JSON);
response.setStatus(statusCode);
response.getWriter().write(errorJson);
}
private Map<String, Object> createError(int code, String message) {
Map<String, Object> error = new HashMap<>();
error.put("code", code);
error.put("message", message);
return error;
}
private McpSchema.ServerCapabilities getServerCapabilities() {
return McpSchema.ServerCapabilities.builder()
.tools(true)
.prompts(true)
.resources(true, true)
.build();
}
private McpSchema.Implementation getServerInfo() {
return new McpSchema.Implementation("OpenMetadata MCP Server - Streamable", "1.0.0");
}
/**
* Send a server-initiated notification to a specific session
*/
public void sendNotificationToSession(
String sessionId, String method, Map<String, Object> params) {
// Find SSE connections for this session
for (SSEConnection connection : sseConnections.values()) {
if (sessionId.equals(connection.getSessionId())) {
try {
Map<String, Object> notification = new HashMap<>();
notification.put("jsonrpc", "2.0");
notification.put("method", method);
if (params != null) {
notification.put("params", params);
}
connection.sendEvent(objectMapper.writeValueAsString(notification));
} catch (IOException e) {
log("Error sending notification to session: " + sessionId, e);
}
}
}
}
}

View File

@ -1,33 +1,23 @@
package org.openmetadata.service.mcp;
import static org.openmetadata.service.search.SearchUtil.searchMetadata;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.jetty.MutableServletContextHandler;
import io.modelcontextprotocol.server.McpServerFeatures;
import io.modelcontextprotocol.server.McpSyncServer;
import io.modelcontextprotocol.spec.McpSchema;
import jakarta.servlet.DispatcherType;
import jakarta.servlet.Filter;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import org.eclipse.jetty.servlet.FilterHolder;
import org.eclipse.jetty.servlet.ServletHolder;
import org.openmetadata.common.utils.CommonUtil;
import org.openmetadata.service.OpenMetadataApplicationConfig;
import org.openmetadata.service.config.MCPConfiguration;
import org.openmetadata.service.limits.Limits;
import org.openmetadata.service.mcp.tools.GlossaryTermTool;
import org.openmetadata.service.mcp.tools.GlossaryTool;
import org.openmetadata.service.mcp.tools.PatchEntityTool;
import org.openmetadata.service.security.AuthorizationException;
import org.openmetadata.service.mcp.tools.DefaultToolContext;
import org.openmetadata.service.security.Authorizer;
import org.openmetadata.service.security.JwtFilter;
import org.openmetadata.service.security.auth.CatalogSecurityContext;
import org.openmetadata.service.util.EntityUtil;
import org.openmetadata.service.util.JsonUtils;
@Slf4j
@ -35,8 +25,11 @@ public class McpServer {
private JwtFilter jwtFilter;
private Authorizer authorizer;
private Limits limits;
protected DefaultToolContext toolContext;
public McpServer() {}
public McpServer(DefaultToolContext toolContext) {
this.toolContext = toolContext;
}
public void initializeMcpServer(
Environment environment,
@ -47,6 +40,24 @@ public class McpServer {
new JwtFilter(config.getAuthenticationConfiguration(), config.getAuthorizerConfiguration());
this.authorizer = authorizer;
this.limits = limits;
MutableServletContextHandler contextHandler = environment.getApplicationContext();
McpAuthFilter authFilter =
new McpAuthFilter(
new JwtFilter(
config.getAuthenticationConfiguration(), config.getAuthorizerConfiguration()));
List<McpSchema.Tool> tools = getTools();
addSSETransport(contextHandler, authFilter, tools);
addStreamableHttpServlet(config.getMcpConfiguration(), contextHandler, authFilter, tools);
}
protected List<McpSchema.Tool> getTools() {
return toolContext.loadToolsDefinitionsFromJson("json/data/mcp/tools.json");
}
private void addSSETransport(
MutableServletContextHandler contextHandler,
McpAuthFilter authFilter,
List<McpSchema.Tool> tools) {
McpSchema.ServerCapabilities serverCapabilities =
McpSchema.ServerCapabilities.builder()
.tools(true)
@ -54,152 +65,55 @@ public class McpServer {
.resources(true, true)
.build();
HttpServletSseServerTransportProvider transport =
new HttpServletSseServerTransportProvider("/mcp/messages", "/mcp/sse");
HttpServletSseServerTransportProvider sseTransport =
new HttpServletSseServerTransportProvider(new ObjectMapper(), "/mcp/messages", "/mcp/sse");
McpSyncServer server =
io.modelcontextprotocol.server.McpServer.sync(transport)
.serverInfo("openmetadata-mcp", "0.1.0")
io.modelcontextprotocol.server.McpServer.sync(sseTransport)
.serverInfo("openmetadata-mcp-sse", "0.1.0")
.capabilities(serverCapabilities)
.build();
addToolsToServer(server, tools);
// Add resources, prompts, and tools to the MCP server
addTools(server);
// SSE transport for MCP
ServletHolder servletHolderSSE = new ServletHolder(sseTransport);
contextHandler.addServlet(servletHolderSSE, "/mcp/*");
MutableServletContextHandler contextHandler = environment.getApplicationContext();
ServletHolder servletHolder = new ServletHolder(transport);
contextHandler.addServlet(servletHolder, "/mcp/*");
McpAuthFilter authFilter =
new McpAuthFilter(
new JwtFilter(
config.getAuthenticationConfiguration(), config.getAuthorizerConfiguration()));
contextHandler.addFilter(
new FilterHolder((Filter) authFilter), "/mcp/*", EnumSet.of(DispatcherType.REQUEST));
new FilterHolder(authFilter), "/mcp/*", EnumSet.of(DispatcherType.REQUEST));
}
public void addTools(McpSyncServer server) {
try {
LOG.info("Loading tool definitions...");
List<Map<String, Object>> cachedTools = loadToolsDefinitionsFromJson();
if (cachedTools == null || cachedTools.isEmpty()) {
LOG.error("No tool definitions were loaded!");
throw new RuntimeException("Failed to load tool definitions");
}
LOG.info("Successfully loaded {} tool definitions", cachedTools.size());
private void addStreamableHttpServlet(
MCPConfiguration configuration,
MutableServletContextHandler contextHandler,
McpAuthFilter authFilter,
List<McpSchema.Tool> tools) {
// Streamable HTTP servlet for MCP
MCPStreamableHttpServlet streamableHttpServlet =
new MCPStreamableHttpServlet(
configuration, jwtFilter, authorizer, limits, toolContext, tools);
ServletHolder servletHolderStreamableHttp = new ServletHolder(streamableHttpServlet);
contextHandler.addServlet(servletHolderStreamableHttp, "/mcp");
for (Map<String, Object> toolDef : cachedTools) {
try {
String name = (String) toolDef.get("name");
String description = (String) toolDef.get("description");
Map<String, Object> schema = JsonUtils.getMap(toolDef.get("parameters"));
server.addTool(getTool(JsonUtils.pojoToJson(schema), name, description));
} catch (Exception e) {
LOG.error("Error processing tool definition: {}", toolDef, e);
}
}
LOG.info("Initializing request handlers...");
} catch (Exception e) {
LOG.error("Error during server startup", e);
throw new RuntimeException("Failed to start MCP server", e);
contextHandler.addFilter(
new FilterHolder(authFilter), "/mcp", EnumSet.of(DispatcherType.REQUEST));
}
public void addToolsToServer(McpSyncServer server, List<McpSchema.Tool> tools) {
for (McpSchema.Tool tool : tools) {
server.addTool(getTool(tool));
}
}
protected List<Map<String, Object>> loadToolsDefinitionsFromJson() {
String json = getJsonFromFile("json/data/mcp/tools.json");
return loadToolDefinitionsFromJson(json);
}
protected static String getJsonFromFile(String path) {
try {
return CommonUtil.getResourceAsStream(McpServer.class.getClassLoader(), path);
} catch (Exception ex) {
LOG.error("Error loading JSON file: {}", path, ex);
return null;
}
}
@SuppressWarnings("unchecked")
public List<Map<String, Object>> loadToolDefinitionsFromJson(String json) {
try {
LOG.info("Loaded tool definitions, content length: {}", json.length());
LOG.info("Raw tools.json content: {}", json);
JsonNode toolsJson = JsonUtils.readTree(json);
JsonNode toolsArray = toolsJson.get("tools");
if (toolsArray == null || !toolsArray.isArray()) {
LOG.error("Invalid MCP tools file format. Expected 'tools' array.");
return new ArrayList<>();
}
List<Map<String, Object>> tools = new ArrayList<>();
for (JsonNode toolNode : toolsArray) {
String name = toolNode.get("name").asText();
Map<String, Object> toolDef = JsonUtils.convertValue(toolNode, Map.class);
tools.add(toolDef);
LOG.info("Tool found: {} with definition: {}", name, toolDef);
}
LOG.info("Found {} tool definitions", tools.size());
return tools;
} catch (Exception e) {
LOG.error("Error loading tool definitions: {}", e.getMessage(), e);
throw e;
}
}
private McpServerFeatures.SyncToolSpecification getTool(
String schema, String toolName, String description) {
McpSchema.Tool tool = new McpSchema.Tool(toolName, description, schema);
private McpServerFeatures.SyncToolSpecification getTool(McpSchema.Tool tool) {
return new McpServerFeatures.SyncToolSpecification(
tool,
(exchange, arguments) -> {
McpSchema.Content content =
new McpSchema.TextContent(JsonUtils.pojoToJson(runMethod(toolName, arguments)));
new McpSchema.TextContent(
JsonUtils.pojoToJson(
toolContext.callTool(authorizer, jwtFilter, limits, tool.name(), arguments)));
return new McpSchema.CallToolResult(List.of(content), false);
});
}
protected Object runMethod(String toolName, Map<String, Object> params) {
CatalogSecurityContext securityContext =
jwtFilter.getCatalogSecurityContext((String) params.get("Authorization"));
LOG.info(
"Catalog Principal: {} is trying to call the tool: {}",
securityContext.getUserPrincipal().getName(),
toolName);
Object result;
try {
switch (toolName) {
case "search_metadata":
result = searchMetadata(params);
break;
case "get_entity_details":
result = EntityUtil.getEntityDetails(params);
break;
case "create_glossary":
result = new GlossaryTool().execute(authorizer, limits, securityContext, params);
break;
case "create_glossary_term":
result = new GlossaryTermTool().execute(authorizer, limits, securityContext, params);
break;
case "patch_entity":
result = new PatchEntityTool().execute(authorizer, securityContext, params);
break;
default:
result = Map.of("error", "Unknown function: " + toolName);
break;
}
return result;
} catch (AuthorizationException ex) {
LOG.error("Authorization error: {}", ex.getMessage());
return Map.of(
"error", String.format("Authorization error: %s", ex.getMessage()), "statusCode", 403);
} catch (Exception ex) {
LOG.error("Error executing tool: {}", ex.getMessage());
return Map.of(
"error", String.format("Error executing tool: %s", ex.getMessage()), "statusCode", 500);
}
}
}

View File

@ -0,0 +1,94 @@
package org.openmetadata.service.mcp;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.spec.McpSchema;
import jakarta.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import org.openmetadata.common.utils.CommonUtil;
import org.openmetadata.service.security.JwtFilter;
import org.openmetadata.service.util.JsonUtils;
@Slf4j
public class McpUtils {
public static McpSchema.JSONRPCMessage getJsonRpcMessageWithAuthorizationParam(
ObjectMapper objectMapper, HttpServletRequest request, String body) throws IOException {
Map<String, Object> requestMessage = JsonUtils.getMap(JsonUtils.readTree(body));
Map<String, Object> params = (Map<String, Object>) requestMessage.get("params");
if (params != null) {
Map<String, Object> arguments = (Map<String, Object>) params.get("arguments");
if (arguments != null) {
arguments.put("Authorization", JwtFilter.extractToken(request.getHeader("Authorization")));
}
}
return McpSchema.deserializeJsonRpcMessage(objectMapper, JsonUtils.pojoToJson(requestMessage));
}
@SuppressWarnings("unchecked")
public static List<Map<String, Object>> loadToolDefinitionsFromJson(String json) {
try {
LOG.info("Loaded tool definitions, content length: {}", json.length());
LOG.info("Raw tools.json content: {}", json);
JsonNode toolsJson = JsonUtils.readTree(json);
JsonNode toolsArray = toolsJson.get("tools");
if (toolsArray == null || !toolsArray.isArray()) {
LOG.error("Invalid MCP tools file format. Expected 'tools' array.");
return new ArrayList<>();
}
List<Map<String, Object>> tools = new ArrayList<>();
for (JsonNode toolNode : toolsArray) {
String name = toolNode.get("name").asText();
Map<String, Object> toolDef = JsonUtils.convertValue(toolNode, Map.class);
tools.add(toolDef);
LOG.info("Tool found: {} with definition: {}", name, toolDef);
}
LOG.info("Found {} tool definitions", tools.size());
return tools;
} catch (Exception e) {
LOG.error("Error loading tool definitions: {}", e.getMessage(), e);
throw e;
}
}
public static String getJsonFromFile(String path) {
try {
return CommonUtil.getResourceAsStream(McpServer.class.getClassLoader(), path);
} catch (Exception ex) {
LOG.error("Error loading JSON file: {}", path, ex);
return null;
}
}
public static List<McpSchema.Tool> getToolProperties(String jsonFilePath) {
try {
List<McpSchema.Tool> result = new ArrayList<>();
String json = getJsonFromFile(jsonFilePath);
List<Map<String, Object>> cachedTools = loadToolDefinitionsFromJson(json);
if (cachedTools == null || cachedTools.isEmpty()) {
LOG.error("No tool definitions were loaded!");
throw new RuntimeException("Failed to load tool definitions");
}
LOG.debug("Successfully loaded {} tool definitions", cachedTools.size());
for (int i = 0; i < cachedTools.size(); i++) {
Map<String, Object> toolDef = cachedTools.get(i);
String name = (String) toolDef.get("name");
String description = (String) toolDef.get("description");
Map<String, Object> schema = JsonUtils.getMap(toolDef.get("parameters"));
result.add(new McpSchema.Tool(name, description, JsonUtils.pojoToJson(schema)));
}
return result;
} catch (Exception e) {
LOG.error("Error during server startup", e);
throw new RuntimeException("Failed to start MCP server", e);
}
}
}

View File

@ -0,0 +1,75 @@
package org.openmetadata.service.mcp.tools;
import static org.openmetadata.service.mcp.McpUtils.getToolProperties;
import io.modelcontextprotocol.spec.McpSchema;
import java.util.List;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import org.openmetadata.service.limits.Limits;
import org.openmetadata.service.security.AuthorizationException;
import org.openmetadata.service.security.Authorizer;
import org.openmetadata.service.security.JwtFilter;
import org.openmetadata.service.security.auth.CatalogSecurityContext;
@Slf4j
public class DefaultToolContext {
public DefaultToolContext() {}
/**
* Loads tool definitions from a JSON file located at the specified path.
* The JSON file should contain an array of tool definitions under the "tools" key.
*
* @return List of McpSchema.Tool objects loaded from the JSON file.
*/
public List<McpSchema.Tool> loadToolsDefinitionsFromJson(String toolFilePath) {
return getToolProperties(toolFilePath);
}
public Object callTool(
Authorizer authorizer,
JwtFilter jwtFilter,
Limits limits,
String toolName,
Map<String, Object> params) {
CatalogSecurityContext securityContext =
jwtFilter.getCatalogSecurityContext((String) params.get("Authorization"));
LOG.info(
"Catalog Principal: {} is trying to call the tool: {}",
securityContext.getUserPrincipal().getName(),
toolName);
Object result;
try {
switch (toolName) {
case "search_metadata":
result = new SearchMetadataTool().execute(authorizer, securityContext, params);
break;
case "get_entity_details":
result = new GetEntityTool().execute(authorizer, securityContext, params);
break;
case "create_glossary":
result = new GlossaryTool().execute(authorizer, limits, securityContext, params);
break;
case "create_glossary_term":
result = new GlossaryTermTool().execute(authorizer, limits, securityContext, params);
break;
case "patch_entity":
result = new PatchEntityTool().execute(authorizer, limits, securityContext, params);
break;
default:
result = Map.of("error", "Unknown function: " + toolName);
break;
}
return result;
} catch (AuthorizationException ex) {
LOG.error("Authorization error: {}", ex.getMessage());
return Map.of(
"error", String.format("Authorization error: %s", ex.getMessage()), "statusCode", 403);
} catch (Exception ex) {
LOG.error("Error executing tool: {}", ex.getMessage());
return Map.of(
"error", String.format("Error executing tool: %s", ex.getMessage()), "statusCode", 500);
}
}
}

View File

@ -0,0 +1,42 @@
package org.openmetadata.service.mcp.tools;
import static org.openmetadata.schema.type.MetadataOperation.VIEW_ALL;
import java.io.IOException;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import org.openmetadata.service.Entity;
import org.openmetadata.service.limits.Limits;
import org.openmetadata.service.security.Authorizer;
import org.openmetadata.service.security.auth.CatalogSecurityContext;
import org.openmetadata.service.security.policyevaluator.OperationContext;
import org.openmetadata.service.security.policyevaluator.ResourceContext;
import org.openmetadata.service.util.JsonUtils;
@Slf4j
public class GetEntityTool implements McpTool {
@Override
public Map<String, Object> execute(
Authorizer authorizer, CatalogSecurityContext securityContext, Map<String, Object> params)
throws IOException {
String entityType = (String) params.get("entity_type");
String fqn = (String) params.get("fqn");
authorizer.authorize(
securityContext,
new OperationContext(entityType, VIEW_ALL),
new ResourceContext<>(entityType));
LOG.info("Getting details for entity type: {}, FQN: {}", entityType, fqn);
String fields = "*";
return JsonUtils.getMap(Entity.getEntityByName(entityType, fqn, fields, null));
}
@Override
public Map<String, Object> execute(
Authorizer authorizer,
Limits limits,
CatalogSecurityContext securityContext,
Map<String, Object> params)
throws IOException {
throw new UnsupportedOperationException("GetEntityTool does not requires limit validation.");
}
}

View File

@ -2,7 +2,6 @@ package org.openmetadata.service.mcp.tools;
import static org.openmetadata.common.utils.CommonUtil.nullOrEmpty;
import java.util.HashMap;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import org.openmetadata.schema.entity.data.Glossary;
@ -59,25 +58,19 @@ public class GlossaryTermTool implements McpTool {
limits.enforceLimits(securityContext, createResourceContext, operationContext);
authorizer.authorize(securityContext, operationContext, createResourceContext);
try {
GlossaryRepository glossaryRepository =
(GlossaryRepository) Entity.getEntityRepository(Entity.GLOSSARY);
Glossary glossary =
glossaryRepository.findByNameOrNull(createGlossaryTerm.getGlossary(), Include.ALL);
GlossaryRepository glossaryRepository =
(GlossaryRepository) Entity.getEntityRepository(Entity.GLOSSARY);
Glossary glossary =
glossaryRepository.findByNameOrNull(createGlossaryTerm.getGlossary(), Include.ALL);
GlossaryTermRepository glossaryTermRepository =
(GlossaryTermRepository) Entity.getEntityRepository(Entity.GLOSSARY_TERM);
// TODO: Get the updatedBy from the tool request.
glossaryTermRepository.prepare(glossaryTerm, nullOrEmpty(glossary));
glossaryTermRepository.setFullyQualifiedName(glossaryTerm);
RestUtil.PutResponse<GlossaryTerm> response =
glossaryTermRepository.createOrUpdate(
null, glossaryTerm, securityContext.getUserPrincipal().getName());
return JsonUtils.convertValue(response.getEntity(), Map.class);
} catch (Exception e) {
Map<String, Object> error = new HashMap<>();
error.put("error", e.getMessage());
return error;
}
GlossaryTermRepository glossaryTermRepository =
(GlossaryTermRepository) Entity.getEntityRepository(Entity.GLOSSARY_TERM);
// TODO: Get the updatedBy from the tool request.
glossaryTermRepository.prepare(glossaryTerm, nullOrEmpty(glossary));
glossaryTermRepository.setFullyQualifiedName(glossaryTerm);
RestUtil.PutResponse<GlossaryTerm> response =
glossaryTermRepository.createOrUpdate(
null, glossaryTerm, securityContext.getUserPrincipal().getName());
return JsonUtils.getMap(response.getEntity());
}
}

View File

@ -1,6 +1,5 @@
package org.openmetadata.service.mcp.tools;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
@ -26,7 +25,7 @@ public class GlossaryTool implements McpTool {
@Override
public Map<String, Object> execute(
Authorizer authorizer, CatalogSecurityContext securityContext, Map<String, Object> params) {
throw new UnsupportedOperationException("GlossaryTermTool requires limit validation.");
throw new UnsupportedOperationException("GlossaryTool requires limit validation.");
}
@Override
@ -56,21 +55,15 @@ public class GlossaryTool implements McpTool {
limits.enforceLimits(securityContext, createResourceContext, operationContext);
authorizer.authorize(securityContext, operationContext, createResourceContext);
try {
GlossaryRepository glossaryRepository =
(GlossaryRepository) Entity.getEntityRepository(Entity.GLOSSARY);
GlossaryRepository glossaryRepository =
(GlossaryRepository) Entity.getEntityRepository(Entity.GLOSSARY);
glossaryRepository.prepare(glossary, true);
glossaryRepository.setFullyQualifiedName(glossary);
RestUtil.PutResponse<Glossary> response =
glossaryRepository.createOrUpdate(
null, glossary, securityContext.getUserPrincipal().getName());
return JsonUtils.convertValue(response.getEntity(), Map.class);
} catch (Exception e) {
Map<String, Object> error = new HashMap<>();
error.put("error", e.getMessage());
return error;
}
glossaryRepository.prepare(glossary, true);
glossaryRepository.setFullyQualifiedName(glossary);
RestUtil.PutResponse<Glossary> response =
glossaryRepository.createOrUpdate(
null, glossary, securityContext.getUserPrincipal().getName());
return JsonUtils.convertValue(response.getEntity(), Map.class);
}
public static void setReviewers(CreateGlossary entity, Map<String, Object> params) {

View File

@ -1,5 +1,6 @@
package org.openmetadata.service.mcp.tools;
import java.io.IOException;
import java.util.Map;
import org.openmetadata.service.limits.Limits;
import org.openmetadata.service.security.Authorizer;
@ -7,11 +8,13 @@ import org.openmetadata.service.security.auth.CatalogSecurityContext;
public interface McpTool {
Map<String, Object> execute(
Authorizer authorizer, CatalogSecurityContext securityContext, Map<String, Object> params);
Authorizer authorizer, CatalogSecurityContext securityContext, Map<String, Object> params)
throws IOException;
Map<String, Object> execute(
Authorizer authorizer,
Limits limits,
CatalogSecurityContext securityContext,
Map<String, Object> params);
Map<String, Object> params)
throws IOException;
}

View File

@ -0,0 +1,154 @@
package org.openmetadata.service.mcp.tools;
import static org.openmetadata.service.search.SearchUtil.mapEntityTypesToIndexNames;
import static org.openmetadata.service.security.DefaultAuthorizer.getSubjectContext;
import com.fasterxml.jackson.databind.JsonNode;
import jakarta.ws.rs.core.Response;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import org.openmetadata.schema.search.SearchRequest;
import org.openmetadata.service.Entity;
import org.openmetadata.service.limits.Limits;
import org.openmetadata.service.security.Authorizer;
import org.openmetadata.service.security.auth.CatalogSecurityContext;
import org.openmetadata.service.security.policyevaluator.SubjectContext;
import org.openmetadata.service.util.JsonUtils;
@Slf4j
public class SearchMetadataTool implements McpTool {
private static final List<String> IGNORE_SEARCH_KEYS =
List.of(
"id",
"version",
"updatedAt",
"updatedBy",
"usageSummary",
"followers",
"deleted",
"votes",
"lifeCycle",
"sourceHash",
"processedLineage",
"totalVotes",
"fqnParts",
"service_suggest",
"column_suggest",
"schema_suggest",
"database_suggest",
"upstreamLineage",
"entityRelationship",
"changeSummary",
"fqnHash");
@Override
public Map<String, Object> execute(
Authorizer authorizer, CatalogSecurityContext securityContext, Map<String, Object> params)
throws IOException {
LOG.info("Executing searchMetadata with params: {}", params);
String query = params.containsKey("query") ? (String) params.get("query") : "*";
int limit = 10;
if (params.containsKey("limit")) {
Object limitObj = params.get("limit");
if (limitObj instanceof Number) {
limit = ((Number) limitObj).intValue();
} else if (limitObj instanceof String) {
limit = Integer.parseInt((String) limitObj);
}
}
boolean includeDeleted = false;
if (params.containsKey("include_deleted")) {
Object deletedObj = params.get("include_deleted");
if (deletedObj instanceof Boolean) {
includeDeleted = (Boolean) deletedObj;
} else if (deletedObj instanceof String) {
includeDeleted = "true".equals(deletedObj);
}
}
String entityType =
params.containsKey("entity_type") ? (String) params.get("entity_type") : null;
String index =
(entityType != null && !entityType.isEmpty())
? mapEntityTypesToIndexNames(entityType)
: Entity.TABLE;
LOG.info(
"Search query: {}, index: {}, limit: {}, includeDeleted: {}",
query,
index,
limit,
includeDeleted);
SearchRequest searchRequest =
new SearchRequest()
.withQuery(query)
.withIndex(index)
.withSize(limit)
.withFrom(0)
.withFetchSource(true)
.withDeleted(includeDeleted);
SubjectContext subjectContext = getSubjectContext(securityContext);
Response response = Entity.getSearchRepository().search(searchRequest, subjectContext);
Map<String, Object> searchResponse;
if (response.getEntity() instanceof String responseStr) {
LOG.info("Search returned string response");
JsonNode jsonNode = JsonUtils.readTree(responseStr);
searchResponse = JsonUtils.convertValue(jsonNode, Map.class);
} else {
LOG.info("Search returned object response: {}", response.getEntity().getClass().getName());
searchResponse = JsonUtils.convertValue(response.getEntity(), Map.class);
}
return SearchMetadataTool.cleanSearchResponse(searchResponse);
}
@Override
public Map<String, Object> execute(
Authorizer authorizer,
Limits limits,
CatalogSecurityContext securityContext,
Map<String, Object> params) {
throw new UnsupportedOperationException(
"SearchMetadataTool does not support limits enforcement.");
}
private static Map<String, Object> cleanSearchResponse(Map<String, Object> searchResponse) {
if (searchResponse == null) return Collections.emptyMap();
Map<String, Object> topHits = safeGetMap(searchResponse.get("hits"));
if (topHits == null) return Collections.emptyMap();
List<Object> hits = safeGetList(topHits.get("hits"));
if (hits == null || hits.isEmpty()) return Collections.emptyMap();
for (Object hitObj : hits) {
Map<String, Object> hit = safeGetMap(hitObj);
if (hit == null) continue;
Map<String, Object> source = safeGetMap(hit.get("_source"));
if (source == null) continue;
IGNORE_SEARCH_KEYS.forEach(source::remove);
return source; // Return the first valid, cleaned _source
}
return Collections.emptyMap();
}
@SuppressWarnings("unchecked")
private static Map<String, Object> safeGetMap(Object obj) {
return (obj instanceof Map) ? (Map<String, Object>) obj : null;
}
@SuppressWarnings("unchecked")
private static List<Object> safeGetList(Object obj) {
return (obj instanceof List) ? (List<Object>) obj : null;
}
}

View File

@ -1,44 +1,10 @@
package org.openmetadata.service.search;
import com.fasterxml.jackson.databind.JsonNode;
import jakarta.ws.rs.core.Response;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.openmetadata.schema.search.SearchRequest;
import org.openmetadata.service.Entity;
import org.openmetadata.service.util.JsonUtils;
@Slf4j
public class SearchUtil {
private static final List<String> IGNORE_SEARCH_KEYS =
List.of(
"id",
"version",
"updatedAt",
"updatedBy",
"usageSummary",
"followers",
"deleted",
"votes",
"lifeCycle",
"sourceHash",
"processedLineage",
"totalVotes",
"fqnParts",
"service_suggest",
"column_suggest",
"schema_suggest",
"database_suggest",
"upstreamLineage",
"entityRelationship",
"changeSummary",
"fqnHash");
/**
* Check if the index is a data asset index
* @param indexName name of the index to check
@ -141,105 +107,10 @@ public class SearchUtil {
case "glossary_search_index", Entity.GLOSSARY -> Entity.GLOSSARY;
case "domain_search_index", Entity.DOMAIN -> Entity.DOMAIN;
case "data_product_search_index", Entity.DATA_PRODUCT -> Entity.DATA_PRODUCT;
default -> "default";
case "team_search_index", Entity.TEAM -> Entity.TEAM;
case "user_Search_index", Entity.USER -> Entity.USER;
case "dataAsset" -> "dataAsset";
default -> "dataAsset";
};
}
public static List<Object> searchMetadata(Map<String, Object> params) {
try {
LOG.info("Executing searchMetadata with params: {}", params);
String query = params.containsKey("query") ? (String) params.get("query") : "*";
int limit = 10;
if (params.containsKey("limit")) {
Object limitObj = params.get("limit");
if (limitObj instanceof Number) {
limit = ((Number) limitObj).intValue();
} else if (limitObj instanceof String) {
limit = Integer.parseInt((String) limitObj);
}
}
boolean includeDeleted = false;
if (params.containsKey("include_deleted")) {
Object deletedObj = params.get("include_deleted");
if (deletedObj instanceof Boolean) {
includeDeleted = (Boolean) deletedObj;
} else if (deletedObj instanceof String) {
includeDeleted = "true".equals(deletedObj);
}
}
String entityType =
params.containsKey("entity_type") ? (String) params.get("entity_type") : null;
String index =
(entityType != null && !entityType.isEmpty())
? mapEntityTypesToIndexNames(entityType)
: Entity.TABLE;
LOG.info(
"Search query: {}, index: {}, limit: {}, includeDeleted: {}",
query,
index,
limit,
includeDeleted);
SearchRequest searchRequest =
new SearchRequest()
.withQuery(query)
.withIndex(index)
.withSize(limit)
.withFrom(0)
.withFetchSource(true)
.withDeleted(includeDeleted);
Response response = Entity.getSearchRepository().search(searchRequest, null);
Map<String, Object> searchResponse;
if (response.getEntity() instanceof String responseStr) {
LOG.info("Search returned string response");
JsonNode jsonNode = JsonUtils.readTree(responseStr);
searchResponse = JsonUtils.convertValue(jsonNode, Map.class);
} else {
LOG.info("Search returned object response: {}", response.getEntity().getClass().getName());
searchResponse = JsonUtils.convertValue(response.getEntity(), Map.class);
}
return cleanSearchResponse(searchResponse);
} catch (Exception e) {
LOG.error("Error in searchMetadata", e);
return Collections.emptyList();
}
}
public static List<Object> cleanSearchResponse(Map<String, Object> searchResponse) {
if (searchResponse == null) return Collections.emptyList();
Map<String, Object> topHits = safeGetMap(searchResponse.get("hits"));
if (topHits == null) return Collections.emptyList();
List<Object> hits = safeGetList(topHits.get("hits"));
if (hits == null) return Collections.emptyList();
return hits.stream()
.map(SearchUtil::safeGetMap)
.filter(Objects::nonNull)
.map(
hit -> {
Map<String, Object> source = safeGetMap(hit.get("_source"));
if (source == null) return null;
IGNORE_SEARCH_KEYS.forEach(source::remove);
return source;
})
.filter(Objects::nonNull)
.collect(Collectors.toList());
}
@SuppressWarnings("unchecked")
private static Map<String, Object> safeGetMap(Object obj) {
return (obj instanceof Map) ? (Map<String, Object>) obj : null;
}
@SuppressWarnings("unchecked")
private static List<Object> safeGetList(Object obj) {
return (obj instanceof List) ? (List<Object>) obj : null;
}
}

View File

@ -73,6 +73,7 @@ public record TableIndex(Table table) implements ColumnIndex {
doc.put("processedLineage", table.getProcessedLineage());
doc.put("entityRelationship", SearchIndex.populateEntityRelationshipData(table));
doc.put("databaseSchema", getEntityWithDisplayName(table.getDatabaseSchema()));
doc.put("tableQueries", table.getTableQueries());
doc.put(
"changeSummary",
Optional.ofNullable(table.getChangeDescription())

View File

@ -37,7 +37,6 @@ import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.UUID;
@ -799,19 +798,4 @@ public final class EntityUtil {
&& changeDescription.getFieldsUpdated().isEmpty()
&& changeDescription.getFieldsDeleted().isEmpty();
}
public static Object getEntityDetails(Map<String, Object> params) {
try {
String entityType = (String) params.get("entity_type");
String fqn = (String) params.get("fqn");
LOG.info("Getting details for entity type: {}, FQN: {}", entityType, fqn);
String fields = "*";
Object entity = Entity.getEntityByName(entityType, fqn, fields, null);
return entity;
} catch (Exception e) {
LOG.error("Error getting entity details", e);
return Map.of("error", e.getMessage());
}
}
}

View File

@ -2,7 +2,7 @@
"tools": [
{
"name": "search_metadata",
"description": "Find your data and business terms in OpenMetadata. For example if the user asks to 'find tables that contain customers information', then 'customers' should be the query, and the entity_type should be 'table'.",
"description": "Find your data and business terms in OpenMetadata. For example if the user asks to 'find tables that contain customers information', then 'customers' should be the query, and the entity_type should be 'table'. Here make sure to use 'Href' is available in result to create a hyperlink to the entity in OpenMetadata.",
"parameters": {
"description": "The search query to find metadata in the OpenMetadata catalog, entity type could be table, topic etc. Limit can be used to paginate on the data.",
"type": "object",

View File

@ -1162,6 +1162,14 @@
"description": "Processed lineage for the table",
"type": "boolean",
"default": false
},
"tableQueries": {
"description": "List of queries that are used to create this table.",
"type": "array",
"items": {
"$ref": "../../type/basic.json#/definitions/sqlQuery"
},
"default": null
}
},
"required": [

View File

@ -161,7 +161,11 @@ export interface Table {
* Table Profiler Config to include or exclude columns from profiling.
*/
tableProfilerConfig?: TableProfilerConfig;
tableType?: TableType;
/**
* List of queries that are used to create this table.
*/
tableQueries?: string[];
tableType?: TableType;
/**
* Tags for this table.
*/