diff --git a/src/lib.rs b/src/lib.rs index ba01b71..1982d95 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,7 +12,7 @@ use tokio::{ pub(crate) type PacketSender = UnboundedSender; pub(crate) type PacketReceiver = UnboundedReceiver; -pub(crate) type SessionCollection = std::sync::Arc>>; +pub(crate) type SessionCollection = AHashMap; mod error; mod packet; @@ -105,7 +105,8 @@ fn run( mut device: Device, accept_sender: UnboundedSender, ) -> JoinHandle> { - let sessions: SessionCollection = std::sync::Arc::new(tokio::sync::Mutex::new(AHashMap::new())); + let mut sessions: SessionCollection = AHashMap::new(); + let (session_remove_tx, mut session_remove_rx) = mpsc::unbounded_channel::(); let pi = config.packet_information; let offset = if pi && cfg!(unix) { 4 } else { 0 }; let mut buffer = vec![0_u8; u16::MAX as usize + offset]; @@ -115,8 +116,7 @@ fn run( loop { select! { Ok(n) = device.read(&mut buffer) => { - let u = up_pkt_sender.clone(); - if let Err(e) = process_device_read(&buffer[offset..n], sessions.clone(), u, &config, &accept_sender).await { + if let Err(e) = process_device_read(&buffer[offset..n], &mut sessions, &session_remove_tx, &up_pkt_sender, &config, &accept_sender).await { let io_err: std::io::Error = e.into(); if io_err.kind() == std::io::ErrorKind::ConnectionRefused { log::trace!("Received junk data: {io_err}"); @@ -125,6 +125,10 @@ fn run( } } } + Some(network_tuple) = session_remove_rx.recv() => { + sessions.remove(&network_tuple); + log::debug!("session destroyed: {network_tuple}"); + } Some(packet) = up_pkt_receiver.recv() => { process_upstream_recv(packet, &mut device, #[cfg(unix)]pi).await?; } @@ -135,8 +139,9 @@ fn run( async fn process_device_read( data: &[u8], - sessions: SessionCollection, - up_pkt_sender: PacketSender, + sessions: &mut SessionCollection, + session_remove_tx: &UnboundedSender, + up_pkt_sender: &PacketSender, config: &IpStackConfig, accept_sender: &UnboundedSender, ) -> Result<()> { @@ -153,15 +158,14 @@ async fn process_device_read( packet.payload.unwrap_or_default(), &packet.ip, config.mtu, - up_pkt_sender, + up_pkt_sender.clone(), )); accept_sender.send(stream)?; return Ok(()); } - let sessions_clone = sessions.clone(); let network_tuple = packet.network_tuple(); - match sessions.lock().await.entry(network_tuple) { + match sessions.entry(network_tuple) { std::collections::hash_map::Entry::Occupied(entry) => { let len = packet.payload.as_ref().map(|p| p.len()).unwrap_or(0); log::trace!("packet sent to stream: {network_tuple} len {len}"); @@ -169,11 +173,13 @@ async fn process_device_read( } std::collections::hash_map::Entry::Vacant(entry) => { let (tx, rx) = tokio::sync::oneshot::channel::<()>(); - let ip_stack_stream = create_stream(packet, config, up_pkt_sender, Some(tx))?; + let ip_stack_stream = create_stream(packet, config, up_pkt_sender.clone(), Some(tx))?; + let session_remove_tx = session_remove_tx.clone(); tokio::spawn(async move { rx.await.ok(); - sessions_clone.lock().await.remove(&network_tuple); - log::debug!("session destroyed: {network_tuple}"); + if let Err(e) = session_remove_tx.send(network_tuple) { + log::error!("Failed to send session removal for {network_tuple}: {e}"); + } }); let packet_sender = ip_stack_stream.stream_sender()?; accept_sender.send(ip_stack_stream)?;