feat: add conn cancellation token

This commit is contained in:
Lucas Colombo 2024-06-02 09:08:41 -03:00
parent 9a747793a4
commit 2d16c6230f
Signed by: lucas
GPG Key ID: EF34786CFEFFAE35
2 changed files with 40 additions and 6 deletions

View File

@ -6,7 +6,9 @@ use {
rustler_core::{ rustler_core::{
bus::{self, SubscriberTrait}, bus::{self, SubscriberTrait},
rustlers::Quote, rustlers::Quote,
socket::{self, event, Error, EventDispatcher, Outgoing, Request, Response}, socket::{
self, event, CancellationToken, Error, EventDispatcher, Outgoing, Request, Response,
},
}, },
std::sync::Arc, std::sync::Arc,
tokio::{join, sync::Mutex}, tokio::{join, sync::Mutex},
@ -23,6 +25,7 @@ impl EventDispatcher for Dispatcher {
data: event::Data, data: event::Data,
outgoing: Arc<Mutex<Outgoing>>, outgoing: Arc<Mutex<Outgoing>>,
conn_id: String, conn_id: String,
_cancel_token: CancellationToken,
) -> Result<()> { ) -> Result<()> {
info!("Event: {}", event); info!("Event: {}", event);
info!("Data: {:?}", data); info!("Data: {:?}", data);

View File

@ -21,6 +21,7 @@ pub use {
}, },
WebSocketStream, WebSocketStream,
}, },
tokio_util::sync::CancellationToken,
}; };
pub type Outgoing = SplitSink<WebSocketStream<TcpStream>, Message>; pub type Outgoing = SplitSink<WebSocketStream<TcpStream>, Message>;
@ -33,6 +34,7 @@ pub trait EventDispatcher: Send {
data: event::Data, data: event::Data,
outgoing: Arc<Mutex<Outgoing>>, outgoing: Arc<Mutex<Outgoing>>,
conn_id: String, conn_id: String,
cancel_token: CancellationToken,
) -> Result<()>; ) -> Result<()>;
} }
@ -105,9 +107,24 @@ where
let stats = stats.clone(); let stats = stats.clone();
let conn_id = uuid::Uuid::new_v4(); let conn_id = uuid::Uuid::new_v4();
// main cancellation token that will be used to cancel the connection
let cancel_tkn = CancellationToken::new();
tokio::spawn(async move { tokio::spawn(async move {
match Server::handle_connection(ws_stream, dispatcher, conn_id).await { match Server::handle_connection(
Ok(_) => info!("Connection {} closed", conn_id), ws_stream,
dispatcher,
conn_id,
// each connection will have a child token (child tokens can't cancel parent
// tokens but are cancelled when the parent token is cancelled)
cancel_tkn.child_token(),
)
.await
{
Ok(_) => {
cancel_tkn.cancel();
info!("Connection {} closed", conn_id);
}
Err(e) => error!("Error handling connection: {:?}", e), Err(e) => error!("Error handling connection: {:?}", e),
}; };
@ -126,13 +143,20 @@ where
stream: WebSocketStream<TcpStream>, stream: WebSocketStream<TcpStream>,
event_dispatcher: ED, event_dispatcher: ED,
conn_id: uuid::Uuid, conn_id: uuid::Uuid,
cancel_sgn: CancellationToken,
) -> Result<()> { ) -> Result<()> {
let (outgoing, mut incoming) = stream.split(); let (outgoing, mut incoming) = stream.split();
let synced_outgoing = Arc::new(Mutex::new(outgoing)); let synced_outgoing = Arc::new(Mutex::new(outgoing));
while let Some(msg) = incoming.next().await { while let Some(msg) = incoming.next().await {
Server::handle_message(msg?, &event_dispatcher, synced_outgoing.clone(), conn_id) Server::handle_message(
.await?; msg?,
&event_dispatcher,
synced_outgoing.clone(),
conn_id,
cancel_sgn.clone(),
)
.await?;
} }
Ok(()) Ok(())
@ -144,12 +168,19 @@ where
event_dispatcher: &ED, event_dispatcher: &ED,
outgoing: Arc<Mutex<Outgoing>>, outgoing: Arc<Mutex<Outgoing>>,
conn_id: uuid::Uuid, conn_id: uuid::Uuid,
cancel_sgn: CancellationToken,
) -> Result<HandlingResult> { ) -> Result<HandlingResult> {
if msg.is_text() || msg.is_binary() { if msg.is_text() || msg.is_binary() {
if let Ok(event) = serde_json::from_str::<event::WsEvent>(&msg.to_string()) { if let Ok(event) = serde_json::from_str::<event::WsEvent>(&msg.to_string()) {
let outgoing = Arc::clone(&outgoing); let outgoing = Arc::clone(&outgoing);
let result = event_dispatcher let result = event_dispatcher
.dispatch(event.event, event.data, outgoing, conn_id.into()) .dispatch(
event.event,
event.data,
outgoing,
conn_id.into(),
cancel_sgn,
)
.await; .await;
match result { match result {