diff --git a/frontend/rust-lib/flowy-document/src/core/edit/queue.rs b/frontend/rust-lib/flowy-document/src/core/edit/queue.rs index c22d6ed29f..33087ee1c4 100644 --- a/frontend/rust-lib/flowy-document/src/core/edit/queue.rs +++ b/frontend/rust-lib/flowy-document/src/core/edit/queue.rs @@ -15,7 +15,7 @@ use lib_ot::{ core::{Interval, OperationTransformable}, rich_text::{RichTextAttribute, RichTextDelta}, }; -use std::{cell::Cell, sync::Arc}; +use std::sync::Arc; use tokio::sync::{oneshot, RwLock}; // The EditorCommandQueue executes each command that will alter the document in diff --git a/frontend/rust-lib/flowy-net/src/ws/local/local_ws.rs b/frontend/rust-lib/flowy-net/src/ws/local/local_ws.rs index 9608de4b03..e0578e13df 100644 --- a/frontend/rust-lib/flowy-net/src/ws/local/local_ws.rs +++ b/frontend/rust-lib/flowy-net/src/ws/local/local_ws.rs @@ -15,9 +15,13 @@ use tokio::sync::{broadcast, broadcast::Receiver, mpsc, mpsc::UnboundedReceiver} pub struct LocalWebSocket { receivers: Arc>>, state_sender: broadcast::Sender, + // LocalWSSender uses the mpsc::channel sender to simulate the web socket. It spawns a receiver that uses the + // LocalDocumentServer to handle the message. The server will send the WebSocketRawMessage messages that will + // be handled by the WebSocketRawMessage receivers. ws_sender: LocalWSSender, - server: Arc, - server_rx: RwLock>>, + local_server: Arc, + local_server_rx: RwLock>>, + local_server_stop_tx: RwLock>>, user_id: Arc>>, } @@ -28,67 +32,91 @@ impl std::default::Default for LocalWebSocket { let receivers = Arc::new(DashMap::new()); let (server_tx, server_rx) = mpsc::unbounded_channel(); - let server = Arc::new(LocalDocumentServer::new(server_tx)); - let server_rx = RwLock::new(Some(server_rx)); - let user_token = Arc::new(RwLock::new(None)); + let local_server = Arc::new(LocalDocumentServer::new(server_tx)); + let local_server_rx = RwLock::new(Some(server_rx)); + let local_server_stop_tx = RwLock::new(None); + let user_id = Arc::new(RwLock::new(None)); LocalWebSocket { receivers, state_sender, ws_sender, - server, - server_rx, - user_id: user_token, + local_server, + local_server_rx, + local_server_stop_tx, + user_id, } } } impl LocalWebSocket { - fn spawn_client(&self, _addr: String) { + fn restart_ws_receiver(&self) -> mpsc::Receiver<()> { + if let Some(stop_tx) = self.local_server_stop_tx.read().clone() { + tokio::spawn(async move { + let _ = stop_tx.send(()).await; + }); + } + let (stop_tx, stop_rx) = mpsc::channel::<()>(1); + *self.local_server_stop_tx.write() = Some(stop_tx); + stop_rx + } + + fn spawn_client_ws_receiver(&self, _addr: String) { let mut ws_receiver = self.ws_sender.subscribe(); - let local_server = self.server.clone(); + let local_server = self.local_server.clone(); let user_id = self.user_id.clone(); + let mut stop_rx = self.restart_ws_receiver(); tokio::spawn(async move { loop { - // Polling the web socket message sent by user - match ws_receiver.recv().await { - Ok(message) => { - let user_id = user_id.read().clone(); - if user_id.is_none() { - continue; - } - let user_id = user_id.unwrap(); - let server = local_server.clone(); - let fut = || async move { - let bytes = Bytes::from(message.data); - let client_data = DocumentClientWSData::try_from(bytes).map_err(internal_error)?; - let _ = server - .handle_client_data(client_data, user_id) - .await - .map_err(internal_error)?; - Ok::<(), FlowyError>(()) - }; - match fut().await { - Ok(_) => {}, - Err(e) => tracing::error!("[LocalWebSocket] error: {:?}", e), + tokio::select! { + result = ws_receiver.recv() => { + match result { + Ok(message) => { + let user_id = user_id.read().clone(); + handle_ws_raw_message(user_id, &local_server, message).await; + }, + Err(e) => tracing::error!("[LocalWebSocket] error: {}", e), } + } + _ = stop_rx.recv() => { + break }, - Err(e) => tracing::error!("[LocalWebSocket] error: {}", e), } } }); } } +async fn handle_ws_raw_message( + user_id: Option, + local_server: &Arc, + message: WebSocketRawMessage, +) { + let f = || async { + match user_id { + None => Ok(()), + Some(user_id) => { + let bytes = Bytes::from(message.data); + let client_data = DocumentClientWSData::try_from(bytes).map_err(internal_error)?; + let _ = local_server.handle_client_data(client_data, user_id).await?; + Ok::<(), FlowyError>(()) + }, + } + }; + if let Err(e) = f().await { + tracing::error!("[LocalWebSocket] error: {:?}", e); + } +} + impl FlowyRawWebSocket for LocalWebSocket { fn initialize(&self) -> FutureResult<(), FlowyError> { - let mut server_rx = self.server_rx.write().take().expect("Only take once"); + let mut server_rx = self.local_server_rx.write().take().expect("Only take once"); let receivers = self.receivers.clone(); tokio::spawn(async move { while let Some(message) = server_rx.recv().await { match receivers.get(&message.module) { None => tracing::error!("Can't find any handler for message: {:?}", message), - Some(handler) => handler.receive_message(message.clone()), + Some(receiver) => receiver.receive_message(message.clone()), } } }); @@ -97,7 +125,7 @@ impl FlowyRawWebSocket for LocalWebSocket { fn start_connect(&self, addr: String, user_id: String) -> FutureResult<(), FlowyError> { *self.user_id.write() = Some(user_id); - self.spawn_client(addr); + self.spawn_client_ws_receiver(addr); FutureResult::new(async { Ok(()) }) }