diff --git a/catalog-rest-service/src/main/java/org/openmetadata/catalog/CatalogApplication.java b/catalog-rest-service/src/main/java/org/openmetadata/catalog/CatalogApplication.java index a80b48ae193..371f17194aa 100644 --- a/catalog-rest-service/src/main/java/org/openmetadata/catalog/CatalogApplication.java +++ b/catalog-rest-service/src/main/java/org/openmetadata/catalog/CatalogApplication.java @@ -90,8 +90,6 @@ import org.openmetadata.catalog.socket.WebSocketManager; public class CatalogApplication extends Application { private Authorizer authorizer; - private SocketAddressFilter socketAddressFilter = null; - @Override public void run(CatalogApplicationConfig catalogConfig, Environment environment) throws ClassNotFoundException, IllegalAccessException, InstantiationException, NoSuchMethodException, @@ -163,7 +161,7 @@ public class CatalogApplication extends Application { FilterRegistration.Dynamic micrometerFilter = environment.servlets().addFilter("MicrometerHttpFilter", new MicrometerHttpFilter()); micrometerFilter.addMappingForUrlPatterns(EnumSet.allOf(DispatcherType.class), true, "/*"); - intializeWebsockets(environment); + intializeWebsockets(catalogConfig, environment); } @SneakyThrows @@ -216,9 +214,6 @@ public class CatalogApplication extends Application { AuthenticationConfiguration authenticationConfiguration = catalogConfig.getAuthenticationConfiguration(); // to authenticate request while opening websocket connections if (authorizerConf != null) { - if (authorizerConf.getEnableSecureSocketConnection()) { - socketAddressFilter = new SocketAddressFilter(authenticationConfiguration, authorizerConf); - } authorizer = Class.forName(authorizerConf.getClassName()).asSubclass(Authorizer.class).getConstructor().newInstance(); String filterClazzName = authorizerConf.getContainerRequestFilter(); @@ -278,15 +273,23 @@ public class CatalogApplication extends Application { environment.getApplicationContext().setErrorHandler(eph); } - private void intializeWebsockets(Environment environment) { + private void intializeWebsockets(CatalogApplicationConfig catalogConfig, Environment environment) { + SocketAddressFilter socketAddressFilter; + if (catalogConfig.getAuthorizerConfiguration() != null) { + socketAddressFilter = + new SocketAddressFilter( + catalogConfig.getAuthenticationConfiguration(), catalogConfig.getAuthorizerConfiguration()); + } else { + socketAddressFilter = new SocketAddressFilter(); + } + EngineIoServerOptions eioOptions = EngineIoServerOptions.newFromDefault(); eioOptions.setAllowedCorsOrigins(null); WebSocketManager.WebSocketManagerBuilder.build(eioOptions); environment.getApplicationContext().setContextPath("/"); - if (socketAddressFilter != null) - environment - .getApplicationContext() - .addFilter(new FilterHolder(socketAddressFilter), "/api/v1/push/feed/*", EnumSet.of(DispatcherType.REQUEST)); + environment + .getApplicationContext() + .addFilter(new FilterHolder(socketAddressFilter), "/api/v1/push/feed/*", EnumSet.of(DispatcherType.REQUEST)); environment.getApplicationContext().addServlet(new ServletHolder(new FeedServlet()), "/api/v1/push/feed/*"); // Upgrade connection to websocket from Http try { @@ -297,7 +300,7 @@ public class CatalogApplication extends Application { (servletUpgradeRequest, servletUpgradeResponse) -> new JettyWebSocketHandler(WebSocketManager.getInstance().getEngineIoServer())); } catch (ServletException ex) { - LOG.error("Websocket Upgrade Filter error : ", ex.getMessage()); + LOG.error("Websocket Upgrade Filter error : " + ex.getMessage()); } } diff --git a/catalog-rest-service/src/main/java/org/openmetadata/catalog/events/ChangeEventHandler.java b/catalog-rest-service/src/main/java/org/openmetadata/catalog/events/ChangeEventHandler.java index 254c09da0cb..9076b776d18 100644 --- a/catalog-rest-service/src/main/java/org/openmetadata/catalog/events/ChangeEventHandler.java +++ b/catalog-rest-service/src/main/java/org/openmetadata/catalog/events/ChangeEventHandler.java @@ -13,11 +13,13 @@ package org.openmetadata.catalog.events; +import static org.openmetadata.catalog.Entity.TEAM; import static org.openmetadata.catalog.type.EventType.ENTITY_DELETED; import static org.openmetadata.catalog.type.EventType.ENTITY_SOFT_DELETED; import static org.openmetadata.catalog.type.EventType.ENTITY_UPDATED; import static org.openmetadata.common.utils.CommonUtil.listOrEmpty; +import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import java.util.ArrayList; import java.util.Collections; @@ -38,10 +40,7 @@ import org.openmetadata.catalog.jdbi3.CollectionDAO; import org.openmetadata.catalog.jdbi3.FeedRepository; import org.openmetadata.catalog.resources.feeds.MessageParser.EntityLink; import org.openmetadata.catalog.socket.WebSocketManager; -import org.openmetadata.catalog.type.ChangeDescription; -import org.openmetadata.catalog.type.ChangeEvent; -import org.openmetadata.catalog.type.EntityReference; -import org.openmetadata.catalog.type.EventType; +import org.openmetadata.catalog.type.*; import org.openmetadata.catalog.util.ChangeEventParser; import org.openmetadata.catalog.util.JsonUtils; import org.openmetadata.catalog.util.RestUtil; @@ -63,6 +62,7 @@ public class ChangeEventHandler implements EventHandler { SecurityContext securityContext = requestContext.getSecurityContext(); String loggedInUserName = securityContext.getUserPrincipal().getName(); try { + handleWebSocket(responseContext); ChangeEvent changeEvent = getChangeEvent(method, responseContext); if (changeEvent == null) { return null; @@ -104,8 +104,8 @@ public class ChangeEventHandler implements EventHandler { } EntityLink about = EntityLink.parse(thread.getAbout()); feedDao.create(thread, entity.getId(), owner, about); - String json = mapper.writeValueAsString(thread); - WebSocketManager.getInstance().broadCastMessageToClients(json); + String jsonThread = mapper.writeValueAsString(thread); + WebSocketManager.getInstance().broadCastMessageToAll(WebSocketManager.feedBroadcastChannel, jsonThread); } } } @@ -115,6 +115,43 @@ public class ChangeEventHandler implements EventHandler { return null; } + private void handleWebSocket(ContainerResponseContext responseContext) { + int responseCode = responseContext.getStatus(); + if (responseCode == Status.CREATED.getStatusCode() && responseContext.getEntity().getClass().equals(Thread.class)) { + Thread thread = (Thread) responseContext.getEntity(); + try { + String jsonThread = mapper.writeValueAsString(thread); + switch (thread.getType()) { + case Task: + List assignees = thread.getTask().getAssignees(); + assignees.forEach( + (e) -> { + if (Entity.USER.equals(e.getType())) { + WebSocketManager.getInstance() + .sendToOne(e.getId(), WebSocketManager.taskBroadcastChannel, jsonThread); + } else if (Entity.TEAM.equals(e.getType())) { + // fetch all that are there in the team + List userIds = + dao.relationshipDAO() + .findTo(e.getId().toString(), TEAM, Relationship.HAS.ordinal(), Entity.USER); + WebSocketManager.getInstance() + .sendToManyWithString(userIds, WebSocketManager.taskBroadcastChannel, jsonThread); + } + }); + return; + case Conversation: + WebSocketManager.getInstance().broadCastMessageToAll(WebSocketManager.feedBroadcastChannel, jsonThread); + return; + case Announcement: + default: + return; + } + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + } + public static ChangeEvent getChangeEvent(String method, ContainerResponseContext responseContext) { // GET operations don't produce change events if (method.equals("GET")) { diff --git a/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/TeamRepository.java b/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/TeamRepository.java index 9b9de7e35b7..e74b7524cec 100644 --- a/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/TeamRepository.java +++ b/catalog-rest-service/src/main/java/org/openmetadata/catalog/jdbi3/TeamRepository.java @@ -113,6 +113,11 @@ public class TeamRepository extends EntityRepository { return EntityUtil.populateEntityReferences(userIds, Entity.USER); } + private List getUsers(UUID teamID) throws IOException { + List userIds = findTo(teamID, TEAM, Relationship.HAS, Entity.USER); + return EntityUtil.populateEntityReferences(userIds, Entity.USER); + } + private List getOwns(Team team) throws IOException { // Compile entities owned by the team return EntityUtil.getEntityReferences( diff --git a/catalog-rest-service/src/main/java/org/openmetadata/catalog/socket/SocketAddressFilter.java b/catalog-rest-service/src/main/java/org/openmetadata/catalog/socket/SocketAddressFilter.java index 352d4dfe2d1..117e3c36ced 100644 --- a/catalog-rest-service/src/main/java/org/openmetadata/catalog/socket/SocketAddressFilter.java +++ b/catalog-rest-service/src/main/java/org/openmetadata/catalog/socket/SocketAddressFilter.java @@ -15,6 +15,7 @@ package org.openmetadata.catalog.socket; import com.auth0.jwt.interfaces.Claim; import com.auth0.jwt.interfaces.DecodedJWT; +import io.socket.engineio.server.utils.ParseQS; import java.io.IOException; import java.util.Map; import java.util.TreeMap; @@ -35,9 +36,18 @@ public class SocketAddressFilter implements Filter { private static final Logger LOG = LoggerFactory.getLogger(SocketAddressFilter.class); private JwtFilter jwtFilter; + private final boolean enableSecureSocketConnection; + public SocketAddressFilter( AuthenticationConfiguration authenticationConfiguration, AuthorizerConfiguration authorizerConf) { - jwtFilter = new JwtFilter(authenticationConfiguration, authorizerConf); + enableSecureSocketConnection = authorizerConf.getEnableSecureSocketConnection(); + if (enableSecureSocketConnection) { + jwtFilter = new JwtFilter(authenticationConfiguration, authorizerConf); + } + } + + public SocketAddressFilter() { + enableSecureSocketConnection = false; } @Override @@ -47,19 +57,23 @@ public class SocketAddressFilter implements Filter { public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { HttpServletRequest httpServletRequest = (HttpServletRequest) request; - String tokenWithType = httpServletRequest.getHeader("Authorization"); + Map query = ParseQS.decode(httpServletRequest.getQueryString()); HeaderRequestWrapper requestWrapper = new HeaderRequestWrapper(httpServletRequest); - requestWrapper.addHeader("RemoteAddress", request.getRemoteAddr()); - requestWrapper.addHeader("Authorization", tokenWithType); + requestWrapper.addHeader("RemoteAddress", httpServletRequest.getRemoteAddr()); + requestWrapper.addHeader("UserId", query.get("userId")); - String token = JwtFilter.extractToken(tokenWithType); - // validate token - DecodedJWT jwt = jwtFilter.validateAndReturnDecodedJwtToken(token); - // validate Domain and Username - Map claims = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); - claims.putAll(jwt.getClaims()); - jwtFilter.validateAndReturnUsername(claims); + if (enableSecureSocketConnection) { + String tokenWithType = httpServletRequest.getHeader("Authorization"); + requestWrapper.addHeader("Authorization", tokenWithType); + String token = JwtFilter.extractToken(tokenWithType); + // validate token + DecodedJWT jwt = jwtFilter.validateAndReturnDecodedJwtToken(token); + // validate Domain and Username + Map claims = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); + claims.putAll(jwt.getClaims()); + jwtFilter.validateAndReturnUsername(claims); + } // Goes to default servlet. chain.doFilter(requestWrapper, response); } diff --git a/catalog-rest-service/src/main/java/org/openmetadata/catalog/socket/WebSocketManager.java b/catalog-rest-service/src/main/java/org/openmetadata/catalog/socket/WebSocketManager.java index 4dd260e6e9d..f1c0a070369 100644 --- a/catalog-rest-service/src/main/java/org/openmetadata/catalog/socket/WebSocketManager.java +++ b/catalog-rest-service/src/main/java/org/openmetadata/catalog/socket/WebSocketManager.java @@ -5,7 +5,9 @@ import io.socket.engineio.server.EngineIoServerOptions; import io.socket.socketio.server.SocketIoNamespace; import io.socket.socketio.server.SocketIoServer; import io.socket.socketio.server.SocketIoSocket; +import java.util.List; import java.util.Map; +import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -15,8 +17,10 @@ public class WebSocketManager { private static WebSocketManager INSTANCE; private final EngineIoServer mEngineIoServer; private final SocketIoServer mSocketIoServer; - private final String feedBroadcastChannel = "activityFeed"; - private final Map activityFeedEndpoints = new ConcurrentHashMap<>(); + public static final String feedBroadcastChannel = "activityFeed"; + public static final String taskBroadcastChannel = "taskChannel"; + + private final Map activityFeedEndpoints = new ConcurrentHashMap<>(); private WebSocketManager(EngineIoServerOptions eiOptions) { mEngineIoServer = new EngineIoServer(eiOptions); @@ -31,13 +35,30 @@ public class WebSocketManager { "connection", args -> { SocketIoSocket socket = (SocketIoSocket) args[0]; - LOG.info( - "Client :" - + socket.getId() - + "with Remote Address :" - + socket.getInitialHeaders().get("RemoteAddress") - + "connected."); - activityFeedEndpoints.put(socket.getId(), socket); + String userId = socket.getInitialHeaders().get("UserId").get(0); + if (userId != null && !userId.equals("")) { + LOG.info( + "Client :" + + userId + + "with Remote Address :" + + socket.getInitialHeaders().get("RemoteAddress") + + "connected." + + socket.getInitialQuery()); + + // On Socket Disconnect + socket.on( + "disconnect", + args1 -> { + LOG.info( + "Client :" + + userId + + "with Remote Address :" + + socket.getInitialHeaders().get("RemoteAddress") + + " disconnected."); + activityFeedEndpoints.remove(UUID.fromString(userId)); + }); + activityFeedEndpoints.put(UUID.fromString(userId), socket); + } }); } @@ -53,16 +74,39 @@ public class WebSocketManager { return mEngineIoServer; } - public Map getActivityFeedEndpoints() { + public Map getActivityFeedEndpoints() { return activityFeedEndpoints; } - public void broadCastMessageToClients(String message) { - for (Map.Entry endpoints : activityFeedEndpoints.entrySet()) { - endpoints.getValue().send(feedBroadcastChannel, message); + public void broadCastMessageToAll(String event, String message) { + activityFeedEndpoints.forEach((key, value) -> value.send(event, message)); + } + + public void sendToOne(UUID receiver, String event, String message) { + if (activityFeedEndpoints.containsKey(receiver)) { + activityFeedEndpoints.get(receiver).send(event, message); } } + public void sendToManyWithUUID(List receivers, String event, String message) { + receivers.forEach( + (e) -> { + if (activityFeedEndpoints.containsKey(e)) { + activityFeedEndpoints.get(e).send(event, message); + } + }); + } + + public void sendToManyWithString(List receivers, String event, String message) { + receivers.forEach( + (e) -> { + UUID key = UUID.fromString(e); + if (activityFeedEndpoints.containsKey(key)) { + activityFeedEndpoints.get(key).send(event, message); + } + }); + } + public static class WebSocketManagerBuilder { public static void build(EngineIoServerOptions eiOptions) { INSTANCE = new WebSocketManager(eiOptions);