2021-09-18 22:32:00 +08:00
|
|
|
use crate::{connect::WsConnection, errors::WsError, WsMessage};
|
|
|
|
use flowy_net::errors::ServerError;
|
2021-09-16 18:31:25 +08:00
|
|
|
use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender};
|
2021-09-16 23:07:15 +08:00
|
|
|
use futures_core::{future::BoxFuture, ready, Stream};
|
2021-09-18 22:32:00 +08:00
|
|
|
use futures_util::{
|
|
|
|
future,
|
|
|
|
future::{Either, Select},
|
|
|
|
pin_mut,
|
|
|
|
FutureExt,
|
|
|
|
StreamExt,
|
|
|
|
};
|
2021-09-16 18:31:25 +08:00
|
|
|
use pin_project::pin_project;
|
|
|
|
use std::{
|
2021-09-18 22:32:00 +08:00
|
|
|
collections::HashMap,
|
2021-09-16 18:31:25 +08:00
|
|
|
future::Future,
|
|
|
|
pin::Pin,
|
|
|
|
sync::Arc,
|
|
|
|
task::{Context, Poll},
|
|
|
|
};
|
2021-09-18 22:32:00 +08:00
|
|
|
use tokio::{net::TcpStream, task::JoinHandle};
|
2021-09-16 23:07:15 +08:00
|
|
|
use tokio_tungstenite::{
|
|
|
|
connect_async,
|
2021-09-17 19:03:46 +08:00
|
|
|
tungstenite::{handshake::client::Response, http::StatusCode, Error, Message},
|
2021-09-16 23:07:15 +08:00
|
|
|
MaybeTlsStream,
|
|
|
|
WebSocketStream,
|
|
|
|
};
|
|
|
|
|
2021-09-16 18:31:25 +08:00
|
|
|
pub type MsgReceiver = UnboundedReceiver<Message>;
|
|
|
|
pub type MsgSender = UnboundedSender<Message>;
|
|
|
|
pub trait WsMessageHandler: Sync + Send + 'static {
|
2021-09-18 22:32:00 +08:00
|
|
|
fn source(&self) -> String;
|
|
|
|
fn receive_message(&self, msg: WsMessage);
|
2021-09-16 18:31:25 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
pub struct WsController {
|
2021-09-16 23:07:15 +08:00
|
|
|
sender: Option<Arc<WsSender>>,
|
2021-09-18 22:32:00 +08:00
|
|
|
handlers: HashMap<String, Arc<dyn WsMessageHandler>>,
|
2021-09-16 18:31:25 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
impl WsController {
|
|
|
|
pub fn new() -> Self {
|
2021-09-16 23:07:15 +08:00
|
|
|
let controller = Self {
|
|
|
|
sender: None,
|
2021-09-18 22:32:00 +08:00
|
|
|
handlers: HashMap::new(),
|
2021-09-16 23:07:15 +08:00
|
|
|
};
|
|
|
|
controller
|
2021-09-16 18:31:25 +08:00
|
|
|
}
|
|
|
|
|
2021-09-18 22:32:00 +08:00
|
|
|
pub fn add_handler(&mut self, handler: Arc<dyn WsMessageHandler>) -> Result<(), WsError> {
|
|
|
|
let source = handler.source();
|
|
|
|
if self.handlers.contains_key(&source) {
|
|
|
|
return Err(WsError::duplicate_source());
|
|
|
|
}
|
|
|
|
self.handlers.insert(source, handler);
|
2021-09-17 19:03:46 +08:00
|
|
|
Ok(())
|
2021-09-16 23:07:15 +08:00
|
|
|
}
|
|
|
|
|
2021-09-18 22:32:00 +08:00
|
|
|
pub fn connect(&mut self, addr: String) -> Result<JoinHandle<()>, ServerError> {
|
|
|
|
log::debug!("🐴 Try to connect: {}", &addr);
|
|
|
|
let (connection, handlers) = self.make_connect(addr);
|
|
|
|
Ok(tokio::spawn(async {
|
|
|
|
tokio::select! {
|
|
|
|
result = connection => {
|
|
|
|
match result {
|
|
|
|
Ok(stream) => {
|
|
|
|
tokio::spawn(stream).await;
|
|
|
|
// stream.start().await;
|
|
|
|
},
|
|
|
|
Err(e) => {
|
|
|
|
// TODO: retry?
|
|
|
|
log::error!("ws connect failed {:?}", e);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
},
|
|
|
|
result = handlers => log::debug!("handlers completed {:?}", result),
|
|
|
|
};
|
|
|
|
}))
|
|
|
|
}
|
|
|
|
|
|
|
|
fn make_connect(&mut self, addr: String) -> (WsConnection, WsHandlers) {
|
2021-09-16 18:31:25 +08:00
|
|
|
// Stream User
|
|
|
|
// ┌───────────────┐ ┌──────────────┐
|
|
|
|
// ┌──────┐ │ ┌─────────┐ │ ┌────────┐ │ ┌────────┐ │
|
|
|
|
// │Server│──────┼─▶│ ws_read │──┼───▶│ msg_tx │───┼─▶│ msg_rx │ │
|
|
|
|
// └──────┘ │ └─────────┘ │ └────────┘ │ └────────┘ │
|
|
|
|
// ▲ │ │ │ │
|
|
|
|
// │ │ ┌─────────┐ │ ┌────────┐ │ ┌────────┐ │
|
|
|
|
// └─────────┼──│ws_write │◀─┼────│ ws_rx │◀──┼──│ ws_tx │ │
|
|
|
|
// │ └─────────┘ │ └────────┘ │ └────────┘ │
|
|
|
|
// └───────────────┘ └──────────────┘
|
|
|
|
let (msg_tx, msg_rx) = futures_channel::mpsc::unbounded();
|
|
|
|
let (ws_tx, ws_rx) = futures_channel::mpsc::unbounded();
|
2021-09-16 23:07:15 +08:00
|
|
|
let handlers = self.handlers.clone();
|
2021-09-18 22:32:00 +08:00
|
|
|
self.sender = Some(Arc::new(WsSender::new(ws_tx)));
|
2021-09-17 19:03:46 +08:00
|
|
|
(WsConnection::new(msg_tx, ws_rx, addr), WsHandlers::new(handlers, msg_rx))
|
2021-09-16 18:31:25 +08:00
|
|
|
}
|
|
|
|
|
2021-09-18 22:32:00 +08:00
|
|
|
pub fn send_msg<T: Into<WsMessage>>(&self, msg: T) -> Result<(), WsError> {
|
|
|
|
match self.sender.as_ref() {
|
|
|
|
None => Err(WsError::internal().context("Should call make_connect first")),
|
|
|
|
Some(sender) => sender.send(msg.into()),
|
2021-09-16 18:31:25 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[pin_project]
|
2021-09-17 19:03:46 +08:00
|
|
|
pub struct WsHandlers {
|
2021-09-16 18:31:25 +08:00
|
|
|
#[pin]
|
|
|
|
msg_rx: MsgReceiver,
|
2021-09-18 22:32:00 +08:00
|
|
|
handlers: HashMap<String, Arc<dyn WsMessageHandler>>,
|
2021-09-16 18:31:25 +08:00
|
|
|
}
|
|
|
|
|
2021-09-16 23:07:15 +08:00
|
|
|
impl WsHandlers {
|
2021-09-18 22:32:00 +08:00
|
|
|
fn new(handlers: HashMap<String, Arc<dyn WsMessageHandler>>, msg_rx: MsgReceiver) -> Self { Self { msg_rx, handlers } }
|
2021-09-16 18:31:25 +08:00
|
|
|
}
|
|
|
|
|
2021-09-16 23:07:15 +08:00
|
|
|
impl Future for WsHandlers {
|
2021-09-16 18:31:25 +08:00
|
|
|
type Output = ();
|
|
|
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
|
|
loop {
|
|
|
|
match ready!(self.as_mut().project().msg_rx.poll_next(cx)) {
|
2021-09-18 22:32:00 +08:00
|
|
|
None => {
|
|
|
|
// log::debug!("🐴 ws handler done");
|
|
|
|
return Poll::Pending;
|
2021-09-16 23:07:15 +08:00
|
|
|
},
|
2021-09-18 22:32:00 +08:00
|
|
|
Some(message) => {
|
|
|
|
let message = WsMessage::from(message);
|
|
|
|
match self.handlers.get(&message.source) {
|
|
|
|
None => log::error!("Can't find any handler for message: {:?}", message),
|
|
|
|
Some(handler) => handler.receive_message(message.clone()),
|
|
|
|
}
|
2021-09-16 23:07:15 +08:00
|
|
|
},
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2021-09-18 22:32:00 +08:00
|
|
|
struct WsSender {
|
2021-09-16 23:07:15 +08:00
|
|
|
ws_tx: MsgSender,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl WsSender {
|
|
|
|
pub fn new(ws_tx: MsgSender) -> Self { Self { ws_tx } }
|
|
|
|
|
2021-09-18 22:32:00 +08:00
|
|
|
pub fn send(&self, msg: WsMessage) -> Result<(), WsError> {
|
|
|
|
let _ = self.ws_tx.unbounded_send(msg.into()).map_err(|e| WsError::internal().context(e))?;
|
2021-09-16 23:07:15 +08:00
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[cfg(test)]
|
|
|
|
mod tests {
|
|
|
|
use super::WsController;
|
|
|
|
|
|
|
|
#[tokio::test]
|
|
|
|
async fn connect() {
|
2021-09-17 19:03:46 +08:00
|
|
|
std::env::set_var("RUST_LOG", "Debug");
|
|
|
|
env_logger::init();
|
|
|
|
|
2021-09-16 23:07:15 +08:00
|
|
|
let mut controller = WsController::new();
|
|
|
|
let addr = format!("{}/123", flowy_net::config::WS_ADDR.as_str());
|
2021-09-17 19:03:46 +08:00
|
|
|
let (a, b) = controller.make_connect(addr);
|
2021-09-16 23:07:15 +08:00
|
|
|
tokio::select! {
|
2021-09-17 19:03:46 +08:00
|
|
|
r = a => println!("write completed {:?}", r),
|
2021-09-16 23:07:15 +08:00
|
|
|
_ = b => println!("read completed"),
|
2021-09-16 18:31:25 +08:00
|
|
|
};
|
2021-09-16 23:07:15 +08:00
|
|
|
}
|
2021-09-16 18:31:25 +08:00
|
|
|
}
|