From 1912fd87bd15249cca770e27470cd26962805353 Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Tue, 18 Feb 2025 23:49:16 +0800 Subject: [PATCH 01/35] refine code --- Cargo.toml | 2 +- examples/tun2.rs | 8 ++------ src/lib.rs | 19 +++++++------------ src/packet.rs | 8 ++++++++ src/stream/unknown.rs | 27 +++++++++++++-------------- 5 files changed, 31 insertions(+), 33 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index edbd11c..64d9971 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,8 +34,8 @@ udp-stream = { version = "0.0", default-features = false } # Benchmarks criterion = { version = "0.5" } -[target.'cfg(any(target_os = "linux", target_os = "macos"))'.dev-dependencies] tun = { version = "0.7.13", features = ["async"], default-features = false } + [target.'cfg(target_os = "windows")'.dev-dependencies] wintun = { version = "0.5", default-features = false } diff --git a/examples/tun2.rs b/examples/tun2.rs index 3c17db3..535e02b 100644 --- a/examples/tun2.rs +++ b/examples/tun2.rs @@ -28,7 +28,7 @@ //! use clap::Parser; -use etherparse::{IcmpEchoHeader, Icmpv4Header}; +use etherparse::Icmpv4Header; use ipstack::{stream::IpStackStream, IpNumber}; use std::net::{Ipv4Addr, SocketAddr}; use tokio::net::TcpStream; @@ -154,12 +154,8 @@ async fn main() -> Result<(), Box> { let n = number; if u.src_addr().is_ipv4() && u.ip_protocol() == IpNumber::ICMP { let (icmp_header, req_payload) = Icmpv4Header::from_slice(u.payload())?; - if let etherparse::Icmpv4Type::EchoRequest(req) = icmp_header.icmp_type { + if let etherparse::Icmpv4Type::EchoRequest(echo) = icmp_header.icmp_type { log::info!("#{n} ICMPv4 echo"); - let echo = IcmpEchoHeader { - id: req.id, - seq: req.seq, - }; let mut resp = Icmpv4Header::new(etherparse::Icmpv4Type::EchoReply(echo)); resp.update_checksum(req_payload); let mut payload = resp.to_bytes().to_vec(); diff --git a/src/lib.rs b/src/lib.rs index 93ab830..b53b4a3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -93,16 +93,14 @@ pub struct IpStack { } impl IpStack { - pub fn new(config: IpStackConfig, device: D) -> IpStack + pub fn new(config: IpStackConfig, device: Device) -> IpStack where - D: AsyncRead + AsyncWrite + Unpin + Send + 'static, + Device: AsyncRead + AsyncWrite + Unpin + Send + 'static, { let (accept_sender, accept_receiver) = mpsc::unbounded_channel::(); - let handle = run(config, device, accept_sender); - IpStack { accept_receiver, - handle, + handle: run(config, device, accept_sender), } } @@ -114,14 +112,11 @@ impl IpStack { } } -fn run( +fn run( config: IpStackConfig, - mut device: D, + mut device: Device, accept_sender: UnboundedSender, -) -> JoinHandle> -where - D: AsyncRead + AsyncWrite + Unpin + Send + 'static, -{ +) -> JoinHandle> { let mut sessions: SessionCollection = AHashMap::new(); let pi = config.packet_information; let offset = if pi && cfg!(unix) { 4 } else { 0 }; @@ -182,7 +177,7 @@ fn process_device_read( match sessions.entry(packet.network_tuple()) { Occupied(mut entry) => { if let Err(e) = entry.get().send(packet) { - trace!("New stream because: {}", e); + trace!("New stream \"{}\" because: \"{}\"", e.0.network_tuple(), e); create_stream(e.0, config, pkt_sender).map(|s| { entry.insert(s.0); s.1 diff --git a/src/packet.rs b/src/packet.rs index 540d4ea..2906942 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -8,6 +8,14 @@ pub struct NetworkTuple { pub dst: SocketAddr, pub tcp: bool, } + +impl std::fmt::Display for NetworkTuple { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let tcp = if self.tcp { "TCP" } else { "UDP" }; + write!(f, "{}: {} <> {}", tcp, self.src, self.dst) + } +} + pub mod tcp_flags { pub const CWR: u8 = 0b10000000; pub const ECE: u8 = 0b01000000; diff --git a/src/stream/unknown.rs b/src/stream/unknown.rs index 838d93f..b9889aa 100644 --- a/src/stream/unknown.rs +++ b/src/stream/unknown.rs @@ -1,9 +1,9 @@ use crate::{ packet::{IpHeader, NetworkPacket, TransportHeader}, - PacketSender, TTL, + IpStackError, PacketSender, TTL, }; use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel, Ipv6Header}; -use std::{io::Error, mem, net::IpAddr}; +use std::net::IpAddr; pub struct IpStackUnknownTransport { src_addr: IpAddr, @@ -48,32 +48,31 @@ impl IpStackUnknownTransport { pub fn ip_protocol(&self) -> IpNumber { self.protocol } - pub fn send(&self, mut payload: Vec) -> Result<(), Error> { + pub fn send(&self, mut payload: Vec) -> std::io::Result<()> { loop { let packet = self.create_rev_packet(&mut payload)?; - self.packet_sender - .send(packet) - .map_err(|_| Error::new(std::io::ErrorKind::Other, "send error"))?; + self.packet_sender.send(packet).map_err(|e| { + std::io::Error::new(std::io::ErrorKind::Other, format!("send error: {}", e)) + })?; if payload.is_empty() { return Ok(()); } } } - pub fn create_rev_packet(&self, payload: &mut Vec) -> Result { + pub fn create_rev_packet(&self, payload: &mut Vec) -> std::io::Result { match (self.dst_addr, self.src_addr) { (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => { let mut ip_h = Ipv4Header::new(0, TTL, self.protocol, dst.octets(), src.octets()) - .map_err(crate::IpStackError::from)?; + .map_err(IpStackError::from)?; let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16); let p = if payload.len() > line_buffer as usize { payload.drain(0..line_buffer as usize).collect::>() } else { - mem::take(payload) + std::mem::take(payload) }; - ip_h.set_payload_len(p.len()) - .map_err(crate::IpStackError::from)?; + ip_h.set_payload_len(p.len()).map_err(IpStackError::from)?; Ok(NetworkPacket { ip: IpHeader::Ipv4(ip_h), transport: TransportHeader::Unknown, @@ -91,13 +90,13 @@ impl IpStackUnknownTransport { destination: src.octets(), }; let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16); - payload.truncate(line_buffer as usize); - ip_h.payload_length = payload.len() as u16; let p = if payload.len() > line_buffer as usize { payload.drain(0..line_buffer as usize).collect::>() } else { - mem::take(payload) + std::mem::take(payload) }; + ip_h.set_payload_length(p.len()) + .map_err(IpStackError::from)?; Ok(NetworkPacket { ip: IpHeader::Ipv6(ip_h), transport: TransportHeader::Unknown, From 5c2aa856695d870939ce28eecf77d11b58cb063e Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Thu, 20 Feb 2025 15:06:09 +0800 Subject: [PATCH 02/35] improve InvalidTcpPacket detail --- src/error.rs | 4 ++-- src/lib.rs | 13 ++++++++----- src/packet.rs | 41 ++++++++++++++++++++++++++++++++++++++++- src/stream/tcp.rs | 2 +- 4 files changed, 51 insertions(+), 9 deletions(-) diff --git a/src/error.rs b/src/error.rs index 360badd..319a439 100644 --- a/src/error.rs +++ b/src/error.rs @@ -13,8 +13,8 @@ pub enum IpStackError { #[error("ValueTooBigError {0}")] ValueTooBigErrorUsize(#[from] etherparse::err::ValueTooBigError), - #[error("Invalid Tcp packet")] - InvalidTcpPacket, + #[error("Invalid Tcp packet {0}")] + InvalidTcpPacket(crate::packet::TcpHeaderWrapper), #[error("IO error: {0}")] IoError(#[from] std::io::Error), diff --git a/src/lib.rs b/src/lib.rs index b53b4a3..ce3361a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,7 +27,8 @@ mod packet; pub mod stream; pub use self::error::{IpStackError, Result}; -pub use etherparse::IpNumber; +pub use self::packet::TcpHeaderWrapper; +pub use ::etherparse::IpNumber; const DROP_TTL: u8 = 0; @@ -177,7 +178,7 @@ fn process_device_read( match sessions.entry(packet.network_tuple()) { Occupied(mut entry) => { if let Err(e) = entry.get().send(packet) { - trace!("New stream \"{}\" because: \"{}\"", e.0.network_tuple(), e); + log::debug!("New stream \"{}\" because: \"{}\"", e.0.network_tuple(), e); create_stream(e.0, config, pkt_sender).map(|s| { entry.insert(s.0); s.1 @@ -210,8 +211,8 @@ fn create_stream( ) { Ok(stream) => Some((stream.stream_sender(), IpStackStream::Tcp(stream))), Err(e) => { - if matches!(e, IpStackError::InvalidTcpPacket) { - trace!("Invalid TCP packet"); + if matches!(e, IpStackError::InvalidTcpPacket(_)) { + log::debug!("{e}"); } else { error!("IpStackTcpStream::new failed \"{}\"", e); } @@ -246,7 +247,9 @@ where D: AsyncWrite + Unpin + 'static, { if packet.ttl() == 0 { - sessions.remove(&packet.reverse_network_tuple()); + let network_tuple = packet.reverse_network_tuple(); + sessions.remove(&network_tuple); + log::trace!("session removed: {}", network_tuple); return Ok(()); } #[allow(unused_mut)] diff --git a/src/packet.rs b/src/packet.rs index 2906942..312dd4e 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -154,10 +154,49 @@ impl NetworkPacket { } #[derive(Debug, Clone)] -pub(super) struct TcpHeaderWrapper { +pub struct TcpHeaderWrapper { header: TcpHeader, } +impl std::fmt::Display for TcpHeaderWrapper { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut flags = String::new(); + if self.header.cwr { + flags.push_str("CWR "); + } + if self.header.ece { + flags.push_str("ECE "); + } + if self.header.urg { + flags.push_str("URG "); + } + if self.header.ack { + flags.push_str("ACK "); + } + if self.header.psh { + flags.push_str("PSH "); + } + if self.header.rst { + flags.push_str("RST "); + } + if self.header.syn { + flags.push_str("SYN "); + } + if self.header.fin { + flags.push_str("FIN "); + } + write!( + f, + "TcpHeader {{ src_port: {}, dst_port: {}, seq: {}, ack: {}, flags: {} }}", + self.header.source_port, + self.header.destination_port, + self.header.sequence_number, + self.header.acknowledgment_number, + flags.trim() + ) + } +} + impl TcpHeaderWrapper { pub fn inner(&self) -> &TcpHeader { &self.header diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index 77e2547..cd0a0de 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -83,7 +83,7 @@ impl IpStackTcpStream { warn!("Error sending RST/ACK packet: {:?}", err); } } - Err(IpStackError::InvalidTcpPacket) + Err(IpStackError::InvalidTcpPacket(tcp.clone())) } fn calculate_payload_len(&self, ip_header_size: u16, tcp_header_size: u16) -> u16 { From 5be5321411500d843398c43c7b61a6ad53d78e3a Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Thu, 20 Feb 2025 23:11:05 +0800 Subject: [PATCH 03/35] refine code --- README.md | 4 ++-- src/lib.rs | 16 +++++++++------- src/packet.rs | 2 +- src/stream/tcp_wrapper.rs | 3 +++ 4 files changed, 15 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 9517ed6..5ed93f8 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ An asynchronous lightweight userspace implementation of TCP/IP stack for Tun dev Unstable, under development. [![Crates.io](https://img.shields.io/crates/v/ipstack.svg)](https://crates.io/crates/ipstack) -![ipstack](https://docs.rs/ipstack/badge.svg) +[![ipstack](https://docs.rs/ipstack/badge.svg)](https://docs.rs/ipstack) [![Documentation](https://img.shields.io/badge/docs-release-brightgreen.svg?style=flat)](https://docs.rs/ipstack) [![Download](https://img.shields.io/crates/d/ipstack.svg)](https://crates.io/crates/ipstack) [![License](https://img.shields.io/crates/l/ipstack.svg?style=flat)](https://github.com/narrowlink/ipstack/blob/main/LICENSE) @@ -86,4 +86,4 @@ async fn main() { } ``` -We also suggest that you take a look at the complete [examples](examples). +We also suggest that you take a look at the complete [examples](./examples). diff --git a/src/lib.rs b/src/lib.rs index ce3361a..477ddfe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -179,18 +179,20 @@ fn process_device_read( Occupied(mut entry) => { if let Err(e) = entry.get().send(packet) { log::debug!("New stream \"{}\" because: \"{}\"", e.0.network_tuple(), e); - create_stream(e.0, config, pkt_sender).map(|s| { - entry.insert(s.0); - s.1 + create_stream(e.0, config, pkt_sender).map(|(packet_sender, ip_stack_stream)| { + entry.insert(packet_sender); + ip_stack_stream }) } else { None } } - Vacant(entry) => create_stream(packet, config, pkt_sender).map(|s| { - entry.insert(s.0); - s.1 - }), + Vacant(entry) => { + create_stream(packet, config, pkt_sender).map(|(packet_sender, ip_stack_stream)| { + entry.insert(packet_sender); + ip_stack_stream + }) + } } } diff --git a/src/packet.rs b/src/packet.rs index 312dd4e..02d77a7 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -12,7 +12,7 @@ pub struct NetworkTuple { impl std::fmt::Display for NetworkTuple { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let tcp = if self.tcp { "TCP" } else { "UDP" }; - write!(f, "{}: {} <> {}", tcp, self.src, self.dst) + write!(f, "{} {} -> {}", tcp, self.src, self.dst) } } diff --git a/src/stream/tcp_wrapper.rs b/src/stream/tcp_wrapper.rs index e6653b9..7e8817b 100644 --- a/src/stream/tcp_wrapper.rs +++ b/src/stream/tcp_wrapper.rs @@ -105,10 +105,13 @@ impl tokio::io::AsyncWrite for IpStackTcpStream { impl Drop for IpStackTcpStream { fn drop(&mut self) { if let Some(mut inner) = self.inner.take() { + let local_addr = self.local_addr(); + let peer_addr = self.peer_addr(); tokio::spawn(async move { if let Err(err) = timeout(Duration::from_secs(2), inner.shutdown()).await { log::warn!("Error while dropping IpStackTcpStream: {:?}", err); } + log::trace!("TCP Stream closed: {} -> {}", local_addr, peer_addr); }); } } From ba10b6f8b6d1bc4a2f35d8139d35be39ce6ef6fa Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Fri, 21 Feb 2025 20:53:50 +0800 Subject: [PATCH 04/35] rustfmt max_width = 140 --- examples/tun_wintun.rs | 23 ++------ rustfmt.toml | 1 + src/lib.rs | 48 +++++----------- src/packet.rs | 28 ++-------- src/stream/mod.rs | 8 +-- src/stream/tcb.rs | 15 ++--- src/stream/tcp.rs | 115 ++++++++++---------------------------- src/stream/tcp_wrapper.rs | 47 +++++----------- src/stream/udp.rs | 42 +++----------- src/stream/unknown.rs | 21 ++----- 10 files changed, 89 insertions(+), 259 deletions(-) create mode 100644 rustfmt.toml diff --git a/examples/tun_wintun.rs b/examples/tun_wintun.rs index e55c8d2..ceaa96f 100644 --- a/examples/tun_wintun.rs +++ b/examples/tun_wintun.rs @@ -1,7 +1,7 @@ use std::net::{Ipv4Addr, SocketAddr}; use clap::Parser; -use etherparse::{IcmpEchoHeader, Icmpv4Header}; +use etherparse::Icmpv4Header; use ipstack::{stream::IpStackStream, IpNumber}; use tokio::net::TcpStream; use udp_stream::UdpStream; @@ -46,10 +46,7 @@ async fn main() -> Result<(), Box> { let mut ip_stack = ipstack::IpStack::new(ipstack_config, tun::create_as_async(&config)?); #[cfg(target_os = "windows")] - let mut ip_stack = ipstack::IpStack::new( - ipstack_config, - wintun::WinTunDevice::new(ipv4, Ipv4Addr::new(255, 255, 255, 0)), - ); + let mut ip_stack = ipstack::IpStack::new(ipstack_config, wintun::WinTunDevice::new(ipv4, Ipv4Addr::new(255, 255, 255, 0))); let server_addr = args.server_addr; @@ -86,12 +83,8 @@ async fn main() -> Result<(), Box> { IpStackStream::UnknownTransport(u) => { if u.src_addr().is_ipv4() && u.ip_protocol() == IpNumber::ICMP { let (icmp_header, req_payload) = Icmpv4Header::from_slice(u.payload())?; - if let etherparse::Icmpv4Type::EchoRequest(req) = icmp_header.icmp_type { + if let etherparse::Icmpv4Type::EchoRequest(echo) = icmp_header.icmp_type { println!("ICMPv4 echo"); - let echo = IcmpEchoHeader { - id: req.id, - seq: req.seq, - }; let mut resp = Icmpv4Header::new(etherparse::Icmpv4Type::EchoReply(echo)); resp.update_checksum(req_payload); let mut payload = resp.to_bytes().to_vec(); @@ -178,17 +171,11 @@ mod wintun { std::task::Poll::Ready(Ok(buf.len())) } - fn poll_flush( - self: std::pin::Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + fn poll_flush(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) } - fn poll_shutdown( - self: std::pin::Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + fn poll_shutdown(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) } } diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..8449be0 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1 @@ +max_width = 140 diff --git a/src/lib.rs b/src/lib.rs index 477ddfe..2108e68 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -106,10 +106,7 @@ impl IpStack { } pub async fn accept(&mut self) -> Result { - self.accept_receiver - .recv() - .await - .ok_or(IpStackError::AcceptError) + self.accept_receiver.recv().await.ok_or(IpStackError::AcceptError) } } @@ -163,16 +160,14 @@ fn process_device_read( }; if let IpStackPacketProtocol::Unknown = packet.transport_protocol() { - return Some(IpStackStream::UnknownTransport( - IpStackUnknownTransport::new( - packet.src_addr().ip(), - packet.dst_addr().ip(), - packet.payload, - &packet.ip, - config.mtu, - pkt_sender, - ), - )); + return Some(IpStackStream::UnknownTransport(IpStackUnknownTransport::new( + packet.src_addr().ip(), + packet.dst_addr().ip(), + packet.payload, + &packet.ip, + config.mtu, + pkt_sender, + ))); } match sessions.entry(packet.network_tuple()) { @@ -187,30 +182,17 @@ fn process_device_read( None } } - Vacant(entry) => { - create_stream(packet, config, pkt_sender).map(|(packet_sender, ip_stack_stream)| { - entry.insert(packet_sender); - ip_stack_stream - }) - } + Vacant(entry) => create_stream(packet, config, pkt_sender).map(|(packet_sender, ip_stack_stream)| { + entry.insert(packet_sender); + ip_stack_stream + }), } } -fn create_stream( - packet: NetworkPacket, - config: &IpStackConfig, - pkt_sender: PacketSender, -) -> Option<(PacketSender, IpStackStream)> { +fn create_stream(packet: NetworkPacket, config: &IpStackConfig, pkt_sender: PacketSender) -> Option<(PacketSender, IpStackStream)> { match packet.transport_protocol() { IpStackPacketProtocol::Tcp(h) => { - match IpStackTcpStream::new( - packet.src_addr(), - packet.dst_addr(), - h, - pkt_sender, - config.mtu, - config.tcp_timeout, - ) { + match IpStackTcpStream::new(packet.src_addr(), packet.dst_addr(), h, pkt_sender, config.mtu, config.tcp_timeout) { Ok(stream) => Some((stream.stream_sender(), IpStackStream::Tcp(stream))), Err(e) => { if matches!(e, IpStackError::InvalidTcpPacket(_)) { diff --git a/src/packet.rs b/src/packet.rs index 02d77a7..6e6fa60 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -61,32 +61,18 @@ impl NetworkPacket { let ip = p.net.ok_or(IpStackError::InvalidPacket)?; let (ip, ip_payload) = match ip { - NetSlice::Ipv4(ip) => ( - IpHeader::Ipv4(ip.header().to_header()), - ip.payload().payload, - ), - NetSlice::Ipv6(ip) => ( - IpHeader::Ipv6(ip.header().to_header()), - ip.payload().payload, - ), + NetSlice::Ipv4(ip) => (IpHeader::Ipv4(ip.header().to_header()), ip.payload().payload), + NetSlice::Ipv6(ip) => (IpHeader::Ipv6(ip.header().to_header()), ip.payload().payload), NetSlice::Arp(_) => return Err(IpStackError::UnsupportedTransportProtocol), }; let (transport, payload) = match p.transport { - Some(etherparse::TransportSlice::Tcp(h)) => { - (TransportHeader::Tcp(h.to_header()), h.payload()) - } - Some(etherparse::TransportSlice::Udp(u)) => { - (TransportHeader::Udp(u.to_header()), u.payload()) - } + Some(etherparse::TransportSlice::Tcp(h)) => (TransportHeader::Tcp(h.to_header()), h.payload()), + Some(etherparse::TransportSlice::Udp(u)) => (TransportHeader::Udp(u.to_header()), u.payload()), _ => (TransportHeader::Unknown, ip_payload), }; let payload = payload.to_vec(); - Ok(NetworkPacket { - ip, - transport, - payload, - }) + Ok(NetworkPacket { ip, transport, payload }) } pub(crate) fn transport_protocol(&self) -> IpStackPacketProtocol { match self.transport { @@ -235,9 +221,7 @@ impl TcpHeaderWrapper { impl From<&TcpHeader> for TcpHeaderWrapper { fn from(header: &TcpHeader) -> Self { - TcpHeaderWrapper { - header: header.clone(), - } + TcpHeaderWrapper { header: header.clone() } } } diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 42632f4..944e98a 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -22,9 +22,7 @@ impl IpStackStream { match self { IpStackStream::Tcp(tcp) => tcp.local_addr(), IpStackStream::Udp(udp) => udp.local_addr(), - IpStackStream::UnknownNetwork(_) => { - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)) - } + IpStackStream::UnknownNetwork(_) => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)), IpStackStream::UnknownTransport(unknown) => match unknown.src_addr() { IpAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr, 0)), IpAddr::V6(addr) => SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)), @@ -35,9 +33,7 @@ impl IpStackStream { match self { IpStackStream::Tcp(tcp) => tcp.peer_addr(), IpStackStream::Udp(udp) => udp.peer_addr(), - IpStackStream::UnknownNetwork(_) => { - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)) - } + IpStackStream::UnknownNetwork(_) => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)), IpStackStream::UnknownTransport(unknown) => match unknown.dst_addr() { IpAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr, 0)), IpAddr::V6(addr) => SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)), diff --git a/src/stream/tcb.rs b/src/stream/tcb.rs index ec0a671..9aad3fe 100644 --- a/src/stream/tcb.rs +++ b/src/stream/tcb.rs @@ -71,15 +71,10 @@ impl Tcb { if seq < self.ack { return; } - self.unordered_packets - .insert(seq, UnorderedPacket::new(buf)); + self.unordered_packets.insert(seq, UnorderedPacket::new(buf)); } pub(super) fn get_available_read_buffer_size(&self) -> usize { - READ_BUFFER_SIZE.saturating_sub( - self.unordered_packets - .iter() - .fold(0, |acc, (_, p)| acc + p.payload.len()), - ) + READ_BUFFER_SIZE.saturating_sub(self.unordered_packets.iter().fold(0, |acc, (_, p)| acc + p.payload.len())) } pub(super) fn get_unordered_packets(&mut self) -> Option> { // dbg!(self.ack); @@ -110,8 +105,7 @@ impl Tcb { self.state } pub(super) fn change_send_window(&mut self, window: u16) { - let avg_send_window = ((self.avg_send_window.0 * self.avg_send_window.1) + window as u64) - / (self.avg_send_window.1 + 1); + let avg_send_window = ((self.avg_send_window.0 * self.avg_send_window.1) + window as u64) / (self.avg_send_window.1 + 1); self.avg_send_window.0 = avg_send_window; self.avg_send_window.1 += 1; self.send_window = window; @@ -147,8 +141,7 @@ impl Tcb { let current_ack_distance = self.seq.wrapping_sub(self.last_ack); if received_ack_distance > current_ack_distance - || (tcp_header.acknowledgment_number != self.seq - && self.seq.saturating_sub(tcp_header.acknowledgment_number) == 0) + || (tcp_header.acknowledgment_number != self.seq && self.seq.saturating_sub(tcp_header.acknowledgment_number) == 0) { PacketStatus::Invalid } else if self.last_ack == tcp_header.acknowledgment_number { diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index cd0a0de..bc5b7e8 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -93,13 +93,7 @@ impl IpStackTcpStream { ) } - fn create_rev_packet( - &self, - flags: u8, - ttl: u8, - seq: impl Into>, - mut payload: Vec, - ) -> Result { + fn create_rev_packet(&self, flags: u8, ttl: u8, seq: impl Into>, mut payload: Vec) -> Result { let mut tcp_header = etherparse::TcpHeader::new( self.dst_addr.port(), self.src_addr.port(), @@ -116,12 +110,8 @@ impl IpStackTcpStream { let ip_header = match (self.dst_addr.ip(), self.src_addr.ip()) { (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => { - let mut ip_h = Ipv4Header::new(0, ttl, IpNumber::TCP, dst.octets(), src.octets()) - .map_err(IpStackError::from)?; - let payload_len = self.calculate_payload_len( - ip_h.header_len() as u16, - tcp_header.header_len() as u16, - ); + let mut ip_h = Ipv4Header::new(0, ttl, IpNumber::TCP, dst.octets(), src.octets()).map_err(IpStackError::from)?; + let payload_len = self.calculate_payload_len(ip_h.header_len() as u16, tcp_header.header_len() as u16); payload.truncate(payload_len as usize); ip_h.set_payload_len(payload.len() + tcp_header.header_len()) .map_err(IpStackError::from)?; @@ -138,10 +128,7 @@ impl IpStackTcpStream { source: dst.octets(), destination: src.octets(), }; - let payload_len = self.calculate_payload_len( - ip_h.header_len() as u16, - tcp_header.header_len() as u16, - ); + let payload_len = self.calculate_payload_len(ip_h.header_len() as u16, tcp_header.header_len() as u16); payload.truncate(payload_len as usize); let len = payload.len() + tcp_header.header_len(); ip_h.set_payload_length(len).map_err(IpStackError::from)?; @@ -172,11 +159,7 @@ impl IpStackTcpStream { } impl AsyncRead for IpStackTcpStream { - fn poll_read( - mut self: std::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll> { + fn poll_read(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll> { loop { if self.tcb.retransmission.is_some() { self.write_notify = Some(cx.waker().clone()); @@ -186,9 +169,7 @@ impl AsyncRead for IpStackTcpStream { } if let Some(packet) = self.packet_to_send.take() { - self.packet_sender - .send(packet) - .or(Err(ErrorKind::UnexpectedEof))?; + self.packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; } if self.tcb.get_state() == TcpState::Closed { self.shutdown.ready(); @@ -196,8 +177,7 @@ impl AsyncRead for IpStackTcpStream { } if self.tcb.get_state() == TcpState::FinWait2(false) { - self.packet_to_send = - Some(self.create_rev_packet(NON, DROP_TTL, None, Vec::new())?); + self.packet_to_send = Some(self.create_rev_packet(NON, DROP_TTL, None, Vec::new())?); self.tcb.change_state(TcpState::Closed); self.shutdown.ready(); return Poll::Ready(Err(Error::from(ErrorKind::ConnectionAborted))); @@ -218,18 +198,13 @@ impl AsyncRead for IpStackTcpStream { self.tcb.reset_timeout(); if self.tcb.get_state() == TcpState::SynReceived(false) { - self.packet_to_send = - Some(self.create_rev_packet(SYN | ACK, TTL, None, Vec::new())?); + self.packet_to_send = Some(self.create_rev_packet(SYN | ACK, TTL, None, Vec::new())?); self.tcb.add_seq_one(); self.tcb.change_state(TcpState::SynReceived(true)); continue; } - if let Some(b) = self - .tcb - .get_unordered_packets() - .filter(|_| matches!(self.shutdown, Shutdown::None)) - { + if let Some(b) = self.tcb.get_unordered_packets().filter(|_| matches!(self.shutdown, Shutdown::None)) { self.tcb.add_ack(b.len() as u32); buf.put_slice(&b); self.packet_sender @@ -238,8 +213,7 @@ impl AsyncRead for IpStackTcpStream { return Poll::Ready(Ok(())); } if self.tcb.get_state() == TcpState::FinWait1(true) { - self.packet_to_send = - Some(self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?); + self.packet_to_send = Some(self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?); self.tcb.add_seq_one(); self.tcb.add_ack(1); self.tcb.change_state(TcpState::FinWait2(true)); @@ -248,8 +222,7 @@ impl AsyncRead for IpStackTcpStream { && self.tcb.get_state() == TcpState::Established && self.tcb.get_last_ack() == self.tcb.get_seq() { - self.packet_to_send = - Some(self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?); + self.packet_to_send = Some(self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?); self.tcb.add_seq_one(); self.tcb.change_state(TcpState::FinWait1(false)); continue; @@ -260,8 +233,7 @@ impl AsyncRead for IpStackTcpStream { unreachable!() }; if t.flags() & RST != 0 { - self.packet_to_send = - Some(self.create_rev_packet(NON, DROP_TTL, None, Vec::new())?); + self.packet_to_send = Some(self.create_rev_packet(NON, DROP_TTL, None, Vec::new())?); self.tcb.change_state(TcpState::Closed); self.shutdown.ready(); return Poll::Ready(Err(Error::from(ErrorKind::ConnectionReset))); @@ -291,8 +263,7 @@ impl AsyncRead for IpStackTcpStream { PacketStatus::KeepAlive => { self.tcb.change_last_ack(t.inner().acknowledgment_number); self.tcb.change_send_window(t.inner().window_size); - self.packet_to_send = - Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); + self.packet_to_send = Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); continue; } PacketStatus::RetransmissionRequest => { @@ -316,8 +287,7 @@ impl AsyncRead for IpStackTcpStream { // } self.tcb.change_last_ack(t.inner().acknowledgment_number); - self.tcb - .add_unordered_packet(t.inner().sequence_number, p.payload); + self.tcb.add_unordered_packet(t.inner().sequence_number, p.payload); self.tcb.change_send_window(t.inner().window_size); if let Some(ref n) = self.write_notify { @@ -339,30 +309,23 @@ impl AsyncRead for IpStackTcpStream { } if t.flags() == (FIN | ACK) { self.tcb.add_ack(1); - self.packet_to_send = - Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); + self.packet_to_send = Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); self.tcb.change_state(TcpState::FinWait1(true)); continue; } if t.flags() == (PSH | ACK) { - if !matches!( - self.tcb.check_pkt_type(&t, &p.payload), - PacketStatus::NewPacket - ) { + if !matches!(self.tcb.check_pkt_type(&t, &p.payload), PacketStatus::NewPacket) { continue; } self.tcb.change_last_ack(t.inner().acknowledgment_number); - if p.payload.is_empty() - || self.tcb.get_ack() != t.inner().sequence_number - { + if p.payload.is_empty() || self.tcb.get_ack() != t.inner().sequence_number { continue; } self.tcb.change_send_window(t.inner().window_size); - self.tcb - .add_unordered_packet(t.inner().sequence_number, p.payload); + self.tcb.add_unordered_packet(t.inner().sequence_number, p.payload); continue; } } else if self.tcb.get_state() == TcpState::FinWait1(false) { @@ -373,8 +336,7 @@ impl AsyncRead for IpStackTcpStream { continue; } else if t.flags() == (FIN | ACK) { self.tcb.add_ack(1); - self.packet_to_send = - Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); + self.packet_to_send = Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); self.tcb.change_send_window(t.inner().window_size); self.tcb.change_state(TcpState::FinWait2(true)); continue; @@ -383,8 +345,7 @@ impl AsyncRead for IpStackTcpStream { if t.flags() == ACK { self.tcb.change_state(TcpState::FinWait2(false)); } else if t.flags() == (FIN | ACK) { - self.packet_to_send = - Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); + self.packet_to_send = Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); self.tcb.change_state(TcpState::FinWait2(false)); } } @@ -397,19 +358,13 @@ impl AsyncRead for IpStackTcpStream { } impl AsyncWrite for IpStackTcpStream { - fn poll_write( - mut self: std::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { + fn poll_write(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { if self.tcb.get_state() != TcpState::Established { return Poll::Ready(Err(Error::from(ErrorKind::NotConnected))); } self.tcb.reset_timeout(); - if (self.tcb.get_send_window() as u64) < self.tcb.get_avg_send_window() / 2 - || self.tcb.is_send_buffer_full() - { + if (self.tcb.get_send_window() as u64) < self.tcb.get_avg_send_window() / 2 || self.tcb.is_send_buffer_full() { self.write_notify = Some(cx.waker().clone()); return Poll::Pending; } @@ -425,30 +380,22 @@ impl AsyncWrite for IpStackTcpStream { let seq = self.tcb.get_seq(); let payload_len = packet.payload.len(); let payload = packet.payload.clone(); - self.packet_sender - .send(packet) - .or(Err(ErrorKind::UnexpectedEof))?; + self.packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; self.tcb.add_inflight_packet(seq, payload); Poll::Ready(Ok(payload_len)) } - fn poll_flush( - mut self: std::pin::Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { + fn poll_flush(mut self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { if self.tcb.get_state() != TcpState::Established { return Poll::Ready(Err(Error::from(ErrorKind::NotConnected))); } if let Some(s) = self.tcb.retransmission.take() { if let Some(packet) = self.tcb.inflight_packets.iter().find(|p| p.seq == s) { - let rev_packet = - self.create_rev_packet(PSH | ACK, TTL, packet.seq, packet.payload.clone())?; + let rev_packet = self.create_rev_packet(PSH | ACK, TTL, packet.seq, packet.payload.clone())?; - self.packet_sender - .send(rev_packet) - .or(Err(ErrorKind::UnexpectedEof))?; + self.packet_sender.send(rev_packet).or(Err(ErrorKind::UnexpectedEof))?; } else { error!("Packet {} not found in inflight_packets", s); error!("seq: {}", self.tcb.get_seq()); @@ -465,19 +412,13 @@ impl AsyncWrite for IpStackTcpStream { Poll::Ready(Ok(())) } - fn poll_shutdown( - mut self: std::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_shutdown(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if matches!(self.shutdown, Shutdown::Ready) { return Poll::Ready(Ok(())); } else if matches!(self.shutdown, Shutdown::None) { self.shutdown.pending(cx.waker().clone()); } - self.poll_read( - cx, - &mut tokio::io::ReadBuf::uninit(&mut [MaybeUninit::::uninit()]), - ) + self.poll_read(cx, &mut tokio::io::ReadBuf::uninit(&mut [MaybeUninit::::uninit()])) } } diff --git a/src/stream/tcp_wrapper.rs b/src/stream/tcp_wrapper.rs index 7e8817b..a3695f5 100644 --- a/src/stream/tcp_wrapper.rs +++ b/src/stream/tcp_wrapper.rs @@ -23,20 +23,13 @@ impl IpStackTcpStream { tcp_timeout: Duration, ) -> Result { let (stream_sender, stream_receiver) = mpsc::unbounded_channel::(); - IpStackTcpStreamInner::new( - local_addr, - peer_addr, - tcp, - pkt_sender, - stream_receiver, - mtu, - tcp_timeout, - ) - .map(|inner| IpStackTcpStream { - inner: Some(Box::new(inner)), - peer_addr, - local_addr, - stream_sender, + IpStackTcpStreamInner::new(local_addr, peer_addr, tcp, pkt_sender, stream_receiver, mtu, tcp_timeout).map(|inner| { + IpStackTcpStream { + inner: Some(Box::new(inner)), + peer_addr, + local_addr, + stream_sender, + } }) } pub fn local_addr(&self) -> SocketAddr { @@ -58,9 +51,7 @@ impl tokio::io::AsyncRead for IpStackTcpStream { ) -> std::task::Poll> { match self.inner.as_mut() { Some(mut inner) => Pin::new(&mut inner).poll_read(cx, buf), - None => { - std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))) - } + None => std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))), } } } @@ -73,31 +64,19 @@ impl tokio::io::AsyncWrite for IpStackTcpStream { ) -> std::task::Poll> { match self.inner.as_mut() { Some(mut inner) => Pin::new(&mut inner).poll_write(cx, buf), - None => { - std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))) - } + None => std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))), } } - fn poll_flush( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + fn poll_flush(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { match self.inner.as_mut() { Some(mut inner) => Pin::new(&mut inner).poll_flush(cx), - None => { - std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))) - } + None => std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))), } } - fn poll_shutdown( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + fn poll_shutdown(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { match self.inner.as_mut() { Some(mut inner) => Pin::new(&mut inner).poll_shutdown(cx), - None => { - std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))) - } + None => std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))), } } } diff --git a/src/stream/udp.rs b/src/stream/udp.rs index ad8086c..4a8bce1 100644 --- a/src/stream/udp.rs +++ b/src/stream/udp.rs @@ -55,19 +55,12 @@ impl IpStackUdpStream { const UHS: usize = 8; // udp header size is 8 match (self.dst_addr.ip(), self.src_addr.ip()) { (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => { - let mut ip_h = Ipv4Header::new(0, ttl, IpNumber::UDP, dst.octets(), src.octets()) - .map_err(IpStackError::from)?; + let mut ip_h = Ipv4Header::new(0, ttl, IpNumber::UDP, dst.octets(), src.octets()).map_err(IpStackError::from)?; let line_buffer = self.mtu.saturating_sub((ip_h.header_len() + UHS) as u16); payload.truncate(line_buffer as usize); - ip_h.set_payload_len(payload.len() + UHS) + ip_h.set_payload_len(payload.len() + UHS).map_err(IpStackError::from)?; + let udp_header = UdpHeader::with_ipv4_checksum(self.dst_addr.port(), self.src_addr.port(), &ip_h, &payload) .map_err(IpStackError::from)?; - let udp_header = UdpHeader::with_ipv4_checksum( - self.dst_addr.port(), - self.src_addr.port(), - &ip_h, - &payload, - ) - .map_err(IpStackError::from)?; Ok(NetworkPacket { ip: IpHeader::Ipv4(ip_h), transport: TransportHeader::Udp(udp_header), @@ -89,13 +82,8 @@ impl IpStackUdpStream { payload.truncate(line_buffer as usize); ip_h.payload_length = (payload.len() + UHS) as u16; - let udp_header = UdpHeader::with_ipv6_checksum( - self.dst_addr.port(), - self.src_addr.port(), - &ip_h, - &payload, - ) - .map_err(IpStackError::from)?; + let udp_header = UdpHeader::with_ipv6_checksum(self.dst_addr.port(), self.src_addr.port(), &ip_h, &payload) + .map_err(IpStackError::from)?; Ok(NetworkPacket { ip: IpHeader::Ipv6(ip_h), transport: TransportHeader::Udp(udp_header), @@ -148,31 +136,19 @@ impl AsyncRead for IpStackUdpStream { } impl AsyncWrite for IpStackUdpStream { - fn poll_write( - mut self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> std::task::Poll> { + fn poll_write(mut self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>, buf: &[u8]) -> std::task::Poll> { self.reset_timeout(); let packet = self.create_rev_packet(TTL, buf.to_vec())?; let payload_len = packet.payload.len(); - self.pkt_sender - .send(packet) - .or(Err(std::io::ErrorKind::UnexpectedEof))?; + self.pkt_sender.send(packet).or(Err(std::io::ErrorKind::UnexpectedEof))?; std::task::Poll::Ready(Ok(payload_len)) } - fn poll_flush( - self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + fn poll_flush(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) } - fn poll_shutdown( - self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) } } diff --git a/src/stream/unknown.rs b/src/stream/unknown.rs index b9889aa..173dfce 100644 --- a/src/stream/unknown.rs +++ b/src/stream/unknown.rs @@ -15,14 +15,7 @@ pub struct IpStackUnknownTransport { } impl IpStackUnknownTransport { - pub(crate) fn new( - src_addr: IpAddr, - dst_addr: IpAddr, - payload: Vec, - ip: &IpHeader, - mtu: u16, - packet_sender: PacketSender, - ) -> Self { + pub(crate) fn new(src_addr: IpAddr, dst_addr: IpAddr, payload: Vec, ip: &IpHeader, mtu: u16, packet_sender: PacketSender) -> Self { let protocol = match ip { IpHeader::Ipv4(ip) => ip.protocol, IpHeader::Ipv6(ip) => ip.next_header, @@ -51,9 +44,9 @@ impl IpStackUnknownTransport { pub fn send(&self, mut payload: Vec) -> std::io::Result<()> { loop { let packet = self.create_rev_packet(&mut payload)?; - self.packet_sender.send(packet).map_err(|e| { - std::io::Error::new(std::io::ErrorKind::Other, format!("send error: {}", e)) - })?; + self.packet_sender + .send(packet) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("send error: {}", e)))?; if payload.is_empty() { return Ok(()); } @@ -63,8 +56,7 @@ impl IpStackUnknownTransport { pub fn create_rev_packet(&self, payload: &mut Vec) -> std::io::Result { match (self.dst_addr, self.src_addr) { (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => { - let mut ip_h = Ipv4Header::new(0, TTL, self.protocol, dst.octets(), src.octets()) - .map_err(IpStackError::from)?; + let mut ip_h = Ipv4Header::new(0, TTL, self.protocol, dst.octets(), src.octets()).map_err(IpStackError::from)?; let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16); let p = if payload.len() > line_buffer as usize { @@ -95,8 +87,7 @@ impl IpStackUnknownTransport { } else { std::mem::take(payload) }; - ip_h.set_payload_length(p.len()) - .map_err(IpStackError::from)?; + ip_h.set_payload_length(p.len()).map_err(IpStackError::from)?; Ok(NetworkPacket { ip: IpHeader::Ipv6(ip_h), transport: TransportHeader::Unknown, From 6cd38501e48077ac068e3670cbe5a520889d1265 Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Fri, 21 Feb 2025 23:46:15 +0800 Subject: [PATCH 05/35] refine code --- README.md | 10 ++--- examples/tun2.rs | 2 +- src/lib.rs | 104 +++++++++++++++++++++------------------------- src/stream/tcp.rs | 20 ++++----- src/stream/udp.rs | 8 ++-- 5 files changed, 65 insertions(+), 79 deletions(-) diff --git a/README.md b/README.md index 5ed93f8..084a950 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ async fn main() { #[cfg(target_os = "windows")] config.platform_config(|config| { - config.device_guid(Some(12324323423423434234_u128)); + config.device_guid(12324323423423434234_u128); }); let mut ipstack_config = ipstack::IpStackConfig::default(); @@ -51,7 +51,7 @@ async fn main() { }); } IpStackStream::Udp(mut udp) => { - let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 53); + let addr: SocketAddr = "1.1.1.1:53".parse().unwrap(); let mut rhs = UdpStream::connect(addr).await.unwrap(); tokio::spawn(async move { let _ = tokio::io::copy_bidirectional(&mut udp, &mut rhs).await; @@ -60,12 +60,8 @@ async fn main() { IpStackStream::UnknownTransport(u) => { if u.src_addr().is_ipv4() && u.ip_protocol() == IpNumber::ICMP { let (icmp_header, req_payload) = Icmpv4Header::from_slice(u.payload()).unwrap(); - if let etherparse::Icmpv4Type::EchoRequest(req) = icmp_header.icmp_type { + if let etherparse::Icmpv4Type::EchoRequest(echo) = icmp_header.icmp_type { println!("ICMPv4 echo"); - let echo = IcmpEchoHeader { - id: req.id, - seq: req.seq, - }; let mut resp = Icmpv4Header::new(etherparse::Icmpv4Type::EchoReply(echo)); resp.update_checksum(req_payload); let mut payload = resp.to_bytes().to_vec(); diff --git a/examples/tun2.rs b/examples/tun2.rs index 535e02b..eb26417 100644 --- a/examples/tun2.rs +++ b/examples/tun2.rs @@ -170,7 +170,7 @@ async fn main() -> Result<(), Box> { continue; } IpStackStream::UnknownNetwork(pkt) => { - log::info!("#{number} unknown transport - {} bytes", pkt.len()); + log::info!("#{number} unknown network - {} bytes", pkt.len()); continue; } }; diff --git a/src/lib.rs b/src/lib.rs index 2108e68..7bf9e74 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,7 +5,6 @@ use crate::{ stream::{IpStackStream, IpStackTcpStream, IpStackUdpStream, IpStackUnknownTransport}, }; use ahash::AHashMap; -use log::{error, trace}; use packet::{NetworkPacket, NetworkTuple}; use std::{ collections::hash_map::Entry::{Occupied, Vacant}, @@ -119,22 +118,23 @@ fn run( let pi = config.packet_information; let offset = if pi && cfg!(unix) { 4 } else { 0 }; let mut buffer = [0_u8; u16::MAX as usize + 4]; - let (pkt_sender, mut pkt_receiver) = mpsc::unbounded_channel::(); + let (up_pkt_sender, mut up_pkt_receiver) = mpsc::unbounded_channel::(); tokio::spawn(async move { loop { select! { Ok(n) = device.read(&mut buffer) => { - if let Some(stream) = process_device_read( + if let Err(e) = process_device_read( &buffer[offset..n], &mut sessions, - pkt_sender.clone(), + up_pkt_sender.clone(), &config, - ) { - accept_sender.send(stream)?; + &accept_sender, + ) { + log::debug!("process_device_read error: {}", e); } } - Some(packet) = pkt_receiver.recv() => { + Some(packet) = up_pkt_receiver.recv() => { process_upstream_recv( packet, &mut sessions, @@ -152,68 +152,61 @@ fn run( fn process_device_read( data: &[u8], sessions: &mut SessionCollection, - pkt_sender: PacketSender, + up_pkt_sender: PacketSender, config: &IpStackConfig, -) -> Option { + accept_sender: &UnboundedSender, +) -> Result<()> { let Ok(packet) = NetworkPacket::parse(data) else { - return Some(IpStackStream::UnknownNetwork(data.to_owned())); + let stream = IpStackStream::UnknownNetwork(data.to_owned()); + accept_sender.send(stream)?; + return Ok(()); }; if let IpStackPacketProtocol::Unknown = packet.transport_protocol() { - return Some(IpStackStream::UnknownTransport(IpStackUnknownTransport::new( + let stream = IpStackStream::UnknownTransport(IpStackUnknownTransport::new( packet.src_addr().ip(), packet.dst_addr().ip(), packet.payload, &packet.ip, config.mtu, - pkt_sender, - ))); + up_pkt_sender, + )); + accept_sender.send(stream)?; + return Ok(()); } - match sessions.entry(packet.network_tuple()) { + let network_tuple = packet.network_tuple(); + match sessions.entry(network_tuple) { Occupied(mut entry) => { if let Err(e) = entry.get().send(packet) { - log::debug!("New stream \"{}\" because: \"{}\"", e.0.network_tuple(), e); - create_stream(e.0, config, pkt_sender).map(|(packet_sender, ip_stack_stream)| { - entry.insert(packet_sender); - ip_stack_stream - }) + log::debug!("New stream \"{}\" because: \"{}\"", network_tuple, e); + let (packet_sender, ip_stack_stream) = create_stream(e.0, config, up_pkt_sender)?; + entry.insert(packet_sender); + accept_sender.send(ip_stack_stream)?; } else { - None + log::trace!("packet sent to stream: {}", network_tuple); } } - Vacant(entry) => create_stream(packet, config, pkt_sender).map(|(packet_sender, ip_stack_stream)| { + Vacant(entry) => { + let (packet_sender, ip_stack_stream) = create_stream(packet, config, up_pkt_sender)?; entry.insert(packet_sender); - ip_stack_stream - }), + accept_sender.send(ip_stack_stream)?; + } } + Ok(()) } -fn create_stream(packet: NetworkPacket, config: &IpStackConfig, pkt_sender: PacketSender) -> Option<(PacketSender, IpStackStream)> { +fn create_stream(packet: NetworkPacket, cfg: &IpStackConfig, up_pkt_sender: PacketSender) -> Result<(PacketSender, IpStackStream)> { + let src_addr = packet.src_addr(); + let dst_addr = packet.dst_addr(); match packet.transport_protocol() { IpStackPacketProtocol::Tcp(h) => { - match IpStackTcpStream::new(packet.src_addr(), packet.dst_addr(), h, pkt_sender, config.mtu, config.tcp_timeout) { - Ok(stream) => Some((stream.stream_sender(), IpStackStream::Tcp(stream))), - Err(e) => { - if matches!(e, IpStackError::InvalidTcpPacket(_)) { - log::debug!("{e}"); - } else { - error!("IpStackTcpStream::new failed \"{}\"", e); - } - None - } - } + let stream = IpStackTcpStream::new(src_addr, dst_addr, h, up_pkt_sender, cfg.mtu, cfg.tcp_timeout)?; + Ok((stream.stream_sender(), IpStackStream::Tcp(stream))) } IpStackPacketProtocol::Udp => { - let stream = IpStackUdpStream::new( - packet.src_addr(), - packet.dst_addr(), - packet.payload, - pkt_sender, - config.mtu, - config.udp_timeout, - ); - Some((stream.stream_sender(), IpStackStream::Udp(stream))) + let stream = IpStackUdpStream::new(src_addr, dst_addr, packet.payload, up_pkt_sender, cfg.mtu, cfg.udp_timeout); + Ok((stream.stream_sender(), IpStackStream::Udp(stream))) } IpStackPacketProtocol::Unknown => { unreachable!() @@ -221,36 +214,33 @@ fn create_stream(packet: NetworkPacket, config: &IpStackConfig, pkt_sender: Pack } } -async fn process_upstream_recv( - packet: NetworkPacket, +async fn process_upstream_recv( + up_packet: NetworkPacket, sessions: &mut SessionCollection, - device: &mut D, + device: &mut Device, #[cfg(unix)] packet_information: bool, -) -> Result<()> -where - D: AsyncWrite + Unpin + 'static, -{ - if packet.ttl() == 0 { - let network_tuple = packet.reverse_network_tuple(); +) -> Result<()> { + if up_packet.ttl() == DROP_TTL { + let network_tuple = up_packet.reverse_network_tuple(); sessions.remove(&network_tuple); log::trace!("session removed: {}", network_tuple); return Ok(()); } #[allow(unused_mut)] - let Ok(mut packet_bytes) = packet.to_bytes() else { - trace!("to_bytes error"); + let Ok(mut packet_bytes) = up_packet.to_bytes() else { + log::trace!("to_bytes error"); return Ok(()); }; #[cfg(unix)] if packet_information { - if packet.src_addr().is_ipv4() { + if up_packet.src_addr().is_ipv4() { packet_bytes.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP4].concat()); } else { packet_bytes.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP6].concat()); } } device.write_all(&packet_bytes).await?; - // device.flush().await.unwrap(); + // device.flush().await?; Ok(()) } diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index bc5b7e8..0c69bc2 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -45,7 +45,7 @@ pub(crate) struct IpStackTcpStream { src_addr: SocketAddr, dst_addr: SocketAddr, stream_receiver: PacketReceiver, - packet_sender: PacketSender, + up_packet_sender: PacketSender, packet_to_send: Option, tcb: Tcb, mtu: u16, @@ -58,7 +58,7 @@ impl IpStackTcpStream { src_addr: SocketAddr, dst_addr: SocketAddr, tcp: TcpHeaderWrapper, - packet_sender: PacketSender, + up_packet_sender: PacketSender, stream_receiver: PacketReceiver, mtu: u16, tcp_timeout: Duration, @@ -67,7 +67,7 @@ impl IpStackTcpStream { src_addr, dst_addr, stream_receiver, - packet_sender, + up_packet_sender, packet_to_send: None, tcb: Tcb::new(tcp.inner().sequence_number + 1, tcp_timeout), mtu, @@ -79,7 +79,7 @@ impl IpStackTcpStream { } if !tcp.inner().rst { let pkt = stream.create_rev_packet(RST | ACK, TTL, None, Vec::new())?; - if let Err(err) = stream.packet_sender.send(pkt) { + if let Err(err) = stream.up_packet_sender.send(pkt) { warn!("Error sending RST/ACK packet: {:?}", err); } } @@ -169,7 +169,7 @@ impl AsyncRead for IpStackTcpStream { } if let Some(packet) = self.packet_to_send.take() { - self.packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; + self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; } if self.tcb.get_state() == TcpState::Closed { self.shutdown.ready(); @@ -188,7 +188,7 @@ impl AsyncRead for IpStackTcpStream { if matches!(Pin::new(&mut self.tcb.timeout).poll(cx), Poll::Ready(_)) { trace!("timeout reached for {:?}", self.dst_addr); - self.packet_sender + self.up_packet_sender .send(self.create_rev_packet(RST | ACK, TTL, None, Vec::new())?) .or(Err(ErrorKind::UnexpectedEof))?; self.tcb.change_state(TcpState::Closed); @@ -207,7 +207,7 @@ impl AsyncRead for IpStackTcpStream { if let Some(b) = self.tcb.get_unordered_packets().filter(|_| matches!(self.shutdown, Shutdown::None)) { self.tcb.add_ack(b.len() as u32); buf.put_slice(&b); - self.packet_sender + self.up_packet_sender .send(self.create_rev_packet(ACK, TTL, None, Vec::new())?) .or(Err(ErrorKind::UnexpectedEof))?; return Poll::Ready(Ok(())); @@ -380,7 +380,7 @@ impl AsyncWrite for IpStackTcpStream { let seq = self.tcb.get_seq(); let payload_len = packet.payload.len(); let payload = packet.payload.clone(); - self.packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; + self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; self.tcb.add_inflight_packet(seq, payload); Poll::Ready(Ok(payload_len)) @@ -395,7 +395,7 @@ impl AsyncWrite for IpStackTcpStream { if let Some(packet) = self.tcb.inflight_packets.iter().find(|p| p.seq == s) { let rev_packet = self.create_rev_packet(PSH | ACK, TTL, packet.seq, packet.payload.clone())?; - self.packet_sender.send(rev_packet).or(Err(ErrorKind::UnexpectedEof))?; + self.up_packet_sender.send(rev_packet).or(Err(ErrorKind::UnexpectedEof))?; } else { error!("Packet {} not found in inflight_packets", s); error!("seq: {}", self.tcb.get_seq()); @@ -425,7 +425,7 @@ impl AsyncWrite for IpStackTcpStream { impl Drop for IpStackTcpStream { fn drop(&mut self) { if let Ok(p) = self.create_rev_packet(NON, DROP_TTL, None, Vec::new()) { - if let Err(err) = self.packet_sender.send(p) { + if let Err(err) = self.up_packet_sender.send(p) { trace!("Error sending NON packet: {:?}", err); } } diff --git a/src/stream/udp.rs b/src/stream/udp.rs index 4a8bce1..e665817 100644 --- a/src/stream/udp.rs +++ b/src/stream/udp.rs @@ -16,7 +16,7 @@ pub struct IpStackUdpStream { dst_addr: SocketAddr, stream_sender: PacketSender, stream_receiver: PacketReceiver, - pkt_sender: PacketSender, + up_pkt_sender: PacketSender, first_payload: Option>, timeout: Pin>, udp_timeout: Duration, @@ -28,7 +28,7 @@ impl IpStackUdpStream { src_addr: SocketAddr, dst_addr: SocketAddr, payload: Vec, - pkt_sender: PacketSender, + up_pkt_sender: PacketSender, mtu: u16, udp_timeout: Duration, ) -> Self { @@ -39,7 +39,7 @@ impl IpStackUdpStream { dst_addr, stream_sender, stream_receiver, - pkt_sender, + up_pkt_sender, first_payload: Some(payload), timeout: Box::pin(tokio::time::sleep_until(deadline)), udp_timeout, @@ -140,7 +140,7 @@ impl AsyncWrite for IpStackUdpStream { self.reset_timeout(); let packet = self.create_rev_packet(TTL, buf.to_vec())?; let payload_len = packet.payload.len(); - self.pkt_sender.send(packet).or(Err(std::io::ErrorKind::UnexpectedEof))?; + self.up_pkt_sender.send(packet).or(Err(std::io::ErrorKind::UnexpectedEof))?; std::task::Poll::Ready(Ok(payload_len)) } From 32486e4d3af50e90655a35e810100000f74a556b Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Mon, 24 Feb 2025 00:23:04 +0800 Subject: [PATCH 06/35] remove useless IpStackPacketProtocol --- src/lib.rs | 18 ++++++++---------- src/packet.rs | 15 ++------------- src/stream/tcp.rs | 5 +++-- 3 files changed, 13 insertions(+), 25 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 7bf9e74..091f571 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,11 +1,8 @@ #![doc = include_str!("../README.md")] -use crate::{ - packet::IpStackPacketProtocol, - stream::{IpStackStream, IpStackTcpStream, IpStackUdpStream, IpStackUnknownTransport}, -}; +use crate::stream::{IpStackStream, IpStackTcpStream, IpStackUdpStream, IpStackUnknownTransport}; use ahash::AHashMap; -use packet::{NetworkPacket, NetworkTuple}; +use packet::{NetworkPacket, NetworkTuple, TransportHeader}; use std::{ collections::hash_map::Entry::{Occupied, Vacant}, time::Duration, @@ -162,7 +159,7 @@ fn process_device_read( return Ok(()); }; - if let IpStackPacketProtocol::Unknown = packet.transport_protocol() { + if let TransportHeader::Unknown = packet.transport_header() { let stream = IpStackStream::UnknownTransport(IpStackUnknownTransport::new( packet.src_addr().ip(), packet.dst_addr().ip(), @@ -199,16 +196,17 @@ fn process_device_read( fn create_stream(packet: NetworkPacket, cfg: &IpStackConfig, up_pkt_sender: PacketSender) -> Result<(PacketSender, IpStackStream)> { let src_addr = packet.src_addr(); let dst_addr = packet.dst_addr(); - match packet.transport_protocol() { - IpStackPacketProtocol::Tcp(h) => { + match packet.transport_header() { + TransportHeader::Tcp(h) => { + let h: TcpHeaderWrapper = h.into(); let stream = IpStackTcpStream::new(src_addr, dst_addr, h, up_pkt_sender, cfg.mtu, cfg.tcp_timeout)?; Ok((stream.stream_sender(), IpStackStream::Tcp(stream))) } - IpStackPacketProtocol::Udp => { + TransportHeader::Udp(_) => { let stream = IpStackUdpStream::new(src_addr, dst_addr, packet.payload, up_pkt_sender, cfg.mtu, cfg.udp_timeout); Ok((stream.stream_sender(), IpStackStream::Udp(stream))) } - IpStackPacketProtocol::Unknown => { + TransportHeader::Unknown => { unreachable!() } } diff --git a/src/packet.rs b/src/packet.rs index 6e6fa60..380475b 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -28,13 +28,6 @@ pub mod tcp_flags { pub const NON: u8 = 0b00000000; } -#[derive(Debug, Clone)] -pub(crate) enum IpStackPacketProtocol { - Tcp(TcpHeaderWrapper), - Unknown, - Udp, -} - #[derive(Debug, Clone)] pub(crate) enum IpHeader { Ipv4(Ipv4Header), @@ -74,12 +67,8 @@ impl NetworkPacket { Ok(NetworkPacket { ip, transport, payload }) } - pub(crate) fn transport_protocol(&self) -> IpStackPacketProtocol { - match self.transport { - TransportHeader::Udp(_) => IpStackPacketProtocol::Udp, - TransportHeader::Tcp(ref h) => IpStackPacketProtocol::Tcp(h.into()), - _ => IpStackPacketProtocol::Unknown, - } + pub(crate) fn transport_header(&self) -> &TransportHeader { + &self.transport } pub fn src_addr(&self) -> SocketAddr { let port = match &self.transport { diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index 0c69bc2..a8fceb9 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -2,7 +2,7 @@ use crate::{ error::IpStackError, packet::{ tcp_flags::{ACK, FIN, NON, PSH, RST, SYN}, - IpHeader, IpStackPacketProtocol, NetworkPacket, TcpHeaderWrapper, TransportHeader, + IpHeader, NetworkPacket, TcpHeaderWrapper, TransportHeader, }, stream::tcb::{PacketStatus, Tcb, TcpState}, PacketReceiver, PacketSender, DROP_TTL, TTL, @@ -229,9 +229,10 @@ impl AsyncRead for IpStackTcpStream { } match self.stream_receiver.poll_recv(cx) { Poll::Ready(Some(p)) => { - let IpStackPacketProtocol::Tcp(t) = p.transport_protocol() else { + let TransportHeader::Tcp(tcp_header) = p.transport_header() else { unreachable!() }; + let t: TcpHeaderWrapper = tcp_header.into(); if t.flags() & RST != 0 { self.packet_to_send = Some(self.create_rev_packet(NON, DROP_TTL, None, Vec::new())?); self.tcb.change_state(TcpState::Closed); From 240959ff0ebafd91ad2120ccbac8e225aef22c66 Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Mon, 24 Feb 2025 19:18:42 +0800 Subject: [PATCH 07/35] Make the terminated UDP Session removed smoothly --- src/lib.rs | 45 +++++++++++++++++++++++++-------------------- src/stream/udp.rs | 17 +++++++++++++++-- 2 files changed, 40 insertions(+), 22 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 091f571..c3d51e6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,7 +16,7 @@ use tokio::{ pub(crate) type PacketSender = UnboundedSender; pub(crate) type PacketReceiver = UnboundedReceiver; -pub(crate) type SessionCollection = AHashMap; +pub(crate) type SessionCollection = std::sync::Arc>>; mod error; mod packet; @@ -111,7 +111,7 @@ fn run( mut device: Device, accept_sender: UnboundedSender, ) -> JoinHandle> { - let mut sessions: SessionCollection = AHashMap::new(); + let sessions: SessionCollection = std::sync::Arc::new(tokio::sync::Mutex::new(AHashMap::new())); let pi = config.packet_information; let offset = if pi && cfg!(unix) { 4 } else { 0 }; let mut buffer = [0_u8; u16::MAX as usize + 4]; @@ -123,18 +123,18 @@ fn run( Ok(n) = device.read(&mut buffer) => { if let Err(e) = process_device_read( &buffer[offset..n], - &mut sessions, + sessions.clone(), up_pkt_sender.clone(), &config, &accept_sender, - ) { + ).await { log::debug!("process_device_read error: {}", e); } } Some(packet) = up_pkt_receiver.recv() => { process_upstream_recv( packet, - &mut sessions, + sessions.clone(), &mut device, #[cfg(unix)] pi, @@ -146,9 +146,9 @@ fn run( }) } -fn process_device_read( +async fn process_device_read( data: &[u8], - sessions: &mut SessionCollection, + sessions: SessionCollection, up_pkt_sender: PacketSender, config: &IpStackConfig, accept_sender: &UnboundedSender, @@ -172,20 +172,25 @@ fn process_device_read( return Ok(()); } + let sessions_clone = sessions.clone(); let network_tuple = packet.network_tuple(); - match sessions.entry(network_tuple) { - Occupied(mut entry) => { - if let Err(e) = entry.get().send(packet) { - log::debug!("New stream \"{}\" because: \"{}\"", network_tuple, e); - let (packet_sender, ip_stack_stream) = create_stream(e.0, config, up_pkt_sender)?; - entry.insert(packet_sender); - accept_sender.send(ip_stack_stream)?; - } else { - log::trace!("packet sent to stream: {}", network_tuple); - } + match sessions.lock().await.entry(network_tuple) { + Occupied(entry) => { + use std::io::{Error, ErrorKind::Other}; + entry.get().send(packet).map_err(|e| Error::new(Other, e))?; + log::trace!("packet sent to stream: {}", network_tuple); } Vacant(entry) => { - let (packet_sender, ip_stack_stream) = create_stream(packet, config, up_pkt_sender)?; + let (packet_sender, mut ip_stack_stream) = create_stream(packet, config, up_pkt_sender)?; + if let IpStackStream::Udp(ref mut stream) = ip_stack_stream { + let (tx, rx) = tokio::sync::oneshot::channel::<()>(); + stream.set_destroy_messenger(tx); + tokio::spawn(async move { + rx.await.ok(); + sessions_clone.lock().await.remove(&network_tuple); + log::trace!("session removed: {}", network_tuple); + }); + } entry.insert(packet_sender); accept_sender.send(ip_stack_stream)?; } @@ -214,13 +219,13 @@ fn create_stream(packet: NetworkPacket, cfg: &IpStackConfig, up_pkt_sender: Pack async fn process_upstream_recv( up_packet: NetworkPacket, - sessions: &mut SessionCollection, + sessions: SessionCollection, device: &mut Device, #[cfg(unix)] packet_information: bool, ) -> Result<()> { if up_packet.ttl() == DROP_TTL { let network_tuple = up_packet.reverse_network_tuple(); - sessions.remove(&network_tuple); + sessions.lock().await.remove(&network_tuple); log::trace!("session removed: {}", network_tuple); return Ok(()); } diff --git a/src/stream/udp.rs b/src/stream/udp.rs index e665817..c2cedb1 100644 --- a/src/stream/udp.rs +++ b/src/stream/udp.rs @@ -6,11 +6,10 @@ use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel, Ipv6Header, UdpHeader}; use std::{future::Future, net::SocketAddr, pin::Pin, time::Duration}; use tokio::{ io::{AsyncRead, AsyncWrite}, - sync::mpsc, + sync::{mpsc, oneshot}, time::Sleep, }; -#[derive(Debug)] pub struct IpStackUdpStream { src_addr: SocketAddr, dst_addr: SocketAddr, @@ -21,6 +20,7 @@ pub struct IpStackUdpStream { timeout: Pin>, udp_timeout: Duration, mtu: u16, + destroy_messenger: Option>, } impl IpStackUdpStream { @@ -44,9 +44,14 @@ impl IpStackUdpStream { timeout: Box::pin(tokio::time::sleep_until(deadline)), udp_timeout, mtu, + destroy_messenger: None, } } + pub(crate) fn set_destroy_messenger(&mut self, messenger: oneshot::Sender<()>) { + self.destroy_messenger = Some(messenger); + } + pub(crate) fn stream_sender(&self) -> PacketSender { self.stream_sender.clone() } @@ -152,3 +157,11 @@ impl AsyncWrite for IpStackUdpStream { std::task::Poll::Ready(Ok(())) } } + +impl Drop for IpStackUdpStream { + fn drop(&mut self) { + if let Some(messenger) = self.destroy_messenger.take() { + let _ = messenger.send(()); + } + } +} From 4995594720bec536fa2fb4317acca5c61039af3d Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Tue, 25 Feb 2025 09:40:43 +0800 Subject: [PATCH 08/35] rename tcp_timeout to timeout_interval --- src/stream/tcb.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/stream/tcb.rs b/src/stream/tcb.rs index 9aad3fe..ca5656e 100644 --- a/src/stream/tcb.rs +++ b/src/stream/tcb.rs @@ -31,7 +31,7 @@ pub(super) struct Tcb { ack: u32, last_ack: u32, pub(super) timeout: Pin>, - tcp_timeout: Duration, + timeout_interval: Duration, recv_window: u16, send_window: u16, state: TcpState, @@ -41,18 +41,18 @@ pub(super) struct Tcb { } impl Tcb { - pub(super) fn new(ack: u32, tcp_timeout: Duration) -> Tcb { + pub(super) fn new(ack: u32, timeout_interval: Duration) -> Tcb { #[cfg(debug_assertions)] let seq = 100; #[cfg(not(debug_assertions))] let seq = rand::random::(); - let deadline = tokio::time::Instant::now() + tcp_timeout; + let deadline = tokio::time::Instant::now() + timeout_interval; Tcb { seq, retransmission: None, ack, last_ack: seq, - tcp_timeout, + timeout_interval, timeout: Box::pin(tokio::time::sleep_until(deadline)), send_window: u16::MAX, recv_window: 0, @@ -189,7 +189,7 @@ impl Tcb { } pub(crate) fn reset_timeout(&mut self) { - let deadline = tokio::time::Instant::now() + self.tcp_timeout; + let deadline = tokio::time::Instant::now() + self.timeout_interval; self.timeout.as_mut().reset(deadline); } } From 3f448e0966fe336f21d63c4403a4bf72ec8acf29 Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Tue, 25 Feb 2025 14:33:18 +0800 Subject: [PATCH 09/35] use SeqNum data type --- src/stream/mod.rs | 1 + src/stream/seqnum.rs | 174 +++++++++++++++++++++++++++++++++++++++++++ src/stream/tcb.rs | 87 +++++++++++++--------- src/stream/tcp.rs | 38 +++++----- 4 files changed, 245 insertions(+), 55 deletions(-) create mode 100644 src/stream/seqnum.rs diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 944e98a..931a3ad 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -4,6 +4,7 @@ pub use self::tcp_wrapper::IpStackTcpStream; pub use self::udp::IpStackUdpStream; pub use self::unknown::IpStackUnknownTransport; +mod seqnum; mod tcb; mod tcp; mod tcp_wrapper; diff --git a/src/stream/seqnum.rs b/src/stream/seqnum.rs new file mode 100644 index 0000000..131c6a7 --- /dev/null +++ b/src/stream/seqnum.rs @@ -0,0 +1,174 @@ +use std::ops::{Add, AddAssign, Sub, SubAssign}; + +const MAX_DIFF: u32 = u32::MAX / 2; + +/// A TCP sequence number that persents a 32-bit unsigned integer, suppport overflow comparison and arithmetic. +#[derive(Eq, PartialEq, Debug, Copy, Clone, Hash, Default)] +#[repr(transparent)] +pub struct SeqNum(pub u32); + +impl std::fmt::Display for SeqNum { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for SeqNum { + fn from(value: u32) -> Self { + Self(value) + } +} + +impl From for u32 { + fn from(value: SeqNum) -> Self { + value.0 + } +} + +impl From for usize { + fn from(value: SeqNum) -> Self { + value.0 as usize + } +} + +impl TryFrom for SeqNum { + type Error = std::num::TryFromIntError; + fn try_from(value: usize) -> Result { + Ok(Self(value.try_into()?)) + } +} + +impl PartialEq for SeqNum { + fn eq(&self, other: &u32) -> bool { + self.0 == *other + } +} + +impl PartialOrd for SeqNum { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl PartialOrd for SeqNum { + fn partial_cmp(&self, other: &u32) -> Option { + Some(self.cmp(&SeqNum(*other))) + } +} + +impl Ord for SeqNum { + #[inline] + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + let diff = self.0.wrapping_sub(other.0); + if diff == 0 { + std::cmp::Ordering::Equal + } else if diff < MAX_DIFF { + std::cmp::Ordering::Greater + } else { + std::cmp::Ordering::Less + } + } +} + +impl Add for SeqNum { + type Output = SeqNum; + #[inline] + fn add(self, rhs: Self) -> Self::Output { + SeqNum(self.0.wrapping_add(rhs.0)) + } +} + +impl Add for SeqNum { + type Output = SeqNum; + #[inline] + fn add(self, rhs: u32) -> Self::Output { + SeqNum(self.0.wrapping_add(rhs)) + } +} + +impl AddAssign for SeqNum { + fn add_assign(&mut self, rhs: Self) { + self.0 = self.0.wrapping_add(rhs.0) + } +} + +impl AddAssign for SeqNum { + fn add_assign(&mut self, rhs: u32) { + self.0 = self.0.wrapping_add(rhs) + } +} + +impl Sub for SeqNum { + type Output = SeqNum; + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + SeqNum(self.0.wrapping_sub(rhs.0)) + } +} + +impl Sub for SeqNum { + type Output = SeqNum; + #[inline] + fn sub(self, rhs: u32) -> Self::Output { + SeqNum(self.0.wrapping_sub(rhs)) + } +} + +impl SubAssign for SeqNum { + fn sub_assign(&mut self, rhs: Self) { + self.0 = self.0.wrapping_sub(rhs.0) + } +} + +impl SubAssign for SeqNum { + fn sub_assign(&mut self, rhs: u32) { + self.0 = self.0.wrapping_sub(rhs) + } +} + +impl SeqNum { + pub fn distance(&self, other: Self) -> u32 { + let diff = self.0.wrapping_sub(other.0); + if diff <= MAX_DIFF { + diff + } else { + u32::MAX - diff + 1 + } + } +} + +#[test] +fn test_seq_num_near_overflow() { + let a: SeqNum = (u32::MAX - 3).into(); + let b = a + 8; + + assert_eq!(a, SeqNum(4294967292)); + assert_eq!(b, SeqNum(4)); + + assert!(a < b); + assert!(b > a); + assert!(a <= b); + assert!(b >= a); + assert!(a != b); + + assert_eq!(a.distance(b), 8); + assert_eq!(b.distance(a), 8); +} + +#[test] +fn test_seq_num_near_max_diff() { + let a = SeqNum(MAX_DIFF - 1); + let mut b = SeqNum(MAX_DIFF + 1); + + assert!(a < b); + assert!(b > a); + assert_eq!(a.distance(b), 2); + + b += 3; + assert_eq!(b.distance(a), 5); + + b -= 10; + assert_eq!(b.distance(a), 5); + + assert_eq!(b, SeqNum(MAX_DIFF - 6)); +} diff --git a/src/stream/tcb.rs b/src/stream/tcb.rs index ca5656e..65cece5 100644 --- a/src/stream/tcb.rs +++ b/src/stream/tcb.rs @@ -1,3 +1,4 @@ +use super::seqnum::SeqNum; use crate::packet::TcpHeaderWrapper; use std::{collections::BTreeMap, pin::Pin, time::Duration}; use tokio::time::Sleep; @@ -26,10 +27,10 @@ pub(super) enum PacketStatus { #[derive(Debug)] pub(super) struct Tcb { - seq: u32, - pub(super) retransmission: Option, - ack: u32, - last_ack: u32, + seq: SeqNum, + pub(super) retransmission: Option, + ack: SeqNum, + last_ack: SeqNum, pub(super) timeout: Pin>, timeout_interval: Duration, recv_window: u16, @@ -37,21 +38,21 @@ pub(super) struct Tcb { state: TcpState, avg_send_window: (u64, u64), // (avg, count) pub(super) inflight_packets: Vec, - unordered_packets: BTreeMap, + unordered_packets: BTreeMap, } impl Tcb { - pub(super) fn new(ack: u32, timeout_interval: Duration) -> Tcb { + pub(super) fn new(ack: SeqNum, timeout_interval: Duration) -> Tcb { #[cfg(debug_assertions)] let seq = 100; #[cfg(not(debug_assertions))] let seq = rand::random::(); let deadline = tokio::time::Instant::now() + timeout_interval; Tcb { - seq, + seq: seq.into(), retransmission: None, ack, - last_ack: seq, + last_ack: seq.into(), timeout_interval, timeout: Box::pin(tokio::time::sleep_until(deadline)), send_window: u16::MAX, @@ -62,12 +63,12 @@ impl Tcb { unordered_packets: BTreeMap::new(), } } - pub(super) fn add_inflight_packet(&mut self, seq: u32, buf: Vec) { + pub(super) fn add_inflight_packet(&mut self, seq: SeqNum, buf: Vec) { let buf_len = buf.len() as u32; self.inflight_packets.push(InflightPacket::new(seq, buf)); - self.seq = self.seq.wrapping_add(buf_len); + self.seq += buf_len; } - pub(super) fn add_unordered_packet(&mut self, seq: u32, buf: Vec) { + pub(super) fn add_unordered_packet(&mut self, seq: SeqNum, buf: Vec) { if seq < self.ack { return; } @@ -84,18 +85,18 @@ impl Tcb { self.unordered_packets.remove(&self.ack).map(|p| p.payload) } pub(super) fn add_seq_one(&mut self) { - self.seq = self.seq.wrapping_add(1); + self.seq += 1; } - pub(super) fn get_seq(&self) -> u32 { + pub(super) fn get_seq(&self) -> SeqNum { self.seq } - pub(super) fn add_ack(&mut self, add: u32) { - self.ack = self.ack.wrapping_add(add); + pub(super) fn add_ack(&mut self, add: SeqNum) { + self.ack += add; } - pub(super) fn get_ack(&self) -> u32 { + pub(super) fn get_ack(&self) -> SeqNum { self.ack } - pub(super) fn get_last_ack(&self) -> u32 { + pub(super) fn get_last_ack(&self) -> SeqNum { self.last_ack } pub(super) fn change_state(&mut self, state: TcpState) { @@ -137,24 +138,23 @@ impl Tcb { pub(super) fn check_pkt_type(&self, header: &TcpHeaderWrapper, p: &[u8]) -> PacketStatus { let tcp_header = header.inner(); - let received_ack_distance = self.seq.wrapping_sub(tcp_header.acknowledgment_number); + let received_ack = SeqNum(tcp_header.acknowledgment_number); + let received_ack_distance = self.seq - received_ack; - let current_ack_distance = self.seq.wrapping_sub(self.last_ack); - if received_ack_distance > current_ack_distance - || (tcp_header.acknowledgment_number != self.seq && self.seq.saturating_sub(tcp_header.acknowledgment_number) == 0) - { + let current_ack_distance = self.seq - self.last_ack; + if received_ack_distance > current_ack_distance || (self.seq != received_ack && self.seq.0.saturating_sub(received_ack.0) == 0) { PacketStatus::Invalid - } else if self.last_ack == tcp_header.acknowledgment_number { + } else if self.last_ack == received_ack { if !p.is_empty() { PacketStatus::NewPacket } else if self.send_window == tcp_header.window_size && self.seq != self.last_ack { PacketStatus::RetransmissionRequest - } else if self.ack.wrapping_sub(1) == tcp_header.sequence_number { + } else if self.ack - 1 == tcp_header.sequence_number { PacketStatus::KeepAlive } else { PacketStatus::WindowUpdate } - } else if self.last_ack < tcp_header.acknowledgment_number { + } else if self.last_ack < received_ack { if !p.is_empty() { PacketStatus::NewPacket } else { @@ -164,14 +164,14 @@ impl Tcb { PacketStatus::Invalid } } - pub(super) fn change_last_ack(&mut self, ack: u32) { - let distance = ack.wrapping_sub(self.last_ack); - self.last_ack = self.last_ack.wrapping_add(distance); + + pub(super) fn change_last_ack(&mut self, ack: SeqNum) { + self.last_ack = ack; if self.state == TcpState::Established { - if let Some(i) = self.inflight_packets.iter().position(|p| p.contains(ack)) { + if let Some(i) = self.inflight_packets.iter().position(|p| p.contains_seq_num(ack - 1)) { let mut inflight_packet = self.inflight_packets.remove(i); - let distance = ack.wrapping_sub(inflight_packet.seq) as usize; + let distance = (ack - inflight_packet.seq).0 as usize; if distance < inflight_packet.payload.len() { inflight_packet.payload.drain(0..distance); inflight_packet.seq = ack; @@ -179,13 +179,13 @@ impl Tcb { } } self.inflight_packets.retain(|p| { - let last_byte = p.seq.wrapping_add(p.payload.len() as u32); - last_byte.saturating_sub(self.last_ack) > 0 + let last_byte = p.seq + (p.payload.len() as u32); + last_byte > self.last_ack }); } } pub fn is_send_buffer_full(&self) -> bool { - self.seq.wrapping_sub(self.last_ack) >= MAX_UNACK + (self.seq - self.last_ack).0 >= MAX_UNACK } pub(crate) fn reset_timeout(&mut self) { @@ -196,24 +196,37 @@ impl Tcb { #[derive(Debug)] pub struct InflightPacket { - pub seq: u32, + pub seq: SeqNum, pub payload: Vec, // pub send_time: SystemTime, // todo } impl InflightPacket { - fn new(seq: u32, payload: Vec) -> Self { + fn new(seq: SeqNum, payload: Vec) -> Self { Self { seq, payload, // send_time: SystemTime::now(), // todo } } - pub(crate) fn contains(&self, seq: u32) -> bool { - self.seq < seq && seq <= self.seq + self.payload.len() as u32 + pub(crate) fn contains_seq_num(&self, seq: SeqNum) -> bool { + self.seq <= seq && seq < self.seq + self.payload.len() as u32 } } +#[test] +fn test_in_flight_packet() { + let p = InflightPacket::new((u32::MAX - 1).into(), vec![10, 20, 30, 40, 50]); + + assert!(p.contains_seq_num((u32::MAX - 1).into())); + assert!(p.contains_seq_num(u32::MAX.into())); + assert!(p.contains_seq_num(0.into())); + assert!(p.contains_seq_num(1.into())); + assert!(p.contains_seq_num(2.into())); + + assert!(!p.contains_seq_num(3.into())); +} + #[derive(Debug)] struct UnorderedPacket { payload: Vec, diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index a8fceb9..faf2ec7 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -1,3 +1,4 @@ +use super::seqnum::SeqNum; use crate::{ error::IpStackError, packet::{ @@ -69,7 +70,7 @@ impl IpStackTcpStream { stream_receiver, up_packet_sender, packet_to_send: None, - tcb: Tcb::new(tcp.inner().sequence_number + 1, tcp_timeout), + tcb: Tcb::new(SeqNum(tcp.inner().sequence_number) + 1, tcp_timeout), mtu, shutdown: Shutdown::None, write_notify: None, @@ -97,11 +98,11 @@ impl IpStackTcpStream { let mut tcp_header = etherparse::TcpHeader::new( self.dst_addr.port(), self.src_addr.port(), - seq.into().unwrap_or(self.tcb.get_seq()), + seq.into().unwrap_or(self.tcb.get_seq().0), self.tcb.get_recv_window(), ); - tcp_header.acknowledgment_number = self.tcb.get_ack(); + tcp_header.acknowledgment_number = self.tcb.get_ack().0; tcp_header.syn = flags & SYN != 0; tcp_header.ack = flags & ACK != 0; tcp_header.rst = flags & RST != 0; @@ -205,7 +206,8 @@ impl AsyncRead for IpStackTcpStream { } if let Some(b) = self.tcb.get_unordered_packets().filter(|_| matches!(self.shutdown, Shutdown::None)) { - self.tcb.add_ack(b.len() as u32); + use std::io::{Error, ErrorKind::Other}; + self.tcb.add_ack(b.len().try_into().map_err(|e| Error::new(Other, e))?); buf.put_slice(&b); self.up_packet_sender .send(self.create_rev_packet(ACK, TTL, None, Vec::new())?) @@ -215,7 +217,7 @@ impl AsyncRead for IpStackTcpStream { if self.tcb.get_state() == TcpState::FinWait1(true) { self.packet_to_send = Some(self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?); self.tcb.add_seq_one(); - self.tcb.add_ack(1); + self.tcb.add_ack(1.into()); self.tcb.change_state(TcpState::FinWait2(true)); continue; } else if matches!(self.shutdown, Shutdown::Pending(_)) @@ -245,7 +247,7 @@ impl AsyncRead for IpStackTcpStream { if self.tcb.get_state() == TcpState::SynReceived(true) { if t.flags() == ACK { - self.tcb.change_last_ack(t.inner().acknowledgment_number); + self.tcb.change_last_ack(t.inner().acknowledgment_number.into()); self.tcb.change_send_window(t.inner().window_size); self.tcb.change_state(TcpState::Established); } @@ -262,14 +264,14 @@ impl AsyncRead for IpStackTcpStream { } PacketStatus::Invalid => continue, PacketStatus::KeepAlive => { - self.tcb.change_last_ack(t.inner().acknowledgment_number); + self.tcb.change_last_ack(t.inner().acknowledgment_number.into()); self.tcb.change_send_window(t.inner().window_size); self.packet_to_send = Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); continue; } PacketStatus::RetransmissionRequest => { self.tcb.change_send_window(t.inner().window_size); - self.tcb.retransmission = Some(t.inner().acknowledgment_number); + self.tcb.retransmission = Some(t.inner().acknowledgment_number.into()); if matches!(self.as_mut().poll_flush(cx), Poll::Pending) { return Poll::Pending; } @@ -287,8 +289,8 @@ impl AsyncRead for IpStackTcpStream { // continue; // } - self.tcb.change_last_ack(t.inner().acknowledgment_number); - self.tcb.add_unordered_packet(t.inner().sequence_number, p.payload); + self.tcb.change_last_ack(t.inner().acknowledgment_number.into()); + self.tcb.add_unordered_packet(t.inner().sequence_number.into(), p.payload); self.tcb.change_send_window(t.inner().window_size); if let Some(ref n) = self.write_notify { @@ -298,7 +300,7 @@ impl AsyncRead for IpStackTcpStream { continue; } PacketStatus::Ack => { - self.tcb.change_last_ack(t.inner().acknowledgment_number); + self.tcb.change_last_ack(t.inner().acknowledgment_number.into()); self.tcb.change_send_window(t.inner().window_size); if let Some(ref n) = self.write_notify { n.wake_by_ref(); @@ -309,7 +311,7 @@ impl AsyncRead for IpStackTcpStream { }; } if t.flags() == (FIN | ACK) { - self.tcb.add_ack(1); + self.tcb.add_ack(1.into()); self.packet_to_send = Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); self.tcb.change_state(TcpState::FinWait1(true)); continue; @@ -318,7 +320,7 @@ impl AsyncRead for IpStackTcpStream { if !matches!(self.tcb.check_pkt_type(&t, &p.payload), PacketStatus::NewPacket) { continue; } - self.tcb.change_last_ack(t.inner().acknowledgment_number); + self.tcb.change_last_ack(t.inner().acknowledgment_number.into()); if p.payload.is_empty() || self.tcb.get_ack() != t.inner().sequence_number { continue; @@ -326,17 +328,17 @@ impl AsyncRead for IpStackTcpStream { self.tcb.change_send_window(t.inner().window_size); - self.tcb.add_unordered_packet(t.inner().sequence_number, p.payload); + self.tcb.add_unordered_packet(t.inner().sequence_number.into(), p.payload); continue; } } else if self.tcb.get_state() == TcpState::FinWait1(false) { if t.flags() == ACK { - self.tcb.change_last_ack(t.inner().acknowledgment_number); - self.tcb.add_ack(1); + self.tcb.change_last_ack(t.inner().acknowledgment_number.into()); + self.tcb.add_ack(1.into()); self.tcb.change_state(TcpState::FinWait2(true)); continue; } else if t.flags() == (FIN | ACK) { - self.tcb.add_ack(1); + self.tcb.add_ack(1.into()); self.packet_to_send = Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); self.tcb.change_send_window(t.inner().window_size); self.tcb.change_state(TcpState::FinWait2(true)); @@ -394,7 +396,7 @@ impl AsyncWrite for IpStackTcpStream { if let Some(s) = self.tcb.retransmission.take() { if let Some(packet) = self.tcb.inflight_packets.iter().find(|p| p.seq == s) { - let rev_packet = self.create_rev_packet(PSH | ACK, TTL, packet.seq, packet.payload.clone())?; + let rev_packet = self.create_rev_packet(PSH | ACK, TTL, packet.seq.0, packet.payload.clone())?; self.up_packet_sender.send(rev_packet).or(Err(ErrorKind::UnexpectedEof))?; } else { From afa00cf511c37eaf475e4b4fb63c129599c291be Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Tue, 25 Feb 2025 21:30:23 +0800 Subject: [PATCH 10/35] rename udp_timeout to timeout_interval --- src/stream/tcp.rs | 4 ++-- src/stream/tcp_wrapper.rs | 4 ++-- src/stream/udp.rs | 10 +++++----- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index faf2ec7..5f26737 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -62,7 +62,7 @@ impl IpStackTcpStream { up_packet_sender: PacketSender, stream_receiver: PacketReceiver, mtu: u16, - tcp_timeout: Duration, + timeout_interval: Duration, ) -> Result { let stream = IpStackTcpStream { src_addr, @@ -70,7 +70,7 @@ impl IpStackTcpStream { stream_receiver, up_packet_sender, packet_to_send: None, - tcb: Tcb::new(SeqNum(tcp.inner().sequence_number) + 1, tcp_timeout), + tcb: Tcb::new(SeqNum(tcp.inner().sequence_number) + 1, timeout_interval), mtu, shutdown: Shutdown::None, write_notify: None, diff --git a/src/stream/tcp_wrapper.rs b/src/stream/tcp_wrapper.rs index a3695f5..88698e3 100644 --- a/src/stream/tcp_wrapper.rs +++ b/src/stream/tcp_wrapper.rs @@ -20,10 +20,10 @@ impl IpStackTcpStream { tcp: TcpHeaderWrapper, pkt_sender: PacketSender, mtu: u16, - tcp_timeout: Duration, + timeout_interval: Duration, ) -> Result { let (stream_sender, stream_receiver) = mpsc::unbounded_channel::(); - IpStackTcpStreamInner::new(local_addr, peer_addr, tcp, pkt_sender, stream_receiver, mtu, tcp_timeout).map(|inner| { + IpStackTcpStreamInner::new(local_addr, peer_addr, tcp, pkt_sender, stream_receiver, mtu, timeout_interval).map(|inner| { IpStackTcpStream { inner: Some(Box::new(inner)), peer_addr, diff --git a/src/stream/udp.rs b/src/stream/udp.rs index c2cedb1..fbe9b94 100644 --- a/src/stream/udp.rs +++ b/src/stream/udp.rs @@ -18,7 +18,7 @@ pub struct IpStackUdpStream { up_pkt_sender: PacketSender, first_payload: Option>, timeout: Pin>, - udp_timeout: Duration, + timeout_interval: Duration, mtu: u16, destroy_messenger: Option>, } @@ -30,10 +30,10 @@ impl IpStackUdpStream { payload: Vec, up_pkt_sender: PacketSender, mtu: u16, - udp_timeout: Duration, + timeout_interval: Duration, ) -> Self { let (stream_sender, stream_receiver) = mpsc::unbounded_channel::(); - let deadline = tokio::time::Instant::now() + udp_timeout; + let deadline = tokio::time::Instant::now() + timeout_interval; IpStackUdpStream { src_addr, dst_addr, @@ -42,7 +42,7 @@ impl IpStackUdpStream { up_pkt_sender, first_payload: Some(payload), timeout: Box::pin(tokio::time::sleep_until(deadline)), - udp_timeout, + timeout_interval, mtu, destroy_messenger: None, } @@ -108,7 +108,7 @@ impl IpStackUdpStream { } fn reset_timeout(&mut self) { - let deadline = tokio::time::Instant::now() + self.udp_timeout; + let deadline = tokio::time::Instant::now() + self.timeout_interval; self.timeout.as_mut().reset(deadline); } } From 95ec6f6e417a64d0503282cf8788534c797e6624 Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Wed, 26 Feb 2025 20:17:42 +0800 Subject: [PATCH 11/35] use SeqNum in create_rev_packet --- src/stream/tcp.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index 5f26737..056db56 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -94,11 +94,11 @@ impl IpStackTcpStream { ) } - fn create_rev_packet(&self, flags: u8, ttl: u8, seq: impl Into>, mut payload: Vec) -> Result { + fn create_rev_packet(&self, flags: u8, ttl: u8, seq: impl Into>, mut payload: Vec) -> Result { let mut tcp_header = etherparse::TcpHeader::new( self.dst_addr.port(), self.src_addr.port(), - seq.into().unwrap_or(self.tcb.get_seq().0), + seq.into().unwrap_or(self.tcb.get_seq()).0, self.tcb.get_recv_window(), ); @@ -396,7 +396,7 @@ impl AsyncWrite for IpStackTcpStream { if let Some(s) = self.tcb.retransmission.take() { if let Some(packet) = self.tcb.inflight_packets.iter().find(|p| p.seq == s) { - let rev_packet = self.create_rev_packet(PSH | ACK, TTL, packet.seq.0, packet.payload.clone())?; + let rev_packet = self.create_rev_packet(PSH | ACK, TTL, packet.seq, packet.payload.clone())?; self.up_packet_sender.send(rev_packet).or(Err(ErrorKind::UnexpectedEof))?; } else { From b55293a436457b679d28bdeb1cbcee08536e3e8e Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Wed, 26 Feb 2025 21:05:12 +0800 Subject: [PATCH 12/35] rename calculate_payload_len to calculate_payload_max_len --- src/stream/tcp.rs | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index 056db56..e6802b5 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -87,7 +87,7 @@ impl IpStackTcpStream { Err(IpStackError::InvalidTcpPacket(tcp.clone())) } - fn calculate_payload_len(&self, ip_header_size: u16, tcp_header_size: u16) -> u16 { + fn calculate_payload_max_len(&self, ip_header_size: u16, tcp_header_size: u16) -> u16 { cmp::min( self.tcb.get_send_window(), self.mtu.saturating_sub(ip_header_size + tcp_header_size), @@ -111,11 +111,12 @@ impl IpStackTcpStream { let ip_header = match (self.dst_addr.ip(), self.src_addr.ip()) { (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => { - let mut ip_h = Ipv4Header::new(0, ttl, IpNumber::TCP, dst.octets(), src.octets()).map_err(IpStackError::from)?; - let payload_len = self.calculate_payload_len(ip_h.header_len() as u16, tcp_header.header_len() as u16); + let mut ip_h = Ipv4Header::new(0, ttl, IpNumber::TCP, dst.octets(), src.octets()) + .map_err(|e| Error::new(ErrorKind::InvalidInput, e))?; + let payload_len = self.calculate_payload_max_len(ip_h.header_len() as u16, tcp_header.header_len() as u16); payload.truncate(payload_len as usize); ip_h.set_payload_len(payload.len() + tcp_header.header_len()) - .map_err(IpStackError::from)?; + .map_err(|e| Error::new(ErrorKind::InvalidInput, e))?; ip_h.dont_fragment = true; IpHeader::Ipv4(ip_h) } @@ -129,10 +130,10 @@ impl IpStackTcpStream { source: dst.octets(), destination: src.octets(), }; - let payload_len = self.calculate_payload_len(ip_h.header_len() as u16, tcp_header.header_len() as u16); + let payload_len = self.calculate_payload_max_len(ip_h.header_len() as u16, tcp_header.header_len() as u16); payload.truncate(payload_len as usize); let len = payload.len() + tcp_header.header_len(); - ip_h.set_payload_length(len).map_err(IpStackError::from)?; + ip_h.set_payload_length(len).map_err(|e| Error::new(ErrorKind::InvalidInput, e))?; IpHeader::Ipv6(ip_h) } @@ -143,12 +144,12 @@ impl IpStackTcpStream { IpHeader::Ipv4(ref ip_header) => { tcp_header.checksum = tcp_header .calc_checksum_ipv4(ip_header, &payload) - .or(Err(ErrorKind::InvalidInput))?; + .map_err(|e| Error::new(ErrorKind::InvalidInput, e))?; } IpHeader::Ipv6(ref ip_header) => { tcp_header.checksum = tcp_header .calc_checksum_ipv6(ip_header, &payload) - .or(Err(ErrorKind::InvalidInput))?; + .map_err(|e| Error::new(ErrorKind::InvalidInput, e))?; } } Ok(NetworkPacket { From d1e17bbef2ec4034abb1d5c089c28961f4228bef Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Fri, 28 Feb 2025 18:24:20 +0800 Subject: [PATCH 13/35] re-format code --- Cargo.toml | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 64d9971..823dca0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,31 +11,28 @@ readme = "README.md" [dependencies] ahash = "0.8" -tokio = { version = "1.43", features = [ +etherparse = { version = "0.17", default-features = false, features = ["std"] } +log = { version = "0.4", default-features = false } +rand = { version = "0.9", default-features = false, features = ["thread_rng"] } +thiserror = { version = "2.0", default-features = false } +tokio = { version = "1.43", default-features = false, features = [ "sync", "rt", "time", "io-util", "macros", -], default-features = false } -etherparse = { version = "0.17", default-features = false, features = ["std"] } -thiserror = { version = "2.0", default-features = false } -log = { version = "0.4", default-features = false } -rand = { version = "0.9", default-features = false, features = ["thread_rng"] } +] } [dev-dependencies] -tokio = { version = "1.43", features = [ - "rt-multi-thread", -], default-features = false } clap = { version = "4.5", features = ["derive"] } +criterion = { version = "0.5" } # Benchmarks env_logger = "0.11" +tokio = { version = "1.43", default-features = false, features = [ + "rt-multi-thread", +] } +tun = { version = "0.7.13", default-features = false, features = ["async"] } udp-stream = { version = "0.0", default-features = false } -# Benchmarks -criterion = { version = "0.5" } - -tun = { version = "0.7.13", features = ["async"], default-features = false } - [target.'cfg(target_os = "windows")'.dev-dependencies] wintun = { version = "0.5", default-features = false } From 8beadd0b9730cee917ca910e2394ee377f2d40a7 Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Fri, 28 Feb 2025 18:54:35 +0800 Subject: [PATCH 14/35] refactor TCP logic --- src/error.rs | 4 +- src/lib.rs | 4 +- src/packet.rs | 6 +++ src/stream/tcb.rs | 20 ++++++--- src/stream/tcp.rs | 89 +++++++++++++++++++-------------------- src/stream/tcp_wrapper.rs | 8 ++-- src/stream/udp.rs | 1 + 7 files changed, 70 insertions(+), 62 deletions(-) diff --git a/src/error.rs b/src/error.rs index 319a439..360badd 100644 --- a/src/error.rs +++ b/src/error.rs @@ -13,8 +13,8 @@ pub enum IpStackError { #[error("ValueTooBigError {0}")] ValueTooBigErrorUsize(#[from] etherparse::err::ValueTooBigError), - #[error("Invalid Tcp packet {0}")] - InvalidTcpPacket(crate::packet::TcpHeaderWrapper), + #[error("Invalid Tcp packet")] + InvalidTcpPacket, #[error("IO error: {0}")] IoError(#[from] std::io::Error), diff --git a/src/lib.rs b/src/lib.rs index c3d51e6..181f20d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,7 +23,6 @@ mod packet; pub mod stream; pub use self::error::{IpStackError, Result}; -pub use self::packet::TcpHeaderWrapper; pub use ::etherparse::IpNumber; const DROP_TTL: u8 = 0; @@ -203,8 +202,7 @@ fn create_stream(packet: NetworkPacket, cfg: &IpStackConfig, up_pkt_sender: Pack let dst_addr = packet.dst_addr(); match packet.transport_header() { TransportHeader::Tcp(h) => { - let h: TcpHeaderWrapper = h.into(); - let stream = IpStackTcpStream::new(src_addr, dst_addr, h, up_pkt_sender, cfg.mtu, cfg.tcp_timeout)?; + let stream = IpStackTcpStream::new(src_addr, dst_addr, h.clone(), up_pkt_sender, cfg.mtu, cfg.tcp_timeout)?; Ok((stream.stream_sender(), IpStackStream::Tcp(stream))) } TransportHeader::Udp(_) => { diff --git a/src/packet.rs b/src/packet.rs index 380475b..d283f2b 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -208,6 +208,12 @@ impl TcpHeaderWrapper { } } +impl From for TcpHeaderWrapper { + fn from(header: TcpHeader) -> Self { + TcpHeaderWrapper { header } + } +} + impl From<&TcpHeader> for TcpHeaderWrapper { fn from(header: &TcpHeader) -> Self { TcpHeaderWrapper { header: header.clone() } diff --git a/src/stream/tcb.rs b/src/stream/tcb.rs index 65cece5..a8c6073 100644 --- a/src/stream/tcb.rs +++ b/src/stream/tcb.rs @@ -1,17 +1,24 @@ use super::seqnum::SeqNum; -use crate::packet::TcpHeaderWrapper; +use etherparse::TcpHeader; use std::{collections::BTreeMap, pin::Pin, time::Duration}; use tokio::time::Sleep; const MAX_UNACK: u32 = 1024 * 16; // 16KB const READ_BUFFER_SIZE: usize = 1024 * 16; // 16KB +#[allow(dead_code)] #[derive(Debug, PartialEq, Clone, Copy)] -pub enum TcpState { - SynReceived(bool), // bool means if syn/ack is sent +pub(crate) enum TcpState { + Init, /* since we always act as a server beginning from Listen, so we needn't the state Init & SynSent */ + SynSent, + Listen, + SynReceived, Established, - FinWait1(bool), + FinWait1(bool), // act as a client, followed with FinWait2, TimeWait, Closed FinWait2(bool), // bool means waiting for ack + TimeWait, + CloseWait, // act as a server, followed with LastAck, Closed + LastAck, Closed, } @@ -57,7 +64,7 @@ impl Tcb { timeout: Box::pin(tokio::time::sleep_until(deadline)), send_window: u16::MAX, recv_window: 0, - state: TcpState::SynReceived(false), + state: TcpState::Listen, avg_send_window: (1, 1), inflight_packets: Vec::new(), unordered_packets: BTreeMap::new(), @@ -136,8 +143,7 @@ impl Tcb { // } // } - pub(super) fn check_pkt_type(&self, header: &TcpHeaderWrapper, p: &[u8]) -> PacketStatus { - let tcp_header = header.inner(); + pub(super) fn check_pkt_type(&self, tcp_header: &TcpHeader, p: &[u8]) -> PacketStatus { let received_ack = SeqNum(tcp_header.acknowledgment_number); let received_ack_distance = self.seq - received_ack; diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index e6802b5..3a70593 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -8,7 +8,7 @@ use crate::{ stream::tcb::{PacketStatus, Tcb, TcpState}, PacketReceiver, PacketSender, DROP_TTL, TTL, }; -use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel}; +use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel, TcpHeader}; use log::{error, trace, warn}; use std::{ cmp, @@ -58,7 +58,7 @@ impl IpStackTcpStream { pub(crate) fn new( src_addr: SocketAddr, dst_addr: SocketAddr, - tcp: TcpHeaderWrapper, + tcp: TcpHeader, up_packet_sender: PacketSender, stream_receiver: PacketReceiver, mtu: u16, @@ -70,21 +70,22 @@ impl IpStackTcpStream { stream_receiver, up_packet_sender, packet_to_send: None, - tcb: Tcb::new(SeqNum(tcp.inner().sequence_number) + 1, timeout_interval), + tcb: Tcb::new(SeqNum(tcp.sequence_number) + 1, timeout_interval), mtu, shutdown: Shutdown::None, write_notify: None, }; - if tcp.inner().syn { + if tcp.syn { return Ok(stream); } - if !tcp.inner().rst { + if !tcp.rst { let pkt = stream.create_rev_packet(RST | ACK, TTL, None, Vec::new())?; if let Err(err) = stream.up_packet_sender.send(pkt) { warn!("Error sending RST/ACK packet: {:?}", err); } } - Err(IpStackError::InvalidTcpPacket(tcp.clone())) + let info = format!("Invalid TCP packet: {:?}", TcpHeaderWrapper::from(tcp)); + Err(IpStackError::IoError(Error::new(ErrorKind::ConnectionRefused, info))) } fn calculate_payload_max_len(&self, ip_header_size: u16, tcp_header_size: u16) -> u16 { @@ -199,10 +200,10 @@ impl AsyncRead for IpStackTcpStream { } self.tcb.reset_timeout(); - if self.tcb.get_state() == TcpState::SynReceived(false) { + if self.tcb.get_state() == TcpState::Listen { self.packet_to_send = Some(self.create_rev_packet(SYN | ACK, TTL, None, Vec::new())?); self.tcb.add_seq_one(); - self.tcb.change_state(TcpState::SynReceived(true)); + self.tcb.change_state(TcpState::SynReceived); continue; } @@ -236,51 +237,52 @@ impl AsyncRead for IpStackTcpStream { unreachable!() }; let t: TcpHeaderWrapper = tcp_header.into(); + let tcp_header = t.inner(); + let incoming_ack: SeqNum = tcp_header.acknowledgment_number.into(); if t.flags() & RST != 0 { self.packet_to_send = Some(self.create_rev_packet(NON, DROP_TTL, None, Vec::new())?); self.tcb.change_state(TcpState::Closed); self.shutdown.ready(); return Poll::Ready(Err(Error::from(ErrorKind::ConnectionReset))); } - if self.tcb.check_pkt_type(&t, &p.payload) == PacketStatus::Invalid { + if self.tcb.check_pkt_type(tcp_header, &p.payload) == PacketStatus::Invalid { continue; } - if self.tcb.get_state() == TcpState::SynReceived(true) { + if self.tcb.get_state() == TcpState::SynReceived { if t.flags() == ACK { - self.tcb.change_last_ack(t.inner().acknowledgment_number.into()); - self.tcb.change_send_window(t.inner().window_size); + self.tcb.change_last_ack(incoming_ack); + self.tcb.change_send_window(tcp_header.window_size); self.tcb.change_state(TcpState::Established); } } else if self.tcb.get_state() == TcpState::Established { if t.flags() == ACK { - match self.tcb.check_pkt_type(&t, &p.payload) { + match self.tcb.check_pkt_type(tcp_header, &p.payload) { PacketStatus::WindowUpdate => { - self.tcb.change_send_window(t.inner().window_size); - if let Some(ref n) = self.write_notify { - n.wake_by_ref(); - self.write_notify = None; - }; + self.tcb.change_send_window(tcp_header.window_size); + if let Some(waker) = self.write_notify.take() { + waker.wake_by_ref(); + } continue; } PacketStatus::Invalid => continue, PacketStatus::KeepAlive => { - self.tcb.change_last_ack(t.inner().acknowledgment_number.into()); - self.tcb.change_send_window(t.inner().window_size); + self.tcb.change_last_ack(incoming_ack); + self.tcb.change_send_window(tcp_header.window_size); self.packet_to_send = Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); continue; } PacketStatus::RetransmissionRequest => { - self.tcb.change_send_window(t.inner().window_size); - self.tcb.retransmission = Some(t.inner().acknowledgment_number.into()); + self.tcb.change_send_window(tcp_header.window_size); + self.tcb.retransmission = Some(incoming_ack); if matches!(self.as_mut().poll_flush(cx), Poll::Pending) { return Poll::Pending; } continue; } PacketStatus::NewPacket => { - // if t.inner().sequence_number != self.tcb.get_ack() { - // dbg!(t.inner().sequence_number); + // if tcp_header.sequence_number != self.tcb.get_ack() { + // dbg!(tcp_header.sequence_number); // self.packet_to_send = Some(self.create_rev_packet( // ACK, // TTL, @@ -290,23 +292,21 @@ impl AsyncRead for IpStackTcpStream { // continue; // } - self.tcb.change_last_ack(t.inner().acknowledgment_number.into()); - self.tcb.add_unordered_packet(t.inner().sequence_number.into(), p.payload); + self.tcb.change_last_ack(incoming_ack); + self.tcb.add_unordered_packet(tcp_header.sequence_number.into(), p.payload); - self.tcb.change_send_window(t.inner().window_size); - if let Some(ref n) = self.write_notify { - n.wake_by_ref(); - self.write_notify = None; - }; + self.tcb.change_send_window(tcp_header.window_size); + if let Some(waker) = self.write_notify.take() { + waker.wake_by_ref(); + } continue; } PacketStatus::Ack => { - self.tcb.change_last_ack(t.inner().acknowledgment_number.into()); - self.tcb.change_send_window(t.inner().window_size); - if let Some(ref n) = self.write_notify { - n.wake_by_ref(); - self.write_notify = None; - }; + self.tcb.change_last_ack(incoming_ack); + self.tcb.change_send_window(tcp_header.window_size); + if let Some(waker) = self.write_notify.take() { + waker.wake_by_ref(); + } continue; } }; @@ -318,30 +318,30 @@ impl AsyncRead for IpStackTcpStream { continue; } if t.flags() == (PSH | ACK) { - if !matches!(self.tcb.check_pkt_type(&t, &p.payload), PacketStatus::NewPacket) { + if !matches!(self.tcb.check_pkt_type(tcp_header, &p.payload), PacketStatus::NewPacket) { continue; } - self.tcb.change_last_ack(t.inner().acknowledgment_number.into()); + self.tcb.change_last_ack(incoming_ack); - if p.payload.is_empty() || self.tcb.get_ack() != t.inner().sequence_number { + if p.payload.is_empty() || self.tcb.get_ack() != tcp_header.sequence_number { continue; } - self.tcb.change_send_window(t.inner().window_size); + self.tcb.change_send_window(tcp_header.window_size); - self.tcb.add_unordered_packet(t.inner().sequence_number.into(), p.payload); + self.tcb.add_unordered_packet(tcp_header.sequence_number.into(), p.payload); continue; } } else if self.tcb.get_state() == TcpState::FinWait1(false) { if t.flags() == ACK { - self.tcb.change_last_ack(t.inner().acknowledgment_number.into()); + self.tcb.change_last_ack(incoming_ack); self.tcb.add_ack(1.into()); self.tcb.change_state(TcpState::FinWait2(true)); continue; } else if t.flags() == (FIN | ACK) { self.tcb.add_ack(1.into()); self.packet_to_send = Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); - self.tcb.change_send_window(t.inner().window_size); + self.tcb.change_send_window(tcp_header.window_size); self.tcb.change_state(TcpState::FinWait2(true)); continue; } @@ -398,7 +398,6 @@ impl AsyncWrite for IpStackTcpStream { if let Some(s) = self.tcb.retransmission.take() { if let Some(packet) = self.tcb.inflight_packets.iter().find(|p| p.seq == s) { let rev_packet = self.create_rev_packet(PSH | ACK, TTL, packet.seq, packet.payload.clone())?; - self.up_packet_sender.send(rev_packet).or(Err(ErrorKind::UnexpectedEof))?; } else { error!("Packet {} not found in inflight_packets", s); diff --git a/src/stream/tcp_wrapper.rs b/src/stream/tcp_wrapper.rs index 88698e3..282841e 100644 --- a/src/stream/tcp_wrapper.rs +++ b/src/stream/tcp_wrapper.rs @@ -1,8 +1,6 @@ use super::tcp::IpStackTcpStream as IpStackTcpStreamInner; -use crate::{ - packet::{NetworkPacket, TcpHeaderWrapper}, - IpStackError, PacketSender, -}; +use crate::{packet::NetworkPacket, IpStackError, PacketSender}; +use etherparse::TcpHeader; use std::{net::SocketAddr, pin::Pin, time::Duration}; use tokio::{io::AsyncWriteExt, sync::mpsc, time::timeout}; @@ -17,7 +15,7 @@ impl IpStackTcpStream { pub(crate) fn new( local_addr: SocketAddr, peer_addr: SocketAddr, - tcp: TcpHeaderWrapper, + tcp: TcpHeader, pkt_sender: PacketSender, mtu: u16, timeout_interval: Duration, diff --git a/src/stream/udp.rs b/src/stream/udp.rs index fbe9b94..e615e7e 100644 --- a/src/stream/udp.rs +++ b/src/stream/udp.rs @@ -10,6 +10,7 @@ use tokio::{ time::Sleep, }; +#[derive(Debug)] pub struct IpStackUdpStream { src_addr: SocketAddr, dst_addr: SocketAddr, From e8ba6f7ef58c438e9e81c78f833fc0a125e89aea Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Sun, 2 Mar 2025 15:15:32 +0800 Subject: [PATCH 15/35] remove the logic of DROP_TTL --- src/lib.rs | 20 ++++++++++---------- src/packet.rs | 1 - src/stream/tcp.rs | 18 ++++++++++-------- src/stream/tcp_wrapper.rs | 7 +++++++ 4 files changed, 27 insertions(+), 19 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 181f20d..e6d9bd3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,8 +25,6 @@ pub mod stream; pub use self::error::{IpStackError, Result}; pub use ::etherparse::IpNumber; -const DROP_TTL: u8 = 0; - #[cfg(unix)] const TTL: u8 = 64; @@ -133,7 +131,6 @@ fn run( Some(packet) = up_pkt_receiver.recv() => { process_upstream_recv( packet, - sessions.clone(), &mut device, #[cfg(unix)] pi, @@ -181,6 +178,16 @@ async fn process_device_read( } Vacant(entry) => { let (packet_sender, mut ip_stack_stream) = create_stream(packet, config, up_pkt_sender)?; + if let IpStackStream::Tcp(ref mut stream) = ip_stack_stream { + let (tx, rx) = tokio::sync::oneshot::channel::<()>(); + stream.set_destroy_messenger(tx); + let sessions_clone = sessions_clone.clone(); + tokio::spawn(async move { + rx.await.ok(); + sessions_clone.lock().await.remove(&network_tuple); + log::trace!("session removed: {}", network_tuple); + }); + } if let IpStackStream::Udp(ref mut stream) = ip_stack_stream { let (tx, rx) = tokio::sync::oneshot::channel::<()>(); stream.set_destroy_messenger(tx); @@ -217,16 +224,9 @@ fn create_stream(packet: NetworkPacket, cfg: &IpStackConfig, up_pkt_sender: Pack async fn process_upstream_recv( up_packet: NetworkPacket, - sessions: SessionCollection, device: &mut Device, #[cfg(unix)] packet_information: bool, ) -> Result<()> { - if up_packet.ttl() == DROP_TTL { - let network_tuple = up_packet.reverse_network_tuple(); - sessions.lock().await.remove(&network_tuple); - log::trace!("session removed: {}", network_tuple); - return Ok(()); - } #[allow(unused_mut)] let Ok(mut packet_bytes) = up_packet.to_bytes() else { log::trace!("to_bytes error"); diff --git a/src/packet.rs b/src/packet.rs index d283f2b..8572bb8 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -25,7 +25,6 @@ pub mod tcp_flags { pub const RST: u8 = 0b00000100; pub const SYN: u8 = 0b00000010; pub const FIN: u8 = 0b00000001; - pub const NON: u8 = 0b00000000; } #[derive(Debug, Clone)] diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index 3a70593..25505a4 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -2,11 +2,11 @@ use super::seqnum::SeqNum; use crate::{ error::IpStackError, packet::{ - tcp_flags::{ACK, FIN, NON, PSH, RST, SYN}, + tcp_flags::{ACK, FIN, PSH, RST, SYN}, IpHeader, NetworkPacket, TcpHeaderWrapper, TransportHeader, }, stream::tcb::{PacketStatus, Tcb, TcpState}, - PacketReceiver, PacketSender, DROP_TTL, TTL, + PacketReceiver, PacketSender, TTL, }; use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel, TcpHeader}; use log::{error, trace, warn}; @@ -52,6 +52,7 @@ pub(crate) struct IpStackTcpStream { mtu: u16, shutdown: Shutdown, write_notify: Option, + destroy_messenger: Option>, } impl IpStackTcpStream { @@ -74,6 +75,7 @@ impl IpStackTcpStream { mtu, shutdown: Shutdown::None, write_notify: None, + destroy_messenger: None, }; if tcp.syn { return Ok(stream); @@ -88,6 +90,10 @@ impl IpStackTcpStream { Err(IpStackError::IoError(Error::new(ErrorKind::ConnectionRefused, info))) } + pub(crate) fn set_destroy_messenger(&mut self, messenger: tokio::sync::oneshot::Sender<()>) { + self.destroy_messenger = Some(messenger); + } + fn calculate_payload_max_len(&self, ip_header_size: u16, tcp_header_size: u16) -> u16 { cmp::min( self.tcb.get_send_window(), @@ -180,7 +186,6 @@ impl AsyncRead for IpStackTcpStream { } if self.tcb.get_state() == TcpState::FinWait2(false) { - self.packet_to_send = Some(self.create_rev_packet(NON, DROP_TTL, None, Vec::new())?); self.tcb.change_state(TcpState::Closed); self.shutdown.ready(); return Poll::Ready(Err(Error::from(ErrorKind::ConnectionAborted))); @@ -240,7 +245,6 @@ impl AsyncRead for IpStackTcpStream { let tcp_header = t.inner(); let incoming_ack: SeqNum = tcp_header.acknowledgment_number.into(); if t.flags() & RST != 0 { - self.packet_to_send = Some(self.create_rev_packet(NON, DROP_TTL, None, Vec::new())?); self.tcb.change_state(TcpState::Closed); self.shutdown.ready(); return Poll::Ready(Err(Error::from(ErrorKind::ConnectionReset))); @@ -427,10 +431,8 @@ impl AsyncWrite for IpStackTcpStream { impl Drop for IpStackTcpStream { fn drop(&mut self) { - if let Ok(p) = self.create_rev_packet(NON, DROP_TTL, None, Vec::new()) { - if let Err(err) = self.up_packet_sender.send(p) { - trace!("Error sending NON packet: {:?}", err); - } + if let Some(messenger) = self.destroy_messenger.take() { + let _ = messenger.send(()); } } } diff --git a/src/stream/tcp_wrapper.rs b/src/stream/tcp_wrapper.rs index 282841e..c7854f2 100644 --- a/src/stream/tcp_wrapper.rs +++ b/src/stream/tcp_wrapper.rs @@ -30,6 +30,13 @@ impl IpStackTcpStream { } }) } + + pub(crate) fn set_destroy_messenger(&mut self, messenger: tokio::sync::oneshot::Sender<()>) { + if let Some(inner) = self.inner.as_mut() { + inner.set_destroy_messenger(messenger); + } + } + pub fn local_addr(&self) -> SocketAddr { self.local_addr } From 5f10ae5875a7529d38d7b3e2ab41134a7096b2ad Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Sun, 2 Mar 2025 15:49:20 +0800 Subject: [PATCH 16/35] mod tcp_wrapper removed --- src/stream/mod.rs | 3 +- src/stream/tcp.rs | 18 +++++-- src/stream/tcp_wrapper.rs | 102 -------------------------------------- 3 files changed, 16 insertions(+), 107 deletions(-) delete mode 100644 src/stream/tcp_wrapper.rs diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 931a3ad..7207e99 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -1,13 +1,12 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; -pub use self::tcp_wrapper::IpStackTcpStream; +pub use self::tcp::IpStackTcpStream; pub use self::udp::IpStackUdpStream; pub use self::unknown::IpStackUnknownTransport; mod seqnum; mod tcb; mod tcp; -mod tcp_wrapper; mod udp; mod unknown; diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index 25505a4..f4b239a 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -42,9 +42,10 @@ impl Shutdown { } #[derive(Debug)] -pub(crate) struct IpStackTcpStream { +pub struct IpStackTcpStream { src_addr: SocketAddr, dst_addr: SocketAddr, + stream_sender: PacketSender, stream_receiver: PacketReceiver, up_packet_sender: PacketSender, packet_to_send: Option, @@ -61,13 +62,14 @@ impl IpStackTcpStream { dst_addr: SocketAddr, tcp: TcpHeader, up_packet_sender: PacketSender, - stream_receiver: PacketReceiver, mtu: u16, timeout_interval: Duration, ) -> Result { + let (stream_sender, stream_receiver) = tokio::sync::mpsc::unbounded_channel::(); let stream = IpStackTcpStream { src_addr, dst_addr, + stream_sender, stream_receiver, up_packet_sender, packet_to_send: None, @@ -86,10 +88,20 @@ impl IpStackTcpStream { warn!("Error sending RST/ACK packet: {:?}", err); } } - let info = format!("Invalid TCP packet: {:?}", TcpHeaderWrapper::from(tcp)); + let info = format!("Invalid TCP packet: {}", TcpHeaderWrapper::from(tcp)); Err(IpStackError::IoError(Error::new(ErrorKind::ConnectionRefused, info))) } + pub fn local_addr(&self) -> SocketAddr { + self.src_addr + } + pub fn peer_addr(&self) -> SocketAddr { + self.dst_addr + } + pub fn stream_sender(&self) -> PacketSender { + self.stream_sender.clone() + } + pub(crate) fn set_destroy_messenger(&mut self, messenger: tokio::sync::oneshot::Sender<()>) { self.destroy_messenger = Some(messenger); } diff --git a/src/stream/tcp_wrapper.rs b/src/stream/tcp_wrapper.rs deleted file mode 100644 index c7854f2..0000000 --- a/src/stream/tcp_wrapper.rs +++ /dev/null @@ -1,102 +0,0 @@ -use super::tcp::IpStackTcpStream as IpStackTcpStreamInner; -use crate::{packet::NetworkPacket, IpStackError, PacketSender}; -use etherparse::TcpHeader; -use std::{net::SocketAddr, pin::Pin, time::Duration}; -use tokio::{io::AsyncWriteExt, sync::mpsc, time::timeout}; - -pub struct IpStackTcpStream { - inner: Option>, - peer_addr: SocketAddr, - local_addr: SocketAddr, - stream_sender: PacketSender, -} - -impl IpStackTcpStream { - pub(crate) fn new( - local_addr: SocketAddr, - peer_addr: SocketAddr, - tcp: TcpHeader, - pkt_sender: PacketSender, - mtu: u16, - timeout_interval: Duration, - ) -> Result { - let (stream_sender, stream_receiver) = mpsc::unbounded_channel::(); - IpStackTcpStreamInner::new(local_addr, peer_addr, tcp, pkt_sender, stream_receiver, mtu, timeout_interval).map(|inner| { - IpStackTcpStream { - inner: Some(Box::new(inner)), - peer_addr, - local_addr, - stream_sender, - } - }) - } - - pub(crate) fn set_destroy_messenger(&mut self, messenger: tokio::sync::oneshot::Sender<()>) { - if let Some(inner) = self.inner.as_mut() { - inner.set_destroy_messenger(messenger); - } - } - - pub fn local_addr(&self) -> SocketAddr { - self.local_addr - } - pub fn peer_addr(&self) -> SocketAddr { - self.peer_addr - } - pub fn stream_sender(&self) -> PacketSender { - self.stream_sender.clone() - } -} - -impl tokio::io::AsyncRead for IpStackTcpStream { - fn poll_read( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> std::task::Poll> { - match self.inner.as_mut() { - Some(mut inner) => Pin::new(&mut inner).poll_read(cx, buf), - None => std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))), - } - } -} - -impl tokio::io::AsyncWrite for IpStackTcpStream { - fn poll_write( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> std::task::Poll> { - match self.inner.as_mut() { - Some(mut inner) => Pin::new(&mut inner).poll_write(cx, buf), - None => std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))), - } - } - fn poll_flush(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { - match self.inner.as_mut() { - Some(mut inner) => Pin::new(&mut inner).poll_flush(cx), - None => std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))), - } - } - fn poll_shutdown(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { - match self.inner.as_mut() { - Some(mut inner) => Pin::new(&mut inner).poll_shutdown(cx), - None => std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))), - } - } -} - -impl Drop for IpStackTcpStream { - fn drop(&mut self) { - if let Some(mut inner) = self.inner.take() { - let local_addr = self.local_addr(); - let peer_addr = self.peer_addr(); - tokio::spawn(async move { - if let Err(err) = timeout(Duration::from_secs(2), inner.shutdown()).await { - log::warn!("Error while dropping IpStackTcpStream: {:?}", err); - } - log::trace!("TCP Stream closed: {} -> {}", local_addr, peer_addr); - }); - } - } -} From d81fd0d9c93747f22eb8347d5b613b3183cbb551 Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Sun, 2 Mar 2025 16:19:33 +0800 Subject: [PATCH 17/35] TcpHeaderWrapper removed --- src/packet.rs | 140 +++++++++++++++++++--------------------------- src/stream/tcp.rs | 31 +++++----- 2 files changed, 72 insertions(+), 99 deletions(-) diff --git a/src/packet.rs b/src/packet.rs index 8572bb8..175a66a 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -127,96 +127,70 @@ impl NetworkPacket { } } -#[derive(Debug, Clone)] -pub struct TcpHeaderWrapper { - header: TcpHeader, -} - -impl std::fmt::Display for TcpHeaderWrapper { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut flags = String::new(); - if self.header.cwr { - flags.push_str("CWR "); - } - if self.header.ece { - flags.push_str("ECE "); - } - if self.header.urg { - flags.push_str("URG "); - } - if self.header.ack { - flags.push_str("ACK "); - } - if self.header.psh { - flags.push_str("PSH "); - } - if self.header.rst { - flags.push_str("RST "); - } - if self.header.syn { - flags.push_str("SYN "); - } - if self.header.fin { - flags.push_str("FIN "); - } - write!( - f, - "TcpHeader {{ src_port: {}, dst_port: {}, seq: {}, ack: {}, flags: {} }}", - self.header.source_port, - self.header.destination_port, - self.header.sequence_number, - self.header.acknowledgment_number, - flags.trim() - ) +pub fn tcp_header_fmt(header: &TcpHeader) -> String { + let mut flags = String::new(); + if header.cwr { + flags.push_str("CWR "); } -} - -impl TcpHeaderWrapper { - pub fn inner(&self) -> &TcpHeader { - &self.header + if header.ece { + flags.push_str("ECE "); } - pub fn flags(&self) -> u8 { - let inner = self.inner(); - let mut flags = 0; - if inner.cwr { - flags |= tcp_flags::CWR; - } - if inner.ece { - flags |= tcp_flags::ECE; - } - if inner.urg { - flags |= tcp_flags::URG; - } - if inner.ack { - flags |= tcp_flags::ACK; - } - if inner.psh { - flags |= tcp_flags::PSH; - } - if inner.rst { - flags |= tcp_flags::RST; - } - if inner.syn { - flags |= tcp_flags::SYN; - } - if inner.fin { - flags |= tcp_flags::FIN; - } - - flags + if header.urg { + flags.push_str("URG "); } -} - -impl From for TcpHeaderWrapper { - fn from(header: TcpHeader) -> Self { - TcpHeaderWrapper { header } + if header.ack { + flags.push_str("ACK "); + } + if header.psh { + flags.push_str("PSH "); + } + if header.rst { + flags.push_str("RST "); + } + if header.syn { + flags.push_str("SYN "); } + if header.fin { + flags.push_str("FIN "); + } + format!( + "TcpHeader {{ src_port: {}, dst_port: {}, seq: {}, ack: {}, flags: {} }}", + header.source_port, + header.destination_port, + header.sequence_number, + header.acknowledgment_number, + flags.trim() + ) } -impl From<&TcpHeader> for TcpHeaderWrapper { - fn from(header: &TcpHeader) -> Self { - TcpHeaderWrapper { header: header.clone() } +pub fn tcp_header_flags(inner: &TcpHeader) -> u8 { + let mut flags = 0; + if inner.cwr { + flags |= tcp_flags::CWR; + } + if inner.ece { + flags |= tcp_flags::ECE; + } + if inner.urg { + flags |= tcp_flags::URG; + } + if inner.ack { + flags |= tcp_flags::ACK; + } + if inner.psh { + flags |= tcp_flags::PSH; } + if inner.rst { + flags |= tcp_flags::RST; + } + if inner.syn { + flags |= tcp_flags::SYN; + } + if inner.fin { + flags |= tcp_flags::FIN; + } + + flags } // pub struct UdpPacket { diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index f4b239a..9e30f42 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -3,13 +3,13 @@ use crate::{ error::IpStackError, packet::{ tcp_flags::{ACK, FIN, PSH, RST, SYN}, - IpHeader, NetworkPacket, TcpHeaderWrapper, TransportHeader, + tcp_header_flags, tcp_header_fmt, IpHeader, NetworkPacket, TransportHeader, }, stream::tcb::{PacketStatus, Tcb, TcpState}, PacketReceiver, PacketSender, TTL, }; use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel, TcpHeader}; -use log::{error, trace, warn}; +use log::{error, warn}; use std::{ cmp, future::Future, @@ -88,7 +88,7 @@ impl IpStackTcpStream { warn!("Error sending RST/ACK packet: {:?}", err); } } - let info = format!("Invalid TCP packet: {}", TcpHeaderWrapper::from(tcp)); + let info = format!("Invalid TCP packet: {}", tcp_header_fmt(&tcp)); Err(IpStackError::IoError(Error::new(ErrorKind::ConnectionRefused, info))) } @@ -207,7 +207,7 @@ impl AsyncRead for IpStackTcpStream { self.tcb.change_recv_window(min); if matches!(Pin::new(&mut self.tcb.timeout).poll(cx), Poll::Ready(_)) { - trace!("timeout reached for {:?}", self.dst_addr); + log::trace!("timeout reached for {:?}", self.dst_addr); self.up_packet_sender .send(self.create_rev_packet(RST | ACK, TTL, None, Vec::new())?) .or(Err(ErrorKind::UnexpectedEof))?; @@ -253,10 +253,9 @@ impl AsyncRead for IpStackTcpStream { let TransportHeader::Tcp(tcp_header) = p.transport_header() else { unreachable!() }; - let t: TcpHeaderWrapper = tcp_header.into(); - let tcp_header = t.inner(); + let flags = tcp_header_flags(tcp_header); let incoming_ack: SeqNum = tcp_header.acknowledgment_number.into(); - if t.flags() & RST != 0 { + if flags & RST != 0 { self.tcb.change_state(TcpState::Closed); self.shutdown.ready(); return Poll::Ready(Err(Error::from(ErrorKind::ConnectionReset))); @@ -266,13 +265,13 @@ impl AsyncRead for IpStackTcpStream { } if self.tcb.get_state() == TcpState::SynReceived { - if t.flags() == ACK { + if flags == ACK { self.tcb.change_last_ack(incoming_ack); self.tcb.change_send_window(tcp_header.window_size); self.tcb.change_state(TcpState::Established); } } else if self.tcb.get_state() == TcpState::Established { - if t.flags() == ACK { + if flags == ACK { match self.tcb.check_pkt_type(tcp_header, &p.payload) { PacketStatus::WindowUpdate => { self.tcb.change_send_window(tcp_header.window_size); @@ -309,7 +308,7 @@ impl AsyncRead for IpStackTcpStream { // } self.tcb.change_last_ack(incoming_ack); - self.tcb.add_unordered_packet(tcp_header.sequence_number.into(), p.payload); + self.tcb.add_unordered_packet(tcp_header.sequence_number.into(), p.payload.clone()); self.tcb.change_send_window(tcp_header.window_size); if let Some(waker) = self.write_notify.take() { @@ -327,13 +326,13 @@ impl AsyncRead for IpStackTcpStream { } }; } - if t.flags() == (FIN | ACK) { + if flags == (FIN | ACK) { self.tcb.add_ack(1.into()); self.packet_to_send = Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); self.tcb.change_state(TcpState::FinWait1(true)); continue; } - if t.flags() == (PSH | ACK) { + if flags == (PSH | ACK) { if !matches!(self.tcb.check_pkt_type(tcp_header, &p.payload), PacketStatus::NewPacket) { continue; } @@ -349,12 +348,12 @@ impl AsyncRead for IpStackTcpStream { continue; } } else if self.tcb.get_state() == TcpState::FinWait1(false) { - if t.flags() == ACK { + if flags == ACK { self.tcb.change_last_ack(incoming_ack); self.tcb.add_ack(1.into()); self.tcb.change_state(TcpState::FinWait2(true)); continue; - } else if t.flags() == (FIN | ACK) { + } else if flags == (FIN | ACK) { self.tcb.add_ack(1.into()); self.packet_to_send = Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); self.tcb.change_send_window(tcp_header.window_size); @@ -362,9 +361,9 @@ impl AsyncRead for IpStackTcpStream { continue; } } else if self.tcb.get_state() == TcpState::FinWait2(true) { - if t.flags() == ACK { + if flags == ACK { self.tcb.change_state(TcpState::FinWait2(false)); - } else if t.flags() == (FIN | ACK) { + } else if flags == (FIN | ACK) { self.packet_to_send = Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); self.tcb.change_state(TcpState::FinWait2(false)); } From 6ff3a900f864457b6045722c9b9773d02d25c2b4 Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Sun, 2 Mar 2025 18:30:52 +0800 Subject: [PATCH 18/35] find_inflight_packet & get_all_inflight_packets --- src/stream/tcb.rs | 11 ++++++++++- src/stream/tcp.rs | 4 ++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/stream/tcb.rs b/src/stream/tcb.rs index a8c6073..6810abb 100644 --- a/src/stream/tcb.rs +++ b/src/stream/tcb.rs @@ -44,7 +44,7 @@ pub(super) struct Tcb { send_window: u16, state: TcpState, avg_send_window: (u64, u64), // (avg, count) - pub(super) inflight_packets: Vec, + inflight_packets: Vec, unordered_packets: BTreeMap, } @@ -190,6 +190,15 @@ impl Tcb { }); } } + + pub(crate) fn find_inflight_packet(&self, seq: SeqNum) -> Option<&InflightPacket> { + self.inflight_packets.iter().find(|p| p.seq == seq) + } + + pub(crate) fn get_all_inflight_packets(&self) -> &Vec { + &self.inflight_packets + } + pub fn is_send_buffer_full(&self) -> bool { (self.seq - self.last_ack).0 >= MAX_UNACK } diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index 9e30f42..847b3a1 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -411,7 +411,7 @@ impl AsyncWrite for IpStackTcpStream { } if let Some(s) = self.tcb.retransmission.take() { - if let Some(packet) = self.tcb.inflight_packets.iter().find(|p| p.seq == s) { + if let Some(packet) = self.tcb.find_inflight_packet(s) { let rev_packet = self.create_rev_packet(PSH | ACK, TTL, packet.seq, packet.payload.clone())?; self.up_packet_sender.send(rev_packet).or(Err(ErrorKind::UnexpectedEof))?; } else { @@ -420,7 +420,7 @@ impl AsyncWrite for IpStackTcpStream { error!("last_ack: {}", self.tcb.get_last_ack()); error!("ack: {}", self.tcb.get_ack()); error!("inflight_packets:"); - for p in self.tcb.inflight_packets.iter() { + for p in self.tcb.get_all_inflight_packets().iter() { error!("seq: {}", p.seq); error!("payload len: {}", p.payload.len()); } From dba2683ca092c672324a4a99a19a5eb7d5c0fbcb Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Sun, 2 Mar 2025 20:15:01 +0800 Subject: [PATCH 19/35] incoming_seq --- src/stream/tcb.rs | 2 +- src/stream/tcp.rs | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/stream/tcb.rs b/src/stream/tcb.rs index 6810abb..18de66d 100644 --- a/src/stream/tcb.rs +++ b/src/stream/tcb.rs @@ -9,7 +9,7 @@ const READ_BUFFER_SIZE: usize = 1024 * 16; // 16KB #[allow(dead_code)] #[derive(Debug, PartialEq, Clone, Copy)] pub(crate) enum TcpState { - Init, /* since we always act as a server beginning from Listen, so we needn't the state Init & SynSent */ + Init, /* Since we always act as a server, it starts from `Listen`, so we don't use states Init & SynSent. */ SynSent, Listen, SynReceived, diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index 847b3a1..a7ee3cd 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -228,9 +228,8 @@ impl AsyncRead for IpStackTcpStream { use std::io::{Error, ErrorKind::Other}; self.tcb.add_ack(b.len().try_into().map_err(|e| Error::new(Other, e))?); buf.put_slice(&b); - self.up_packet_sender - .send(self.create_rev_packet(ACK, TTL, None, Vec::new())?) - .or(Err(ErrorKind::UnexpectedEof))?; + let packet = self.create_rev_packet(ACK, TTL, None, Vec::new())?; + self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; return Poll::Ready(Ok(())); } if self.tcb.get_state() == TcpState::FinWait1(true) { @@ -255,6 +254,7 @@ impl AsyncRead for IpStackTcpStream { }; let flags = tcp_header_flags(tcp_header); let incoming_ack: SeqNum = tcp_header.acknowledgment_number.into(); + let incoming_seq: SeqNum = tcp_header.sequence_number.into(); if flags & RST != 0 { self.tcb.change_state(TcpState::Closed); self.shutdown.ready(); @@ -296,8 +296,8 @@ impl AsyncRead for IpStackTcpStream { continue; } PacketStatus::NewPacket => { - // if tcp_header.sequence_number != self.tcb.get_ack() { - // dbg!(tcp_header.sequence_number); + // if incoming_seq != self.tcb.get_ack() { + // dbg!(incoming_seq); // self.packet_to_send = Some(self.create_rev_packet( // ACK, // TTL, @@ -308,7 +308,7 @@ impl AsyncRead for IpStackTcpStream { // } self.tcb.change_last_ack(incoming_ack); - self.tcb.add_unordered_packet(tcp_header.sequence_number.into(), p.payload.clone()); + self.tcb.add_unordered_packet(incoming_seq, p.payload.clone()); self.tcb.change_send_window(tcp_header.window_size); if let Some(waker) = self.write_notify.take() { @@ -338,13 +338,13 @@ impl AsyncRead for IpStackTcpStream { } self.tcb.change_last_ack(incoming_ack); - if p.payload.is_empty() || self.tcb.get_ack() != tcp_header.sequence_number { + if p.payload.is_empty() || self.tcb.get_ack() != incoming_seq { continue; } self.tcb.change_send_window(tcp_header.window_size); - self.tcb.add_unordered_packet(tcp_header.sequence_number.into(), p.payload); + self.tcb.add_unordered_packet(incoming_seq, p.payload); continue; } } else if self.tcb.get_state() == TcpState::FinWait1(false) { From 3c60021b4491caf318da2b32cf460fec4515811a Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Sun, 2 Mar 2025 23:50:07 +0800 Subject: [PATCH 20/35] hide some log::trace --- src/lib.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index e6d9bd3..662e3dd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -174,7 +174,7 @@ async fn process_device_read( Occupied(entry) => { use std::io::{Error, ErrorKind::Other}; entry.get().send(packet).map_err(|e| Error::new(Other, e))?; - log::trace!("packet sent to stream: {}", network_tuple); + // log::trace!("packet sent to stream: {}", network_tuple); } Vacant(entry) => { let (packet_sender, mut ip_stack_stream) = create_stream(packet, config, up_pkt_sender)?; @@ -185,7 +185,7 @@ async fn process_device_read( tokio::spawn(async move { rx.await.ok(); sessions_clone.lock().await.remove(&network_tuple); - log::trace!("session removed: {}", network_tuple); + // log::trace!("session removed: {}", network_tuple); }); } if let IpStackStream::Udp(ref mut stream) = ip_stack_stream { @@ -194,7 +194,7 @@ async fn process_device_read( tokio::spawn(async move { rx.await.ok(); sessions_clone.lock().await.remove(&network_tuple); - log::trace!("session removed: {}", network_tuple); + // log::trace!("session removed: {}", network_tuple); }); } entry.insert(packet_sender); From 98bf25f6c9961f76d9c9fc40ee0fbf05954b7cdc Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Mon, 3 Mar 2025 13:05:17 +0800 Subject: [PATCH 21/35] refactor retransmission logic --- src/stream/tcb.rs | 3 +-- src/stream/tcp.rs | 55 +++++++++++++++-------------------------------- 2 files changed, 18 insertions(+), 40 deletions(-) diff --git a/src/stream/tcb.rs b/src/stream/tcb.rs index 18de66d..5a87e68 100644 --- a/src/stream/tcb.rs +++ b/src/stream/tcb.rs @@ -35,7 +35,6 @@ pub(super) enum PacketStatus { #[derive(Debug)] pub(super) struct Tcb { seq: SeqNum, - pub(super) retransmission: Option, ack: SeqNum, last_ack: SeqNum, pub(super) timeout: Pin>, @@ -57,7 +56,6 @@ impl Tcb { let deadline = tokio::time::Instant::now() + timeout_interval; Tcb { seq: seq.into(), - retransmission: None, ack, last_ack: seq.into(), timeout_interval, @@ -195,6 +193,7 @@ impl Tcb { self.inflight_packets.iter().find(|p| p.seq == seq) } + #[allow(dead_code)] pub(crate) fn get_all_inflight_packets(&self) -> &Vec { &self.inflight_packets } diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index a7ee3cd..99a3a71 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -9,7 +9,6 @@ use crate::{ PacketReceiver, PacketSender, TTL, }; use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel, TcpHeader}; -use log::{error, warn}; use std::{ cmp, future::Future, @@ -85,7 +84,7 @@ impl IpStackTcpStream { if !tcp.rst { let pkt = stream.create_rev_packet(RST | ACK, TTL, None, Vec::new())?; if let Err(err) = stream.up_packet_sender.send(pkt) { - warn!("Error sending RST/ACK packet: {:?}", err); + log::warn!("Error sending RST/ACK packet: {:?}", err); } } let info = format!("Invalid TCP packet: {}", tcp_header_fmt(&tcp)); @@ -182,13 +181,6 @@ impl IpStackTcpStream { impl AsyncRead for IpStackTcpStream { fn poll_read(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll> { loop { - if self.tcb.retransmission.is_some() { - self.write_notify = Some(cx.waker().clone()); - if matches!(self.as_mut().poll_flush(cx), Poll::Pending) { - return Poll::Pending; - } - } - if let Some(packet) = self.packet_to_send.take() { self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; } @@ -288,10 +280,22 @@ impl AsyncRead for IpStackTcpStream { continue; } PacketStatus::RetransmissionRequest => { + log::trace!("Retransmission request {}", tcp_header_fmt(tcp_header)); self.tcb.change_send_window(tcp_header.window_size); - self.tcb.retransmission = Some(incoming_ack); - if matches!(self.as_mut().poll_flush(cx), Poll::Pending) { - return Poll::Pending; + if let Some(packet) = self.tcb.find_inflight_packet(incoming_ack) { + let rev_packet = self.create_rev_packet(PSH | ACK, TTL, packet.seq, packet.payload.clone())?; + self.up_packet_sender.send(rev_packet).or(Err(ErrorKind::UnexpectedEof))?; + } else { + log::error!("Packet {} not found in inflight_packets", incoming_ack); + log::error!("seq: {}", self.tcb.get_seq()); + log::error!("last_ack: {}", self.tcb.get_last_ack()); + log::error!("ack: {}", self.tcb.get_ack()); + log::error!("inflight_packets:"); + for p in self.tcb.get_all_inflight_packets().iter() { + log::error!("seq: {}", p.seq); + log::error!("payload len: {}", p.payload.len()); + } + panic!("Please report these values at: https://github.com/narrowlink/ipstack/"); } continue; } @@ -388,13 +392,6 @@ impl AsyncWrite for IpStackTcpStream { return Poll::Pending; } - if self.tcb.retransmission.is_some() { - self.write_notify = Some(cx.waker().clone()); - if matches!(self.as_mut().poll_flush(cx), Poll::Pending) { - return Poll::Pending; - } - } - let packet = self.create_rev_packet(PSH | ACK, TTL, None, buf.to_vec())?; let seq = self.tcb.get_seq(); let payload_len = packet.payload.len(); @@ -405,28 +402,10 @@ impl AsyncWrite for IpStackTcpStream { Poll::Ready(Ok(payload_len)) } - fn poll_flush(mut self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + fn poll_flush(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { if self.tcb.get_state() != TcpState::Established { return Poll::Ready(Err(Error::from(ErrorKind::NotConnected))); } - - if let Some(s) = self.tcb.retransmission.take() { - if let Some(packet) = self.tcb.find_inflight_packet(s) { - let rev_packet = self.create_rev_packet(PSH | ACK, TTL, packet.seq, packet.payload.clone())?; - self.up_packet_sender.send(rev_packet).or(Err(ErrorKind::UnexpectedEof))?; - } else { - error!("Packet {} not found in inflight_packets", s); - error!("seq: {}", self.tcb.get_seq()); - error!("last_ack: {}", self.tcb.get_last_ack()); - error!("ack: {}", self.tcb.get_ack()); - error!("inflight_packets:"); - for p in self.tcb.get_all_inflight_packets().iter() { - error!("seq: {}", p.seq); - error!("payload len: {}", p.payload.len()); - } - panic!("Please report these values at: https://github.com/narrowlink/ipstack/"); - } - } Poll::Ready(Ok(())) } From 21a5a9ce9013e4f56bd666b83af1aaf74c6f774e Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Mon, 3 Mar 2025 13:59:40 +0800 Subject: [PATCH 22/35] packet_to_send: Option removed --- src/stream/tcp.rs | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index 99a3a71..fce6287 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -47,7 +47,6 @@ pub struct IpStackTcpStream { stream_sender: PacketSender, stream_receiver: PacketReceiver, up_packet_sender: PacketSender, - packet_to_send: Option, tcb: Tcb, mtu: u16, shutdown: Shutdown, @@ -71,7 +70,6 @@ impl IpStackTcpStream { stream_sender, stream_receiver, up_packet_sender, - packet_to_send: None, tcb: Tcb::new(SeqNum(tcp.sequence_number) + 1, timeout_interval), mtu, shutdown: Shutdown::None, @@ -181,9 +179,6 @@ impl IpStackTcpStream { impl AsyncRead for IpStackTcpStream { fn poll_read(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll> { loop { - if let Some(packet) = self.packet_to_send.take() { - self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; - } if self.tcb.get_state() == TcpState::Closed { self.shutdown.ready(); return Poll::Ready(Ok(())); @@ -200,9 +195,8 @@ impl AsyncRead for IpStackTcpStream { if matches!(Pin::new(&mut self.tcb.timeout).poll(cx), Poll::Ready(_)) { log::trace!("timeout reached for {:?}", self.dst_addr); - self.up_packet_sender - .send(self.create_rev_packet(RST | ACK, TTL, None, Vec::new())?) - .or(Err(ErrorKind::UnexpectedEof))?; + let packet = self.create_rev_packet(RST | ACK, TTL, None, Vec::new())?; + self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; self.tcb.change_state(TcpState::Closed); self.shutdown.ready(); return Poll::Ready(Err(Error::from(ErrorKind::TimedOut))); @@ -210,9 +204,10 @@ impl AsyncRead for IpStackTcpStream { self.tcb.reset_timeout(); if self.tcb.get_state() == TcpState::Listen { - self.packet_to_send = Some(self.create_rev_packet(SYN | ACK, TTL, None, Vec::new())?); + let packet = self.create_rev_packet(SYN | ACK, TTL, None, Vec::new())?; self.tcb.add_seq_one(); self.tcb.change_state(TcpState::SynReceived); + self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; continue; } @@ -225,21 +220,25 @@ impl AsyncRead for IpStackTcpStream { return Poll::Ready(Ok(())); } if self.tcb.get_state() == TcpState::FinWait1(true) { - self.packet_to_send = Some(self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?); + let packet = self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?; self.tcb.add_seq_one(); self.tcb.add_ack(1.into()); self.tcb.change_state(TcpState::FinWait2(true)); + self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; continue; } else if matches!(self.shutdown, Shutdown::Pending(_)) && self.tcb.get_state() == TcpState::Established && self.tcb.get_last_ack() == self.tcb.get_seq() { - self.packet_to_send = Some(self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?); + let packet = self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?; self.tcb.add_seq_one(); self.tcb.change_state(TcpState::FinWait1(false)); + self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; continue; } match self.stream_receiver.poll_recv(cx) { + Poll::Ready(None) => return Poll::Ready(Ok(())), + Poll::Pending => return Poll::Pending, Poll::Ready(Some(p)) => { let TransportHeader::Tcp(tcp_header) = p.transport_header() else { unreachable!() @@ -276,7 +275,8 @@ impl AsyncRead for IpStackTcpStream { PacketStatus::KeepAlive => { self.tcb.change_last_ack(incoming_ack); self.tcb.change_send_window(tcp_header.window_size); - self.packet_to_send = Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); + let packet = self.create_rev_packet(ACK, TTL, None, Vec::new())?; + self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; continue; } PacketStatus::RetransmissionRequest => { @@ -332,8 +332,9 @@ impl AsyncRead for IpStackTcpStream { } if flags == (FIN | ACK) { self.tcb.add_ack(1.into()); - self.packet_to_send = Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); + let packet = self.create_rev_packet(ACK, TTL, None, Vec::new())?; self.tcb.change_state(TcpState::FinWait1(true)); + self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; continue; } if flags == (PSH | ACK) { @@ -359,22 +360,22 @@ impl AsyncRead for IpStackTcpStream { continue; } else if flags == (FIN | ACK) { self.tcb.add_ack(1.into()); - self.packet_to_send = Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); + let packet = self.create_rev_packet(ACK, TTL, None, Vec::new())?; self.tcb.change_send_window(tcp_header.window_size); self.tcb.change_state(TcpState::FinWait2(true)); + self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; continue; } } else if self.tcb.get_state() == TcpState::FinWait2(true) { if flags == ACK { self.tcb.change_state(TcpState::FinWait2(false)); } else if flags == (FIN | ACK) { - self.packet_to_send = Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); + let packet = self.create_rev_packet(ACK, TTL, None, Vec::new())?; self.tcb.change_state(TcpState::FinWait2(false)); + self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; } } } - Poll::Ready(None) => return Poll::Ready(Ok(())), - Poll::Pending => return Poll::Pending, } } } From bfd468bbdd7bb957e0000bfc93950ed8b5fb6d86 Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Mon, 3 Mar 2025 17:43:05 +0800 Subject: [PATCH 23/35] minor changes in SeqNum --- src/stream/seqnum.rs | 10 ++++++++-- src/stream/tcp.rs | 3 +-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/stream/seqnum.rs b/src/stream/seqnum.rs index 131c6a7..33d4762 100644 --- a/src/stream/seqnum.rs +++ b/src/stream/seqnum.rs @@ -32,9 +32,15 @@ impl From for usize { } impl TryFrom for SeqNum { - type Error = std::num::TryFromIntError; + type Error = std::io::Error; fn try_from(value: usize) -> Result { - Ok(Self(value.try_into()?)) + if value > u32::MAX as usize { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!("value 0x{:X} is too large to convert to SeqNum", value), + )); + } + Ok(Self(value as u32)) } } diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index fce6287..8a7f5ed 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -212,8 +212,7 @@ impl AsyncRead for IpStackTcpStream { } if let Some(b) = self.tcb.get_unordered_packets().filter(|_| matches!(self.shutdown, Shutdown::None)) { - use std::io::{Error, ErrorKind::Other}; - self.tcb.add_ack(b.len().try_into().map_err(|e| Error::new(Other, e))?); + self.tcb.add_ack(b.len().try_into()?); buf.put_slice(&b); let packet = self.create_rev_packet(ACK, TTL, None, Vec::new())?; self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; From 227f76efa820ea890955a7fa88afebfe18a6fe81 Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Mon, 3 Mar 2025 18:09:11 +0800 Subject: [PATCH 24/35] payload alias --- src/stream/tcp.rs | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index 8a7f5ed..9016031 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -238,10 +238,11 @@ impl AsyncRead for IpStackTcpStream { match self.stream_receiver.poll_recv(cx) { Poll::Ready(None) => return Poll::Ready(Ok(())), Poll::Pending => return Poll::Pending, - Poll::Ready(Some(p)) => { - let TransportHeader::Tcp(tcp_header) = p.transport_header() else { + Poll::Ready(Some(network_packet)) => { + let TransportHeader::Tcp(tcp_header) = network_packet.transport_header() else { unreachable!() }; + let payload = &network_packet.payload; let flags = tcp_header_flags(tcp_header); let incoming_ack: SeqNum = tcp_header.acknowledgment_number.into(); let incoming_seq: SeqNum = tcp_header.sequence_number.into(); @@ -250,7 +251,7 @@ impl AsyncRead for IpStackTcpStream { self.shutdown.ready(); return Poll::Ready(Err(Error::from(ErrorKind::ConnectionReset))); } - if self.tcb.check_pkt_type(tcp_header, &p.payload) == PacketStatus::Invalid { + if self.tcb.check_pkt_type(tcp_header, payload) == PacketStatus::Invalid { continue; } @@ -262,7 +263,7 @@ impl AsyncRead for IpStackTcpStream { } } else if self.tcb.get_state() == TcpState::Established { if flags == ACK { - match self.tcb.check_pkt_type(tcp_header, &p.payload) { + match self.tcb.check_pkt_type(tcp_header, payload) { PacketStatus::WindowUpdate => { self.tcb.change_send_window(tcp_header.window_size); if let Some(waker) = self.write_notify.take() { @@ -311,7 +312,7 @@ impl AsyncRead for IpStackTcpStream { // } self.tcb.change_last_ack(incoming_ack); - self.tcb.add_unordered_packet(incoming_seq, p.payload.clone()); + self.tcb.add_unordered_packet(incoming_seq, payload.clone()); self.tcb.change_send_window(tcp_header.window_size); if let Some(waker) = self.write_notify.take() { @@ -337,18 +338,18 @@ impl AsyncRead for IpStackTcpStream { continue; } if flags == (PSH | ACK) { - if !matches!(self.tcb.check_pkt_type(tcp_header, &p.payload), PacketStatus::NewPacket) { + if !matches!(self.tcb.check_pkt_type(tcp_header, payload), PacketStatus::NewPacket) { continue; } self.tcb.change_last_ack(incoming_ack); - if p.payload.is_empty() || self.tcb.get_ack() != incoming_seq { + if payload.is_empty() || self.tcb.get_ack() != incoming_seq { continue; } self.tcb.change_send_window(tcp_header.window_size); - self.tcb.add_unordered_packet(incoming_seq, p.payload); + self.tcb.add_unordered_packet(incoming_seq, payload.clone()); continue; } } else if self.tcb.get_state() == TcpState::FinWait1(false) { From 9cab122eec696594c68b6e28b197b5b9c8b852de Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Mon, 3 Mar 2025 20:02:38 +0800 Subject: [PATCH 25/35] Refactor completed --- src/packet.rs | 6 ++++++ src/stream/tcb.rs | 14 ++++++------- src/stream/tcp.rs | 51 +++++++++++++++++++++++++---------------------- 3 files changed, 40 insertions(+), 31 deletions(-) diff --git a/src/packet.rs b/src/packet.rs index 175a66a..39e2d6c 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -9,6 +9,12 @@ pub struct NetworkTuple { pub tcp: bool, } +impl NetworkTuple { + pub fn new(src: SocketAddr, dst: SocketAddr, tcp: bool) -> Self { + NetworkTuple { src, dst, tcp } + } +} + impl std::fmt::Display for NetworkTuple { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let tcp = if self.tcp { "TCP" } else { "UDP" }; diff --git a/src/stream/tcb.rs b/src/stream/tcb.rs index 5a87e68..485536a 100644 --- a/src/stream/tcb.rs +++ b/src/stream/tcb.rs @@ -6,16 +6,15 @@ use tokio::time::Sleep; const MAX_UNACK: u32 = 1024 * 16; // 16KB const READ_BUFFER_SIZE: usize = 1024 * 16; // 16KB -#[allow(dead_code)] #[derive(Debug, PartialEq, Clone, Copy)] pub(crate) enum TcpState { - Init, /* Since we always act as a server, it starts from `Listen`, so we don't use states Init & SynSent. */ - SynSent, + // Init, /* Since we always act as a server, it starts from `Listen`, so we don't use states Init & SynSent. */ + // SynSent, Listen, SynReceived, Established, - FinWait1(bool), // act as a client, followed with FinWait2, TimeWait, Closed - FinWait2(bool), // bool means waiting for ack + FinWait1, // act as a client, actively send a farewell packet to the other side, followed with FinWait2, TimeWait, Closed + FinWait2, TimeWait, CloseWait, // act as a server, followed with LastAck, Closed LastAck, @@ -202,8 +201,9 @@ impl Tcb { (self.seq - self.last_ack).0 >= MAX_UNACK } - pub(crate) fn reset_timeout(&mut self) { - let deadline = tokio::time::Instant::now() + self.timeout_interval; + pub(crate) fn reset_timeout(&mut self, final_reset: bool) { + let two_msl = Duration::from_secs(2); + let deadline = tokio::time::Instant::now() + if final_reset { two_msl } else { self.timeout_interval }; self.timeout.as_mut().reset(deadline); } } diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index 9016031..b2650d4 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -3,7 +3,7 @@ use crate::{ error::IpStackError, packet::{ tcp_flags::{ACK, FIN, PSH, RST, SYN}, - tcp_header_flags, tcp_header_fmt, IpHeader, NetworkPacket, TransportHeader, + tcp_header_flags, tcp_header_fmt, IpHeader, NetworkPacket, NetworkTuple, TransportHeader, }, stream::tcb::{PacketStatus, Tcb, TcpState}, PacketReceiver, PacketSender, TTL, @@ -184,24 +184,22 @@ impl AsyncRead for IpStackTcpStream { return Poll::Ready(Ok(())); } - if self.tcb.get_state() == TcpState::FinWait2(false) { - self.tcb.change_state(TcpState::Closed); - self.shutdown.ready(); - return Poll::Ready(Err(Error::from(ErrorKind::ConnectionAborted))); - } - let min = self.tcb.get_available_read_buffer_size() as u16; self.tcb.change_recv_window(min); + let final_reset = self.tcb.get_state() == TcpState::TimeWait; if matches!(Pin::new(&mut self.tcb.timeout).poll(cx), Poll::Ready(_)) { - log::trace!("timeout reached for {:?}", self.dst_addr); + if !final_reset { + let network_tuple = NetworkTuple::new(self.src_addr, self.dst_addr, true); + log::trace!("timeout reached for {}", network_tuple); + } let packet = self.create_rev_packet(RST | ACK, TTL, None, Vec::new())?; self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; self.tcb.change_state(TcpState::Closed); self.shutdown.ready(); return Poll::Ready(Err(Error::from(ErrorKind::TimedOut))); } - self.tcb.reset_timeout(); + self.tcb.reset_timeout(final_reset); if self.tcb.get_state() == TcpState::Listen { let packet = self.create_rev_packet(SYN | ACK, TTL, None, Vec::new())?; @@ -218,20 +216,21 @@ impl AsyncRead for IpStackTcpStream { self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; return Poll::Ready(Ok(())); } - if self.tcb.get_state() == TcpState::FinWait1(true) { + if self.tcb.get_state() == TcpState::CloseWait { let packet = self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?; self.tcb.add_seq_one(); self.tcb.add_ack(1.into()); - self.tcb.change_state(TcpState::FinWait2(true)); + self.tcb.change_state(TcpState::LastAck); self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; continue; } else if matches!(self.shutdown, Shutdown::Pending(_)) && self.tcb.get_state() == TcpState::Established && self.tcb.get_last_ack() == self.tcb.get_seq() { + // Act as a client, actively send a farewell packet to the other side. let packet = self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?; self.tcb.add_seq_one(); - self.tcb.change_state(TcpState::FinWait1(false)); + self.tcb.change_state(TcpState::FinWait1); self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; continue; } @@ -333,7 +332,7 @@ impl AsyncRead for IpStackTcpStream { if flags == (FIN | ACK) { self.tcb.add_ack(1.into()); let packet = self.create_rev_packet(ACK, TTL, None, Vec::new())?; - self.tcb.change_state(TcpState::FinWait1(true)); + self.tcb.change_state(TcpState::CloseWait); self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; continue; } @@ -352,28 +351,32 @@ impl AsyncRead for IpStackTcpStream { self.tcb.add_unordered_packet(incoming_seq, payload.clone()); continue; } - } else if self.tcb.get_state() == TcpState::FinWait1(false) { + } else if self.tcb.get_state() == TcpState::FinWait1 { if flags == ACK { self.tcb.change_last_ack(incoming_ack); self.tcb.add_ack(1.into()); - self.tcb.change_state(TcpState::FinWait2(true)); + self.tcb.change_state(TcpState::FinWait2); continue; - } else if flags == (FIN | ACK) { + } + } else if self.tcb.get_state() == TcpState::FinWait2 { + if flags == (FIN | ACK) { self.tcb.add_ack(1.into()); let packet = self.create_rev_packet(ACK, TTL, None, Vec::new())?; self.tcb.change_send_window(tcp_header.window_size); - self.tcb.change_state(TcpState::FinWait2(true)); + self.tcb.change_state(TcpState::TimeWait); self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; continue; } - } else if self.tcb.get_state() == TcpState::FinWait2(true) { + } else if self.tcb.get_state() == TcpState::LastAck { if flags == ACK { - self.tcb.change_state(TcpState::FinWait2(false)); - } else if flags == (FIN | ACK) { - let packet = self.create_rev_packet(ACK, TTL, None, Vec::new())?; - self.tcb.change_state(TcpState::FinWait2(false)); - self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; + self.tcb.change_state(TcpState::Closed); } + } else if self.tcb.get_state() == TcpState::TimeWait && flags == (FIN | ACK) { + let packet = self.create_rev_packet(ACK, TTL, None, Vec::new())?; + // wait to timeout, can't change state here + // self.tcb.change_state(TcpState::Closed); + self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; + // now we need to wait for the timeout to reach... } } } @@ -386,7 +389,7 @@ impl AsyncWrite for IpStackTcpStream { if self.tcb.get_state() != TcpState::Established { return Poll::Ready(Err(Error::from(ErrorKind::NotConnected))); } - self.tcb.reset_timeout(); + self.tcb.reset_timeout(false); if (self.tcb.get_send_window() as u64) < self.tcb.get_avg_send_window() / 2 || self.tcb.is_send_buffer_full() { self.write_notify = Some(cx.waker().clone()); From 3736ec22514230db25ed52befdabd7c7cf0896e1 Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Tue, 4 Mar 2025 06:09:09 +0800 Subject: [PATCH 26/35] .env file support in examples --- .gitignore | 1 + Cargo.toml | 1 + examples/tun2.rs | 1 + examples/tun_wintun.rs | 1 + 4 files changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index 5fd3f0f..b1ecb27 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.env .vscode/ .VSCodeCounter/ /target/ diff --git a/Cargo.toml b/Cargo.toml index 823dca0..0d3881f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,7 @@ tokio = { version = "1.43", default-features = false, features = [ [dev-dependencies] clap = { version = "4.5", features = ["derive"] } criterion = { version = "0.5" } # Benchmarks +dotenvy = "0.15" env_logger = "0.11" tokio = { version = "1.43", default-features = false, features = [ "rt-multi-thread", diff --git a/examples/tun2.rs b/examples/tun2.rs index eb26417..8126260 100644 --- a/examples/tun2.rs +++ b/examples/tun2.rs @@ -71,6 +71,7 @@ struct Args { #[tokio::main] async fn main() -> Result<(), Box> { + dotenvy::dotenv().ok(); let args = Args::parse(); let default = format!("{:?}", args.verbosity); diff --git a/examples/tun_wintun.rs b/examples/tun_wintun.rs index ceaa96f..c7254a3 100644 --- a/examples/tun_wintun.rs +++ b/examples/tun_wintun.rs @@ -19,6 +19,7 @@ struct Args { #[tokio::main] async fn main() -> Result<(), Box> { + dotenvy::dotenv().ok(); let args = Args::parse(); env_logger::init(); From 0d9a4677cc1e42761fcf4bebf98135ccec500b86 Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Tue, 4 Mar 2025 06:33:12 +0800 Subject: [PATCH 27/35] network_tuple method for IpStackTcpStream --- src/packet.rs | 7 +++---- src/stream/tcp.rs | 11 +++++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/packet.rs b/src/packet.rs index 39e2d6c..22dd286 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -133,7 +133,7 @@ impl NetworkPacket { } } -pub fn tcp_header_fmt(header: &TcpHeader) -> String { +pub fn tcp_header_fmt(network_tuple: NetworkTuple, header: &TcpHeader) -> String { let mut flags = String::new(); if header.cwr { flags.push_str("CWR "); @@ -160,9 +160,8 @@ pub fn tcp_header_fmt(header: &TcpHeader) -> String { flags.push_str("FIN "); } format!( - "TcpHeader {{ src_port: {}, dst_port: {}, seq: {}, ack: {}, flags: {} }}", - header.source_port, - header.destination_port, + "{} TcpHeader {{ seq: {}, ack: {}, flags: {} }}", + network_tuple, header.sequence_number, header.acknowledgment_number, flags.trim() diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index b2650d4..9705af9 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -85,10 +85,14 @@ impl IpStackTcpStream { log::warn!("Error sending RST/ACK packet: {:?}", err); } } - let info = format!("Invalid TCP packet: {}", tcp_header_fmt(&tcp)); + let info = format!("Invalid TCP packet: {}", tcp_header_fmt(stream.network_tuple(), &tcp)); Err(IpStackError::IoError(Error::new(ErrorKind::ConnectionRefused, info))) } + pub(crate) fn network_tuple(&self) -> NetworkTuple { + NetworkTuple::new(self.src_addr, self.dst_addr, true) + } + pub fn local_addr(&self) -> SocketAddr { self.src_addr } @@ -190,8 +194,7 @@ impl AsyncRead for IpStackTcpStream { let final_reset = self.tcb.get_state() == TcpState::TimeWait; if matches!(Pin::new(&mut self.tcb.timeout).poll(cx), Poll::Ready(_)) { if !final_reset { - let network_tuple = NetworkTuple::new(self.src_addr, self.dst_addr, true); - log::trace!("timeout reached for {}", network_tuple); + log::trace!("timeout reached for {}", self.network_tuple()); } let packet = self.create_rev_packet(RST | ACK, TTL, None, Vec::new())?; self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; @@ -279,7 +282,7 @@ impl AsyncRead for IpStackTcpStream { continue; } PacketStatus::RetransmissionRequest => { - log::trace!("Retransmission request {}", tcp_header_fmt(tcp_header)); + log::trace!("Retransmission request {}", tcp_header_fmt(self.network_tuple(), tcp_header)); self.tcb.change_send_window(tcp_header.window_size); if let Some(packet) = self.tcb.find_inflight_packet(incoming_ack) { let rev_packet = self.create_rev_packet(PSH | ACK, TTL, packet.seq, packet.payload.clone())?; From 2b94d4f70359150055e237c1926178023877805d Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Tue, 4 Mar 2025 12:39:26 +0800 Subject: [PATCH 28/35] refine code --- src/lib.rs | 24 +++++++----------------- src/stream/tcb.rs | 12 +++++++----- 2 files changed, 14 insertions(+), 22 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 662e3dd..7532212 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -111,31 +111,20 @@ fn run( let sessions: SessionCollection = std::sync::Arc::new(tokio::sync::Mutex::new(AHashMap::new())); let pi = config.packet_information; let offset = if pi && cfg!(unix) { 4 } else { 0 }; - let mut buffer = [0_u8; u16::MAX as usize + 4]; + let mut buffer = vec![0_u8; u16::MAX as usize + offset]; let (up_pkt_sender, mut up_pkt_receiver) = mpsc::unbounded_channel::(); tokio::spawn(async move { loop { select! { Ok(n) = device.read(&mut buffer) => { - if let Err(e) = process_device_read( - &buffer[offset..n], - sessions.clone(), - up_pkt_sender.clone(), - &config, - &accept_sender, - ).await { + let u = up_pkt_sender.clone(); + if let Err(e) = process_device_read(&buffer[offset..n], sessions.clone(), u, &config, &accept_sender).await { log::debug!("process_device_read error: {}", e); } } Some(packet) = up_pkt_receiver.recv() => { - process_upstream_recv( - packet, - &mut device, - #[cfg(unix)] - pi, - ) - .await?; + process_upstream_recv(packet, &mut device, #[cfg(unix)]pi).await?; } } } @@ -177,6 +166,7 @@ async fn process_device_read( // log::trace!("packet sent to stream: {}", network_tuple); } Vacant(entry) => { + // log::trace!("new session: {}", network_tuple); let (packet_sender, mut ip_stack_stream) = create_stream(packet, config, up_pkt_sender)?; if let IpStackStream::Tcp(ref mut stream) = ip_stack_stream { let (tx, rx) = tokio::sync::oneshot::channel::<()>(); @@ -185,7 +175,7 @@ async fn process_device_read( tokio::spawn(async move { rx.await.ok(); sessions_clone.lock().await.remove(&network_tuple); - // log::trace!("session removed: {}", network_tuple); + // log::trace!("session destroyed: {}", network_tuple); }); } if let IpStackStream::Udp(ref mut stream) = ip_stack_stream { @@ -194,7 +184,7 @@ async fn process_device_read( tokio::spawn(async move { rx.await.ok(); sessions_clone.lock().await.remove(&network_tuple); - // log::trace!("session removed: {}", network_tuple); + // log::trace!("session destroyed: {}", network_tuple); }); } entry.insert(packet_sender); diff --git a/src/stream/tcb.rs b/src/stream/tcb.rs index 485536a..8f254a9 100644 --- a/src/stream/tcb.rs +++ b/src/stream/tcb.rs @@ -67,11 +67,7 @@ impl Tcb { unordered_packets: BTreeMap::new(), } } - pub(super) fn add_inflight_packet(&mut self, seq: SeqNum, buf: Vec) { - let buf_len = buf.len() as u32; - self.inflight_packets.push(InflightPacket::new(seq, buf)); - self.seq += buf_len; - } + pub(super) fn add_unordered_packet(&mut self, seq: SeqNum, buf: Vec) { if seq < self.ack { return; @@ -168,6 +164,12 @@ impl Tcb { } } + pub(super) fn add_inflight_packet(&mut self, seq: SeqNum, buf: Vec) { + let buf_len = buf.len() as u32; + self.inflight_packets.push(InflightPacket::new(seq, buf)); + self.seq += buf_len; + } + pub(super) fn change_last_ack(&mut self, ack: SeqNum) { self.last_ack = ack; From dda57b8ba9879b1502c3a77df33a0b82f59d2db8 Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Tue, 4 Mar 2025 13:50:44 +0800 Subject: [PATCH 29/35] move timeout timer from tcb to tcp --- src/stream/tcb.rs | 16 ++-------------- src/stream/tcp.rs | 19 +++++++++++++++---- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/src/stream/tcb.rs b/src/stream/tcb.rs index 8f254a9..5c413c9 100644 --- a/src/stream/tcb.rs +++ b/src/stream/tcb.rs @@ -1,7 +1,6 @@ use super::seqnum::SeqNum; use etherparse::TcpHeader; -use std::{collections::BTreeMap, pin::Pin, time::Duration}; -use tokio::time::Sleep; +use std::collections::BTreeMap; const MAX_UNACK: u32 = 1024 * 16; // 16KB const READ_BUFFER_SIZE: usize = 1024 * 16; // 16KB @@ -36,8 +35,6 @@ pub(super) struct Tcb { seq: SeqNum, ack: SeqNum, last_ack: SeqNum, - pub(super) timeout: Pin>, - timeout_interval: Duration, recv_window: u16, send_window: u16, state: TcpState, @@ -47,18 +44,15 @@ pub(super) struct Tcb { } impl Tcb { - pub(super) fn new(ack: SeqNum, timeout_interval: Duration) -> Tcb { + pub(super) fn new(ack: SeqNum) -> Tcb { #[cfg(debug_assertions)] let seq = 100; #[cfg(not(debug_assertions))] let seq = rand::random::(); - let deadline = tokio::time::Instant::now() + timeout_interval; Tcb { seq: seq.into(), ack, last_ack: seq.into(), - timeout_interval, - timeout: Box::pin(tokio::time::sleep_until(deadline)), send_window: u16::MAX, recv_window: 0, state: TcpState::Listen, @@ -202,12 +196,6 @@ impl Tcb { pub fn is_send_buffer_full(&self) -> bool { (self.seq - self.last_ack).0 >= MAX_UNACK } - - pub(crate) fn reset_timeout(&mut self, final_reset: bool) { - let two_msl = Duration::from_secs(2); - let deadline = tokio::time::Instant::now() + if final_reset { two_msl } else { self.timeout_interval }; - self.timeout.as_mut().reset(deadline); - } } #[derive(Debug)] diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index 9705af9..252da72 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -52,6 +52,8 @@ pub struct IpStackTcpStream { shutdown: Shutdown, write_notify: Option, destroy_messenger: Option>, + timeout: Pin>, + timeout_interval: Duration, } impl IpStackTcpStream { @@ -64,17 +66,20 @@ impl IpStackTcpStream { timeout_interval: Duration, ) -> Result { let (stream_sender, stream_receiver) = tokio::sync::mpsc::unbounded_channel::(); + let deadline = tokio::time::Instant::now() + timeout_interval; let stream = IpStackTcpStream { src_addr, dst_addr, stream_sender, stream_receiver, up_packet_sender, - tcb: Tcb::new(SeqNum(tcp.sequence_number) + 1, timeout_interval), + tcb: Tcb::new(SeqNum(tcp.sequence_number) + 1), mtu, shutdown: Shutdown::None, write_notify: None, destroy_messenger: None, + timeout: Box::pin(tokio::time::sleep_until(deadline)), + timeout_interval, }; if tcp.syn { return Ok(stream); @@ -89,6 +94,12 @@ impl IpStackTcpStream { Err(IpStackError::IoError(Error::new(ErrorKind::ConnectionRefused, info))) } + fn reset_timeout(&mut self, final_reset: bool) { + let two_msl = Duration::from_secs(2); + let deadline = tokio::time::Instant::now() + if final_reset { two_msl } else { self.timeout_interval }; + self.timeout.as_mut().reset(deadline); + } + pub(crate) fn network_tuple(&self) -> NetworkTuple { NetworkTuple::new(self.src_addr, self.dst_addr, true) } @@ -192,7 +203,7 @@ impl AsyncRead for IpStackTcpStream { self.tcb.change_recv_window(min); let final_reset = self.tcb.get_state() == TcpState::TimeWait; - if matches!(Pin::new(&mut self.tcb.timeout).poll(cx), Poll::Ready(_)) { + if matches!(Pin::new(&mut self.timeout).poll(cx), Poll::Ready(_)) { if !final_reset { log::trace!("timeout reached for {}", self.network_tuple()); } @@ -202,7 +213,7 @@ impl AsyncRead for IpStackTcpStream { self.shutdown.ready(); return Poll::Ready(Err(Error::from(ErrorKind::TimedOut))); } - self.tcb.reset_timeout(final_reset); + self.reset_timeout(final_reset); if self.tcb.get_state() == TcpState::Listen { let packet = self.create_rev_packet(SYN | ACK, TTL, None, Vec::new())?; @@ -392,7 +403,7 @@ impl AsyncWrite for IpStackTcpStream { if self.tcb.get_state() != TcpState::Established { return Poll::Ready(Err(Error::from(ErrorKind::NotConnected))); } - self.tcb.reset_timeout(false); + self.reset_timeout(false); if (self.tcb.get_send_window() as u64) < self.tcb.get_avg_send_window() / 2 || self.tcb.is_send_buffer_full() { self.write_notify = Some(cx.waker().clone()); From 9f48ba85bcd9cdd22da8b0c3b3099cb51e6e9844 Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Tue, 4 Mar 2025 16:01:19 +0800 Subject: [PATCH 30/35] refine code --- src/stream/tcb.rs | 3 ++- src/stream/tcp.rs | 17 +++++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/stream/tcb.rs b/src/stream/tcb.rs index 5c413c9..bbea012 100644 --- a/src/stream/tcb.rs +++ b/src/stream/tcb.rs @@ -64,6 +64,7 @@ impl Tcb { pub(super) fn add_unordered_packet(&mut self, seq: SeqNum, buf: Vec) { if seq < self.ack { + log::debug!("Received packet with seq < ack: seq = {}, ack = {}", seq, self.ack); return; } self.unordered_packets.insert(seq, UnorderedPacket::new(buf)); @@ -170,7 +171,7 @@ impl Tcb { if self.state == TcpState::Established { if let Some(i) = self.inflight_packets.iter().position(|p| p.contains_seq_num(ack - 1)) { let mut inflight_packet = self.inflight_packets.remove(i); - let distance = (ack - inflight_packet.seq).0 as usize; + let distance = ack.distance(inflight_packet.seq) as usize; if distance < inflight_packet.payload.len() { inflight_packet.payload.drain(0..distance); inflight_packet.seq = ack; diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index 252da72..94840d0 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -259,6 +259,7 @@ impl AsyncRead for IpStackTcpStream { let flags = tcp_header_flags(tcp_header); let incoming_ack: SeqNum = tcp_header.acknowledgment_number.into(); let incoming_seq: SeqNum = tcp_header.sequence_number.into(); + let window_size = tcp_header.window_size; if flags & RST != 0 { self.tcb.change_state(TcpState::Closed); self.shutdown.ready(); @@ -271,14 +272,14 @@ impl AsyncRead for IpStackTcpStream { if self.tcb.get_state() == TcpState::SynReceived { if flags == ACK { self.tcb.change_last_ack(incoming_ack); - self.tcb.change_send_window(tcp_header.window_size); + self.tcb.change_send_window(window_size); self.tcb.change_state(TcpState::Established); } } else if self.tcb.get_state() == TcpState::Established { if flags == ACK { match self.tcb.check_pkt_type(tcp_header, payload) { PacketStatus::WindowUpdate => { - self.tcb.change_send_window(tcp_header.window_size); + self.tcb.change_send_window(window_size); if let Some(waker) = self.write_notify.take() { waker.wake_by_ref(); } @@ -287,14 +288,14 @@ impl AsyncRead for IpStackTcpStream { PacketStatus::Invalid => continue, PacketStatus::KeepAlive => { self.tcb.change_last_ack(incoming_ack); - self.tcb.change_send_window(tcp_header.window_size); + self.tcb.change_send_window(window_size); let packet = self.create_rev_packet(ACK, TTL, None, Vec::new())?; self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; continue; } PacketStatus::RetransmissionRequest => { log::trace!("Retransmission request {}", tcp_header_fmt(self.network_tuple(), tcp_header)); - self.tcb.change_send_window(tcp_header.window_size); + self.tcb.change_send_window(window_size); if let Some(packet) = self.tcb.find_inflight_packet(incoming_ack) { let rev_packet = self.create_rev_packet(PSH | ACK, TTL, packet.seq, packet.payload.clone())?; self.up_packet_sender.send(rev_packet).or(Err(ErrorKind::UnexpectedEof))?; @@ -327,7 +328,7 @@ impl AsyncRead for IpStackTcpStream { self.tcb.change_last_ack(incoming_ack); self.tcb.add_unordered_packet(incoming_seq, payload.clone()); - self.tcb.change_send_window(tcp_header.window_size); + self.tcb.change_send_window(window_size); if let Some(waker) = self.write_notify.take() { waker.wake_by_ref(); } @@ -335,7 +336,7 @@ impl AsyncRead for IpStackTcpStream { } PacketStatus::Ack => { self.tcb.change_last_ack(incoming_ack); - self.tcb.change_send_window(tcp_header.window_size); + self.tcb.change_send_window(window_size); if let Some(waker) = self.write_notify.take() { waker.wake_by_ref(); } @@ -360,7 +361,7 @@ impl AsyncRead for IpStackTcpStream { continue; } - self.tcb.change_send_window(tcp_header.window_size); + self.tcb.change_send_window(window_size); self.tcb.add_unordered_packet(incoming_seq, payload.clone()); continue; @@ -376,7 +377,7 @@ impl AsyncRead for IpStackTcpStream { if flags == (FIN | ACK) { self.tcb.add_ack(1.into()); let packet = self.create_rev_packet(ACK, TTL, None, Vec::new())?; - self.tcb.change_send_window(tcp_header.window_size); + self.tcb.change_send_window(window_size); self.tcb.change_state(TcpState::TimeWait); self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; continue; From fb1b1fe41df57a3706045c64d934952b35f94849 Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Sun, 9 Mar 2025 17:06:24 +0800 Subject: [PATCH 31/35] refine log info --- src/lib.rs | 12 ++++++------ src/stream/tcp.rs | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 7532212..cddcc19 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -120,7 +120,7 @@ fn run( Ok(n) = device.read(&mut buffer) => { let u = up_pkt_sender.clone(); if let Err(e) = process_device_read(&buffer[offset..n], sessions.clone(), u, &config, &accept_sender).await { - log::debug!("process_device_read error: {}", e); + log::warn!("process_device_read error: {}", e); } } Some(packet) = up_pkt_receiver.recv() => { @@ -163,10 +163,10 @@ async fn process_device_read( Occupied(entry) => { use std::io::{Error, ErrorKind::Other}; entry.get().send(packet).map_err(|e| Error::new(Other, e))?; - // log::trace!("packet sent to stream: {}", network_tuple); + log::trace!("packet sent to stream: {}", network_tuple); } Vacant(entry) => { - // log::trace!("new session: {}", network_tuple); + log::debug!("session created: {}", network_tuple); let (packet_sender, mut ip_stack_stream) = create_stream(packet, config, up_pkt_sender)?; if let IpStackStream::Tcp(ref mut stream) = ip_stack_stream { let (tx, rx) = tokio::sync::oneshot::channel::<()>(); @@ -175,7 +175,7 @@ async fn process_device_read( tokio::spawn(async move { rx.await.ok(); sessions_clone.lock().await.remove(&network_tuple); - // log::trace!("session destroyed: {}", network_tuple); + log::debug!("session destroyed: {}", network_tuple); }); } if let IpStackStream::Udp(ref mut stream) = ip_stack_stream { @@ -184,7 +184,7 @@ async fn process_device_read( tokio::spawn(async move { rx.await.ok(); sessions_clone.lock().await.remove(&network_tuple); - // log::trace!("session destroyed: {}", network_tuple); + log::debug!("session destroyed: {}", network_tuple); }); } entry.insert(packet_sender); @@ -219,7 +219,7 @@ async fn process_upstream_recv( ) -> Result<()> { #[allow(unused_mut)] let Ok(mut packet_bytes) = up_packet.to_bytes() else { - log::trace!("to_bytes error"); + log::warn!("to_bytes error"); return Ok(()); }; #[cfg(unix)] diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index 94840d0..bf93c85 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -205,7 +205,7 @@ impl AsyncRead for IpStackTcpStream { let final_reset = self.tcb.get_state() == TcpState::TimeWait; if matches!(Pin::new(&mut self.timeout).poll(cx), Poll::Ready(_)) { if !final_reset { - log::trace!("timeout reached for {}", self.network_tuple()); + log::warn!("timeout reached for {}", self.network_tuple()); } let packet = self.create_rev_packet(RST | ACK, TTL, None, Vec::new())?; self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; @@ -294,7 +294,7 @@ impl AsyncRead for IpStackTcpStream { continue; } PacketStatus::RetransmissionRequest => { - log::trace!("Retransmission request {}", tcp_header_fmt(self.network_tuple(), tcp_header)); + log::debug!("Retransmission request {}", tcp_header_fmt(self.network_tuple(), tcp_header)); self.tcb.change_send_window(window_size); if let Some(packet) = self.tcb.find_inflight_packet(incoming_ack) { let rev_packet = self.create_rev_packet(PSH | ACK, TTL, packet.seq, packet.payload.clone())?; From 6c90e6f5f35e2035ac235b71fe0d3ab8c9a888a8 Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Sun, 9 Mar 2025 23:51:07 +0800 Subject: [PATCH 32/35] refine code --- src/lib.rs | 41 +++++++++++++++++------------------------ 1 file changed, 17 insertions(+), 24 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index cddcc19..75290ff 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,10 +3,7 @@ use crate::stream::{IpStackStream, IpStackTcpStream, IpStackUdpStream, IpStackUnknownTransport}; use ahash::AHashMap; use packet::{NetworkPacket, NetworkTuple, TransportHeader}; -use std::{ - collections::hash_map::Entry::{Occupied, Vacant}, - time::Duration, -}; +use std::time::Duration; use tokio::{ io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, select, @@ -160,33 +157,29 @@ async fn process_device_read( let sessions_clone = sessions.clone(); let network_tuple = packet.network_tuple(); match sessions.lock().await.entry(network_tuple) { - Occupied(entry) => { + std::collections::hash_map::Entry::Occupied(entry) => { use std::io::{Error, ErrorKind::Other}; entry.get().send(packet).map_err(|e| Error::new(Other, e))?; log::trace!("packet sent to stream: {}", network_tuple); } - Vacant(entry) => { + std::collections::hash_map::Entry::Vacant(entry) => { log::debug!("session created: {}", network_tuple); let (packet_sender, mut ip_stack_stream) = create_stream(packet, config, up_pkt_sender)?; - if let IpStackStream::Tcp(ref mut stream) = ip_stack_stream { - let (tx, rx) = tokio::sync::oneshot::channel::<()>(); - stream.set_destroy_messenger(tx); - let sessions_clone = sessions_clone.clone(); - tokio::spawn(async move { - rx.await.ok(); - sessions_clone.lock().await.remove(&network_tuple); - log::debug!("session destroyed: {}", network_tuple); - }); - } - if let IpStackStream::Udp(ref mut stream) = ip_stack_stream { - let (tx, rx) = tokio::sync::oneshot::channel::<()>(); - stream.set_destroy_messenger(tx); - tokio::spawn(async move { - rx.await.ok(); - sessions_clone.lock().await.remove(&network_tuple); - log::debug!("session destroyed: {}", network_tuple); - }); + let (tx, rx) = tokio::sync::oneshot::channel::<()>(); + match ip_stack_stream { + IpStackStream::Tcp(ref mut stream) => { + stream.set_destroy_messenger(tx); + } + IpStackStream::Udp(ref mut stream) => { + stream.set_destroy_messenger(tx); + } + _ => unreachable!(), } + tokio::spawn(async move { + rx.await.ok(); + sessions_clone.lock().await.remove(&network_tuple); + log::debug!("session destroyed: {}", network_tuple); + }); entry.insert(packet_sender); accept_sender.send(ip_stack_stream)?; } From f02436b03df66731b3613931eaa13ffa6dbd7166 Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Mon, 10 Mar 2025 06:44:03 +0800 Subject: [PATCH 33/35] rename example tun2 to tun --- examples/{tun2.rs => tun.rs} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename examples/{tun2.rs => tun.rs} (98%) diff --git a/examples/tun2.rs b/examples/tun.rs similarity index 98% rename from examples/tun2.rs rename to examples/tun.rs index 8126260..a800341 100644 --- a/examples/tun2.rs +++ b/examples/tun.rs @@ -5,7 +5,7 @@ //! //! This example must be run as root or administrator privileges. //! ``` -//! sudo target/debug/examples/tun2 --server-addr 127.0.0.1:8080 # Linux or macOS +//! sudo target/debug/examples/tun --server-addr 127.0.0.1:8080 # Linux or macOS //! ``` //! Then please run the `echo` example server, which listens on TCP & UDP ports 127.0.0.1:8080. //! ``` From dda03693b38ea6ababc9923b36ffe9d762669e5b Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Mon, 10 Mar 2025 08:37:51 +0800 Subject: [PATCH 34/35] refactor add_inflight_packet --- src/stream/tcb.rs | 4 ++-- src/stream/tcp.rs | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/stream/tcb.rs b/src/stream/tcb.rs index bbea012..f904bf2 100644 --- a/src/stream/tcb.rs +++ b/src/stream/tcb.rs @@ -159,9 +159,9 @@ impl Tcb { } } - pub(super) fn add_inflight_packet(&mut self, seq: SeqNum, buf: Vec) { + pub(super) fn add_inflight_packet(&mut self, buf: Vec) { let buf_len = buf.len() as u32; - self.inflight_packets.push(InflightPacket::new(seq, buf)); + self.inflight_packets.push(InflightPacket::new(self.seq, buf)); self.seq += buf_len; } diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index bf93c85..f14ef77 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -412,11 +412,10 @@ impl AsyncWrite for IpStackTcpStream { } let packet = self.create_rev_packet(PSH | ACK, TTL, None, buf.to_vec())?; - let seq = self.tcb.get_seq(); let payload_len = packet.payload.len(); let payload = packet.payload.clone(); self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; - self.tcb.add_inflight_packet(seq, payload); + self.tcb.add_inflight_packet(payload); Poll::Ready(Ok(payload_len)) } From d5648d7888c76431fa8a1048dd30030d06821049 Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Mon, 10 Mar 2025 11:35:59 +0800 Subject: [PATCH 35/35] minor changes --- src/lib.rs | 2 +- src/stream/tcp.rs | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 75290ff..281c752 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -158,9 +158,9 @@ async fn process_device_read( let network_tuple = packet.network_tuple(); match sessions.lock().await.entry(network_tuple) { std::collections::hash_map::Entry::Occupied(entry) => { + log::trace!("packet sent to stream: {} len {}", network_tuple, packet.payload.len()); use std::io::{Error, ErrorKind::Other}; entry.get().send(packet).map_err(|e| Error::new(Other, e))?; - log::trace!("packet sent to stream: {}", network_tuple); } std::collections::hash_map::Entry::Vacant(entry) => { log::debug!("session created: {}", network_tuple); diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index f14ef77..a8cb91d 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -316,12 +316,8 @@ impl AsyncRead for IpStackTcpStream { PacketStatus::NewPacket => { // if incoming_seq != self.tcb.get_ack() { // dbg!(incoming_seq); - // self.packet_to_send = Some(self.create_rev_packet( - // ACK, - // TTL, - // None, - // Vec::new(), - // )?); + // let packet = self.create_rev_packet(ACK, TTL, None, Vec::new())?; + // self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; // continue; // }