diff --git a/catalog-rest-service/src/main/java/org/openmetadata/catalog/socket/WebSocketManager.java b/catalog-rest-service/src/main/java/org/openmetadata/catalog/socket/WebSocketManager.java index f1c0a070369..61c4a35c11d 100644 --- a/catalog-rest-service/src/main/java/org/openmetadata/catalog/socket/WebSocketManager.java +++ b/catalog-rest-service/src/main/java/org/openmetadata/catalog/socket/WebSocketManager.java @@ -5,9 +5,7 @@ import io.socket.engineio.server.EngineIoServerOptions; import io.socket.socketio.server.SocketIoNamespace; import io.socket.socketio.server.SocketIoServer; import io.socket.socketio.server.SocketIoSocket; -import java.util.List; -import java.util.Map; -import java.util.UUID; +import java.util.*; import java.util.concurrent.ConcurrentHashMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -19,8 +17,7 @@ public class WebSocketManager { private final SocketIoServer mSocketIoServer; public static final String feedBroadcastChannel = "activityFeed"; public static final String taskBroadcastChannel = "taskChannel"; - - private final Map activityFeedEndpoints = new ConcurrentHashMap<>(); + private final Map> activityFeedEndpoints = new ConcurrentHashMap<>(); private WebSocketManager(EngineIoServerOptions eiOptions) { mEngineIoServer = new EngineIoServer(eiOptions); @@ -50,14 +47,22 @@ public class WebSocketManager { "disconnect", args1 -> { LOG.info( - "Client :" + "Client from:" + userId + "with Remote Address :" + socket.getInitialHeaders().get("RemoteAddress") + " disconnected."); - activityFeedEndpoints.remove(UUID.fromString(userId)); + UUID id = UUID.fromString(userId); + Map allUserConnection = activityFeedEndpoints.get(id); + allUserConnection.remove(socket.getId()); + activityFeedEndpoints.put(id, allUserConnection); }); - activityFeedEndpoints.put(UUID.fromString(userId), socket); + UUID id = UUID.fromString(userId); + Map userSocketConnections; + userSocketConnections = + activityFeedEndpoints.containsKey(id) ? activityFeedEndpoints.get(id) : new HashMap<>(); + userSocketConnections.put(socket.getId(), socket); + activityFeedEndpoints.put(id, userSocketConnections); } }); } @@ -74,26 +79,27 @@ public class WebSocketManager { return mEngineIoServer; } - public Map getActivityFeedEndpoints() { + public Map> getActivityFeedEndpoints() { return activityFeedEndpoints; } public void broadCastMessageToAll(String event, String message) { - activityFeedEndpoints.forEach((key, value) -> value.send(event, message)); + activityFeedEndpoints.forEach( + (key, value) -> { + value.forEach((key1, value1) -> value1.send(event, message)); + }); } public void sendToOne(UUID receiver, String event, String message) { if (activityFeedEndpoints.containsKey(receiver)) { - activityFeedEndpoints.get(receiver).send(event, message); + activityFeedEndpoints.get(receiver).forEach((key, value) -> value.send(event, message)); } } public void sendToManyWithUUID(List receivers, String event, String message) { receivers.forEach( (e) -> { - if (activityFeedEndpoints.containsKey(e)) { - activityFeedEndpoints.get(e).send(event, message); - } + sendToOne(e, event, message); }); } @@ -101,9 +107,7 @@ public class WebSocketManager { receivers.forEach( (e) -> { UUID key = UUID.fromString(e); - if (activityFeedEndpoints.containsKey(key)) { - activityFeedEndpoints.get(key).send(event, message); - } + sendToOne(key, event, message); }); }