389 lines
13 KiB
Rust
Raw Normal View History

2021-11-27 19:19:41 +08:00
#![allow(clippy::type_complexity)]
2021-09-19 23:21:10 +08:00
use crate::{
2021-12-16 22:24:05 +08:00
connect::{WSConnectionFuture, WSStream},
errors::WSError,
2022-01-22 18:48:43 +08:00
WSChannel,
2021-12-26 19:10:37 +08:00
WebSocketRawMessage,
2021-09-19 23:21:10 +08:00
};
2021-11-20 08:35:04 +08:00
use backend_service::errors::ServerError;
2021-09-24 16:02:17 +08:00
use bytes::Bytes;
2021-09-30 17:24:02 +08:00
use dashmap::DashMap;
2021-09-16 18:31:25 +08:00
use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender};
use futures_core::{ready, Stream};
use lib_infra::retry::{Action, FixedInterval, Retry};
2021-09-30 17:24:02 +08:00
use parking_lot::RwLock;
2021-09-16 18:31:25 +08:00
use pin_project::pin_project;
use std::{
convert::TryFrom,
fmt::Formatter,
2021-09-16 18:31:25 +08:00
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
2021-10-05 17:54:11 +08:00
time::Duration,
2021-09-16 18:31:25 +08:00
};
use tokio::sync::{broadcast, oneshot};
2021-09-19 23:21:10 +08:00
use tokio_tungstenite::tungstenite::{
protocol::{frame::coding::CloseCode, CloseFrame},
Message,
2021-09-19 18:39:56 +08:00
};
2021-09-16 23:07:15 +08:00
2021-09-16 18:31:25 +08:00
pub type MsgReceiver = UnboundedReceiver<Message>;
pub type MsgSender = UnboundedSender<Message>;
2022-01-22 18:48:43 +08:00
type Handlers = DashMap<WSChannel, Arc<dyn WSMessageReceiver>>;
2021-09-30 17:24:02 +08:00
2021-12-16 22:24:05 +08:00
pub trait WSMessageReceiver: Sync + Send + 'static {
2022-01-22 18:48:43 +08:00
fn source(&self) -> WSChannel;
2021-12-26 19:10:37 +08:00
fn receive_message(&self, msg: WebSocketRawMessage);
2021-09-16 18:31:25 +08:00
}
2021-12-16 22:24:05 +08:00
pub struct WSController {
2021-09-30 17:24:02 +08:00
handlers: Handlers,
2021-12-16 22:24:05 +08:00
state_notify: Arc<broadcast::Sender<WSConnectState>>,
sender_ctrl: Arc<RwLock<WSSenderController>>,
2021-10-05 17:54:11 +08:00
addr: Arc<RwLock<Option<String>>>,
2021-09-16 18:31:25 +08:00
}
2021-12-16 22:24:05 +08:00
impl std::default::Default for WSController {
2021-11-27 19:19:41 +08:00
fn default() -> Self {
2021-09-30 17:24:02 +08:00
let (state_notify, _) = broadcast::channel(16);
2021-11-27 19:19:41 +08:00
Self {
2021-09-30 17:24:02 +08:00
handlers: DashMap::new(),
2021-12-16 22:24:05 +08:00
sender_ctrl: Arc::new(RwLock::new(WSSenderController::default())),
2021-09-30 17:24:02 +08:00
state_notify: Arc::new(state_notify),
2021-10-05 17:54:11 +08:00
addr: Arc::new(RwLock::new(None)),
2021-11-27 19:19:41 +08:00
}
2021-09-16 18:31:25 +08:00
}
2021-11-27 19:19:41 +08:00
}
2021-12-16 22:24:05 +08:00
impl WSController {
pub fn new() -> Self { WSController::default() }
2021-09-16 18:31:25 +08:00
2022-01-07 17:37:11 +08:00
pub fn add_ws_message_receiver(&self, handler: Arc<dyn WSMessageReceiver>) -> Result<(), WSError> {
2021-09-18 22:32:00 +08:00
let source = handler.source();
if self.handlers.contains_key(&source) {
2021-09-23 17:50:28 +08:00
log::error!("WsSource's {:?} is already registered", source);
2021-09-18 22:32:00 +08:00
}
self.handlers.insert(source, handler);
2021-09-17 19:03:46 +08:00
Ok(())
2021-09-16 23:07:15 +08:00
}
2021-12-10 22:18:44 +08:00
pub async fn start(&self, addr: String) -> Result<(), ServerError> {
2021-10-05 17:54:11 +08:00
*self.addr.write() = Some(addr.clone());
2021-10-05 19:32:58 +08:00
let strategy = FixedInterval::from_millis(5000).take(3);
2021-10-05 17:54:11 +08:00
self.connect(addr, strategy).await
}
2021-12-16 22:24:05 +08:00
pub async fn stop(&self) { self.sender_ctrl.write().set_state(WSConnectState::Disconnected); }
2021-12-14 20:50:07 +08:00
2021-10-05 17:54:11 +08:00
async fn connect<T, I>(&self, addr: String, strategy: T) -> Result<(), ServerError>
where
T: IntoIterator<IntoIter = I, Item = Duration>,
I: Iterator<Item = Duration> + Send + 'static,
{
2021-09-30 17:24:02 +08:00
let (ret, rx) = oneshot::channel::<Result<(), ServerError>>();
2021-10-05 17:54:11 +08:00
*self.addr.write() = Some(addr.clone());
2021-12-16 22:24:05 +08:00
let action = WSConnectAction {
addr,
handlers: self.handlers.clone(),
};
2021-10-05 17:54:11 +08:00
let retry = Retry::spawn(strategy, action);
2021-12-05 14:04:25 +08:00
let sender_ctrl = self.sender_ctrl.clone();
2021-12-16 22:24:05 +08:00
sender_ctrl.write().set_state(WSConnectState::Connecting);
2021-09-30 17:24:02 +08:00
tokio::spawn(async move {
match retry.await {
Ok(result) => {
2021-12-16 22:24:05 +08:00
let WSConnectResult {
stream,
handlers_fut,
sender,
} = result;
2021-12-05 14:04:25 +08:00
sender_ctrl.write().set_sender(sender);
2021-12-16 22:24:05 +08:00
sender_ctrl.write().set_state(WSConnectState::Connected);
let _ = ret.send(Ok(()));
2021-12-05 14:04:25 +08:00
spawn_stream_and_handlers(stream, handlers_fut, sender_ctrl.clone()).await;
2021-09-19 16:11:02 +08:00
},
2021-09-19 18:39:56 +08:00
Err(e) => {
2021-12-05 14:04:25 +08:00
sender_ctrl.write().set_error(e.clone());
let _ = ret.send(Err(ServerError::internal().context(e)));
2021-09-18 22:32:00 +08:00
},
2021-09-19 16:11:02 +08:00
}
2021-09-30 17:24:02 +08:00
});
rx.await?
2021-09-18 22:32:00 +08:00
}
2021-12-05 14:04:25 +08:00
pub async fn retry(&self, count: usize) -> Result<(), ServerError> {
2022-01-23 22:33:47 +08:00
if !self.sender_ctrl.read().is_disconnected() {
2021-12-05 14:04:25 +08:00
return Ok(());
}
let strategy = FixedInterval::from_millis(5000).take(count);
2021-10-05 17:54:11 +08:00
let addr = self
.addr
.read()
.as_ref()
2022-01-23 22:33:47 +08:00
.expect("Retry web socket connection failed, should call start_connect first")
2021-10-05 17:54:11 +08:00
.clone();
2021-12-05 14:04:25 +08:00
2021-10-05 17:54:11 +08:00
self.connect(addr, strategy).await
}
2021-12-16 22:24:05 +08:00
pub fn subscribe_state(&self) -> broadcast::Receiver<WSConnectState> { self.state_notify.subscribe() }
2022-01-07 17:37:11 +08:00
pub fn ws_message_sender(&self) -> Result<Arc<WSSender>, WSError> {
2021-12-05 14:04:25 +08:00
match self.sender_ctrl.read().sender() {
2022-01-23 22:33:47 +08:00
None => Err(WSError::internal().context("WebSocket is not initialized, should call connect first")),
2021-12-05 14:04:25 +08:00
Some(sender) => Ok(sender),
}
2021-09-16 18:31:25 +08:00
}
}
async fn spawn_stream_and_handlers(
2021-12-16 22:24:05 +08:00
stream: WSStream,
handlers: WSHandlerFuture,
sender_ctrl: Arc<RwLock<WSSenderController>>,
2021-09-30 17:24:02 +08:00
) {
tokio::select! {
result = stream => {
2021-12-05 14:04:25 +08:00
if let Err(e) = result {
sender_ctrl.write().set_error(e);
2021-09-30 17:24:02 +08:00
}
},
2021-11-03 15:37:38 +08:00
result = handlers => tracing::debug!("handlers completed {:?}", result),
2021-09-30 17:24:02 +08:00
};
}
2021-09-16 18:31:25 +08:00
#[pin_project]
2021-12-16 22:24:05 +08:00
pub struct WSHandlerFuture {
2021-09-16 18:31:25 +08:00
#[pin]
msg_rx: MsgReceiver,
2021-09-30 17:24:02 +08:00
handlers: Handlers,
2021-09-16 18:31:25 +08:00
}
2021-12-16 22:24:05 +08:00
impl WSHandlerFuture {
2021-09-30 17:24:02 +08:00
fn new(handlers: Handlers, msg_rx: MsgReceiver) -> Self { Self { msg_rx, handlers } }
2021-09-24 16:02:17 +08:00
fn handler_ws_message(&self, message: Message) {
2021-11-27 19:19:41 +08:00
if let Message::Binary(bytes) = message {
self.handle_binary_message(bytes)
2021-09-24 16:02:17 +08:00
}
}
fn handle_binary_message(&self, bytes: Vec<u8>) {
let bytes = Bytes::from(bytes);
2021-12-26 19:10:37 +08:00
match WebSocketRawMessage::try_from(bytes) {
2022-01-22 18:48:43 +08:00
Ok(message) => match self.handlers.get(&message.channel) {
2021-09-24 16:02:17 +08:00
None => log::error!("Can't find any handler for message: {:?}", message),
Some(handler) => handler.receive_message(message.clone()),
},
Err(e) => {
log::error!("Deserialize binary ws message failed: {:?}", e);
},
}
}
2021-09-16 18:31:25 +08:00
}
2021-12-16 22:24:05 +08:00
impl Future for WSHandlerFuture {
2021-09-16 18:31:25 +08:00
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
match ready!(self.as_mut().project().msg_rx.poll_next(cx)) {
2021-09-18 22:32:00 +08:00
None => {
2021-09-19 16:11:02 +08:00
return Poll::Ready(());
2021-09-16 23:07:15 +08:00
},
2021-09-24 16:02:17 +08:00
Some(message) => self.handler_ws_message(message),
2021-09-16 23:07:15 +08:00
}
}
}
}
2021-09-19 18:39:56 +08:00
#[derive(Debug, Clone)]
2021-12-16 22:24:05 +08:00
pub struct WSSender {
2021-09-16 23:07:15 +08:00
ws_tx: MsgSender,
}
2021-12-16 22:24:05 +08:00
impl WSSender {
2021-12-26 19:10:37 +08:00
pub fn send_msg<T: Into<WebSocketRawMessage>>(&self, msg: T) -> Result<(), WSError> {
2021-09-19 18:39:56 +08:00
let msg = msg.into();
2021-09-27 23:23:23 +08:00
let _ = self
.ws_tx
.unbounded_send(msg.into())
2021-12-16 22:24:05 +08:00
.map_err(|e| WSError::internal().context(e))?;
2021-09-16 23:07:15 +08:00
Ok(())
}
2021-09-19 18:39:56 +08:00
2022-01-22 18:48:43 +08:00
pub fn send_text(&self, source: &WSChannel, text: &str) -> Result<(), WSError> {
2021-12-26 19:10:37 +08:00
let msg = WebSocketRawMessage {
2022-01-22 18:48:43 +08:00
channel: source.clone(),
2021-09-19 18:39:56 +08:00
data: text.as_bytes().to_vec(),
};
self.send_msg(msg)
}
2022-01-22 18:48:43 +08:00
pub fn send_binary(&self, source: &WSChannel, bytes: Vec<u8>) -> Result<(), WSError> {
2021-12-26 19:10:37 +08:00
let msg = WebSocketRawMessage {
2022-01-22 18:48:43 +08:00
channel: source.clone(),
data: bytes,
};
2021-09-19 18:39:56 +08:00
self.send_msg(msg)
}
2021-12-16 22:24:05 +08:00
pub fn send_disconnect(&self, reason: &str) -> Result<(), WSError> {
2021-09-19 18:39:56 +08:00
let frame = CloseFrame {
code: CloseCode::Normal,
reason: reason.to_owned().into(),
};
let msg = Message::Close(Some(frame));
2021-09-27 23:23:23 +08:00
let _ = self
.ws_tx
.unbounded_send(msg)
2021-12-16 22:24:05 +08:00
.map_err(|e| WSError::internal().context(e))?;
2021-09-19 18:39:56 +08:00
Ok(())
}
2021-09-16 23:07:15 +08:00
}
2021-12-16 22:24:05 +08:00
struct WSConnectAction {
addr: String,
handlers: Handlers,
}
2021-12-16 22:24:05 +08:00
impl Action for WSConnectAction {
2021-12-05 14:04:25 +08:00
type Future = Pin<Box<dyn Future<Output = Result<Self::Item, Self::Error>> + Send + Sync>>;
2021-12-16 22:24:05 +08:00
type Item = WSConnectResult;
type Error = WSError;
2021-12-05 14:04:25 +08:00
fn run(&mut self) -> Self::Future {
let addr = self.addr.clone();
let handlers = self.handlers.clone();
2021-12-16 22:24:05 +08:00
Box::pin(WSConnectActionFut::new(addr, handlers))
2021-12-05 14:04:25 +08:00
}
}
2021-12-16 22:24:05 +08:00
struct WSConnectResult {
stream: WSStream,
handlers_fut: WSHandlerFuture,
sender: WSSender,
}
#[pin_project]
2021-12-16 22:24:05 +08:00
struct WSConnectActionFut {
addr: String,
#[pin]
2021-12-16 22:24:05 +08:00
conn: WSConnectionFuture,
handlers_fut: Option<WSHandlerFuture>,
sender: Option<WSSender>,
}
2021-12-16 22:24:05 +08:00
impl WSConnectActionFut {
fn new(addr: String, handlers: Handlers) -> Self {
// Stream User
// ┌───────────────┐ ┌──────────────┐
// ┌──────┐ │ ┌─────────┐ │ ┌────────┐ │ ┌────────┐ │
// │Server│──────┼─▶│ ws_read │──┼───▶│ msg_tx │───┼─▶│ msg_rx │ │
// └──────┘ │ └─────────┘ │ └────────┘ │ └────────┘ │
// ▲ │ │ │ │
// │ │ ┌─────────┐ │ ┌────────┐ │ ┌────────┐ │
// └─────────┼──│ws_write │◀─┼────│ ws_rx │◀──┼──│ ws_tx │ │
// │ └─────────┘ │ └────────┘ │ └────────┘ │
// └───────────────┘ └──────────────┘
let (msg_tx, msg_rx) = futures_channel::mpsc::unbounded();
let (ws_tx, ws_rx) = futures_channel::mpsc::unbounded();
2021-12-16 22:24:05 +08:00
let sender = WSSender { ws_tx };
let handlers_fut = WSHandlerFuture::new(handlers, msg_rx);
let conn = WSConnectionFuture::new(msg_tx, ws_rx, addr.clone());
Self {
addr,
conn,
handlers_fut: Some(handlers_fut),
sender: Some(sender),
}
}
}
2021-12-16 22:24:05 +08:00
impl Future for WSConnectActionFut {
type Output = Result<WSConnectResult, WSError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
match ready!(this.conn.as_mut().poll(cx)) {
Ok(stream) => {
let handlers_fut = this.handlers_fut.take().expect("Only take once");
let sender = this.sender.take().expect("Only take once");
2021-12-16 22:24:05 +08:00
Poll::Ready(Ok(WSConnectResult {
stream,
handlers_fut,
sender,
}))
},
2021-10-05 19:32:58 +08:00
Err(e) => Poll::Ready(Err(e)),
}
}
}
2021-12-05 14:04:25 +08:00
#[derive(Clone, Eq, PartialEq)]
2021-12-16 22:24:05 +08:00
pub enum WSConnectState {
2021-12-05 14:04:25 +08:00
Init,
Connecting,
Connected,
Disconnected,
}
2021-12-16 22:24:05 +08:00
impl std::fmt::Display for WSConnectState {
2021-12-05 14:04:25 +08:00
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
2021-12-16 22:24:05 +08:00
WSConnectState::Init => f.write_str("Init"),
WSConnectState::Connected => f.write_str("Connecting"),
WSConnectState::Connecting => f.write_str("Connected"),
WSConnectState::Disconnected => f.write_str("Disconnected"),
2021-12-05 14:04:25 +08:00
}
}
}
2021-12-16 22:24:05 +08:00
impl std::fmt::Debug for WSConnectState {
2021-12-05 14:04:25 +08:00
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.write_str(&format!("{}", self)) }
}
2021-12-16 22:24:05 +08:00
struct WSSenderController {
state: WSConnectState,
state_notify: Arc<broadcast::Sender<WSConnectState>>,
sender: Option<Arc<WSSender>>,
2021-12-05 14:04:25 +08:00
}
2021-12-16 22:24:05 +08:00
impl WSSenderController {
fn set_sender(&mut self, sender: WSSender) { self.sender = Some(Arc::new(sender)); }
2021-12-05 14:04:25 +08:00
2021-12-16 22:24:05 +08:00
fn set_state(&mut self, state: WSConnectState) {
if state != WSConnectState::Connected {
2021-12-05 14:04:25 +08:00
self.sender = None;
}
2021-12-14 20:50:07 +08:00
self.state = state;
let _ = self.state_notify.send(self.state.clone());
2021-12-05 14:04:25 +08:00
}
2021-12-16 22:24:05 +08:00
fn set_error(&mut self, error: WSError) {
2021-12-05 14:04:25 +08:00
log::error!("{:?}", error);
2021-12-16 22:24:05 +08:00
self.set_state(WSConnectState::Disconnected);
2021-12-05 14:04:25 +08:00
}
2021-12-16 22:24:05 +08:00
fn sender(&self) -> Option<Arc<WSSender>> { self.sender.clone() }
2021-12-05 14:04:25 +08:00
2022-01-23 22:33:47 +08:00
#[allow(dead_code)]
2021-12-16 22:24:05 +08:00
fn is_connecting(&self) -> bool { self.state == WSConnectState::Connecting }
2021-12-05 14:04:25 +08:00
2022-01-23 22:33:47 +08:00
fn is_disconnected(&self) -> bool { self.state == WSConnectState::Disconnected }
2021-12-05 14:04:25 +08:00
}
2021-12-16 22:24:05 +08:00
impl std::default::Default for WSSenderController {
2021-12-05 14:04:25 +08:00
fn default() -> Self {
let (state_notify, _) = broadcast::channel(16);
2021-12-16 22:24:05 +08:00
WSSenderController {
state: WSConnectState::Init,
2021-12-05 14:04:25 +08:00
state_notify: Arc::new(state_notify),
sender: None,
}
}
}