212 lines
6.1 KiB
Rust

use {
super::{event, stats::ServerStats},
async_trait::async_trait,
eyre::Result,
futures::{stream::SplitSink, StreamExt},
lool::logger::{error, info},
std::sync::Arc,
tokio::sync::Mutex,
tokio_tungstenite::{
accept_hdr_async,
tungstenite::{handshake::server::Callback, Message},
},
};
pub use {
tokio::net::{TcpListener, TcpStream},
tokio_tungstenite::{
tungstenite::{
handshake::server::{ErrorResponse, Request, Response},
Error,
},
WebSocketStream,
},
tokio_util::sync::CancellationToken,
};
pub type Outgoing = SplitSink<WebSocketStream<TcpStream>, Message>;
#[async_trait]
pub trait EventDispatcher: Send {
async fn dispatch(
&self,
event: String,
data: event::Data,
outgoing: Arc<Mutex<Outgoing>>,
conn_id: String,
cancel_token: CancellationToken,
) -> Result<()>;
}
/// 🐎 » Socket Server
/// --
///
/// A websocket gateway server that listens for incoming connections and dispatches events to the
/// appropriate event handlers by using the provided `EventDispatcher`.
///
/// ### Example
/// See `examples/socket.rs` for a complete example.
pub struct Server<ED>
where
ED: EventDispatcher + Clone + Send + Sync + 'static,
{
stats: Arc<ServerStats>,
listener: TcpListener,
host: String,
port: String,
event_dispatcher: ED,
}
impl<ED> Server<ED>
where
ED: EventDispatcher + Clone + Send + Sync + 'static,
{
pub async fn new(host: &str, port: &str, event_dispatcher: ED) -> std::io::Result<Self> {
let listener = TcpListener::bind(format!("{}:{}", host, port)).await?;
Ok(Self {
listener,
event_dispatcher,
host: host.to_string(),
port: port.to_string(),
stats: Arc::new(ServerStats::new()),
})
}
/// #### 🐎 » start without setting a callback
///
/// will use a _noop_ callback that will do nothing.
pub async fn start_no_cb(&mut self) {
let noop_cb = |_: &Request, response: Response| Ok(response);
self.start(noop_cb).await;
}
/// #### 🐎 » start the server
///
/// Starts the server with a handshake callback. Usefull for customizing the
/// handshake process, e.g. checking headers, etc.
///
/// **Tip:** if you don't need to customize the handshake process, use
/// `start_no_cb` instead.
pub async fn start<HCb>(&mut self, cb: HCb)
where
HCb: Callback + Unpin + Clone,
{
info!("Started Rustler WS Server on {}:{}", self.host, self.port);
let stats = &self.stats.clone();
while let Ok((stream, peer)) = self.listener.accept().await {
let dispatcher = self.event_dispatcher.clone();
let cb = cb.clone();
info!("Incoming connection from: {}", peer);
// call the handshake callback
let ws_stream = accept_hdr_async(stream, cb).await;
if let Ok(ws_stream) = ws_stream {
stats.inc_current_clients();
let stats = stats.clone();
let conn_id = uuid::Uuid::new_v4();
// main cancellation token that will be used to cancel the connection
let cancel_tkn = CancellationToken::new();
tokio::spawn(async move {
match Server::handle_connection(
ws_stream,
dispatcher,
conn_id,
// each connection will have a child token (child tokens can't cancel parent
// tokens but are cancelled when the parent token is cancelled)
cancel_tkn.child_token(),
)
.await
{
Ok(_) => {
cancel_tkn.cancel();
info!("Connection {} closed", conn_id);
}
Err(e) => error!("Error handling connection: {:?}", e),
};
// decrement client count
stats.clone().dec_current_clients();
info!("{:?}", stats);
});
}
info!("{:?}", stats);
}
}
/// subscribe to incoming messages
async fn handle_connection(
stream: WebSocketStream<TcpStream>,
event_dispatcher: ED,
conn_id: uuid::Uuid,
cancel_sgn: CancellationToken,
) -> Result<()> {
let (outgoing, mut incoming) = stream.split();
let synced_outgoing = Arc::new(Mutex::new(outgoing));
while let Some(msg) = incoming.next().await {
Server::handle_message(
msg?,
&event_dispatcher,
synced_outgoing.clone(),
conn_id,
cancel_sgn.clone(),
)
.await?;
}
Ok(())
}
/// handle an incoming message
async fn handle_message(
msg: Message,
event_dispatcher: &ED,
outgoing: Arc<Mutex<Outgoing>>,
conn_id: uuid::Uuid,
cancel_sgn: CancellationToken,
) -> Result<HandlingResult> {
if msg.is_text() || msg.is_binary() {
if let Ok(event) = serde_json::from_str::<event::WsEvent>(&msg.to_string()) {
let outgoing = Arc::clone(&outgoing);
let result = event_dispatcher
.dispatch(
event.event,
event.data,
outgoing,
conn_id.into(),
cancel_sgn,
)
.await;
match result {
Ok(_) => {}
Err(e) => {
error!("Error dispatching event: {:?}", e);
}
};
}
}
if msg.is_close() {
return Ok(HandlingResult::Closed);
}
// TODO: should we handle ping/pong messages?
Ok(HandlingResult::Handled)
}
}
#[derive(PartialEq)]
enum HandlingResult {
Handled,
Closed,
}