169 lines
6.1 KiB
Rust
Raw Normal View History

2021-09-18 22:32:00 +08:00
use crate::{connect::WsConnection, errors::WsError, WsMessage};
use flowy_net::errors::ServerError;
2021-09-16 18:31:25 +08:00
use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender};
2021-09-16 23:07:15 +08:00
use futures_core::{future::BoxFuture, ready, Stream};
2021-09-18 22:32:00 +08:00
use futures_util::{
future,
future::{Either, Select},
pin_mut,
FutureExt,
StreamExt,
};
2021-09-16 18:31:25 +08:00
use pin_project::pin_project;
use std::{
2021-09-18 22:32:00 +08:00
collections::HashMap,
2021-09-16 18:31:25 +08:00
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
2021-09-18 22:32:00 +08:00
use tokio::{net::TcpStream, task::JoinHandle};
2021-09-16 23:07:15 +08:00
use tokio_tungstenite::{
connect_async,
2021-09-17 19:03:46 +08:00
tungstenite::{handshake::client::Response, http::StatusCode, Error, Message},
2021-09-16 23:07:15 +08:00
MaybeTlsStream,
WebSocketStream,
};
2021-09-16 18:31:25 +08:00
pub type MsgReceiver = UnboundedReceiver<Message>;
pub type MsgSender = UnboundedSender<Message>;
pub trait WsMessageHandler: Sync + Send + 'static {
2021-09-18 22:32:00 +08:00
fn source(&self) -> String;
fn receive_message(&self, msg: WsMessage);
2021-09-16 18:31:25 +08:00
}
pub struct WsController {
2021-09-16 23:07:15 +08:00
sender: Option<Arc<WsSender>>,
2021-09-18 22:32:00 +08:00
handlers: HashMap<String, Arc<dyn WsMessageHandler>>,
2021-09-16 18:31:25 +08:00
}
impl WsController {
pub fn new() -> Self {
2021-09-16 23:07:15 +08:00
let controller = Self {
sender: None,
2021-09-18 22:32:00 +08:00
handlers: HashMap::new(),
2021-09-16 23:07:15 +08:00
};
controller
2021-09-16 18:31:25 +08:00
}
2021-09-18 22:32:00 +08:00
pub fn add_handler(&mut self, handler: Arc<dyn WsMessageHandler>) -> Result<(), WsError> {
let source = handler.source();
if self.handlers.contains_key(&source) {
return Err(WsError::duplicate_source());
}
self.handlers.insert(source, handler);
2021-09-17 19:03:46 +08:00
Ok(())
2021-09-16 23:07:15 +08:00
}
2021-09-18 22:32:00 +08:00
pub fn connect(&mut self, addr: String) -> Result<JoinHandle<()>, ServerError> {
log::debug!("🐴 Try to connect: {}", &addr);
let (connection, handlers) = self.make_connect(addr);
Ok(tokio::spawn(async {
tokio::select! {
result = connection => {
match result {
Ok(stream) => {
tokio::spawn(stream).await;
// stream.start().await;
},
Err(e) => {
// TODO: retry?
log::error!("ws connect failed {:?}", e);
}
}
},
result = handlers => log::debug!("handlers completed {:?}", result),
};
}))
}
fn make_connect(&mut self, addr: String) -> (WsConnection, WsHandlers) {
2021-09-16 18:31:25 +08:00
// Stream User
// ┌───────────────┐ ┌──────────────┐
// ┌──────┐ │ ┌─────────┐ │ ┌────────┐ │ ┌────────┐ │
// │Server│──────┼─▶│ ws_read │──┼───▶│ msg_tx │───┼─▶│ msg_rx │ │
// └──────┘ │ └─────────┘ │ └────────┘ │ └────────┘ │
// ▲ │ │ │ │
// │ │ ┌─────────┐ │ ┌────────┐ │ ┌────────┐ │
// └─────────┼──│ws_write │◀─┼────│ ws_rx │◀──┼──│ ws_tx │ │
// │ └─────────┘ │ └────────┘ │ └────────┘ │
// └───────────────┘ └──────────────┘
let (msg_tx, msg_rx) = futures_channel::mpsc::unbounded();
let (ws_tx, ws_rx) = futures_channel::mpsc::unbounded();
2021-09-16 23:07:15 +08:00
let handlers = self.handlers.clone();
2021-09-18 22:32:00 +08:00
self.sender = Some(Arc::new(WsSender::new(ws_tx)));
2021-09-17 19:03:46 +08:00
(WsConnection::new(msg_tx, ws_rx, addr), WsHandlers::new(handlers, msg_rx))
2021-09-16 18:31:25 +08:00
}
2021-09-18 22:32:00 +08:00
pub fn send_msg<T: Into<WsMessage>>(&self, msg: T) -> Result<(), WsError> {
match self.sender.as_ref() {
None => Err(WsError::internal().context("Should call make_connect first")),
Some(sender) => sender.send(msg.into()),
2021-09-16 18:31:25 +08:00
}
}
}
#[pin_project]
2021-09-17 19:03:46 +08:00
pub struct WsHandlers {
2021-09-16 18:31:25 +08:00
#[pin]
msg_rx: MsgReceiver,
2021-09-18 22:32:00 +08:00
handlers: HashMap<String, Arc<dyn WsMessageHandler>>,
2021-09-16 18:31:25 +08:00
}
2021-09-16 23:07:15 +08:00
impl WsHandlers {
2021-09-18 22:32:00 +08:00
fn new(handlers: HashMap<String, Arc<dyn WsMessageHandler>>, msg_rx: MsgReceiver) -> Self { Self { msg_rx, handlers } }
2021-09-16 18:31:25 +08:00
}
2021-09-16 23:07:15 +08:00
impl Future for WsHandlers {
2021-09-16 18:31:25 +08:00
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
match ready!(self.as_mut().project().msg_rx.poll_next(cx)) {
2021-09-18 22:32:00 +08:00
None => {
// log::debug!("🐴 ws handler done");
return Poll::Pending;
2021-09-16 23:07:15 +08:00
},
2021-09-18 22:32:00 +08:00
Some(message) => {
let message = WsMessage::from(message);
match self.handlers.get(&message.source) {
None => log::error!("Can't find any handler for message: {:?}", message),
Some(handler) => handler.receive_message(message.clone()),
}
2021-09-16 23:07:15 +08:00
},
}
}
}
}
2021-09-18 22:32:00 +08:00
struct WsSender {
2021-09-16 23:07:15 +08:00
ws_tx: MsgSender,
}
impl WsSender {
pub fn new(ws_tx: MsgSender) -> Self { Self { ws_tx } }
2021-09-18 22:32:00 +08:00
pub fn send(&self, msg: WsMessage) -> Result<(), WsError> {
let _ = self.ws_tx.unbounded_send(msg.into()).map_err(|e| WsError::internal().context(e))?;
2021-09-16 23:07:15 +08:00
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::WsController;
#[tokio::test]
async fn connect() {
2021-09-17 19:03:46 +08:00
std::env::set_var("RUST_LOG", "Debug");
env_logger::init();
2021-09-16 23:07:15 +08:00
let mut controller = WsController::new();
let addr = format!("{}/123", flowy_net::config::WS_ADDR.as_str());
2021-09-17 19:03:46 +08:00
let (a, b) = controller.make_connect(addr);
2021-09-16 23:07:15 +08:00
tokio::select! {
2021-09-17 19:03:46 +08:00
r = a => println!("write completed {:?}", r),
2021-09-16 23:07:15 +08:00
_ = b => println!("read completed"),
2021-09-16 18:31:25 +08:00
};
2021-09-16 23:07:15 +08:00
}
2021-09-16 18:31:25 +08:00
}