mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2026-01-07 21:16:45 +00:00
MCP Core Items Improvements (#21643)
* Search Util fix and added tableQueries * some json input fix * Add team and user * WIP : Add Streamable HTTP * - Add proper tools/list schema and tools/call * - auth filter exact match * - Add Tools Class to dynamically build tools * Add Origin Validation Mandate * Refactor MCP Stream * comment * Cleanups * Typo * Typo
This commit is contained in:
parent
40aba1d906
commit
dc25350ea2
@ -443,6 +443,8 @@ operationalConfig:
|
||||
mcpConfiguration:
|
||||
enabled: ${MCP_ENABLED:-true}
|
||||
path: ${MCP_PATH:-"/mcp"}
|
||||
mcpServerVersion: ${MCP_SERVER_VERSION:-"1.0.0"}
|
||||
mcpServerVersion: ${MCP_SERVER_VERSION:-"1.0.0"}
|
||||
originHeaderUri: ${OPENMETADATA_SERVER_URL:-"http://localhost"}
|
||||
|
||||
|
||||
|
||||
|
||||
@ -91,6 +91,7 @@ import org.openmetadata.service.jobs.JobHandlerRegistry;
|
||||
import org.openmetadata.service.limits.DefaultLimits;
|
||||
import org.openmetadata.service.limits.Limits;
|
||||
import org.openmetadata.service.mcp.McpServer;
|
||||
import org.openmetadata.service.mcp.tools.DefaultToolContext;
|
||||
import org.openmetadata.service.migration.Migration;
|
||||
import org.openmetadata.service.migration.MigrationValidationClient;
|
||||
import org.openmetadata.service.migration.api.MigrationWorkflow;
|
||||
@ -289,7 +290,7 @@ public class OpenMetadataApplication extends Application<OpenMetadataApplication
|
||||
OpenMetadataApplicationConfig catalogConfig, Environment environment) {
|
||||
if (catalogConfig.getMcpConfiguration() != null
|
||||
&& catalogConfig.getMcpConfiguration().isEnabled()) {
|
||||
McpServer mcpServer = new McpServer();
|
||||
McpServer mcpServer = new McpServer(new DefaultToolContext());
|
||||
mcpServer.initializeMcpServer(environment, authorizer, limits, catalogConfig);
|
||||
}
|
||||
}
|
||||
|
||||
@ -18,4 +18,7 @@ public class MCPConfiguration {
|
||||
|
||||
@JsonProperty("path")
|
||||
private String path = "/api/v1/mcp";
|
||||
|
||||
@JsonProperty("originHeaderUri")
|
||||
private String originHeaderUri = "http://localhost";
|
||||
}
|
||||
|
||||
@ -26,8 +26,10 @@ import static org.openmetadata.schema.type.Include.NON_DELETED;
|
||||
import static org.openmetadata.service.Entity.DATABASE_SCHEMA;
|
||||
import static org.openmetadata.service.Entity.FIELD_OWNERS;
|
||||
import static org.openmetadata.service.Entity.FIELD_TAGS;
|
||||
import static org.openmetadata.service.Entity.QUERY;
|
||||
import static org.openmetadata.service.Entity.TABLE;
|
||||
import static org.openmetadata.service.Entity.TEST_SUITE;
|
||||
import static org.openmetadata.service.Entity.getEntities;
|
||||
import static org.openmetadata.service.Entity.populateEntityFieldTags;
|
||||
import static org.openmetadata.service.util.EntityUtil.getLocalColumnName;
|
||||
import static org.openmetadata.service.util.FullyQualifiedName.getColumnName;
|
||||
@ -62,6 +64,7 @@ import org.openmetadata.schema.EntityInterface;
|
||||
import org.openmetadata.schema.api.data.CreateTableProfile;
|
||||
import org.openmetadata.schema.api.feed.ResolveTask;
|
||||
import org.openmetadata.schema.entity.data.DatabaseSchema;
|
||||
import org.openmetadata.schema.entity.data.Query;
|
||||
import org.openmetadata.schema.entity.data.Table;
|
||||
import org.openmetadata.schema.entity.feed.Suggestion;
|
||||
import org.openmetadata.schema.tests.CustomMetric;
|
||||
@ -179,6 +182,15 @@ public class TableRepository extends EntityRepository<Table> {
|
||||
column.setCustomMetrics(getCustomMetrics(table, column.getName()));
|
||||
}
|
||||
}
|
||||
if (fields.contains("tableQueries")) {
|
||||
List<Query> queriesEntity =
|
||||
getEntities(
|
||||
listOrEmpty(findTo(table.getId(), TABLE, Relationship.MENTIONED_IN, QUERY)),
|
||||
"id",
|
||||
ALL);
|
||||
List<String> queries = listOrEmpty(queriesEntity).stream().map(Query::getQuery).toList();
|
||||
table.setTableQueries(queries);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@ -1,9 +1,7 @@
|
||||
/*
|
||||
HttpServletSseServerTransportProvider - Jakarta servlet-based MCP server transport
|
||||
*/
|
||||
|
||||
package org.openmetadata.service.mcp;
|
||||
|
||||
import static org.openmetadata.service.mcp.McpUtils.getJsonRpcMessageWithAuthorizationParam;
|
||||
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import io.modelcontextprotocol.spec.McpError;
|
||||
@ -11,6 +9,7 @@ import io.modelcontextprotocol.spec.McpSchema;
|
||||
import io.modelcontextprotocol.spec.McpServerSession;
|
||||
import io.modelcontextprotocol.spec.McpServerTransport;
|
||||
import io.modelcontextprotocol.spec.McpServerTransportProvider;
|
||||
import io.modelcontextprotocol.util.Assert;
|
||||
import jakarta.servlet.AsyncContext;
|
||||
import jakarta.servlet.ServletException;
|
||||
import jakarta.servlet.annotation.WebServlet;
|
||||
@ -23,9 +22,10 @@ import java.io.PrintWriter;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import org.openmetadata.service.security.JwtFilter;
|
||||
import org.openmetadata.service.util.JsonUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import reactor.core.publisher.Flux;
|
||||
@ -42,171 +42,220 @@ public class HttpServletSseServerTransportProvider extends HttpServlet
|
||||
public static final String DEFAULT_SSE_ENDPOINT = "/sse";
|
||||
public static final String MESSAGE_EVENT_TYPE = "message";
|
||||
public static final String ENDPOINT_EVENT_TYPE = "endpoint";
|
||||
public static final String DEFAULT_BASE_URL = "";
|
||||
private final ObjectMapper objectMapper;
|
||||
private final String baseUrl;
|
||||
private final String messageEndpoint;
|
||||
private final String sseEndpoint;
|
||||
private final Map<String, McpServerSession> sessions;
|
||||
private final AtomicBoolean isClosing;
|
||||
private final Map<String, McpServerSession> sessions = new ConcurrentHashMap<>();
|
||||
private final AtomicBoolean isClosing = new AtomicBoolean(false);
|
||||
private McpServerSession.Factory sessionFactory;
|
||||
private ExecutorService executorService =
|
||||
Executors.newCachedThreadPool(
|
||||
r -> {
|
||||
Thread t = new Thread(r, "MCP-Worker-SSE");
|
||||
t.setDaemon(true);
|
||||
return t;
|
||||
});
|
||||
|
||||
public HttpServletSseServerTransportProvider(String messageEndpoint, String sseEndpoint) {
|
||||
this.sessions = new ConcurrentHashMap<>();
|
||||
this.isClosing = new AtomicBoolean(false);
|
||||
this.objectMapper = JsonUtils.getObjectMapper();
|
||||
public HttpServletSseServerTransportProvider(
|
||||
ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) {
|
||||
this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint);
|
||||
}
|
||||
|
||||
public HttpServletSseServerTransportProvider(
|
||||
ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) {
|
||||
this.objectMapper = objectMapper;
|
||||
this.baseUrl = baseUrl;
|
||||
this.messageEndpoint = messageEndpoint;
|
||||
this.sseEndpoint = sseEndpoint;
|
||||
}
|
||||
|
||||
public HttpServletSseServerTransportProvider(String messageEndpoint) {
|
||||
this(messageEndpoint, "/sse");
|
||||
public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) {
|
||||
this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setSessionFactory(McpServerSession.Factory sessionFactory) {
|
||||
this.sessionFactory = sessionFactory;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Mono<Void> notifyClients(String method, Map<String, Object> params) {
|
||||
if (this.sessions.isEmpty()) {
|
||||
if (sessions.isEmpty()) {
|
||||
logger.debug("No active sessions to broadcast message to");
|
||||
return Mono.empty();
|
||||
} else {
|
||||
logger.debug("Attempting to broadcast message to {} active sessions", this.sessions.size());
|
||||
return Flux.fromIterable(this.sessions.values())
|
||||
.flatMap(
|
||||
(session) ->
|
||||
session
|
||||
.sendNotification(method, params)
|
||||
.doOnError(
|
||||
(e) ->
|
||||
logger.error(
|
||||
"Failed to send message to session {}: {}",
|
||||
session.getId(),
|
||||
e.getMessage()))
|
||||
.onErrorComplete())
|
||||
.then();
|
||||
}
|
||||
|
||||
logger.debug("Attempting to broadcast message to {} active sessions", sessions.size());
|
||||
return Flux.fromIterable(sessions.values())
|
||||
.flatMap(
|
||||
session ->
|
||||
session
|
||||
.sendNotification(method, params)
|
||||
.doOnError(
|
||||
e ->
|
||||
logger.error(
|
||||
"Failed to send message to session {}: {}",
|
||||
session.getId(),
|
||||
e.getMessage()))
|
||||
.onErrorComplete())
|
||||
.then();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doGet(HttpServletRequest request, HttpServletResponse response)
|
||||
throws ServletException, IOException {
|
||||
throws IOException {
|
||||
handleSseEvent(request, response);
|
||||
}
|
||||
|
||||
private void handleSseEvent(HttpServletRequest request, HttpServletResponse response)
|
||||
throws ServletException, IOException {
|
||||
String pathInfo = request.getPathInfo();
|
||||
if (!this.sseEndpoint.contains(pathInfo)) {
|
||||
response.sendError(404);
|
||||
} else if (this.isClosing.get()) {
|
||||
response.sendError(503, "Server is shutting down");
|
||||
} else {
|
||||
response.setContentType("text/event-stream");
|
||||
response.setCharacterEncoding("UTF-8");
|
||||
response.setHeader("Cache-Control", "no-cache");
|
||||
response.setHeader("Connection", "keep-alive");
|
||||
response.setHeader("Access-Control-Allow-Origin", "*");
|
||||
String sessionId = UUID.randomUUID().toString();
|
||||
AsyncContext asyncContext = request.startAsync();
|
||||
asyncContext.setTimeout(0L);
|
||||
PrintWriter writer = response.getWriter();
|
||||
HttpServletMcpSessionTransport sessionTransport =
|
||||
new HttpServletMcpSessionTransport(sessionId, asyncContext, writer);
|
||||
McpServerSession session = this.sessionFactory.create(sessionTransport);
|
||||
this.sessions.put(sessionId, session);
|
||||
this.sendEvent(writer, "endpoint", this.messageEndpoint + "?sessionId=" + sessionId);
|
||||
throws IOException {
|
||||
String requestURI = request.getRequestURI();
|
||||
if (!requestURI.endsWith(sseEndpoint)) {
|
||||
response.sendError(HttpServletResponse.SC_NOT_FOUND);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
protected void doPost(HttpServletRequest request, HttpServletResponse response)
|
||||
throws ServletException, IOException {
|
||||
if (this.isClosing.get()) {
|
||||
response.sendError(503, "Server is shutting down");
|
||||
} else {
|
||||
String requestURI = request.getRequestURI();
|
||||
if (!requestURI.endsWith(this.messageEndpoint)) {
|
||||
response.sendError(404);
|
||||
} else {
|
||||
String sessionId = request.getParameter("sessionId");
|
||||
if (sessionId == null) {
|
||||
response.setContentType("application/json");
|
||||
response.setCharacterEncoding("UTF-8");
|
||||
response.setStatus(400);
|
||||
String jsonError =
|
||||
this.objectMapper.writeValueAsString(
|
||||
new McpError("Session ID missing in message endpoint"));
|
||||
PrintWriter writer = response.getWriter();
|
||||
writer.write(jsonError);
|
||||
writer.flush();
|
||||
} else {
|
||||
McpServerSession session = (McpServerSession) this.sessions.get(sessionId);
|
||||
if (session == null) {
|
||||
response.setContentType("application/json");
|
||||
response.setCharacterEncoding("UTF-8");
|
||||
response.setStatus(404);
|
||||
String jsonError =
|
||||
this.objectMapper.writeValueAsString(
|
||||
new McpError("Session not found: " + sessionId));
|
||||
PrintWriter writer = response.getWriter();
|
||||
writer.write(jsonError);
|
||||
writer.flush();
|
||||
} else {
|
||||
if (isClosing.get()) {
|
||||
response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down");
|
||||
return;
|
||||
}
|
||||
|
||||
response.setContentType("text/event-stream");
|
||||
response.setCharacterEncoding(UTF_8);
|
||||
response.setHeader("Cache-Control", "no-cache");
|
||||
response.setHeader("Connection", "keep-alive");
|
||||
response.setHeader("Access-Control-Allow-Origin", "*");
|
||||
|
||||
String sessionId = UUID.randomUUID().toString();
|
||||
AsyncContext asyncContext = request.startAsync();
|
||||
asyncContext.setTimeout(0);
|
||||
|
||||
PrintWriter writer = response.getWriter();
|
||||
|
||||
// Create a new session transport
|
||||
HttpServletMcpSessionTransport sessionTransport =
|
||||
new HttpServletMcpSessionTransport(sessionId, asyncContext, writer);
|
||||
|
||||
// Create a new session using the session factory
|
||||
McpServerSession session = sessionFactory.create(sessionTransport);
|
||||
this.sessions.put(sessionId, session);
|
||||
|
||||
executorService.submit(
|
||||
() -> {
|
||||
// TODO: Handle session lifecycle and keepalive
|
||||
try {
|
||||
while (sessions.containsKey(sessionId)) {
|
||||
// Send keepalive every 30 seconds
|
||||
Thread.sleep(30000);
|
||||
writer.write(": keep-alive\n\n");
|
||||
writer.flush();
|
||||
}
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
} catch (Exception e) {
|
||||
log("SSE error", e);
|
||||
} finally {
|
||||
try {
|
||||
BufferedReader reader = request.getReader();
|
||||
StringBuilder body = new StringBuilder();
|
||||
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
body.append(line);
|
||||
}
|
||||
|
||||
McpSchema.JSONRPCMessage message =
|
||||
getJsonRpcMessageWithAuthorizationParam(request, body.toString());
|
||||
session.handle(message).block();
|
||||
response.setStatus(200);
|
||||
} catch (Exception var11) {
|
||||
Exception e = var11;
|
||||
logger.error("Error processing message: {}", var11.getMessage());
|
||||
|
||||
try {
|
||||
McpError mcpError = new McpError(e.getMessage());
|
||||
response.setContentType("application/json");
|
||||
response.setCharacterEncoding("UTF-8");
|
||||
response.setStatus(500);
|
||||
String jsonError = this.objectMapper.writeValueAsString(mcpError);
|
||||
PrintWriter writer = response.getWriter();
|
||||
writer.write(jsonError);
|
||||
writer.flush();
|
||||
} catch (IOException ex) {
|
||||
logger.error("Failed to send error response: {}", ex.getMessage());
|
||||
response.sendError(500, "Error processing message");
|
||||
}
|
||||
session.closeGracefully();
|
||||
asyncContext.complete();
|
||||
} catch (Exception e) {
|
||||
log("Error closing long-lived SSE connection", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Send initial endpoint event
|
||||
this.sendEvent(
|
||||
writer,
|
||||
ENDPOINT_EVENT_TYPE,
|
||||
this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doPost(HttpServletRequest request, HttpServletResponse response)
|
||||
throws ServletException, IOException {
|
||||
|
||||
if (isClosing.get()) {
|
||||
response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down");
|
||||
return;
|
||||
}
|
||||
|
||||
String requestURI = request.getRequestURI();
|
||||
if (!requestURI.endsWith(messageEndpoint)) {
|
||||
response.sendError(HttpServletResponse.SC_NOT_FOUND);
|
||||
return;
|
||||
}
|
||||
|
||||
// Get the session ID from the request parameter
|
||||
String sessionId = request.getParameter("sessionId");
|
||||
if (sessionId == null) {
|
||||
response.setContentType(APPLICATION_JSON);
|
||||
response.setCharacterEncoding(UTF_8);
|
||||
response.setStatus(HttpServletResponse.SC_BAD_REQUEST);
|
||||
String jsonError =
|
||||
objectMapper.writeValueAsString(new McpError("Session ID missing in message endpoint"));
|
||||
PrintWriter writer = response.getWriter();
|
||||
writer.write(jsonError);
|
||||
writer.flush();
|
||||
return;
|
||||
}
|
||||
|
||||
// Get the session from the sessions map
|
||||
McpServerSession session = sessions.get(sessionId);
|
||||
if (session == null) {
|
||||
response.setContentType(APPLICATION_JSON);
|
||||
response.setCharacterEncoding(UTF_8);
|
||||
response.setStatus(HttpServletResponse.SC_NOT_FOUND);
|
||||
String jsonError =
|
||||
objectMapper.writeValueAsString(new McpError("Session not found: " + sessionId));
|
||||
PrintWriter writer = response.getWriter();
|
||||
writer.write(jsonError);
|
||||
writer.flush();
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
BufferedReader reader = request.getReader();
|
||||
StringBuilder body = new StringBuilder();
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
body.append(line);
|
||||
}
|
||||
|
||||
McpSchema.JSONRPCMessage message =
|
||||
getJsonRpcMessageWithAuthorizationParam(this.objectMapper, request, body.toString());
|
||||
|
||||
// Process the message through the session's handle method
|
||||
session.handle(message).block(); // Block for Servlet compatibility
|
||||
|
||||
response.setStatus(HttpServletResponse.SC_OK);
|
||||
} catch (Exception e) {
|
||||
logger.error("Error processing message: {}", e.getMessage());
|
||||
try {
|
||||
McpError mcpError = new McpError(e.getMessage());
|
||||
response.setContentType(APPLICATION_JSON);
|
||||
response.setCharacterEncoding(UTF_8);
|
||||
response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
|
||||
String jsonError = objectMapper.writeValueAsString(mcpError);
|
||||
PrintWriter writer = response.getWriter();
|
||||
writer.write(jsonError);
|
||||
writer.flush();
|
||||
} catch (IOException ex) {
|
||||
logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex.getMessage());
|
||||
response.sendError(
|
||||
HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Error processing message");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private McpSchema.JSONRPCMessage getJsonRpcMessageWithAuthorizationParam(
|
||||
HttpServletRequest request, String body) throws IOException {
|
||||
Map<String, Object> requestMessage = JsonUtils.getMap(JsonUtils.readTree(body));
|
||||
Map<String, Object> params = (Map<String, Object>) requestMessage.get("params");
|
||||
if (params != null) {
|
||||
Map<String, Object> arguments = (Map<String, Object>) params.get("arguments");
|
||||
if (arguments != null) {
|
||||
arguments.put("Authorization", JwtFilter.extractToken(request.getHeader("Authorization")));
|
||||
}
|
||||
}
|
||||
return McpSchema.deserializeJsonRpcMessage(
|
||||
this.objectMapper, JsonUtils.pojoToJson(requestMessage));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Mono<Void> closeGracefully() {
|
||||
this.isClosing.set(true);
|
||||
logger.debug("Initiating graceful shutdown with {} active sessions", this.sessions.size());
|
||||
return Flux.fromIterable(this.sessions.values())
|
||||
.flatMap(McpServerSession::closeGracefully)
|
||||
.then();
|
||||
isClosing.set(true);
|
||||
logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size());
|
||||
|
||||
return Flux.fromIterable(sessions.values()).flatMap(McpServerSession::closeGracefully).then();
|
||||
}
|
||||
|
||||
private void sendEvent(PrintWriter writer, String eventType, String data) throws IOException {
|
||||
@ -218,8 +267,19 @@ public class HttpServletSseServerTransportProvider extends HttpServlet
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void destroy() {
|
||||
this.closeGracefully().block();
|
||||
closeGracefully().block();
|
||||
if (executorService != null) {
|
||||
executorService.shutdown();
|
||||
try {
|
||||
if (!executorService.awaitTermination(5, TimeUnit.SECONDS)) {
|
||||
executorService.shutdownNow();
|
||||
}
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
}
|
||||
}
|
||||
super.destroy();
|
||||
}
|
||||
|
||||
@ -233,65 +293,106 @@ public class HttpServletSseServerTransportProvider extends HttpServlet
|
||||
this.sessionId = sessionId;
|
||||
this.asyncContext = asyncContext;
|
||||
this.writer = writer;
|
||||
HttpServletSseServerTransportProvider.logger.debug(
|
||||
"Session transport {} initialized with SSE writer", sessionId);
|
||||
logger.debug("Session transport {} initialized with SSE writer", sessionId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
|
||||
return Mono.fromRunnable(
|
||||
() -> {
|
||||
try {
|
||||
String jsonText =
|
||||
HttpServletSseServerTransportProvider.this.objectMapper.writeValueAsString(
|
||||
message);
|
||||
HttpServletSseServerTransportProvider.this.sendEvent(
|
||||
this.writer, "message", jsonText);
|
||||
HttpServletSseServerTransportProvider.logger.debug(
|
||||
"Message sent to session {}", this.sessionId);
|
||||
String jsonText = objectMapper.writeValueAsString(message);
|
||||
sendEvent(writer, MESSAGE_EVENT_TYPE, jsonText);
|
||||
logger.debug("Message sent to session {}", sessionId);
|
||||
} catch (Exception e) {
|
||||
HttpServletSseServerTransportProvider.logger.error(
|
||||
"Failed to send message to session {}: {}", this.sessionId, e.getMessage());
|
||||
HttpServletSseServerTransportProvider.this.sessions.remove(this.sessionId);
|
||||
this.asyncContext.complete();
|
||||
logger.error("Failed to send message to session {}: {}", sessionId, e.getMessage());
|
||||
sessions.remove(sessionId);
|
||||
asyncContext.complete();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
|
||||
return (T)
|
||||
HttpServletSseServerTransportProvider.this.objectMapper.convertValue(data, typeRef);
|
||||
return objectMapper.convertValue(data, typeRef);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Mono<Void> closeGracefully() {
|
||||
return Mono.fromRunnable(
|
||||
() -> {
|
||||
HttpServletSseServerTransportProvider.logger.debug(
|
||||
"Closing session transport: {}", this.sessionId);
|
||||
|
||||
logger.debug("Closing session transport: {}", sessionId);
|
||||
try {
|
||||
HttpServletSseServerTransportProvider.this.sessions.remove(this.sessionId);
|
||||
this.asyncContext.complete();
|
||||
HttpServletSseServerTransportProvider.logger.debug(
|
||||
"Successfully completed async context for session {}", this.sessionId);
|
||||
sessions.remove(sessionId);
|
||||
asyncContext.complete();
|
||||
logger.debug("Successfully completed async context for session {}", sessionId);
|
||||
} catch (Exception e) {
|
||||
HttpServletSseServerTransportProvider.logger.warn(
|
||||
"Failed to complete async context for session {}: {}",
|
||||
this.sessionId,
|
||||
e.getMessage());
|
||||
logger.warn(
|
||||
"Failed to complete async context for session {}: {}", sessionId, e.getMessage());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
try {
|
||||
HttpServletSseServerTransportProvider.this.sessions.remove(this.sessionId);
|
||||
this.asyncContext.complete();
|
||||
HttpServletSseServerTransportProvider.logger.debug(
|
||||
"Successfully completed async context for session {}", this.sessionId);
|
||||
sessions.remove(sessionId);
|
||||
asyncContext.complete();
|
||||
logger.debug("Successfully completed async context for session {}", sessionId);
|
||||
} catch (Exception e) {
|
||||
HttpServletSseServerTransportProvider.logger.warn(
|
||||
"Failed to complete async context for session {}: {}", this.sessionId, e.getMessage());
|
||||
logger.warn(
|
||||
"Failed to complete async context for session {}: {}", sessionId, e.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private ObjectMapper objectMapper = new ObjectMapper();
|
||||
|
||||
private String baseUrl = DEFAULT_BASE_URL;
|
||||
|
||||
private String messageEndpoint;
|
||||
|
||||
private String sseEndpoint = DEFAULT_SSE_ENDPOINT;
|
||||
|
||||
public Builder objectMapper(ObjectMapper objectMapper) {
|
||||
Assert.notNull(objectMapper, "ObjectMapper must not be null");
|
||||
this.objectMapper = objectMapper;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder baseUrl(String baseUrl) {
|
||||
Assert.notNull(baseUrl, "Base URL must not be null");
|
||||
this.baseUrl = baseUrl;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder messageEndpoint(String messageEndpoint) {
|
||||
Assert.hasText(messageEndpoint, "Message endpoint must not be empty");
|
||||
this.messageEndpoint = messageEndpoint;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder sseEndpoint(String sseEndpoint) {
|
||||
Assert.hasText(sseEndpoint, "SSE endpoint must not be empty");
|
||||
this.sseEndpoint = sseEndpoint;
|
||||
return this;
|
||||
}
|
||||
|
||||
public HttpServletSseServerTransportProvider build() {
|
||||
if (objectMapper == null) {
|
||||
throw new IllegalStateException("ObjectMapper must be set");
|
||||
}
|
||||
if (messageEndpoint == null) {
|
||||
throw new IllegalStateException("MessageEndpoint must be set");
|
||||
}
|
||||
return new HttpServletSseServerTransportProvider(
|
||||
objectMapper, baseUrl, messageEndpoint, sseEndpoint);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,971 @@
|
||||
package org.openmetadata.service.mcp;
|
||||
|
||||
import static org.openmetadata.service.mcp.McpUtils.getJsonRpcMessageWithAuthorizationParam;
|
||||
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import io.modelcontextprotocol.spec.McpSchema;
|
||||
import io.modelcontextprotocol.spec.McpServerSession;
|
||||
import io.modelcontextprotocol.spec.McpServerTransportProvider;
|
||||
import jakarta.servlet.AsyncContext;
|
||||
import jakarta.servlet.ServletException;
|
||||
import jakarta.servlet.annotation.WebServlet;
|
||||
import jakarta.servlet.http.HttpServlet;
|
||||
import jakarta.servlet.http.HttpServletRequest;
|
||||
import jakarta.servlet.http.HttpServletResponse;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.IOException;
|
||||
import java.io.PrintWriter;
|
||||
import java.security.SecureRandom;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Base64;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import lombok.Getter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.openmetadata.service.config.MCPConfiguration;
|
||||
import org.openmetadata.service.exception.BadRequestException;
|
||||
import org.openmetadata.service.limits.Limits;
|
||||
import org.openmetadata.service.mcp.tools.DefaultToolContext;
|
||||
import org.openmetadata.service.security.Authorizer;
|
||||
import org.openmetadata.service.security.JwtFilter;
|
||||
import org.openmetadata.service.util.JsonUtils;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
/**
|
||||
* MCP (Model Context Protocol) Streamable HTTP Servlet
|
||||
* This servlet implements the Streamable HTTP transport specification for MCP.
|
||||
*/
|
||||
@WebServlet(value = "/mcp", asyncSupported = true)
|
||||
@Slf4j
|
||||
public class MCPStreamableHttpServlet extends HttpServlet implements McpServerTransportProvider {
|
||||
|
||||
private static final String SESSION_HEADER = "Mcp-Session-Id";
|
||||
private static final String CONTENT_TYPE_JSON = "application/json";
|
||||
private static final String CONTENT_TYPE_SSE = "text/event-stream";
|
||||
|
||||
private final transient ObjectMapper objectMapper = new ObjectMapper();
|
||||
private final transient Map<String, MCPSession> sessions = new ConcurrentHashMap<>();
|
||||
private final transient Map<String, SSEConnection> sseConnections = new ConcurrentHashMap<>();
|
||||
private final transient SecureRandom secureRandom = new SecureRandom();
|
||||
private final transient ExecutorService executorService =
|
||||
Executors.newCachedThreadPool(
|
||||
r -> {
|
||||
Thread t = new Thread(r, "MCP-Worker-Streamable");
|
||||
t.setDaemon(true);
|
||||
return t;
|
||||
});
|
||||
private transient McpServerSession.Factory sessionFactory;
|
||||
private final transient JwtFilter jwtFilter;
|
||||
private final transient Authorizer authorizer;
|
||||
private final transient List<McpSchema.Tool> tools = new ArrayList<>();
|
||||
private final transient Limits limits;
|
||||
private final transient DefaultToolContext toolContext;
|
||||
private final transient MCPConfiguration mcpConfiguration;
|
||||
|
||||
public MCPStreamableHttpServlet(
|
||||
MCPConfiguration mcpConfiguration,
|
||||
JwtFilter jwtFilter,
|
||||
Authorizer authorizer,
|
||||
Limits limits,
|
||||
DefaultToolContext toolContext,
|
||||
List<McpSchema.Tool> tools) {
|
||||
this.mcpConfiguration = mcpConfiguration;
|
||||
this.jwtFilter = jwtFilter;
|
||||
this.authorizer = authorizer;
|
||||
this.limits = limits;
|
||||
this.toolContext = toolContext;
|
||||
this.tools.addAll(tools);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void init() throws ServletException {
|
||||
super.init();
|
||||
log("MCP Streamable HTTP Servlet initialized");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void destroy() {
|
||||
if (executorService != null) {
|
||||
executorService.shutdown();
|
||||
try {
|
||||
if (!executorService.awaitTermination(5, TimeUnit.SECONDS)) {
|
||||
executorService.shutdownNow();
|
||||
}
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
}
|
||||
}
|
||||
|
||||
// Close all SSE connections
|
||||
for (SSEConnection connection : sseConnections.values()) {
|
||||
try {
|
||||
connection.close();
|
||||
} catch (IOException e) {
|
||||
log("Error closing SSE connection", e);
|
||||
}
|
||||
}
|
||||
|
||||
super.destroy();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doPost(HttpServletRequest request, HttpServletResponse response)
|
||||
throws ServletException, IOException {
|
||||
String origin = request.getHeader("Origin");
|
||||
if (origin != null && !isValidOrigin(origin)) {
|
||||
sendError(response, HttpServletResponse.SC_FORBIDDEN, "Invalid origin");
|
||||
return;
|
||||
}
|
||||
|
||||
String acceptHeader = request.getHeader("Accept");
|
||||
if (acceptHeader == null
|
||||
|| !acceptHeader.contains(CONTENT_TYPE_JSON)
|
||||
|| !acceptHeader.contains(CONTENT_TYPE_SSE)) {
|
||||
sendError(
|
||||
response,
|
||||
HttpServletResponse.SC_BAD_REQUEST,
|
||||
"Accept header must include both application/json and text/event-stream");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
String requestBody = readRequestBody(request);
|
||||
String sessionId = request.getHeader(SESSION_HEADER);
|
||||
|
||||
// Parse JSON-RPC message(s)
|
||||
McpSchema.JSONRPCMessage message =
|
||||
getJsonRpcMessageWithAuthorizationParam(this.objectMapper, request, requestBody);
|
||||
JsonNode jsonNode = objectMapper.valueToTree(message);
|
||||
|
||||
if (jsonNode.isArray()) {
|
||||
handleBatchRequest(request, response, jsonNode, sessionId);
|
||||
} else {
|
||||
handleSingleMessage(request, response, jsonNode, sessionId);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log("Error handling POST request", e);
|
||||
sendError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Internal server error");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle a single JSON-RPC message according to MCP specification
|
||||
*/
|
||||
private void handleSingleMessage(
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response,
|
||||
JsonNode jsonMessage,
|
||||
String sessionId)
|
||||
throws IOException {
|
||||
|
||||
String method = jsonMessage.has("method") ? jsonMessage.get("method").asText() : null;
|
||||
boolean hasId = jsonMessage.has("id");
|
||||
boolean isResponse = jsonMessage.has("result") || jsonMessage.has("error");
|
||||
|
||||
// Handle initialization specially
|
||||
if ("initialize".equals(method)) {
|
||||
handleInitialize(response, jsonMessage);
|
||||
return;
|
||||
}
|
||||
|
||||
// Validate session for non-initialization requests (except for responses/notifications)
|
||||
if (sessionId != null && !sessions.containsKey(sessionId) && !isResponse) {
|
||||
sendError(response, HttpServletResponse.SC_NOT_FOUND, "Session not found");
|
||||
return;
|
||||
}
|
||||
|
||||
if (isResponse || (!hasId && method != null)) {
|
||||
// This is either a response or a notification
|
||||
handleResponseOrNotification(response, jsonMessage, sessionId, isResponse);
|
||||
} else if (hasId && method != null) {
|
||||
// This is a request - may return JSON or start SSE stream
|
||||
handleRequest(request, response, jsonMessage, sessionId);
|
||||
} else {
|
||||
sendError(response, HttpServletResponse.SC_BAD_REQUEST, "Invalid JSON-RPC message format");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle responses and notifications from client
|
||||
*/
|
||||
private void handleResponseOrNotification(
|
||||
HttpServletResponse response, JsonNode message, String sessionId, boolean isResponse)
|
||||
throws IOException {
|
||||
|
||||
try {
|
||||
if (isResponse) {
|
||||
processClientResponse(message, sessionId);
|
||||
log("Processed client response for session: " + sessionId);
|
||||
} else {
|
||||
processNotification(message, sessionId);
|
||||
log("Processed client notification for session: " + sessionId);
|
||||
}
|
||||
|
||||
response.setStatus(HttpServletResponse.SC_ACCEPTED);
|
||||
response.setContentLength(0);
|
||||
|
||||
} catch (Exception e) {
|
||||
log("Error processing response/notification", e);
|
||||
sendError(response, HttpServletResponse.SC_BAD_REQUEST, "Failed to process message");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle JSON-RPC requests from client
|
||||
*/
|
||||
private void handleRequest(
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response,
|
||||
JsonNode jsonRequest,
|
||||
String sessionId)
|
||||
throws IOException {
|
||||
|
||||
String acceptHeader = request.getHeader("Accept");
|
||||
boolean supportsSSE = acceptHeader != null && acceptHeader.contains(CONTENT_TYPE_SSE);
|
||||
|
||||
// Determine whether to use SSE stream or direct JSON response
|
||||
boolean useSSE = supportsSSE && shouldUseSSEForRequest(jsonRequest, sessionId);
|
||||
|
||||
if (useSSE) {
|
||||
// Initiate SSE stream and process request asynchronously
|
||||
startSSEStreamForRequest(request, response, jsonRequest, sessionId);
|
||||
} else {
|
||||
// Send direct JSON response
|
||||
sendDirectJSONResponse(response, jsonRequest, sessionId);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Send direct JSON response for requests
|
||||
*/
|
||||
private void sendDirectJSONResponse(
|
||||
HttpServletResponse response, JsonNode request, String sessionId) throws IOException {
|
||||
|
||||
Map<String, Object> jsonResponse = processRequest(request, sessionId);
|
||||
String responseJson = objectMapper.writeValueAsString(jsonResponse);
|
||||
|
||||
response.setContentType(CONTENT_TYPE_JSON);
|
||||
response.setStatus(HttpServletResponse.SC_OK);
|
||||
|
||||
// Include session ID in response if present
|
||||
if (sessionId != null) {
|
||||
response.setHeader(SESSION_HEADER, sessionId);
|
||||
}
|
||||
|
||||
response.getWriter().write(responseJson);
|
||||
response.getWriter().flush();
|
||||
}
|
||||
|
||||
/**
|
||||
* Start SSE stream for processing requests
|
||||
*/
|
||||
private void startSSEStreamForRequest(
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response,
|
||||
JsonNode jsonRequest,
|
||||
String sessionId)
|
||||
throws IOException {
|
||||
|
||||
// Set up SSE response headers
|
||||
response.setContentType(CONTENT_TYPE_SSE);
|
||||
response.setHeader("Cache-Control", "no-cache");
|
||||
response.setHeader("Connection", "keep-alive");
|
||||
response.setHeader("Access-Control-Allow-Origin", "*");
|
||||
|
||||
// Include session ID in response if present
|
||||
if (sessionId != null) {
|
||||
response.setHeader(SESSION_HEADER, sessionId);
|
||||
}
|
||||
|
||||
response.setStatus(HttpServletResponse.SC_OK);
|
||||
|
||||
// Start async processing
|
||||
AsyncContext asyncContext = request.startAsync();
|
||||
asyncContext.setTimeout(300000); // 5 minutes timeout
|
||||
|
||||
SSEConnection connection = new SSEConnection(response.getWriter(), sessionId);
|
||||
String connectionId = UUID.randomUUID().toString();
|
||||
sseConnections.put(connectionId, connection);
|
||||
|
||||
// Process request asynchronously
|
||||
executorService.submit(
|
||||
() -> {
|
||||
try {
|
||||
// Send server-initiated messages first
|
||||
// sendServerInitiatedMessages(connection);
|
||||
|
||||
// Process the actual request and send response
|
||||
Map<String, Object> jsonResponse = processRequest(jsonRequest, sessionId);
|
||||
String eventId = generateEventId();
|
||||
connection.sendEventWithId(eventId, objectMapper.writeValueAsString(jsonResponse));
|
||||
|
||||
// Close the stream after sending response
|
||||
connection.close();
|
||||
|
||||
} catch (Exception e) {
|
||||
log("Error in SSE stream processing for request", e);
|
||||
try {
|
||||
// Send error response before closing
|
||||
Map<String, Object> errorResponse =
|
||||
createErrorResponse(
|
||||
jsonRequest.get("id"), -32603, "Internal error: " + e.getMessage());
|
||||
connection.sendEvent(objectMapper.writeValueAsString(errorResponse));
|
||||
connection.close();
|
||||
} catch (Exception ex) {
|
||||
log("Error sending error response in SSE stream", ex);
|
||||
}
|
||||
} finally {
|
||||
try {
|
||||
asyncContext.complete();
|
||||
} catch (Exception e) {
|
||||
log("Error completing async context", e);
|
||||
}
|
||||
sseConnections.remove(connectionId);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle batch requests according to MCP specification
|
||||
*/
|
||||
private void handleBatchRequest(
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response,
|
||||
JsonNode batchRequest,
|
||||
String sessionId)
|
||||
throws IOException {
|
||||
|
||||
if (!batchRequest.isArray() || batchRequest.size() == 0) {
|
||||
sendError(response, HttpServletResponse.SC_BAD_REQUEST, "Invalid batch request format");
|
||||
return;
|
||||
}
|
||||
|
||||
// Analyze the batch to determine message types
|
||||
boolean hasRequests = false;
|
||||
boolean hasResponsesOrNotifications = false;
|
||||
|
||||
for (JsonNode message : batchRequest) {
|
||||
boolean hasId = message.has("id");
|
||||
boolean isResponse = message.has("result") || message.has("error");
|
||||
boolean hasMethod = message.has("method");
|
||||
|
||||
if (hasId && hasMethod && !isResponse) {
|
||||
hasRequests = true;
|
||||
} else if (isResponse || (!hasId && hasMethod)) {
|
||||
hasResponsesOrNotifications = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (hasRequests && hasResponsesOrNotifications) {
|
||||
sendError(
|
||||
response,
|
||||
HttpServletResponse.SC_BAD_REQUEST,
|
||||
"Batch cannot mix requests with responses/notifications");
|
||||
return;
|
||||
}
|
||||
|
||||
if (hasResponsesOrNotifications) {
|
||||
// Process responses and notifications, return 202 Accepted
|
||||
processBatchResponsesOrNotifications(batchRequest, sessionId);
|
||||
response.setStatus(HttpServletResponse.SC_ACCEPTED);
|
||||
response.setContentLength(0);
|
||||
} else if (hasRequests) {
|
||||
// Process batch requests - determine if SSE or direct JSON
|
||||
String acceptHeader = request.getHeader("Accept");
|
||||
boolean supportsSSE = acceptHeader != null && acceptHeader.contains(CONTENT_TYPE_SSE);
|
||||
|
||||
if (supportsSSE && shouldUseSSEForBatch(batchRequest, sessionId)) {
|
||||
startSSEStreamForBatchRequests(request, response, batchRequest, sessionId);
|
||||
} else {
|
||||
sendBatchJSONResponse(response, batchRequest, sessionId);
|
||||
}
|
||||
} else {
|
||||
sendError(response, HttpServletResponse.SC_BAD_REQUEST, "Invalid batch content");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process client responses (for server-initiated requests)
|
||||
*/
|
||||
private void processClientResponse(JsonNode response, String sessionId) {
|
||||
// Handle responses to server-initiated requests
|
||||
Object id = response.has("id") ? response.get("id") : null;
|
||||
LOG.info("Received client response for request ID: " + id + " (session: " + sessionId + ")");
|
||||
|
||||
// TODO: Match with pending server requests and complete them
|
||||
// This would involve maintaining a map of pending requests by ID
|
||||
}
|
||||
|
||||
/**
|
||||
* Determine if SSE should be used for a specific request
|
||||
*/
|
||||
private boolean shouldUseSSEForRequest(JsonNode request, String sessionId) {
|
||||
// TODO: This is good for now, but we can enhance this logic later, like tools/call can be long
|
||||
// running for our use case should be fine
|
||||
// Use SSE for requests that are streaming operations or have specific methods
|
||||
MCPSession session = sessions.get(sessionId);
|
||||
return session != null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Determine if SSE should be used for batch requests
|
||||
*/
|
||||
private boolean shouldUseSSEForBatch(JsonNode batchRequest, String sessionId) {
|
||||
// Use SSE for batches containing streaming operations
|
||||
for (JsonNode request : batchRequest) {
|
||||
if (shouldUseSSEForRequest(request, sessionId)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Process batch responses and notifications
|
||||
*/
|
||||
private void processBatchResponsesOrNotifications(JsonNode batch, String sessionId) {
|
||||
for (JsonNode message : batch) {
|
||||
boolean isResponse = message.has("result") || message.has("error");
|
||||
if (isResponse) {
|
||||
processClientResponse(message, sessionId);
|
||||
} else {
|
||||
processNotification(message, sessionId);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Send batch JSON response
|
||||
*/
|
||||
private void sendBatchJSONResponse(
|
||||
HttpServletResponse response, JsonNode batchRequest, String sessionId) throws IOException {
|
||||
|
||||
List<Map<String, Object>> responses = new ArrayList<>();
|
||||
|
||||
for (JsonNode request : batchRequest) {
|
||||
Map<String, Object> jsonResponse = processRequest(request, sessionId);
|
||||
responses.add(jsonResponse);
|
||||
}
|
||||
|
||||
String responseJson = objectMapper.writeValueAsString(responses);
|
||||
|
||||
response.setContentType(CONTENT_TYPE_JSON);
|
||||
response.setStatus(HttpServletResponse.SC_OK);
|
||||
|
||||
if (sessionId != null) {
|
||||
response.setHeader(SESSION_HEADER, sessionId);
|
||||
}
|
||||
|
||||
response.getWriter().write(responseJson);
|
||||
response.getWriter().flush();
|
||||
}
|
||||
|
||||
/**
|
||||
* Start SSE stream for batch requests
|
||||
*/
|
||||
private void startSSEStreamForBatchRequests(
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response,
|
||||
JsonNode batchRequest,
|
||||
String sessionId)
|
||||
throws IOException {
|
||||
|
||||
// Set up SSE response headers
|
||||
response.setContentType(CONTENT_TYPE_SSE);
|
||||
response.setHeader("Cache-Control", "no-cache");
|
||||
response.setHeader("Connection", "keep-alive");
|
||||
response.setHeader("Access-Control-Allow-Origin", "*");
|
||||
|
||||
if (sessionId != null) {
|
||||
response.setHeader(SESSION_HEADER, sessionId);
|
||||
}
|
||||
|
||||
response.setStatus(HttpServletResponse.SC_OK);
|
||||
|
||||
AsyncContext asyncContext = request.startAsync();
|
||||
asyncContext.setTimeout(300000); // 5 minutes timeout
|
||||
|
||||
SSEConnection connection = new SSEConnection(response.getWriter(), sessionId);
|
||||
String connectionId = UUID.randomUUID().toString();
|
||||
sseConnections.put(connectionId, connection);
|
||||
|
||||
// Process batch asynchronously
|
||||
executorService.submit(
|
||||
() -> {
|
||||
try {
|
||||
// Send server-initiated messages first
|
||||
// sendServerInitiatedMessages(connection);
|
||||
|
||||
// Process each request in the batch
|
||||
List<Map<String, Object>> responses = new ArrayList<>();
|
||||
for (JsonNode req : batchRequest) {
|
||||
Map<String, Object> jsonResponse = processRequest(req, sessionId);
|
||||
responses.add(jsonResponse);
|
||||
}
|
||||
|
||||
// Send batch response
|
||||
String eventId = generateEventId();
|
||||
connection.sendEventWithId(eventId, objectMapper.writeValueAsString(responses));
|
||||
|
||||
connection.close();
|
||||
|
||||
} catch (Exception e) {
|
||||
log("Error in SSE stream processing for batch", e);
|
||||
try {
|
||||
Map<String, Object> errorResponse =
|
||||
createErrorResponse(null, -32603, "Internal error: " + e.getMessage());
|
||||
connection.sendEvent(objectMapper.writeValueAsString(errorResponse));
|
||||
connection.close();
|
||||
} catch (Exception ex) {
|
||||
log("Error sending error response in batch SSE stream", ex);
|
||||
}
|
||||
} finally {
|
||||
try {
|
||||
asyncContext.complete();
|
||||
} catch (Exception e) {
|
||||
log("Error completing async context for batch", e);
|
||||
}
|
||||
sseConnections.remove(connectionId);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Create error response object
|
||||
*/
|
||||
private Map<String, Object> createErrorResponse(JsonNode id, int code, String message) {
|
||||
Map<String, Object> response = new HashMap<>();
|
||||
response.put("jsonrpc", "2.0");
|
||||
response.put("id", id);
|
||||
response.put("error", createError(code, message));
|
||||
return response;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate unique event ID for SSE
|
||||
*/
|
||||
private String generateEventId() {
|
||||
return UUID.randomUUID().toString();
|
||||
}
|
||||
|
||||
/**
|
||||
* Enhanced SSE connection class with event ID support
|
||||
*/
|
||||
public static class SSEConnection {
|
||||
private final PrintWriter writer;
|
||||
private final String sessionId;
|
||||
private volatile boolean closed = false;
|
||||
private int eventCounter = 0;
|
||||
|
||||
public SSEConnection(PrintWriter writer, String sessionId) {
|
||||
this.writer = writer;
|
||||
this.sessionId = sessionId;
|
||||
}
|
||||
|
||||
public synchronized void sendEvent(String data) throws IOException {
|
||||
if (closed) return;
|
||||
|
||||
writer.printf("id: %d%n", ++eventCounter);
|
||||
writer.printf("data: %s%n%n", data);
|
||||
writer.flush();
|
||||
|
||||
if (writer.checkError()) {
|
||||
throw new IOException("SSE write error");
|
||||
}
|
||||
}
|
||||
|
||||
public synchronized void sendEventWithId(String id, String data) throws IOException {
|
||||
if (closed) return;
|
||||
|
||||
writer.printf("id: %s%n", id);
|
||||
writer.printf("data: %s%n%n", data);
|
||||
writer.flush();
|
||||
|
||||
if (writer.checkError()) {
|
||||
throw new IOException("SSE write error");
|
||||
}
|
||||
}
|
||||
|
||||
public synchronized void sendComment(String comment) throws IOException {
|
||||
if (closed) return;
|
||||
|
||||
writer.printf(": %s%n%n", comment);
|
||||
writer.flush();
|
||||
|
||||
if (writer.checkError()) {
|
||||
throw new IOException("SSE write error");
|
||||
}
|
||||
}
|
||||
|
||||
public void close() throws IOException {
|
||||
closed = true;
|
||||
if (writer != null) {
|
||||
writer.close();
|
||||
}
|
||||
}
|
||||
|
||||
public boolean isClosed() {
|
||||
return closed;
|
||||
}
|
||||
|
||||
public String getSessionId() {
|
||||
return sessionId;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doGet(HttpServletRequest request, HttpServletResponse response)
|
||||
throws ServletException, IOException {
|
||||
|
||||
String sessionId = request.getHeader(SESSION_HEADER);
|
||||
String acceptHeader = request.getHeader("Accept");
|
||||
|
||||
if (acceptHeader == null || !acceptHeader.contains(CONTENT_TYPE_SSE)) {
|
||||
sendError(
|
||||
response,
|
||||
HttpServletResponse.SC_BAD_REQUEST,
|
||||
"Accept header must include text/event-stream");
|
||||
return;
|
||||
}
|
||||
|
||||
if (sessionId != null && !sessions.containsKey(sessionId)) {
|
||||
sendError(response, HttpServletResponse.SC_NOT_FOUND, "Session not found");
|
||||
return;
|
||||
}
|
||||
|
||||
startSSEStreamForGet(request, response, sessionId);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doDelete(HttpServletRequest request, HttpServletResponse response)
|
||||
throws ServletException, IOException {
|
||||
String sessionId = request.getHeader(SESSION_HEADER);
|
||||
removeMCPSession(sessionId);
|
||||
response.setStatus(HttpServletResponse.SC_OK);
|
||||
}
|
||||
|
||||
private void removeMCPSession(String sessionId) {
|
||||
if (sessionId == null) {
|
||||
throw BadRequestException.of("Session ID required");
|
||||
}
|
||||
|
||||
// Terminate session
|
||||
MCPSession session = sessions.remove(sessionId);
|
||||
if (session != null) {
|
||||
LOG.info("Session terminated: " + sessionId);
|
||||
}
|
||||
}
|
||||
|
||||
private void handleInitialize(HttpServletResponse response, JsonNode request) throws IOException {
|
||||
// Create new session
|
||||
String sessionId = generateSessionId();
|
||||
MCPSession session = new MCPSession(this.objectMapper, sessionId, response.getWriter());
|
||||
sessions.put(sessionId, session);
|
||||
|
||||
// Create initialize response
|
||||
Map<String, Object> jsonResponse = new HashMap<>();
|
||||
jsonResponse.put("jsonrpc", "2.0");
|
||||
jsonResponse.put("id", request.get("id"));
|
||||
|
||||
Map<String, Object> result = new HashMap<>();
|
||||
result.put("protocolVersion", "2024-11-05");
|
||||
result.put("capabilities", getServerCapabilities());
|
||||
result.put("serverInfo", getServerInfo());
|
||||
jsonResponse.put("result", result);
|
||||
|
||||
// Send response with session ID
|
||||
String responseJson = objectMapper.writeValueAsString(jsonResponse);
|
||||
response.setContentType(CONTENT_TYPE_JSON);
|
||||
response.setHeader(SESSION_HEADER, sessionId);
|
||||
response.setStatus(HttpServletResponse.SC_OK);
|
||||
response.getWriter().write(responseJson);
|
||||
|
||||
log("New session initialized: " + sessionId);
|
||||
}
|
||||
|
||||
private void startSSEStreamForGet(
|
||||
HttpServletRequest request, HttpServletResponse response, String sessionId)
|
||||
throws IOException {
|
||||
|
||||
response.setContentType(CONTENT_TYPE_SSE);
|
||||
response.setHeader("Cache-Control", "no-cache");
|
||||
response.setHeader("Connection", "keep-alive");
|
||||
response.setHeader("Access-Control-Allow-Origin", "*");
|
||||
response.setStatus(HttpServletResponse.SC_OK);
|
||||
|
||||
AsyncContext asyncContext = request.startAsync();
|
||||
asyncContext.setTimeout(0); // No timeout for long-lived connections
|
||||
|
||||
SSEConnection connection = new SSEConnection(response.getWriter(), sessionId);
|
||||
String connectionId = UUID.randomUUID().toString();
|
||||
sseConnections.put(connectionId, connection);
|
||||
|
||||
// Keep connection alive and handle server-initiated messages
|
||||
executorService.submit(
|
||||
() -> {
|
||||
try {
|
||||
while (!connection.isClosed()) {
|
||||
// Send keepalive every 30 seconds
|
||||
Thread.sleep(30000);
|
||||
connection.sendComment("keepalive");
|
||||
}
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
} catch (IOException e) {
|
||||
LOG.error("SSE connection error for connection ID: {}", connectionId, e);
|
||||
} finally {
|
||||
try {
|
||||
connection.close();
|
||||
asyncContext.complete();
|
||||
} catch (Exception e) {
|
||||
log("Error closing long-lived SSE connection", e);
|
||||
}
|
||||
sseConnections.remove(connectionId);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setSessionFactory(McpServerSession.Factory sessionFactory) {
|
||||
this.sessionFactory = sessionFactory;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Mono<Void> notifyClients(String method, Map<String, Object> params) {
|
||||
if (sessions.isEmpty()) {
|
||||
LOG.debug("No active sessions to broadcast message to");
|
||||
return Mono.empty();
|
||||
}
|
||||
|
||||
LOG.debug("Attempting to broadcast message to {} active sessions", sessions.size());
|
||||
return Flux.fromIterable(sessions.values())
|
||||
.flatMap(
|
||||
session ->
|
||||
session
|
||||
.sendNotification(method, params)
|
||||
.doOnError(
|
||||
e ->
|
||||
LOG.error(
|
||||
"Failed to send message to session {}: {}",
|
||||
session.getSessionId(),
|
||||
e.getMessage()))
|
||||
.onErrorComplete())
|
||||
.then();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Mono<Void> closeGracefully() {
|
||||
LOG.debug("Initiating graceful shutdown with {} active sessions", sessions.size());
|
||||
return Flux.fromIterable(sessions.values()).flatMap(MCPSession::closeGracefully).then();
|
||||
}
|
||||
|
||||
private static class MCPSession {
|
||||
@Getter private final String sessionId;
|
||||
private final long createdAt;
|
||||
private final Map<String, Object> state;
|
||||
private final PrintWriter outputStream;
|
||||
private final ObjectMapper objectMapper;
|
||||
|
||||
public MCPSession(ObjectMapper objectMapper, String sessionId, PrintWriter outputStream) {
|
||||
this.sessionId = sessionId;
|
||||
this.createdAt = System.currentTimeMillis();
|
||||
this.state = new ConcurrentHashMap<>();
|
||||
this.objectMapper = objectMapper;
|
||||
this.outputStream = outputStream;
|
||||
}
|
||||
|
||||
public String getSessionId() {
|
||||
return sessionId;
|
||||
}
|
||||
|
||||
public long getCreatedAt() {
|
||||
return createdAt;
|
||||
}
|
||||
|
||||
public Map<String, Object> getState() {
|
||||
return state;
|
||||
}
|
||||
|
||||
public Mono<Void> sendNotification(String method, Map<String, Object> params) {
|
||||
McpSchema.JSONRPCNotification jsonrpcNotification =
|
||||
new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, method, params);
|
||||
return Mono.fromRunnable(
|
||||
() -> {
|
||||
try {
|
||||
String json = objectMapper.writeValueAsString(jsonrpcNotification);
|
||||
outputStream.write(json);
|
||||
outputStream.write('\n');
|
||||
outputStream.flush();
|
||||
} catch (IOException e) {
|
||||
LOG.error("Failed to send message", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public Mono<Void> closeGracefully() {
|
||||
return Mono.fromRunnable(
|
||||
() -> {
|
||||
outputStream.flush();
|
||||
outputStream.close();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private Map<String, Object> processRequest(JsonNode request, String sessionId) {
|
||||
Map<String, Object> response = new HashMap<>();
|
||||
response.put("jsonrpc", "2.0");
|
||||
response.put("id", request.get("id"));
|
||||
|
||||
String method = request.get("method").asText();
|
||||
|
||||
try {
|
||||
switch (method) {
|
||||
case "ping":
|
||||
response.put("result", "pong");
|
||||
break;
|
||||
|
||||
case "resources/list":
|
||||
response.put("result", new McpSchema.ListResourcesResult(new ArrayList<>(), null));
|
||||
break;
|
||||
|
||||
case "resources/read":
|
||||
// TODO: Implement resource reading logic
|
||||
break;
|
||||
|
||||
case "tools/list":
|
||||
response.put("result", new McpSchema.ListToolsResult(tools, null));
|
||||
break;
|
||||
|
||||
case "tools/call":
|
||||
JsonNode toolParams = request.get("params");
|
||||
if (toolParams != null && toolParams.has("name")) {
|
||||
String toolName = toolParams.get("name").asText();
|
||||
JsonNode arguments = toolParams.get("arguments");
|
||||
McpSchema.Content content =
|
||||
new McpSchema.TextContent(
|
||||
JsonUtils.pojoToJson(
|
||||
toolContext.callTool(
|
||||
authorizer, jwtFilter, limits, toolName, JsonUtils.getMap(arguments))));
|
||||
response.put("result", new McpSchema.CallToolResult(List.of(content), false));
|
||||
} else {
|
||||
response.put("error", createError(-32602, "Invalid params"));
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
response.put("error", createError(-32601, "Method not found: " + method));
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log("Error processing request: " + method, e);
|
||||
response.put("error", createError(-32603, "Internal error: " + e.getMessage()));
|
||||
}
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
private void processNotification(JsonNode notification, String sessionId) {
|
||||
String method = notification.get("method").asText();
|
||||
log("Received notification: " + method + " (session: " + sessionId + ")");
|
||||
|
||||
// Handle specific notifications
|
||||
switch (method) {
|
||||
case "notifications/initialized":
|
||||
LOG.info("Client initialized for session: {}", sessionId);
|
||||
break;
|
||||
case "notifications/cancelled":
|
||||
// TODO: Check this
|
||||
LOG.info("Client sent a cancellation request for session: {}", sessionId);
|
||||
removeMCPSession(sessionId);
|
||||
break;
|
||||
default:
|
||||
log("Unknown notification: " + method);
|
||||
}
|
||||
}
|
||||
|
||||
// Utility methods
|
||||
private String generateSessionId() {
|
||||
byte[] bytes = new byte[32];
|
||||
secureRandom.nextBytes(bytes);
|
||||
return Base64.getUrlEncoder().withoutPadding().encodeToString(bytes);
|
||||
}
|
||||
|
||||
private boolean isValidOrigin(String origin) {
|
||||
if (null == origin || origin.isEmpty()) {
|
||||
return false;
|
||||
}
|
||||
return origin.startsWith(mcpConfiguration.getOriginHeaderUri());
|
||||
}
|
||||
|
||||
private String readRequestBody(HttpServletRequest request) throws IOException {
|
||||
StringBuilder body = new StringBuilder();
|
||||
try (BufferedReader reader = request.getReader()) {
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
body.append(line);
|
||||
}
|
||||
}
|
||||
return body.toString();
|
||||
}
|
||||
|
||||
private void sendError(HttpServletResponse response, int statusCode, String message)
|
||||
throws IOException {
|
||||
Map<String, Object> error = new HashMap<>();
|
||||
error.put("error", message);
|
||||
|
||||
String errorJson = objectMapper.writeValueAsString(error);
|
||||
response.setContentType(CONTENT_TYPE_JSON);
|
||||
response.setStatus(statusCode);
|
||||
response.getWriter().write(errorJson);
|
||||
}
|
||||
|
||||
private Map<String, Object> createError(int code, String message) {
|
||||
Map<String, Object> error = new HashMap<>();
|
||||
error.put("code", code);
|
||||
error.put("message", message);
|
||||
return error;
|
||||
}
|
||||
|
||||
private McpSchema.ServerCapabilities getServerCapabilities() {
|
||||
return McpSchema.ServerCapabilities.builder()
|
||||
.tools(true)
|
||||
.prompts(true)
|
||||
.resources(true, true)
|
||||
.build();
|
||||
}
|
||||
|
||||
private McpSchema.Implementation getServerInfo() {
|
||||
return new McpSchema.Implementation("OpenMetadata MCP Server - Streamable", "1.0.0");
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a server-initiated notification to a specific session
|
||||
*/
|
||||
public void sendNotificationToSession(
|
||||
String sessionId, String method, Map<String, Object> params) {
|
||||
// Find SSE connections for this session
|
||||
for (SSEConnection connection : sseConnections.values()) {
|
||||
if (sessionId.equals(connection.getSessionId())) {
|
||||
try {
|
||||
Map<String, Object> notification = new HashMap<>();
|
||||
notification.put("jsonrpc", "2.0");
|
||||
notification.put("method", method);
|
||||
if (params != null) {
|
||||
notification.put("params", params);
|
||||
}
|
||||
|
||||
connection.sendEvent(objectMapper.writeValueAsString(notification));
|
||||
} catch (IOException e) {
|
||||
log("Error sending notification to session: " + sessionId, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,33 +1,23 @@
|
||||
package org.openmetadata.service.mcp;
|
||||
|
||||
import static org.openmetadata.service.search.SearchUtil.searchMetadata;
|
||||
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import io.dropwizard.core.setup.Environment;
|
||||
import io.dropwizard.jetty.MutableServletContextHandler;
|
||||
import io.modelcontextprotocol.server.McpServerFeatures;
|
||||
import io.modelcontextprotocol.server.McpSyncServer;
|
||||
import io.modelcontextprotocol.spec.McpSchema;
|
||||
import jakarta.servlet.DispatcherType;
|
||||
import jakarta.servlet.Filter;
|
||||
import java.util.ArrayList;
|
||||
import java.util.EnumSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.eclipse.jetty.servlet.FilterHolder;
|
||||
import org.eclipse.jetty.servlet.ServletHolder;
|
||||
import org.openmetadata.common.utils.CommonUtil;
|
||||
import org.openmetadata.service.OpenMetadataApplicationConfig;
|
||||
import org.openmetadata.service.config.MCPConfiguration;
|
||||
import org.openmetadata.service.limits.Limits;
|
||||
import org.openmetadata.service.mcp.tools.GlossaryTermTool;
|
||||
import org.openmetadata.service.mcp.tools.GlossaryTool;
|
||||
import org.openmetadata.service.mcp.tools.PatchEntityTool;
|
||||
import org.openmetadata.service.security.AuthorizationException;
|
||||
import org.openmetadata.service.mcp.tools.DefaultToolContext;
|
||||
import org.openmetadata.service.security.Authorizer;
|
||||
import org.openmetadata.service.security.JwtFilter;
|
||||
import org.openmetadata.service.security.auth.CatalogSecurityContext;
|
||||
import org.openmetadata.service.util.EntityUtil;
|
||||
import org.openmetadata.service.util.JsonUtils;
|
||||
|
||||
@Slf4j
|
||||
@ -35,8 +25,11 @@ public class McpServer {
|
||||
private JwtFilter jwtFilter;
|
||||
private Authorizer authorizer;
|
||||
private Limits limits;
|
||||
protected DefaultToolContext toolContext;
|
||||
|
||||
public McpServer() {}
|
||||
public McpServer(DefaultToolContext toolContext) {
|
||||
this.toolContext = toolContext;
|
||||
}
|
||||
|
||||
public void initializeMcpServer(
|
||||
Environment environment,
|
||||
@ -47,6 +40,24 @@ public class McpServer {
|
||||
new JwtFilter(config.getAuthenticationConfiguration(), config.getAuthorizerConfiguration());
|
||||
this.authorizer = authorizer;
|
||||
this.limits = limits;
|
||||
MutableServletContextHandler contextHandler = environment.getApplicationContext();
|
||||
McpAuthFilter authFilter =
|
||||
new McpAuthFilter(
|
||||
new JwtFilter(
|
||||
config.getAuthenticationConfiguration(), config.getAuthorizerConfiguration()));
|
||||
List<McpSchema.Tool> tools = getTools();
|
||||
addSSETransport(contextHandler, authFilter, tools);
|
||||
addStreamableHttpServlet(config.getMcpConfiguration(), contextHandler, authFilter, tools);
|
||||
}
|
||||
|
||||
protected List<McpSchema.Tool> getTools() {
|
||||
return toolContext.loadToolsDefinitionsFromJson("json/data/mcp/tools.json");
|
||||
}
|
||||
|
||||
private void addSSETransport(
|
||||
MutableServletContextHandler contextHandler,
|
||||
McpAuthFilter authFilter,
|
||||
List<McpSchema.Tool> tools) {
|
||||
McpSchema.ServerCapabilities serverCapabilities =
|
||||
McpSchema.ServerCapabilities.builder()
|
||||
.tools(true)
|
||||
@ -54,152 +65,55 @@ public class McpServer {
|
||||
.resources(true, true)
|
||||
.build();
|
||||
|
||||
HttpServletSseServerTransportProvider transport =
|
||||
new HttpServletSseServerTransportProvider("/mcp/messages", "/mcp/sse");
|
||||
HttpServletSseServerTransportProvider sseTransport =
|
||||
new HttpServletSseServerTransportProvider(new ObjectMapper(), "/mcp/messages", "/mcp/sse");
|
||||
|
||||
McpSyncServer server =
|
||||
io.modelcontextprotocol.server.McpServer.sync(transport)
|
||||
.serverInfo("openmetadata-mcp", "0.1.0")
|
||||
io.modelcontextprotocol.server.McpServer.sync(sseTransport)
|
||||
.serverInfo("openmetadata-mcp-sse", "0.1.0")
|
||||
.capabilities(serverCapabilities)
|
||||
.build();
|
||||
addToolsToServer(server, tools);
|
||||
|
||||
// Add resources, prompts, and tools to the MCP server
|
||||
addTools(server);
|
||||
// SSE transport for MCP
|
||||
ServletHolder servletHolderSSE = new ServletHolder(sseTransport);
|
||||
contextHandler.addServlet(servletHolderSSE, "/mcp/*");
|
||||
|
||||
MutableServletContextHandler contextHandler = environment.getApplicationContext();
|
||||
ServletHolder servletHolder = new ServletHolder(transport);
|
||||
contextHandler.addServlet(servletHolder, "/mcp/*");
|
||||
|
||||
McpAuthFilter authFilter =
|
||||
new McpAuthFilter(
|
||||
new JwtFilter(
|
||||
config.getAuthenticationConfiguration(), config.getAuthorizerConfiguration()));
|
||||
contextHandler.addFilter(
|
||||
new FilterHolder((Filter) authFilter), "/mcp/*", EnumSet.of(DispatcherType.REQUEST));
|
||||
new FilterHolder(authFilter), "/mcp/*", EnumSet.of(DispatcherType.REQUEST));
|
||||
}
|
||||
|
||||
public void addTools(McpSyncServer server) {
|
||||
try {
|
||||
LOG.info("Loading tool definitions...");
|
||||
List<Map<String, Object>> cachedTools = loadToolsDefinitionsFromJson();
|
||||
if (cachedTools == null || cachedTools.isEmpty()) {
|
||||
LOG.error("No tool definitions were loaded!");
|
||||
throw new RuntimeException("Failed to load tool definitions");
|
||||
}
|
||||
LOG.info("Successfully loaded {} tool definitions", cachedTools.size());
|
||||
private void addStreamableHttpServlet(
|
||||
MCPConfiguration configuration,
|
||||
MutableServletContextHandler contextHandler,
|
||||
McpAuthFilter authFilter,
|
||||
List<McpSchema.Tool> tools) {
|
||||
// Streamable HTTP servlet for MCP
|
||||
MCPStreamableHttpServlet streamableHttpServlet =
|
||||
new MCPStreamableHttpServlet(
|
||||
configuration, jwtFilter, authorizer, limits, toolContext, tools);
|
||||
ServletHolder servletHolderStreamableHttp = new ServletHolder(streamableHttpServlet);
|
||||
contextHandler.addServlet(servletHolderStreamableHttp, "/mcp");
|
||||
|
||||
for (Map<String, Object> toolDef : cachedTools) {
|
||||
try {
|
||||
String name = (String) toolDef.get("name");
|
||||
String description = (String) toolDef.get("description");
|
||||
Map<String, Object> schema = JsonUtils.getMap(toolDef.get("parameters"));
|
||||
server.addTool(getTool(JsonUtils.pojoToJson(schema), name, description));
|
||||
} catch (Exception e) {
|
||||
LOG.error("Error processing tool definition: {}", toolDef, e);
|
||||
}
|
||||
}
|
||||
LOG.info("Initializing request handlers...");
|
||||
} catch (Exception e) {
|
||||
LOG.error("Error during server startup", e);
|
||||
throw new RuntimeException("Failed to start MCP server", e);
|
||||
contextHandler.addFilter(
|
||||
new FilterHolder(authFilter), "/mcp", EnumSet.of(DispatcherType.REQUEST));
|
||||
}
|
||||
|
||||
public void addToolsToServer(McpSyncServer server, List<McpSchema.Tool> tools) {
|
||||
for (McpSchema.Tool tool : tools) {
|
||||
server.addTool(getTool(tool));
|
||||
}
|
||||
}
|
||||
|
||||
protected List<Map<String, Object>> loadToolsDefinitionsFromJson() {
|
||||
String json = getJsonFromFile("json/data/mcp/tools.json");
|
||||
return loadToolDefinitionsFromJson(json);
|
||||
}
|
||||
|
||||
protected static String getJsonFromFile(String path) {
|
||||
try {
|
||||
return CommonUtil.getResourceAsStream(McpServer.class.getClassLoader(), path);
|
||||
} catch (Exception ex) {
|
||||
LOG.error("Error loading JSON file: {}", path, ex);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public List<Map<String, Object>> loadToolDefinitionsFromJson(String json) {
|
||||
try {
|
||||
LOG.info("Loaded tool definitions, content length: {}", json.length());
|
||||
LOG.info("Raw tools.json content: {}", json);
|
||||
|
||||
JsonNode toolsJson = JsonUtils.readTree(json);
|
||||
JsonNode toolsArray = toolsJson.get("tools");
|
||||
|
||||
if (toolsArray == null || !toolsArray.isArray()) {
|
||||
LOG.error("Invalid MCP tools file format. Expected 'tools' array.");
|
||||
return new ArrayList<>();
|
||||
}
|
||||
|
||||
List<Map<String, Object>> tools = new ArrayList<>();
|
||||
for (JsonNode toolNode : toolsArray) {
|
||||
String name = toolNode.get("name").asText();
|
||||
Map<String, Object> toolDef = JsonUtils.convertValue(toolNode, Map.class);
|
||||
tools.add(toolDef);
|
||||
LOG.info("Tool found: {} with definition: {}", name, toolDef);
|
||||
}
|
||||
|
||||
LOG.info("Found {} tool definitions", tools.size());
|
||||
return tools;
|
||||
} catch (Exception e) {
|
||||
LOG.error("Error loading tool definitions: {}", e.getMessage(), e);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
private McpServerFeatures.SyncToolSpecification getTool(
|
||||
String schema, String toolName, String description) {
|
||||
McpSchema.Tool tool = new McpSchema.Tool(toolName, description, schema);
|
||||
|
||||
private McpServerFeatures.SyncToolSpecification getTool(McpSchema.Tool tool) {
|
||||
return new McpServerFeatures.SyncToolSpecification(
|
||||
tool,
|
||||
(exchange, arguments) -> {
|
||||
McpSchema.Content content =
|
||||
new McpSchema.TextContent(JsonUtils.pojoToJson(runMethod(toolName, arguments)));
|
||||
new McpSchema.TextContent(
|
||||
JsonUtils.pojoToJson(
|
||||
toolContext.callTool(authorizer, jwtFilter, limits, tool.name(), arguments)));
|
||||
return new McpSchema.CallToolResult(List.of(content), false);
|
||||
});
|
||||
}
|
||||
|
||||
protected Object runMethod(String toolName, Map<String, Object> params) {
|
||||
CatalogSecurityContext securityContext =
|
||||
jwtFilter.getCatalogSecurityContext((String) params.get("Authorization"));
|
||||
LOG.info(
|
||||
"Catalog Principal: {} is trying to call the tool: {}",
|
||||
securityContext.getUserPrincipal().getName(),
|
||||
toolName);
|
||||
Object result;
|
||||
try {
|
||||
switch (toolName) {
|
||||
case "search_metadata":
|
||||
result = searchMetadata(params);
|
||||
break;
|
||||
case "get_entity_details":
|
||||
result = EntityUtil.getEntityDetails(params);
|
||||
break;
|
||||
case "create_glossary":
|
||||
result = new GlossaryTool().execute(authorizer, limits, securityContext, params);
|
||||
break;
|
||||
case "create_glossary_term":
|
||||
result = new GlossaryTermTool().execute(authorizer, limits, securityContext, params);
|
||||
break;
|
||||
case "patch_entity":
|
||||
result = new PatchEntityTool().execute(authorizer, securityContext, params);
|
||||
break;
|
||||
default:
|
||||
result = Map.of("error", "Unknown function: " + toolName);
|
||||
break;
|
||||
}
|
||||
|
||||
return result;
|
||||
} catch (AuthorizationException ex) {
|
||||
LOG.error("Authorization error: {}", ex.getMessage());
|
||||
return Map.of(
|
||||
"error", String.format("Authorization error: %s", ex.getMessage()), "statusCode", 403);
|
||||
} catch (Exception ex) {
|
||||
LOG.error("Error executing tool: {}", ex.getMessage());
|
||||
return Map.of(
|
||||
"error", String.format("Error executing tool: %s", ex.getMessage()), "statusCode", 500);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,94 @@
|
||||
package org.openmetadata.service.mcp;
|
||||
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import io.modelcontextprotocol.spec.McpSchema;
|
||||
import jakarta.servlet.http.HttpServletRequest;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.openmetadata.common.utils.CommonUtil;
|
||||
import org.openmetadata.service.security.JwtFilter;
|
||||
import org.openmetadata.service.util.JsonUtils;
|
||||
|
||||
@Slf4j
|
||||
public class McpUtils {
|
||||
|
||||
public static McpSchema.JSONRPCMessage getJsonRpcMessageWithAuthorizationParam(
|
||||
ObjectMapper objectMapper, HttpServletRequest request, String body) throws IOException {
|
||||
Map<String, Object> requestMessage = JsonUtils.getMap(JsonUtils.readTree(body));
|
||||
Map<String, Object> params = (Map<String, Object>) requestMessage.get("params");
|
||||
if (params != null) {
|
||||
Map<String, Object> arguments = (Map<String, Object>) params.get("arguments");
|
||||
if (arguments != null) {
|
||||
arguments.put("Authorization", JwtFilter.extractToken(request.getHeader("Authorization")));
|
||||
}
|
||||
}
|
||||
return McpSchema.deserializeJsonRpcMessage(objectMapper, JsonUtils.pojoToJson(requestMessage));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static List<Map<String, Object>> loadToolDefinitionsFromJson(String json) {
|
||||
try {
|
||||
LOG.info("Loaded tool definitions, content length: {}", json.length());
|
||||
LOG.info("Raw tools.json content: {}", json);
|
||||
|
||||
JsonNode toolsJson = JsonUtils.readTree(json);
|
||||
JsonNode toolsArray = toolsJson.get("tools");
|
||||
|
||||
if (toolsArray == null || !toolsArray.isArray()) {
|
||||
LOG.error("Invalid MCP tools file format. Expected 'tools' array.");
|
||||
return new ArrayList<>();
|
||||
}
|
||||
|
||||
List<Map<String, Object>> tools = new ArrayList<>();
|
||||
for (JsonNode toolNode : toolsArray) {
|
||||
String name = toolNode.get("name").asText();
|
||||
Map<String, Object> toolDef = JsonUtils.convertValue(toolNode, Map.class);
|
||||
tools.add(toolDef);
|
||||
LOG.info("Tool found: {} with definition: {}", name, toolDef);
|
||||
}
|
||||
|
||||
LOG.info("Found {} tool definitions", tools.size());
|
||||
return tools;
|
||||
} catch (Exception e) {
|
||||
LOG.error("Error loading tool definitions: {}", e.getMessage(), e);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
public static String getJsonFromFile(String path) {
|
||||
try {
|
||||
return CommonUtil.getResourceAsStream(McpServer.class.getClassLoader(), path);
|
||||
} catch (Exception ex) {
|
||||
LOG.error("Error loading JSON file: {}", path, ex);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public static List<McpSchema.Tool> getToolProperties(String jsonFilePath) {
|
||||
try {
|
||||
List<McpSchema.Tool> result = new ArrayList<>();
|
||||
String json = getJsonFromFile(jsonFilePath);
|
||||
List<Map<String, Object>> cachedTools = loadToolDefinitionsFromJson(json);
|
||||
if (cachedTools == null || cachedTools.isEmpty()) {
|
||||
LOG.error("No tool definitions were loaded!");
|
||||
throw new RuntimeException("Failed to load tool definitions");
|
||||
}
|
||||
LOG.debug("Successfully loaded {} tool definitions", cachedTools.size());
|
||||
for (int i = 0; i < cachedTools.size(); i++) {
|
||||
Map<String, Object> toolDef = cachedTools.get(i);
|
||||
String name = (String) toolDef.get("name");
|
||||
String description = (String) toolDef.get("description");
|
||||
Map<String, Object> schema = JsonUtils.getMap(toolDef.get("parameters"));
|
||||
result.add(new McpSchema.Tool(name, description, JsonUtils.pojoToJson(schema)));
|
||||
}
|
||||
return result;
|
||||
} catch (Exception e) {
|
||||
LOG.error("Error during server startup", e);
|
||||
throw new RuntimeException("Failed to start MCP server", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,75 @@
|
||||
package org.openmetadata.service.mcp.tools;
|
||||
|
||||
import static org.openmetadata.service.mcp.McpUtils.getToolProperties;
|
||||
|
||||
import io.modelcontextprotocol.spec.McpSchema;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.openmetadata.service.limits.Limits;
|
||||
import org.openmetadata.service.security.AuthorizationException;
|
||||
import org.openmetadata.service.security.Authorizer;
|
||||
import org.openmetadata.service.security.JwtFilter;
|
||||
import org.openmetadata.service.security.auth.CatalogSecurityContext;
|
||||
|
||||
@Slf4j
|
||||
public class DefaultToolContext {
|
||||
public DefaultToolContext() {}
|
||||
|
||||
/**
|
||||
* Loads tool definitions from a JSON file located at the specified path.
|
||||
* The JSON file should contain an array of tool definitions under the "tools" key.
|
||||
*
|
||||
* @return List of McpSchema.Tool objects loaded from the JSON file.
|
||||
*/
|
||||
public List<McpSchema.Tool> loadToolsDefinitionsFromJson(String toolFilePath) {
|
||||
return getToolProperties(toolFilePath);
|
||||
}
|
||||
|
||||
public Object callTool(
|
||||
Authorizer authorizer,
|
||||
JwtFilter jwtFilter,
|
||||
Limits limits,
|
||||
String toolName,
|
||||
Map<String, Object> params) {
|
||||
CatalogSecurityContext securityContext =
|
||||
jwtFilter.getCatalogSecurityContext((String) params.get("Authorization"));
|
||||
LOG.info(
|
||||
"Catalog Principal: {} is trying to call the tool: {}",
|
||||
securityContext.getUserPrincipal().getName(),
|
||||
toolName);
|
||||
Object result;
|
||||
try {
|
||||
switch (toolName) {
|
||||
case "search_metadata":
|
||||
result = new SearchMetadataTool().execute(authorizer, securityContext, params);
|
||||
break;
|
||||
case "get_entity_details":
|
||||
result = new GetEntityTool().execute(authorizer, securityContext, params);
|
||||
break;
|
||||
case "create_glossary":
|
||||
result = new GlossaryTool().execute(authorizer, limits, securityContext, params);
|
||||
break;
|
||||
case "create_glossary_term":
|
||||
result = new GlossaryTermTool().execute(authorizer, limits, securityContext, params);
|
||||
break;
|
||||
case "patch_entity":
|
||||
result = new PatchEntityTool().execute(authorizer, limits, securityContext, params);
|
||||
break;
|
||||
default:
|
||||
result = Map.of("error", "Unknown function: " + toolName);
|
||||
break;
|
||||
}
|
||||
|
||||
return result;
|
||||
} catch (AuthorizationException ex) {
|
||||
LOG.error("Authorization error: {}", ex.getMessage());
|
||||
return Map.of(
|
||||
"error", String.format("Authorization error: %s", ex.getMessage()), "statusCode", 403);
|
||||
} catch (Exception ex) {
|
||||
LOG.error("Error executing tool: {}", ex.getMessage());
|
||||
return Map.of(
|
||||
"error", String.format("Error executing tool: %s", ex.getMessage()), "statusCode", 500);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,42 @@
|
||||
package org.openmetadata.service.mcp.tools;
|
||||
|
||||
import static org.openmetadata.schema.type.MetadataOperation.VIEW_ALL;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.openmetadata.service.Entity;
|
||||
import org.openmetadata.service.limits.Limits;
|
||||
import org.openmetadata.service.security.Authorizer;
|
||||
import org.openmetadata.service.security.auth.CatalogSecurityContext;
|
||||
import org.openmetadata.service.security.policyevaluator.OperationContext;
|
||||
import org.openmetadata.service.security.policyevaluator.ResourceContext;
|
||||
import org.openmetadata.service.util.JsonUtils;
|
||||
|
||||
@Slf4j
|
||||
public class GetEntityTool implements McpTool {
|
||||
@Override
|
||||
public Map<String, Object> execute(
|
||||
Authorizer authorizer, CatalogSecurityContext securityContext, Map<String, Object> params)
|
||||
throws IOException {
|
||||
String entityType = (String) params.get("entity_type");
|
||||
String fqn = (String) params.get("fqn");
|
||||
authorizer.authorize(
|
||||
securityContext,
|
||||
new OperationContext(entityType, VIEW_ALL),
|
||||
new ResourceContext<>(entityType));
|
||||
LOG.info("Getting details for entity type: {}, FQN: {}", entityType, fqn);
|
||||
String fields = "*";
|
||||
return JsonUtils.getMap(Entity.getEntityByName(entityType, fqn, fields, null));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Object> execute(
|
||||
Authorizer authorizer,
|
||||
Limits limits,
|
||||
CatalogSecurityContext securityContext,
|
||||
Map<String, Object> params)
|
||||
throws IOException {
|
||||
throw new UnsupportedOperationException("GetEntityTool does not requires limit validation.");
|
||||
}
|
||||
}
|
||||
@ -2,7 +2,6 @@ package org.openmetadata.service.mcp.tools;
|
||||
|
||||
import static org.openmetadata.common.utils.CommonUtil.nullOrEmpty;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.openmetadata.schema.entity.data.Glossary;
|
||||
@ -59,25 +58,19 @@ public class GlossaryTermTool implements McpTool {
|
||||
limits.enforceLimits(securityContext, createResourceContext, operationContext);
|
||||
authorizer.authorize(securityContext, operationContext, createResourceContext);
|
||||
|
||||
try {
|
||||
GlossaryRepository glossaryRepository =
|
||||
(GlossaryRepository) Entity.getEntityRepository(Entity.GLOSSARY);
|
||||
Glossary glossary =
|
||||
glossaryRepository.findByNameOrNull(createGlossaryTerm.getGlossary(), Include.ALL);
|
||||
GlossaryRepository glossaryRepository =
|
||||
(GlossaryRepository) Entity.getEntityRepository(Entity.GLOSSARY);
|
||||
Glossary glossary =
|
||||
glossaryRepository.findByNameOrNull(createGlossaryTerm.getGlossary(), Include.ALL);
|
||||
|
||||
GlossaryTermRepository glossaryTermRepository =
|
||||
(GlossaryTermRepository) Entity.getEntityRepository(Entity.GLOSSARY_TERM);
|
||||
// TODO: Get the updatedBy from the tool request.
|
||||
glossaryTermRepository.prepare(glossaryTerm, nullOrEmpty(glossary));
|
||||
glossaryTermRepository.setFullyQualifiedName(glossaryTerm);
|
||||
RestUtil.PutResponse<GlossaryTerm> response =
|
||||
glossaryTermRepository.createOrUpdate(
|
||||
null, glossaryTerm, securityContext.getUserPrincipal().getName());
|
||||
return JsonUtils.convertValue(response.getEntity(), Map.class);
|
||||
} catch (Exception e) {
|
||||
Map<String, Object> error = new HashMap<>();
|
||||
error.put("error", e.getMessage());
|
||||
return error;
|
||||
}
|
||||
GlossaryTermRepository glossaryTermRepository =
|
||||
(GlossaryTermRepository) Entity.getEntityRepository(Entity.GLOSSARY_TERM);
|
||||
// TODO: Get the updatedBy from the tool request.
|
||||
glossaryTermRepository.prepare(glossaryTerm, nullOrEmpty(glossary));
|
||||
glossaryTermRepository.setFullyQualifiedName(glossaryTerm);
|
||||
RestUtil.PutResponse<GlossaryTerm> response =
|
||||
glossaryTermRepository.createOrUpdate(
|
||||
null, glossaryTerm, securityContext.getUserPrincipal().getName());
|
||||
return JsonUtils.getMap(response.getEntity());
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
package org.openmetadata.service.mcp.tools;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@ -26,7 +25,7 @@ public class GlossaryTool implements McpTool {
|
||||
@Override
|
||||
public Map<String, Object> execute(
|
||||
Authorizer authorizer, CatalogSecurityContext securityContext, Map<String, Object> params) {
|
||||
throw new UnsupportedOperationException("GlossaryTermTool requires limit validation.");
|
||||
throw new UnsupportedOperationException("GlossaryTool requires limit validation.");
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -56,21 +55,15 @@ public class GlossaryTool implements McpTool {
|
||||
limits.enforceLimits(securityContext, createResourceContext, operationContext);
|
||||
authorizer.authorize(securityContext, operationContext, createResourceContext);
|
||||
|
||||
try {
|
||||
GlossaryRepository glossaryRepository =
|
||||
(GlossaryRepository) Entity.getEntityRepository(Entity.GLOSSARY);
|
||||
GlossaryRepository glossaryRepository =
|
||||
(GlossaryRepository) Entity.getEntityRepository(Entity.GLOSSARY);
|
||||
|
||||
glossaryRepository.prepare(glossary, true);
|
||||
glossaryRepository.setFullyQualifiedName(glossary);
|
||||
RestUtil.PutResponse<Glossary> response =
|
||||
glossaryRepository.createOrUpdate(
|
||||
null, glossary, securityContext.getUserPrincipal().getName());
|
||||
return JsonUtils.convertValue(response.getEntity(), Map.class);
|
||||
} catch (Exception e) {
|
||||
Map<String, Object> error = new HashMap<>();
|
||||
error.put("error", e.getMessage());
|
||||
return error;
|
||||
}
|
||||
glossaryRepository.prepare(glossary, true);
|
||||
glossaryRepository.setFullyQualifiedName(glossary);
|
||||
RestUtil.PutResponse<Glossary> response =
|
||||
glossaryRepository.createOrUpdate(
|
||||
null, glossary, securityContext.getUserPrincipal().getName());
|
||||
return JsonUtils.convertValue(response.getEntity(), Map.class);
|
||||
}
|
||||
|
||||
public static void setReviewers(CreateGlossary entity, Map<String, Object> params) {
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
package org.openmetadata.service.mcp.tools;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
import org.openmetadata.service.limits.Limits;
|
||||
import org.openmetadata.service.security.Authorizer;
|
||||
@ -7,11 +8,13 @@ import org.openmetadata.service.security.auth.CatalogSecurityContext;
|
||||
|
||||
public interface McpTool {
|
||||
Map<String, Object> execute(
|
||||
Authorizer authorizer, CatalogSecurityContext securityContext, Map<String, Object> params);
|
||||
Authorizer authorizer, CatalogSecurityContext securityContext, Map<String, Object> params)
|
||||
throws IOException;
|
||||
|
||||
Map<String, Object> execute(
|
||||
Authorizer authorizer,
|
||||
Limits limits,
|
||||
CatalogSecurityContext securityContext,
|
||||
Map<String, Object> params);
|
||||
Map<String, Object> params)
|
||||
throws IOException;
|
||||
}
|
||||
|
||||
@ -0,0 +1,154 @@
|
||||
package org.openmetadata.service.mcp.tools;
|
||||
|
||||
import static org.openmetadata.service.search.SearchUtil.mapEntityTypesToIndexNames;
|
||||
import static org.openmetadata.service.security.DefaultAuthorizer.getSubjectContext;
|
||||
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import jakarta.ws.rs.core.Response;
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.openmetadata.schema.search.SearchRequest;
|
||||
import org.openmetadata.service.Entity;
|
||||
import org.openmetadata.service.limits.Limits;
|
||||
import org.openmetadata.service.security.Authorizer;
|
||||
import org.openmetadata.service.security.auth.CatalogSecurityContext;
|
||||
import org.openmetadata.service.security.policyevaluator.SubjectContext;
|
||||
import org.openmetadata.service.util.JsonUtils;
|
||||
|
||||
@Slf4j
|
||||
public class SearchMetadataTool implements McpTool {
|
||||
|
||||
private static final List<String> IGNORE_SEARCH_KEYS =
|
||||
List.of(
|
||||
"id",
|
||||
"version",
|
||||
"updatedAt",
|
||||
"updatedBy",
|
||||
"usageSummary",
|
||||
"followers",
|
||||
"deleted",
|
||||
"votes",
|
||||
"lifeCycle",
|
||||
"sourceHash",
|
||||
"processedLineage",
|
||||
"totalVotes",
|
||||
"fqnParts",
|
||||
"service_suggest",
|
||||
"column_suggest",
|
||||
"schema_suggest",
|
||||
"database_suggest",
|
||||
"upstreamLineage",
|
||||
"entityRelationship",
|
||||
"changeSummary",
|
||||
"fqnHash");
|
||||
|
||||
@Override
|
||||
public Map<String, Object> execute(
|
||||
Authorizer authorizer, CatalogSecurityContext securityContext, Map<String, Object> params)
|
||||
throws IOException {
|
||||
LOG.info("Executing searchMetadata with params: {}", params);
|
||||
String query = params.containsKey("query") ? (String) params.get("query") : "*";
|
||||
int limit = 10;
|
||||
if (params.containsKey("limit")) {
|
||||
Object limitObj = params.get("limit");
|
||||
if (limitObj instanceof Number) {
|
||||
limit = ((Number) limitObj).intValue();
|
||||
} else if (limitObj instanceof String) {
|
||||
limit = Integer.parseInt((String) limitObj);
|
||||
}
|
||||
}
|
||||
|
||||
boolean includeDeleted = false;
|
||||
if (params.containsKey("include_deleted")) {
|
||||
Object deletedObj = params.get("include_deleted");
|
||||
if (deletedObj instanceof Boolean) {
|
||||
includeDeleted = (Boolean) deletedObj;
|
||||
} else if (deletedObj instanceof String) {
|
||||
includeDeleted = "true".equals(deletedObj);
|
||||
}
|
||||
}
|
||||
|
||||
String entityType =
|
||||
params.containsKey("entity_type") ? (String) params.get("entity_type") : null;
|
||||
String index =
|
||||
(entityType != null && !entityType.isEmpty())
|
||||
? mapEntityTypesToIndexNames(entityType)
|
||||
: Entity.TABLE;
|
||||
|
||||
LOG.info(
|
||||
"Search query: {}, index: {}, limit: {}, includeDeleted: {}",
|
||||
query,
|
||||
index,
|
||||
limit,
|
||||
includeDeleted);
|
||||
|
||||
SearchRequest searchRequest =
|
||||
new SearchRequest()
|
||||
.withQuery(query)
|
||||
.withIndex(index)
|
||||
.withSize(limit)
|
||||
.withFrom(0)
|
||||
.withFetchSource(true)
|
||||
.withDeleted(includeDeleted);
|
||||
|
||||
SubjectContext subjectContext = getSubjectContext(securityContext);
|
||||
Response response = Entity.getSearchRepository().search(searchRequest, subjectContext);
|
||||
|
||||
Map<String, Object> searchResponse;
|
||||
if (response.getEntity() instanceof String responseStr) {
|
||||
LOG.info("Search returned string response");
|
||||
JsonNode jsonNode = JsonUtils.readTree(responseStr);
|
||||
searchResponse = JsonUtils.convertValue(jsonNode, Map.class);
|
||||
} else {
|
||||
LOG.info("Search returned object response: {}", response.getEntity().getClass().getName());
|
||||
searchResponse = JsonUtils.convertValue(response.getEntity(), Map.class);
|
||||
}
|
||||
return SearchMetadataTool.cleanSearchResponse(searchResponse);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Object> execute(
|
||||
Authorizer authorizer,
|
||||
Limits limits,
|
||||
CatalogSecurityContext securityContext,
|
||||
Map<String, Object> params) {
|
||||
throw new UnsupportedOperationException(
|
||||
"SearchMetadataTool does not support limits enforcement.");
|
||||
}
|
||||
|
||||
private static Map<String, Object> cleanSearchResponse(Map<String, Object> searchResponse) {
|
||||
if (searchResponse == null) return Collections.emptyMap();
|
||||
|
||||
Map<String, Object> topHits = safeGetMap(searchResponse.get("hits"));
|
||||
if (topHits == null) return Collections.emptyMap();
|
||||
|
||||
List<Object> hits = safeGetList(topHits.get("hits"));
|
||||
if (hits == null || hits.isEmpty()) return Collections.emptyMap();
|
||||
|
||||
for (Object hitObj : hits) {
|
||||
Map<String, Object> hit = safeGetMap(hitObj);
|
||||
if (hit == null) continue;
|
||||
|
||||
Map<String, Object> source = safeGetMap(hit.get("_source"));
|
||||
if (source == null) continue;
|
||||
|
||||
IGNORE_SEARCH_KEYS.forEach(source::remove);
|
||||
return source; // Return the first valid, cleaned _source
|
||||
}
|
||||
|
||||
return Collections.emptyMap();
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static Map<String, Object> safeGetMap(Object obj) {
|
||||
return (obj instanceof Map) ? (Map<String, Object>) obj : null;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static List<Object> safeGetList(Object obj) {
|
||||
return (obj instanceof List) ? (List<Object>) obj : null;
|
||||
}
|
||||
}
|
||||
@ -1,44 +1,10 @@
|
||||
package org.openmetadata.service.search;
|
||||
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import jakarta.ws.rs.core.Response;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.openmetadata.schema.search.SearchRequest;
|
||||
import org.openmetadata.service.Entity;
|
||||
import org.openmetadata.service.util.JsonUtils;
|
||||
|
||||
@Slf4j
|
||||
public class SearchUtil {
|
||||
|
||||
private static final List<String> IGNORE_SEARCH_KEYS =
|
||||
List.of(
|
||||
"id",
|
||||
"version",
|
||||
"updatedAt",
|
||||
"updatedBy",
|
||||
"usageSummary",
|
||||
"followers",
|
||||
"deleted",
|
||||
"votes",
|
||||
"lifeCycle",
|
||||
"sourceHash",
|
||||
"processedLineage",
|
||||
"totalVotes",
|
||||
"fqnParts",
|
||||
"service_suggest",
|
||||
"column_suggest",
|
||||
"schema_suggest",
|
||||
"database_suggest",
|
||||
"upstreamLineage",
|
||||
"entityRelationship",
|
||||
"changeSummary",
|
||||
"fqnHash");
|
||||
|
||||
/**
|
||||
* Check if the index is a data asset index
|
||||
* @param indexName name of the index to check
|
||||
@ -141,105 +107,10 @@ public class SearchUtil {
|
||||
case "glossary_search_index", Entity.GLOSSARY -> Entity.GLOSSARY;
|
||||
case "domain_search_index", Entity.DOMAIN -> Entity.DOMAIN;
|
||||
case "data_product_search_index", Entity.DATA_PRODUCT -> Entity.DATA_PRODUCT;
|
||||
default -> "default";
|
||||
case "team_search_index", Entity.TEAM -> Entity.TEAM;
|
||||
case "user_Search_index", Entity.USER -> Entity.USER;
|
||||
case "dataAsset" -> "dataAsset";
|
||||
default -> "dataAsset";
|
||||
};
|
||||
}
|
||||
|
||||
public static List<Object> searchMetadata(Map<String, Object> params) {
|
||||
try {
|
||||
LOG.info("Executing searchMetadata with params: {}", params);
|
||||
String query = params.containsKey("query") ? (String) params.get("query") : "*";
|
||||
int limit = 10;
|
||||
if (params.containsKey("limit")) {
|
||||
Object limitObj = params.get("limit");
|
||||
if (limitObj instanceof Number) {
|
||||
limit = ((Number) limitObj).intValue();
|
||||
} else if (limitObj instanceof String) {
|
||||
limit = Integer.parseInt((String) limitObj);
|
||||
}
|
||||
}
|
||||
|
||||
boolean includeDeleted = false;
|
||||
if (params.containsKey("include_deleted")) {
|
||||
Object deletedObj = params.get("include_deleted");
|
||||
if (deletedObj instanceof Boolean) {
|
||||
includeDeleted = (Boolean) deletedObj;
|
||||
} else if (deletedObj instanceof String) {
|
||||
includeDeleted = "true".equals(deletedObj);
|
||||
}
|
||||
}
|
||||
|
||||
String entityType =
|
||||
params.containsKey("entity_type") ? (String) params.get("entity_type") : null;
|
||||
String index =
|
||||
(entityType != null && !entityType.isEmpty())
|
||||
? mapEntityTypesToIndexNames(entityType)
|
||||
: Entity.TABLE;
|
||||
|
||||
LOG.info(
|
||||
"Search query: {}, index: {}, limit: {}, includeDeleted: {}",
|
||||
query,
|
||||
index,
|
||||
limit,
|
||||
includeDeleted);
|
||||
|
||||
SearchRequest searchRequest =
|
||||
new SearchRequest()
|
||||
.withQuery(query)
|
||||
.withIndex(index)
|
||||
.withSize(limit)
|
||||
.withFrom(0)
|
||||
.withFetchSource(true)
|
||||
.withDeleted(includeDeleted);
|
||||
|
||||
Response response = Entity.getSearchRepository().search(searchRequest, null);
|
||||
|
||||
Map<String, Object> searchResponse;
|
||||
if (response.getEntity() instanceof String responseStr) {
|
||||
LOG.info("Search returned string response");
|
||||
JsonNode jsonNode = JsonUtils.readTree(responseStr);
|
||||
searchResponse = JsonUtils.convertValue(jsonNode, Map.class);
|
||||
} else {
|
||||
LOG.info("Search returned object response: {}", response.getEntity().getClass().getName());
|
||||
searchResponse = JsonUtils.convertValue(response.getEntity(), Map.class);
|
||||
}
|
||||
return cleanSearchResponse(searchResponse);
|
||||
} catch (Exception e) {
|
||||
LOG.error("Error in searchMetadata", e);
|
||||
return Collections.emptyList();
|
||||
}
|
||||
}
|
||||
|
||||
public static List<Object> cleanSearchResponse(Map<String, Object> searchResponse) {
|
||||
if (searchResponse == null) return Collections.emptyList();
|
||||
|
||||
Map<String, Object> topHits = safeGetMap(searchResponse.get("hits"));
|
||||
if (topHits == null) return Collections.emptyList();
|
||||
|
||||
List<Object> hits = safeGetList(topHits.get("hits"));
|
||||
if (hits == null) return Collections.emptyList();
|
||||
|
||||
return hits.stream()
|
||||
.map(SearchUtil::safeGetMap)
|
||||
.filter(Objects::nonNull)
|
||||
.map(
|
||||
hit -> {
|
||||
Map<String, Object> source = safeGetMap(hit.get("_source"));
|
||||
if (source == null) return null;
|
||||
IGNORE_SEARCH_KEYS.forEach(source::remove);
|
||||
return source;
|
||||
})
|
||||
.filter(Objects::nonNull)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static Map<String, Object> safeGetMap(Object obj) {
|
||||
return (obj instanceof Map) ? (Map<String, Object>) obj : null;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static List<Object> safeGetList(Object obj) {
|
||||
return (obj instanceof List) ? (List<Object>) obj : null;
|
||||
}
|
||||
}
|
||||
|
||||
@ -73,6 +73,7 @@ public record TableIndex(Table table) implements ColumnIndex {
|
||||
doc.put("processedLineage", table.getProcessedLineage());
|
||||
doc.put("entityRelationship", SearchIndex.populateEntityRelationshipData(table));
|
||||
doc.put("databaseSchema", getEntityWithDisplayName(table.getDatabaseSchema()));
|
||||
doc.put("tableQueries", table.getTableQueries());
|
||||
doc.put(
|
||||
"changeSummary",
|
||||
Optional.ofNullable(table.getChangeDescription())
|
||||
|
||||
@ -37,7 +37,6 @@ import java.util.Comparator;
|
||||
import java.util.HashSet;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.TreeSet;
|
||||
import java.util.UUID;
|
||||
@ -799,19 +798,4 @@ public final class EntityUtil {
|
||||
&& changeDescription.getFieldsUpdated().isEmpty()
|
||||
&& changeDescription.getFieldsDeleted().isEmpty();
|
||||
}
|
||||
|
||||
public static Object getEntityDetails(Map<String, Object> params) {
|
||||
try {
|
||||
String entityType = (String) params.get("entity_type");
|
||||
String fqn = (String) params.get("fqn");
|
||||
|
||||
LOG.info("Getting details for entity type: {}, FQN: {}", entityType, fqn);
|
||||
String fields = "*";
|
||||
Object entity = Entity.getEntityByName(entityType, fqn, fields, null);
|
||||
return entity;
|
||||
} catch (Exception e) {
|
||||
LOG.error("Error getting entity details", e);
|
||||
return Map.of("error", e.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
"tools": [
|
||||
{
|
||||
"name": "search_metadata",
|
||||
"description": "Find your data and business terms in OpenMetadata. For example if the user asks to 'find tables that contain customers information', then 'customers' should be the query, and the entity_type should be 'table'.",
|
||||
"description": "Find your data and business terms in OpenMetadata. For example if the user asks to 'find tables that contain customers information', then 'customers' should be the query, and the entity_type should be 'table'. Here make sure to use 'Href' is available in result to create a hyperlink to the entity in OpenMetadata.",
|
||||
"parameters": {
|
||||
"description": "The search query to find metadata in the OpenMetadata catalog, entity type could be table, topic etc. Limit can be used to paginate on the data.",
|
||||
"type": "object",
|
||||
|
||||
@ -1162,6 +1162,14 @@
|
||||
"description": "Processed lineage for the table",
|
||||
"type": "boolean",
|
||||
"default": false
|
||||
},
|
||||
"tableQueries": {
|
||||
"description": "List of queries that are used to create this table.",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "../../type/basic.json#/definitions/sqlQuery"
|
||||
},
|
||||
"default": null
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
|
||||
@ -161,7 +161,11 @@ export interface Table {
|
||||
* Table Profiler Config to include or exclude columns from profiling.
|
||||
*/
|
||||
tableProfilerConfig?: TableProfilerConfig;
|
||||
tableType?: TableType;
|
||||
/**
|
||||
* List of queries that are used to create this table.
|
||||
*/
|
||||
tableQueries?: string[];
|
||||
tableType?: TableType;
|
||||
/**
|
||||
* Tags for this table.
|
||||
*/
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user