2021-09-19 23:21:10 +08:00
|
|
|
use crate::{
|
2021-09-20 15:38:55 +08:00
|
|
|
connect::{Retry, WsConnectionFuture},
|
2021-09-19 23:21:10 +08:00
|
|
|
errors::WsError,
|
|
|
|
WsMessage,
|
2021-09-23 19:59:58 +08:00
|
|
|
WsModule,
|
2021-09-19 23:21:10 +08:00
|
|
|
};
|
2021-09-24 16:02:17 +08:00
|
|
|
use bytes::Bytes;
|
2021-09-18 22:32:00 +08:00
|
|
|
use flowy_net::errors::ServerError;
|
2021-09-16 18:31:25 +08:00
|
|
|
use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender};
|
2021-09-19 23:21:10 +08:00
|
|
|
use futures_core::{future::BoxFuture, ready, Stream};
|
2021-09-16 18:31:25 +08:00
|
|
|
use pin_project::pin_project;
|
|
|
|
use std::{
|
2021-09-18 22:32:00 +08:00
|
|
|
collections::HashMap,
|
2021-09-25 21:47:02 +08:00
|
|
|
convert::TryFrom,
|
2021-09-16 18:31:25 +08:00
|
|
|
future::Future,
|
|
|
|
pin::Pin,
|
|
|
|
sync::Arc,
|
|
|
|
task::{Context, Poll},
|
|
|
|
};
|
2021-09-19 18:39:56 +08:00
|
|
|
use tokio::{sync::RwLock, task::JoinHandle};
|
2021-09-19 23:21:10 +08:00
|
|
|
use tokio_tungstenite::tungstenite::{
|
|
|
|
protocol::{frame::coding::CloseCode, CloseFrame},
|
|
|
|
Message,
|
2021-09-19 18:39:56 +08:00
|
|
|
};
|
2021-09-16 23:07:15 +08:00
|
|
|
|
2021-09-16 18:31:25 +08:00
|
|
|
pub type MsgReceiver = UnboundedReceiver<Message>;
|
|
|
|
pub type MsgSender = UnboundedSender<Message>;
|
|
|
|
pub trait WsMessageHandler: Sync + Send + 'static {
|
2021-09-23 19:59:58 +08:00
|
|
|
fn source(&self) -> WsModule;
|
2021-09-18 22:32:00 +08:00
|
|
|
fn receive_message(&self, msg: WsMessage);
|
2021-09-16 18:31:25 +08:00
|
|
|
}
|
|
|
|
|
2021-09-19 18:39:56 +08:00
|
|
|
type NotifyCallback = Arc<dyn Fn(&WsState) + Send + Sync + 'static>;
|
|
|
|
struct WsStateNotify {
|
|
|
|
#[allow(dead_code)]
|
|
|
|
state: WsState,
|
|
|
|
callback: Option<NotifyCallback>,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl WsStateNotify {
|
|
|
|
fn update_state(&mut self, state: WsState) {
|
|
|
|
if let Some(f) = &self.callback {
|
|
|
|
f(&state);
|
|
|
|
}
|
|
|
|
self.state = state;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
pub enum WsState {
|
|
|
|
Init,
|
|
|
|
Connected(Arc<WsSender>),
|
|
|
|
Disconnected(WsError),
|
|
|
|
}
|
|
|
|
|
2021-09-16 18:31:25 +08:00
|
|
|
pub struct WsController {
|
2021-09-23 19:59:58 +08:00
|
|
|
handlers: HashMap<WsModule, Arc<dyn WsMessageHandler>>,
|
2021-09-19 18:39:56 +08:00
|
|
|
state_notify: Arc<RwLock<WsStateNotify>>,
|
2021-09-22 14:42:14 +08:00
|
|
|
#[allow(dead_code)]
|
2021-09-19 16:11:02 +08:00
|
|
|
addr: Option<String>,
|
2021-09-19 18:39:56 +08:00
|
|
|
sender: Option<Arc<WsSender>>,
|
2021-09-16 18:31:25 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
impl WsController {
|
|
|
|
pub fn new() -> Self {
|
2021-09-19 18:39:56 +08:00
|
|
|
let state_notify = Arc::new(RwLock::new(WsStateNotify {
|
|
|
|
state: WsState::Init,
|
|
|
|
callback: None,
|
|
|
|
}));
|
|
|
|
|
2021-09-16 23:07:15 +08:00
|
|
|
let controller = Self {
|
2021-09-18 22:32:00 +08:00
|
|
|
handlers: HashMap::new(),
|
2021-09-19 18:39:56 +08:00
|
|
|
state_notify,
|
2021-09-19 16:11:02 +08:00
|
|
|
addr: None,
|
2021-09-19 18:39:56 +08:00
|
|
|
sender: None,
|
2021-09-16 23:07:15 +08:00
|
|
|
};
|
|
|
|
controller
|
2021-09-16 18:31:25 +08:00
|
|
|
}
|
|
|
|
|
2021-09-19 18:39:56 +08:00
|
|
|
pub async fn state_callback<SC>(&self, callback: SC)
|
|
|
|
where
|
|
|
|
SC: Fn(&WsState) + Send + Sync + 'static,
|
|
|
|
{
|
|
|
|
(self.state_notify.write().await).callback = Some(Arc::new(callback));
|
|
|
|
}
|
|
|
|
|
2021-09-18 22:32:00 +08:00
|
|
|
pub fn add_handler(&mut self, handler: Arc<dyn WsMessageHandler>) -> Result<(), WsError> {
|
|
|
|
let source = handler.source();
|
|
|
|
if self.handlers.contains_key(&source) {
|
2021-09-23 17:50:28 +08:00
|
|
|
log::error!("WsSource's {:?} is already registered", source);
|
2021-09-18 22:32:00 +08:00
|
|
|
}
|
|
|
|
self.handlers.insert(source, handler);
|
2021-09-17 19:03:46 +08:00
|
|
|
Ok(())
|
2021-09-16 23:07:15 +08:00
|
|
|
}
|
|
|
|
|
2021-09-19 16:11:02 +08:00
|
|
|
pub fn connect(&mut self, addr: String) -> Result<JoinHandle<()>, ServerError> { self._connect(addr.clone(), None) }
|
|
|
|
|
|
|
|
pub fn connect_with_retry<F>(&mut self, addr: String, retry: Retry<F>) -> Result<JoinHandle<()>, ServerError>
|
|
|
|
where
|
|
|
|
F: Fn(&str) + Send + Sync + 'static,
|
|
|
|
{
|
|
|
|
self._connect(addr, Some(Box::pin(async { retry.await })))
|
|
|
|
}
|
|
|
|
|
2021-09-20 15:38:55 +08:00
|
|
|
pub fn get_sender(&self) -> Result<Arc<WsSender>, WsError> {
|
|
|
|
match &self.sender {
|
2021-09-27 23:23:23 +08:00
|
|
|
None => Err(WsError::internal().context("WsSender is not initialized, should call connect first")),
|
2021-09-20 15:38:55 +08:00
|
|
|
Some(sender) => Ok(sender.clone()),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2021-09-19 16:11:02 +08:00
|
|
|
fn _connect(&mut self, addr: String, retry: Option<BoxFuture<'static, ()>>) -> Result<JoinHandle<()>, ServerError> {
|
2021-09-19 12:54:28 +08:00
|
|
|
log::debug!("🐴 ws connect: {}", &addr);
|
2021-09-19 16:11:02 +08:00
|
|
|
let (connection, handlers) = self.make_connect(addr.clone());
|
2021-09-19 18:39:56 +08:00
|
|
|
let state_notify = self.state_notify.clone();
|
2021-09-27 23:23:23 +08:00
|
|
|
let sender = self
|
|
|
|
.sender
|
|
|
|
.clone()
|
|
|
|
.expect("Sender should be not empty after calling make_connect");
|
2021-09-19 16:11:02 +08:00
|
|
|
Ok(tokio::spawn(async move {
|
|
|
|
match connection.await {
|
|
|
|
Ok(stream) => {
|
2021-09-19 18:39:56 +08:00
|
|
|
state_notify.write().await.update_state(WsState::Connected(sender));
|
2021-09-19 16:11:02 +08:00
|
|
|
tokio::select! {
|
|
|
|
result = stream => {
|
|
|
|
match result {
|
|
|
|
Ok(_) => {},
|
|
|
|
Err(e) => {
|
|
|
|
// TODO: retry?
|
|
|
|
log::error!("ws stream error {:?}", e);
|
2021-09-19 18:39:56 +08:00
|
|
|
state_notify.write().await.update_state(WsState::Disconnected(e));
|
2021-09-19 16:11:02 +08:00
|
|
|
}
|
|
|
|
}
|
2021-09-18 22:32:00 +08:00
|
|
|
},
|
2021-09-19 16:11:02 +08:00
|
|
|
result = handlers => log::debug!("handlers completed {:?}", result),
|
|
|
|
};
|
|
|
|
},
|
2021-09-19 18:39:56 +08:00
|
|
|
Err(e) => {
|
|
|
|
log::error!("ws connect {} failed {:?}", addr, e);
|
|
|
|
state_notify.write().await.update_state(WsState::Disconnected(e));
|
|
|
|
if let Some(retry) = retry {
|
2021-09-19 16:11:02 +08:00
|
|
|
tokio::spawn(retry);
|
2021-09-19 18:39:56 +08:00
|
|
|
}
|
2021-09-18 22:32:00 +08:00
|
|
|
},
|
2021-09-19 16:11:02 +08:00
|
|
|
}
|
2021-09-18 22:32:00 +08:00
|
|
|
}))
|
|
|
|
}
|
|
|
|
|
2021-09-20 15:38:55 +08:00
|
|
|
fn make_connect(&mut self, addr: String) -> (WsConnectionFuture, WsHandlerFuture) {
|
2021-09-16 18:31:25 +08:00
|
|
|
// Stream User
|
|
|
|
// ┌───────────────┐ ┌──────────────┐
|
|
|
|
// ┌──────┐ │ ┌─────────┐ │ ┌────────┐ │ ┌────────┐ │
|
|
|
|
// │Server│──────┼─▶│ ws_read │──┼───▶│ msg_tx │───┼─▶│ msg_rx │ │
|
|
|
|
// └──────┘ │ └─────────┘ │ └────────┘ │ └────────┘ │
|
|
|
|
// ▲ │ │ │ │
|
|
|
|
// │ │ ┌─────────┐ │ ┌────────┐ │ ┌────────┐ │
|
|
|
|
// └─────────┼──│ws_write │◀─┼────│ ws_rx │◀──┼──│ ws_tx │ │
|
|
|
|
// │ └─────────┘ │ └────────┘ │ └────────┘ │
|
|
|
|
// └───────────────┘ └──────────────┘
|
|
|
|
let (msg_tx, msg_rx) = futures_channel::mpsc::unbounded();
|
|
|
|
let (ws_tx, ws_rx) = futures_channel::mpsc::unbounded();
|
2021-09-16 23:07:15 +08:00
|
|
|
let handlers = self.handlers.clone();
|
2021-09-19 18:39:56 +08:00
|
|
|
self.sender = Some(Arc::new(WsSender { ws_tx }));
|
2021-09-19 16:11:02 +08:00
|
|
|
self.addr = Some(addr.clone());
|
2021-09-27 23:23:23 +08:00
|
|
|
(
|
|
|
|
WsConnectionFuture::new(msg_tx, ws_rx, addr),
|
|
|
|
WsHandlerFuture::new(handlers, msg_rx),
|
|
|
|
)
|
2021-09-16 18:31:25 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[pin_project]
|
2021-09-20 15:38:55 +08:00
|
|
|
pub struct WsHandlerFuture {
|
2021-09-16 18:31:25 +08:00
|
|
|
#[pin]
|
|
|
|
msg_rx: MsgReceiver,
|
2021-09-23 19:59:58 +08:00
|
|
|
handlers: HashMap<WsModule, Arc<dyn WsMessageHandler>>,
|
2021-09-16 18:31:25 +08:00
|
|
|
}
|
|
|
|
|
2021-09-20 15:38:55 +08:00
|
|
|
impl WsHandlerFuture {
|
2021-09-27 23:23:23 +08:00
|
|
|
fn new(handlers: HashMap<WsModule, Arc<dyn WsMessageHandler>>, msg_rx: MsgReceiver) -> Self {
|
|
|
|
Self { msg_rx, handlers }
|
|
|
|
}
|
2021-09-24 16:02:17 +08:00
|
|
|
|
|
|
|
fn handler_ws_message(&self, message: Message) {
|
|
|
|
match message {
|
|
|
|
Message::Binary(bytes) => self.handle_binary_message(bytes),
|
|
|
|
_ => {},
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
fn handle_binary_message(&self, bytes: Vec<u8>) {
|
|
|
|
let bytes = Bytes::from(bytes);
|
|
|
|
match WsMessage::try_from(bytes) {
|
|
|
|
Ok(message) => match self.handlers.get(&message.module) {
|
|
|
|
None => log::error!("Can't find any handler for message: {:?}", message),
|
|
|
|
Some(handler) => handler.receive_message(message.clone()),
|
|
|
|
},
|
|
|
|
Err(e) => {
|
|
|
|
log::error!("Deserialize binary ws message failed: {:?}", e);
|
|
|
|
},
|
|
|
|
}
|
|
|
|
}
|
2021-09-16 18:31:25 +08:00
|
|
|
}
|
|
|
|
|
2021-09-20 15:38:55 +08:00
|
|
|
impl Future for WsHandlerFuture {
|
2021-09-16 18:31:25 +08:00
|
|
|
type Output = ();
|
|
|
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
|
|
loop {
|
|
|
|
match ready!(self.as_mut().project().msg_rx.poll_next(cx)) {
|
2021-09-18 22:32:00 +08:00
|
|
|
None => {
|
2021-09-19 16:11:02 +08:00
|
|
|
return Poll::Ready(());
|
2021-09-16 23:07:15 +08:00
|
|
|
},
|
2021-09-24 16:02:17 +08:00
|
|
|
Some(message) => self.handler_ws_message(message),
|
2021-09-16 23:07:15 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2021-09-19 18:39:56 +08:00
|
|
|
#[derive(Debug, Clone)]
|
|
|
|
pub struct WsSender {
|
2021-09-16 23:07:15 +08:00
|
|
|
ws_tx: MsgSender,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl WsSender {
|
2021-09-19 18:39:56 +08:00
|
|
|
pub fn send_msg<T: Into<WsMessage>>(&self, msg: T) -> Result<(), WsError> {
|
|
|
|
let msg = msg.into();
|
2021-09-27 23:23:23 +08:00
|
|
|
let _ = self
|
|
|
|
.ws_tx
|
|
|
|
.unbounded_send(msg.into())
|
|
|
|
.map_err(|e| WsError::internal().context(e))?;
|
2021-09-16 23:07:15 +08:00
|
|
|
Ok(())
|
|
|
|
}
|
2021-09-19 18:39:56 +08:00
|
|
|
|
2021-09-23 19:59:58 +08:00
|
|
|
pub fn send_text(&self, source: &WsModule, text: &str) -> Result<(), WsError> {
|
2021-09-19 18:39:56 +08:00
|
|
|
let msg = WsMessage {
|
2021-09-23 19:59:58 +08:00
|
|
|
module: source.clone(),
|
2021-09-19 18:39:56 +08:00
|
|
|
data: text.as_bytes().to_vec(),
|
|
|
|
};
|
|
|
|
self.send_msg(msg)
|
|
|
|
}
|
|
|
|
|
2021-09-23 19:59:58 +08:00
|
|
|
pub fn send_binary(&self, source: &WsModule, bytes: Vec<u8>) -> Result<(), WsError> {
|
|
|
|
let msg = WsMessage {
|
|
|
|
module: source.clone(),
|
|
|
|
data: bytes,
|
|
|
|
};
|
2021-09-19 18:39:56 +08:00
|
|
|
self.send_msg(msg)
|
|
|
|
}
|
|
|
|
|
|
|
|
pub fn send_disconnect(&self, reason: &str) -> Result<(), WsError> {
|
|
|
|
let frame = CloseFrame {
|
|
|
|
code: CloseCode::Normal,
|
|
|
|
reason: reason.to_owned().into(),
|
|
|
|
};
|
|
|
|
let msg = Message::Close(Some(frame));
|
2021-09-27 23:23:23 +08:00
|
|
|
let _ = self
|
|
|
|
.ws_tx
|
|
|
|
.unbounded_send(msg)
|
|
|
|
.map_err(|e| WsError::internal().context(e))?;
|
2021-09-19 18:39:56 +08:00
|
|
|
Ok(())
|
|
|
|
}
|
2021-09-16 23:07:15 +08:00
|
|
|
}
|
|
|
|
|
2021-09-28 15:29:29 +08:00
|
|
|
// #[cfg(test)]
|
|
|
|
// mod tests {
|
|
|
|
// use super::WsController;
|
|
|
|
//
|
|
|
|
// #[tokio::test]
|
|
|
|
// async fn connect() {
|
|
|
|
// std::env::set_var("RUST_LOG", "Debug");
|
|
|
|
// env_logger::init();
|
|
|
|
//
|
|
|
|
// let mut controller = WsController::new();
|
|
|
|
// let addr = format!("{}/123", flowy_net::config::WS_ADDR.as_str());
|
|
|
|
// let (a, b) = controller.make_connect(addr);
|
|
|
|
// tokio::select! {
|
|
|
|
// r = a => println!("write completed {:?}", r),
|
|
|
|
// _ = b => println!("read completed"),
|
|
|
|
// };
|
|
|
|
// }
|
|
|
|
// }
|