feat: add conn cancellation token

This commit is contained in:
Lucas Colombo 2024-06-02 09:08:41 -03:00
parent 9a747793a4
commit 2d16c6230f
Signed by: lucas
GPG Key ID: EF34786CFEFFAE35
2 changed files with 40 additions and 6 deletions

View File

@ -6,7 +6,9 @@ use {
rustler_core::{
bus::{self, SubscriberTrait},
rustlers::Quote,
socket::{self, event, Error, EventDispatcher, Outgoing, Request, Response},
socket::{
self, event, CancellationToken, Error, EventDispatcher, Outgoing, Request, Response,
},
},
std::sync::Arc,
tokio::{join, sync::Mutex},
@ -23,6 +25,7 @@ impl EventDispatcher for Dispatcher {
data: event::Data,
outgoing: Arc<Mutex<Outgoing>>,
conn_id: String,
_cancel_token: CancellationToken,
) -> Result<()> {
info!("Event: {}", event);
info!("Data: {:?}", data);

View File

@ -21,6 +21,7 @@ pub use {
},
WebSocketStream,
},
tokio_util::sync::CancellationToken,
};
pub type Outgoing = SplitSink<WebSocketStream<TcpStream>, Message>;
@ -33,6 +34,7 @@ pub trait EventDispatcher: Send {
data: event::Data,
outgoing: Arc<Mutex<Outgoing>>,
conn_id: String,
cancel_token: CancellationToken,
) -> Result<()>;
}
@ -105,9 +107,24 @@ where
let stats = stats.clone();
let conn_id = uuid::Uuid::new_v4();
// main cancellation token that will be used to cancel the connection
let cancel_tkn = CancellationToken::new();
tokio::spawn(async move {
match Server::handle_connection(ws_stream, dispatcher, conn_id).await {
Ok(_) => info!("Connection {} closed", conn_id),
match Server::handle_connection(
ws_stream,
dispatcher,
conn_id,
// each connection will have a child token (child tokens can't cancel parent
// tokens but are cancelled when the parent token is cancelled)
cancel_tkn.child_token(),
)
.await
{
Ok(_) => {
cancel_tkn.cancel();
info!("Connection {} closed", conn_id);
}
Err(e) => error!("Error handling connection: {:?}", e),
};
@ -126,12 +143,19 @@ where
stream: WebSocketStream<TcpStream>,
event_dispatcher: ED,
conn_id: uuid::Uuid,
cancel_sgn: CancellationToken,
) -> Result<()> {
let (outgoing, mut incoming) = stream.split();
let synced_outgoing = Arc::new(Mutex::new(outgoing));
while let Some(msg) = incoming.next().await {
Server::handle_message(msg?, &event_dispatcher, synced_outgoing.clone(), conn_id)
Server::handle_message(
msg?,
&event_dispatcher,
synced_outgoing.clone(),
conn_id,
cancel_sgn.clone(),
)
.await?;
}
@ -144,12 +168,19 @@ where
event_dispatcher: &ED,
outgoing: Arc<Mutex<Outgoing>>,
conn_id: uuid::Uuid,
cancel_sgn: CancellationToken,
) -> Result<HandlingResult> {
if msg.is_text() || msg.is_binary() {
if let Ok(event) = serde_json::from_str::<event::WsEvent>(&msg.to_string()) {
let outgoing = Arc::clone(&outgoing);
let result = event_dispatcher
.dispatch(event.event, event.data, outgoing, conn_id.into())
.dispatch(
event.event,
event.data,
outgoing,
conn_id.into(),
cancel_sgn,
)
.await;
match result {