280 lines
9.1 KiB
Rust
Raw Normal View History

2021-09-19 23:21:10 +08:00
use crate::{
2021-09-30 17:24:02 +08:00
connect::{WsConnectionFuture, WsStream},
2021-09-19 23:21:10 +08:00
errors::WsError,
WsMessage,
WsModule,
2021-09-19 23:21:10 +08:00
};
2021-09-24 16:02:17 +08:00
use bytes::Bytes;
2021-09-30 17:24:02 +08:00
use dashmap::DashMap;
use flowy_net::errors::{internal_error, ServerError};
2021-09-16 18:31:25 +08:00
use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender};
2021-09-19 23:21:10 +08:00
use futures_core::{future::BoxFuture, ready, Stream};
2021-09-30 17:24:02 +08:00
use parking_lot::RwLock;
2021-09-16 18:31:25 +08:00
use pin_project::pin_project;
use std::{
2021-09-18 22:32:00 +08:00
collections::HashMap,
convert::TryFrom,
2021-09-16 18:31:25 +08:00
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
2021-09-30 17:24:02 +08:00
use tokio::{
sync::{broadcast, oneshot},
task::JoinHandle,
};
2021-09-19 23:21:10 +08:00
use tokio_tungstenite::tungstenite::{
protocol::{frame::coding::CloseCode, CloseFrame},
Message,
2021-09-19 18:39:56 +08:00
};
2021-09-16 23:07:15 +08:00
2021-09-16 18:31:25 +08:00
pub type MsgReceiver = UnboundedReceiver<Message>;
pub type MsgSender = UnboundedSender<Message>;
2021-09-30 17:24:02 +08:00
type Handlers = DashMap<WsModule, Arc<dyn WsMessageHandler>>;
2021-09-16 18:31:25 +08:00
pub trait WsMessageHandler: Sync + Send + 'static {
fn source(&self) -> WsModule;
2021-09-18 22:32:00 +08:00
fn receive_message(&self, msg: WsMessage);
2021-09-16 18:31:25 +08:00
}
2021-09-19 18:39:56 +08:00
type NotifyCallback = Arc<dyn Fn(&WsState) + Send + Sync + 'static>;
struct WsStateNotify {
#[allow(dead_code)]
state: WsState,
callback: Option<NotifyCallback>,
}
impl WsStateNotify {
fn update_state(&mut self, state: WsState) {
if let Some(f) = &self.callback {
f(&state);
}
self.state = state;
}
}
2021-09-30 17:24:02 +08:00
#[derive(Clone)]
2021-09-19 18:39:56 +08:00
pub enum WsState {
Init,
Connected(Arc<WsSender>),
Disconnected(WsError),
}
2021-09-16 18:31:25 +08:00
pub struct WsController {
2021-09-30 17:24:02 +08:00
handlers: Handlers,
state_notify: Arc<broadcast::Sender<WsState>>,
sender: RwLock<Option<Arc<WsSender>>>,
2021-09-16 18:31:25 +08:00
}
impl WsController {
pub fn new() -> Self {
2021-09-30 17:24:02 +08:00
let (state_notify, _) = broadcast::channel(16);
2021-09-16 23:07:15 +08:00
let controller = Self {
2021-09-30 17:24:02 +08:00
handlers: DashMap::new(),
sender: RwLock::new(None),
state_notify: Arc::new(state_notify),
2021-09-16 23:07:15 +08:00
};
controller
2021-09-16 18:31:25 +08:00
}
2021-09-30 17:24:02 +08:00
pub fn add_handler(&self, handler: Arc<dyn WsMessageHandler>) -> Result<(), WsError> {
2021-09-18 22:32:00 +08:00
let source = handler.source();
if self.handlers.contains_key(&source) {
2021-09-23 17:50:28 +08:00
log::error!("WsSource's {:?} is already registered", source);
2021-09-18 22:32:00 +08:00
}
self.handlers.insert(source, handler);
2021-09-17 19:03:46 +08:00
Ok(())
2021-09-16 23:07:15 +08:00
}
2021-09-30 17:24:02 +08:00
pub async fn connect(&self, addr: String) -> Result<(), ServerError> {
let (ret, rx) = oneshot::channel::<Result<(), ServerError>>();
self._connect(addr.clone(), ret);
rx.await?
2021-09-19 16:11:02 +08:00
}
2021-09-30 17:24:02 +08:00
#[allow(dead_code)]
pub fn state_subscribe(&self) -> broadcast::Receiver<WsState> { self.state_notify.subscribe() }
pub fn sender(&self) -> Result<Arc<WsSender>, WsError> {
match &*self.sender.read() {
2021-09-27 23:23:23 +08:00
None => Err(WsError::internal().context("WsSender is not initialized, should call connect first")),
2021-09-20 15:38:55 +08:00
Some(sender) => Ok(sender.clone()),
}
}
2021-09-30 17:24:02 +08:00
fn _connect(&self, addr: String, ret: oneshot::Sender<Result<(), ServerError>>) {
log::debug!("🐴 ws connect: {}", &addr);
2021-09-19 16:11:02 +08:00
let (connection, handlers) = self.make_connect(addr.clone());
2021-09-19 18:39:56 +08:00
let state_notify = self.state_notify.clone();
2021-09-27 23:23:23 +08:00
let sender = self
.sender
2021-09-30 17:24:02 +08:00
.read()
2021-09-27 23:23:23 +08:00
.clone()
.expect("Sender should be not empty after calling make_connect");
2021-09-30 17:24:02 +08:00
tokio::spawn(async move {
2021-09-19 16:11:02 +08:00
match connection.await {
Ok(stream) => {
2021-09-30 17:24:02 +08:00
state_notify.send(WsState::Connected(sender));
ret.send(Ok(()));
spawn_steam_and_handlers(stream, handlers, state_notify).await;
2021-09-19 16:11:02 +08:00
},
2021-09-19 18:39:56 +08:00
Err(e) => {
2021-09-30 17:24:02 +08:00
state_notify.send(WsState::Disconnected(e.clone()));
ret.send(Err(ServerError::internal().context(e)));
2021-09-18 22:32:00 +08:00
},
2021-09-19 16:11:02 +08:00
}
2021-09-30 17:24:02 +08:00
});
2021-09-18 22:32:00 +08:00
}
2021-09-30 17:24:02 +08:00
fn make_connect(&self, addr: String) -> (WsConnectionFuture, WsHandlerFuture) {
2021-09-16 18:31:25 +08:00
// Stream User
// ┌───────────────┐ ┌──────────────┐
// ┌──────┐ │ ┌─────────┐ │ ┌────────┐ │ ┌────────┐ │
// │Server│──────┼─▶│ ws_read │──┼───▶│ msg_tx │───┼─▶│ msg_rx │ │
// └──────┘ │ └─────────┘ │ └────────┘ │ └────────┘ │
// ▲ │ │ │ │
// │ │ ┌─────────┐ │ ┌────────┐ │ ┌────────┐ │
// └─────────┼──│ws_write │◀─┼────│ ws_rx │◀──┼──│ ws_tx │ │
// │ └─────────┘ │ └────────┘ │ └────────┘ │
// └───────────────┘ └──────────────┘
let (msg_tx, msg_rx) = futures_channel::mpsc::unbounded();
let (ws_tx, ws_rx) = futures_channel::mpsc::unbounded();
2021-09-16 23:07:15 +08:00
let handlers = self.handlers.clone();
2021-09-30 17:24:02 +08:00
*self.sender.write() = Some(Arc::new(WsSender { ws_tx }));
2021-09-27 23:23:23 +08:00
(
WsConnectionFuture::new(msg_tx, ws_rx, addr),
WsHandlerFuture::new(handlers, msg_rx),
)
2021-09-16 18:31:25 +08:00
}
}
2021-09-30 17:24:02 +08:00
async fn spawn_steam_and_handlers(
stream: WsStream,
handlers: WsHandlerFuture,
state_notify: Arc<broadcast::Sender<WsState>>,
) {
tokio::select! {
result = stream => {
match result {
Ok(_) => {},
Err(e) => {
// TODO: retry?
log::error!("ws stream error {:?}", e);
state_notify.send(WsState::Disconnected(e));
}
}
},
result = handlers => log::debug!("handlers completed {:?}", result),
};
}
2021-09-16 18:31:25 +08:00
#[pin_project]
2021-09-20 15:38:55 +08:00
pub struct WsHandlerFuture {
2021-09-16 18:31:25 +08:00
#[pin]
msg_rx: MsgReceiver,
2021-09-30 17:24:02 +08:00
// Opti: Hashmap would be better
handlers: Handlers,
2021-09-16 18:31:25 +08:00
}
2021-09-20 15:38:55 +08:00
impl WsHandlerFuture {
2021-09-30 17:24:02 +08:00
fn new(handlers: Handlers, msg_rx: MsgReceiver) -> Self { Self { msg_rx, handlers } }
2021-09-24 16:02:17 +08:00
fn handler_ws_message(&self, message: Message) {
match message {
Message::Binary(bytes) => self.handle_binary_message(bytes),
_ => {},
}
}
fn handle_binary_message(&self, bytes: Vec<u8>) {
let bytes = Bytes::from(bytes);
match WsMessage::try_from(bytes) {
Ok(message) => match self.handlers.get(&message.module) {
None => log::error!("Can't find any handler for message: {:?}", message),
Some(handler) => handler.receive_message(message.clone()),
},
Err(e) => {
log::error!("Deserialize binary ws message failed: {:?}", e);
},
}
}
2021-09-16 18:31:25 +08:00
}
2021-09-20 15:38:55 +08:00
impl Future for WsHandlerFuture {
2021-09-16 18:31:25 +08:00
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
match ready!(self.as_mut().project().msg_rx.poll_next(cx)) {
2021-09-18 22:32:00 +08:00
None => {
2021-09-19 16:11:02 +08:00
return Poll::Ready(());
2021-09-16 23:07:15 +08:00
},
2021-09-24 16:02:17 +08:00
Some(message) => self.handler_ws_message(message),
2021-09-16 23:07:15 +08:00
}
}
}
}
2021-09-19 18:39:56 +08:00
#[derive(Debug, Clone)]
pub struct WsSender {
2021-09-16 23:07:15 +08:00
ws_tx: MsgSender,
}
impl WsSender {
2021-09-19 18:39:56 +08:00
pub fn send_msg<T: Into<WsMessage>>(&self, msg: T) -> Result<(), WsError> {
let msg = msg.into();
2021-09-27 23:23:23 +08:00
let _ = self
.ws_tx
.unbounded_send(msg.into())
.map_err(|e| WsError::internal().context(e))?;
2021-09-16 23:07:15 +08:00
Ok(())
}
2021-09-19 18:39:56 +08:00
pub fn send_text(&self, source: &WsModule, text: &str) -> Result<(), WsError> {
2021-09-19 18:39:56 +08:00
let msg = WsMessage {
module: source.clone(),
2021-09-19 18:39:56 +08:00
data: text.as_bytes().to_vec(),
};
self.send_msg(msg)
}
pub fn send_binary(&self, source: &WsModule, bytes: Vec<u8>) -> Result<(), WsError> {
let msg = WsMessage {
module: source.clone(),
data: bytes,
};
2021-09-19 18:39:56 +08:00
self.send_msg(msg)
}
pub fn send_disconnect(&self, reason: &str) -> Result<(), WsError> {
let frame = CloseFrame {
code: CloseCode::Normal,
reason: reason.to_owned().into(),
};
let msg = Message::Close(Some(frame));
2021-09-27 23:23:23 +08:00
let _ = self
.ws_tx
.unbounded_send(msg)
.map_err(|e| WsError::internal().context(e))?;
2021-09-19 18:39:56 +08:00
Ok(())
}
2021-09-16 23:07:15 +08:00
}
// #[cfg(test)]
// mod tests {
// use super::WsController;
//
// #[tokio::test]
// async fn connect() {
// std::env::set_var("RUST_LOG", "Debug");
// env_logger::init();
//
// let mut controller = WsController::new();
// let addr = format!("{}/123", flowy_net::config::WS_ADDR.as_str());
// let (a, b) = controller.make_connect(addr);
// tokio::select! {
// r = a => println!("write completed {:?}", r),
// _ = b => println!("read completed"),
// };
// }
// }