Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use tokio::{

pub(crate) type PacketSender = UnboundedSender<NetworkPacket>;
pub(crate) type PacketReceiver = UnboundedReceiver<NetworkPacket>;
pub(crate) type SessionCollection = std::sync::Arc<tokio::sync::Mutex<AHashMap<NetworkTuple, PacketSender>>>;
pub(crate) type SessionCollection = AHashMap<NetworkTuple, PacketSender>;

mod error;
mod packet;
Expand Down Expand Up @@ -105,7 +105,8 @@ fn run<Device: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
mut device: Device,
accept_sender: UnboundedSender<IpStackStream>,
) -> JoinHandle<Result<()>> {
let sessions: SessionCollection = std::sync::Arc::new(tokio::sync::Mutex::new(AHashMap::new()));
let mut sessions: SessionCollection = AHashMap::new();
let (session_remove_tx, mut session_remove_rx) = mpsc::unbounded_channel::<NetworkTuple>();
let pi = config.packet_information;
let offset = if pi && cfg!(unix) { 4 } else { 0 };
let mut buffer = vec![0_u8; u16::MAX as usize + offset];
Expand All @@ -115,8 +116,7 @@ fn run<Device: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
loop {
select! {
Ok(n) = device.read(&mut buffer) => {
let u = up_pkt_sender.clone();
if let Err(e) = process_device_read(&buffer[offset..n], sessions.clone(), u, &config, &accept_sender).await {
if let Err(e) = process_device_read(&buffer[offset..n], &mut sessions, &session_remove_tx, &up_pkt_sender, &config, &accept_sender).await {
let io_err: std::io::Error = e.into();
if io_err.kind() == std::io::ErrorKind::ConnectionRefused {
log::trace!("Received junk data: {io_err}");
Expand All @@ -125,6 +125,10 @@ fn run<Device: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
}
}
}
Some(network_tuple) = session_remove_rx.recv() => {
sessions.remove(&network_tuple);
log::debug!("session destroyed: {network_tuple}");
}
Some(packet) = up_pkt_receiver.recv() => {
process_upstream_recv(packet, &mut device, #[cfg(unix)]pi).await?;
}
Expand All @@ -135,8 +139,9 @@ fn run<Device: AsyncRead + AsyncWrite + Unpin + Send + 'static>(

async fn process_device_read(
data: &[u8],
sessions: SessionCollection,
up_pkt_sender: PacketSender,
sessions: &mut SessionCollection,
session_remove_tx: &UnboundedSender<NetworkTuple>,
up_pkt_sender: &PacketSender,
config: &IpStackConfig,
accept_sender: &UnboundedSender<IpStackStream>,
) -> Result<()> {
Expand All @@ -153,27 +158,28 @@ async fn process_device_read(
packet.payload.unwrap_or_default(),
&packet.ip,
config.mtu,
up_pkt_sender,
up_pkt_sender.clone(),
));
accept_sender.send(stream)?;
return Ok(());
}

let sessions_clone = sessions.clone();
let network_tuple = packet.network_tuple();
match sessions.lock().await.entry(network_tuple) {
match sessions.entry(network_tuple) {
std::collections::hash_map::Entry::Occupied(entry) => {
let len = packet.payload.as_ref().map(|p| p.len()).unwrap_or(0);
log::trace!("packet sent to stream: {network_tuple} len {len}");
entry.get().send(packet).map_err(std::io::Error::other)?;
}
std::collections::hash_map::Entry::Vacant(entry) => {
let (tx, rx) = tokio::sync::oneshot::channel::<()>();
let ip_stack_stream = create_stream(packet, config, up_pkt_sender, Some(tx))?;
let ip_stack_stream = create_stream(packet, config, up_pkt_sender.clone(), Some(tx))?;
let session_remove_tx = session_remove_tx.clone();
tokio::spawn(async move {
rx.await.ok();
sessions_clone.lock().await.remove(&network_tuple);
log::debug!("session destroyed: {network_tuple}");
if let Err(e) = session_remove_tx.send(network_tuple) {
log::error!("Failed to send session removal for {network_tuple}: {e}");
}
});
let packet_sender = ip_stack_stream.stream_sender()?;
accept_sender.send(ip_stack_stream)?;
Expand Down