diff --git a/.gitignore b/.gitignore index 5fd3f0f..b1ecb27 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.env .vscode/ .VSCodeCounter/ /target/ diff --git a/Cargo.toml b/Cargo.toml index edbd11c..0d3881f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,31 +11,29 @@ readme = "README.md" [dependencies] ahash = "0.8" -tokio = { version = "1.43", features = [ +etherparse = { version = "0.17", default-features = false, features = ["std"] } +log = { version = "0.4", default-features = false } +rand = { version = "0.9", default-features = false, features = ["thread_rng"] } +thiserror = { version = "2.0", default-features = false } +tokio = { version = "1.43", default-features = false, features = [ "sync", "rt", "time", "io-util", "macros", -], default-features = false } -etherparse = { version = "0.17", default-features = false, features = ["std"] } -thiserror = { version = "2.0", default-features = false } -log = { version = "0.4", default-features = false } -rand = { version = "0.9", default-features = false, features = ["thread_rng"] } +] } [dev-dependencies] -tokio = { version = "1.43", features = [ - "rt-multi-thread", -], default-features = false } clap = { version = "4.5", features = ["derive"] } +criterion = { version = "0.5" } # Benchmarks +dotenvy = "0.15" env_logger = "0.11" +tokio = { version = "1.43", default-features = false, features = [ + "rt-multi-thread", +] } +tun = { version = "0.7.13", default-features = false, features = ["async"] } udp-stream = { version = "0.0", default-features = false } -# Benchmarks -criterion = { version = "0.5" } - -[target.'cfg(any(target_os = "linux", target_os = "macos"))'.dev-dependencies] -tun = { version = "0.7.13", features = ["async"], default-features = false } [target.'cfg(target_os = "windows")'.dev-dependencies] wintun = { version = "0.5", default-features = false } diff --git a/README.md b/README.md index 9517ed6..084a950 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ An asynchronous lightweight userspace implementation of TCP/IP stack for Tun dev Unstable, under development. [![Crates.io](https://img.shields.io/crates/v/ipstack.svg)](https://crates.io/crates/ipstack) -![ipstack](https://docs.rs/ipstack/badge.svg) +[![ipstack](https://docs.rs/ipstack/badge.svg)](https://docs.rs/ipstack) [![Documentation](https://img.shields.io/badge/docs-release-brightgreen.svg?style=flat)](https://docs.rs/ipstack) [![Download](https://img.shields.io/crates/d/ipstack.svg)](https://crates.io/crates/ipstack) [![License](https://img.shields.io/crates/l/ipstack.svg?style=flat)](https://github.com/narrowlink/ipstack/blob/main/LICENSE) @@ -34,7 +34,7 @@ async fn main() { #[cfg(target_os = "windows")] config.platform_config(|config| { - config.device_guid(Some(12324323423423434234_u128)); + config.device_guid(12324323423423434234_u128); }); let mut ipstack_config = ipstack::IpStackConfig::default(); @@ -51,7 +51,7 @@ async fn main() { }); } IpStackStream::Udp(mut udp) => { - let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 53); + let addr: SocketAddr = "1.1.1.1:53".parse().unwrap(); let mut rhs = UdpStream::connect(addr).await.unwrap(); tokio::spawn(async move { let _ = tokio::io::copy_bidirectional(&mut udp, &mut rhs).await; @@ -60,12 +60,8 @@ async fn main() { IpStackStream::UnknownTransport(u) => { if u.src_addr().is_ipv4() && u.ip_protocol() == IpNumber::ICMP { let (icmp_header, req_payload) = Icmpv4Header::from_slice(u.payload()).unwrap(); - if let etherparse::Icmpv4Type::EchoRequest(req) = icmp_header.icmp_type { + if let etherparse::Icmpv4Type::EchoRequest(echo) = icmp_header.icmp_type { println!("ICMPv4 echo"); - let echo = IcmpEchoHeader { - id: req.id, - seq: req.seq, - }; let mut resp = Icmpv4Header::new(etherparse::Icmpv4Type::EchoReply(echo)); resp.update_checksum(req_payload); let mut payload = resp.to_bytes().to_vec(); @@ -86,4 +82,4 @@ async fn main() { } ``` -We also suggest that you take a look at the complete [examples](examples). +We also suggest that you take a look at the complete [examples](./examples). diff --git a/examples/tun2.rs b/examples/tun.rs similarity index 93% rename from examples/tun2.rs rename to examples/tun.rs index 3c17db3..a800341 100644 --- a/examples/tun2.rs +++ b/examples/tun.rs @@ -5,7 +5,7 @@ //! //! This example must be run as root or administrator privileges. //! ``` -//! sudo target/debug/examples/tun2 --server-addr 127.0.0.1:8080 # Linux or macOS +//! sudo target/debug/examples/tun --server-addr 127.0.0.1:8080 # Linux or macOS //! ``` //! Then please run the `echo` example server, which listens on TCP & UDP ports 127.0.0.1:8080. //! ``` @@ -28,7 +28,7 @@ //! use clap::Parser; -use etherparse::{IcmpEchoHeader, Icmpv4Header}; +use etherparse::Icmpv4Header; use ipstack::{stream::IpStackStream, IpNumber}; use std::net::{Ipv4Addr, SocketAddr}; use tokio::net::TcpStream; @@ -71,6 +71,7 @@ struct Args { #[tokio::main] async fn main() -> Result<(), Box> { + dotenvy::dotenv().ok(); let args = Args::parse(); let default = format!("{:?}", args.verbosity); @@ -154,12 +155,8 @@ async fn main() -> Result<(), Box> { let n = number; if u.src_addr().is_ipv4() && u.ip_protocol() == IpNumber::ICMP { let (icmp_header, req_payload) = Icmpv4Header::from_slice(u.payload())?; - if let etherparse::Icmpv4Type::EchoRequest(req) = icmp_header.icmp_type { + if let etherparse::Icmpv4Type::EchoRequest(echo) = icmp_header.icmp_type { log::info!("#{n} ICMPv4 echo"); - let echo = IcmpEchoHeader { - id: req.id, - seq: req.seq, - }; let mut resp = Icmpv4Header::new(etherparse::Icmpv4Type::EchoReply(echo)); resp.update_checksum(req_payload); let mut payload = resp.to_bytes().to_vec(); @@ -174,7 +171,7 @@ async fn main() -> Result<(), Box> { continue; } IpStackStream::UnknownNetwork(pkt) => { - log::info!("#{number} unknown transport - {} bytes", pkt.len()); + log::info!("#{number} unknown network - {} bytes", pkt.len()); continue; } }; diff --git a/examples/tun_wintun.rs b/examples/tun_wintun.rs index e55c8d2..c7254a3 100644 --- a/examples/tun_wintun.rs +++ b/examples/tun_wintun.rs @@ -1,7 +1,7 @@ use std::net::{Ipv4Addr, SocketAddr}; use clap::Parser; -use etherparse::{IcmpEchoHeader, Icmpv4Header}; +use etherparse::Icmpv4Header; use ipstack::{stream::IpStackStream, IpNumber}; use tokio::net::TcpStream; use udp_stream::UdpStream; @@ -19,6 +19,7 @@ struct Args { #[tokio::main] async fn main() -> Result<(), Box> { + dotenvy::dotenv().ok(); let args = Args::parse(); env_logger::init(); @@ -46,10 +47,7 @@ async fn main() -> Result<(), Box> { let mut ip_stack = ipstack::IpStack::new(ipstack_config, tun::create_as_async(&config)?); #[cfg(target_os = "windows")] - let mut ip_stack = ipstack::IpStack::new( - ipstack_config, - wintun::WinTunDevice::new(ipv4, Ipv4Addr::new(255, 255, 255, 0)), - ); + let mut ip_stack = ipstack::IpStack::new(ipstack_config, wintun::WinTunDevice::new(ipv4, Ipv4Addr::new(255, 255, 255, 0))); let server_addr = args.server_addr; @@ -86,12 +84,8 @@ async fn main() -> Result<(), Box> { IpStackStream::UnknownTransport(u) => { if u.src_addr().is_ipv4() && u.ip_protocol() == IpNumber::ICMP { let (icmp_header, req_payload) = Icmpv4Header::from_slice(u.payload())?; - if let etherparse::Icmpv4Type::EchoRequest(req) = icmp_header.icmp_type { + if let etherparse::Icmpv4Type::EchoRequest(echo) = icmp_header.icmp_type { println!("ICMPv4 echo"); - let echo = IcmpEchoHeader { - id: req.id, - seq: req.seq, - }; let mut resp = Icmpv4Header::new(etherparse::Icmpv4Type::EchoReply(echo)); resp.update_checksum(req_payload); let mut payload = resp.to_bytes().to_vec(); @@ -178,17 +172,11 @@ mod wintun { std::task::Poll::Ready(Ok(buf.len())) } - fn poll_flush( - self: std::pin::Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + fn poll_flush(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) } - fn poll_shutdown( - self: std::pin::Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + fn poll_shutdown(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) } } diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..8449be0 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1 @@ +max_width = 140 diff --git a/src/lib.rs b/src/lib.rs index 93ab830..281c752 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,16 +1,9 @@ #![doc = include_str!("../README.md")] -use crate::{ - packet::IpStackPacketProtocol, - stream::{IpStackStream, IpStackTcpStream, IpStackUdpStream, IpStackUnknownTransport}, -}; +use crate::stream::{IpStackStream, IpStackTcpStream, IpStackUdpStream, IpStackUnknownTransport}; use ahash::AHashMap; -use log::{error, trace}; -use packet::{NetworkPacket, NetworkTuple}; -use std::{ - collections::hash_map::Entry::{Occupied, Vacant}, - time::Duration, -}; +use packet::{NetworkPacket, NetworkTuple, TransportHeader}; +use std::time::Duration; use tokio::{ io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, select, @@ -20,16 +13,14 @@ use tokio::{ pub(crate) type PacketSender = UnboundedSender; pub(crate) type PacketReceiver = UnboundedReceiver; -pub(crate) type SessionCollection = AHashMap; +pub(crate) type SessionCollection = std::sync::Arc>>; mod error; mod packet; pub mod stream; pub use self::error::{IpStackError, Result}; -pub use etherparse::IpNumber; - -const DROP_TTL: u8 = 0; +pub use ::etherparse::IpNumber; #[cfg(unix)] const TTL: u8 = 64; @@ -93,182 +84,147 @@ pub struct IpStack { } impl IpStack { - pub fn new(config: IpStackConfig, device: D) -> IpStack + pub fn new(config: IpStackConfig, device: Device) -> IpStack where - D: AsyncRead + AsyncWrite + Unpin + Send + 'static, + Device: AsyncRead + AsyncWrite + Unpin + Send + 'static, { let (accept_sender, accept_receiver) = mpsc::unbounded_channel::(); - let handle = run(config, device, accept_sender); - IpStack { accept_receiver, - handle, + handle: run(config, device, accept_sender), } } pub async fn accept(&mut self) -> Result { - self.accept_receiver - .recv() - .await - .ok_or(IpStackError::AcceptError) + self.accept_receiver.recv().await.ok_or(IpStackError::AcceptError) } } -fn run( +fn run( config: IpStackConfig, - mut device: D, + mut device: Device, accept_sender: UnboundedSender, -) -> JoinHandle> -where - D: AsyncRead + AsyncWrite + Unpin + Send + 'static, -{ - let mut sessions: SessionCollection = AHashMap::new(); +) -> JoinHandle> { + let sessions: SessionCollection = std::sync::Arc::new(tokio::sync::Mutex::new(AHashMap::new())); let pi = config.packet_information; let offset = if pi && cfg!(unix) { 4 } else { 0 }; - let mut buffer = [0_u8; u16::MAX as usize + 4]; - let (pkt_sender, mut pkt_receiver) = mpsc::unbounded_channel::(); + let mut buffer = vec![0_u8; u16::MAX as usize + offset]; + let (up_pkt_sender, mut up_pkt_receiver) = mpsc::unbounded_channel::(); tokio::spawn(async move { loop { select! { Ok(n) = device.read(&mut buffer) => { - if let Some(stream) = process_device_read( - &buffer[offset..n], - &mut sessions, - pkt_sender.clone(), - &config, - ) { - accept_sender.send(stream)?; + let u = up_pkt_sender.clone(); + if let Err(e) = process_device_read(&buffer[offset..n], sessions.clone(), u, &config, &accept_sender).await { + log::warn!("process_device_read error: {}", e); } } - Some(packet) = pkt_receiver.recv() => { - process_upstream_recv( - packet, - &mut sessions, - &mut device, - #[cfg(unix)] - pi, - ) - .await?; + Some(packet) = up_pkt_receiver.recv() => { + process_upstream_recv(packet, &mut device, #[cfg(unix)]pi).await?; } } } }) } -fn process_device_read( +async fn process_device_read( data: &[u8], - sessions: &mut SessionCollection, - pkt_sender: PacketSender, + sessions: SessionCollection, + up_pkt_sender: PacketSender, config: &IpStackConfig, -) -> Option { + accept_sender: &UnboundedSender, +) -> Result<()> { let Ok(packet) = NetworkPacket::parse(data) else { - return Some(IpStackStream::UnknownNetwork(data.to_owned())); + let stream = IpStackStream::UnknownNetwork(data.to_owned()); + accept_sender.send(stream)?; + return Ok(()); }; - if let IpStackPacketProtocol::Unknown = packet.transport_protocol() { - return Some(IpStackStream::UnknownTransport( - IpStackUnknownTransport::new( - packet.src_addr().ip(), - packet.dst_addr().ip(), - packet.payload, - &packet.ip, - config.mtu, - pkt_sender, - ), + if let TransportHeader::Unknown = packet.transport_header() { + let stream = IpStackStream::UnknownTransport(IpStackUnknownTransport::new( + packet.src_addr().ip(), + packet.dst_addr().ip(), + packet.payload, + &packet.ip, + config.mtu, + up_pkt_sender, )); + accept_sender.send(stream)?; + return Ok(()); } - match sessions.entry(packet.network_tuple()) { - Occupied(mut entry) => { - if let Err(e) = entry.get().send(packet) { - trace!("New stream because: {}", e); - create_stream(e.0, config, pkt_sender).map(|s| { - entry.insert(s.0); - s.1 - }) - } else { - None + let sessions_clone = sessions.clone(); + let network_tuple = packet.network_tuple(); + match sessions.lock().await.entry(network_tuple) { + std::collections::hash_map::Entry::Occupied(entry) => { + log::trace!("packet sent to stream: {} len {}", network_tuple, packet.payload.len()); + use std::io::{Error, ErrorKind::Other}; + entry.get().send(packet).map_err(|e| Error::new(Other, e))?; + } + std::collections::hash_map::Entry::Vacant(entry) => { + log::debug!("session created: {}", network_tuple); + let (packet_sender, mut ip_stack_stream) = create_stream(packet, config, up_pkt_sender)?; + let (tx, rx) = tokio::sync::oneshot::channel::<()>(); + match ip_stack_stream { + IpStackStream::Tcp(ref mut stream) => { + stream.set_destroy_messenger(tx); + } + IpStackStream::Udp(ref mut stream) => { + stream.set_destroy_messenger(tx); + } + _ => unreachable!(), } + tokio::spawn(async move { + rx.await.ok(); + sessions_clone.lock().await.remove(&network_tuple); + log::debug!("session destroyed: {}", network_tuple); + }); + entry.insert(packet_sender); + accept_sender.send(ip_stack_stream)?; } - Vacant(entry) => create_stream(packet, config, pkt_sender).map(|s| { - entry.insert(s.0); - s.1 - }), } + Ok(()) } -fn create_stream( - packet: NetworkPacket, - config: &IpStackConfig, - pkt_sender: PacketSender, -) -> Option<(PacketSender, IpStackStream)> { - match packet.transport_protocol() { - IpStackPacketProtocol::Tcp(h) => { - match IpStackTcpStream::new( - packet.src_addr(), - packet.dst_addr(), - h, - pkt_sender, - config.mtu, - config.tcp_timeout, - ) { - Ok(stream) => Some((stream.stream_sender(), IpStackStream::Tcp(stream))), - Err(e) => { - if matches!(e, IpStackError::InvalidTcpPacket) { - trace!("Invalid TCP packet"); - } else { - error!("IpStackTcpStream::new failed \"{}\"", e); - } - None - } - } +fn create_stream(packet: NetworkPacket, cfg: &IpStackConfig, up_pkt_sender: PacketSender) -> Result<(PacketSender, IpStackStream)> { + let src_addr = packet.src_addr(); + let dst_addr = packet.dst_addr(); + match packet.transport_header() { + TransportHeader::Tcp(h) => { + let stream = IpStackTcpStream::new(src_addr, dst_addr, h.clone(), up_pkt_sender, cfg.mtu, cfg.tcp_timeout)?; + Ok((stream.stream_sender(), IpStackStream::Tcp(stream))) } - IpStackPacketProtocol::Udp => { - let stream = IpStackUdpStream::new( - packet.src_addr(), - packet.dst_addr(), - packet.payload, - pkt_sender, - config.mtu, - config.udp_timeout, - ); - Some((stream.stream_sender(), IpStackStream::Udp(stream))) + TransportHeader::Udp(_) => { + let stream = IpStackUdpStream::new(src_addr, dst_addr, packet.payload, up_pkt_sender, cfg.mtu, cfg.udp_timeout); + Ok((stream.stream_sender(), IpStackStream::Udp(stream))) } - IpStackPacketProtocol::Unknown => { + TransportHeader::Unknown => { unreachable!() } } } -async fn process_upstream_recv( - packet: NetworkPacket, - sessions: &mut SessionCollection, - device: &mut D, +async fn process_upstream_recv( + up_packet: NetworkPacket, + device: &mut Device, #[cfg(unix)] packet_information: bool, -) -> Result<()> -where - D: AsyncWrite + Unpin + 'static, -{ - if packet.ttl() == 0 { - sessions.remove(&packet.reverse_network_tuple()); - return Ok(()); - } +) -> Result<()> { #[allow(unused_mut)] - let Ok(mut packet_bytes) = packet.to_bytes() else { - trace!("to_bytes error"); + let Ok(mut packet_bytes) = up_packet.to_bytes() else { + log::warn!("to_bytes error"); return Ok(()); }; #[cfg(unix)] if packet_information { - if packet.src_addr().is_ipv4() { + if up_packet.src_addr().is_ipv4() { packet_bytes.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP4].concat()); } else { packet_bytes.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP6].concat()); } } device.write_all(&packet_bytes).await?; - // device.flush().await.unwrap(); + // device.flush().await?; Ok(()) } diff --git a/src/packet.rs b/src/packet.rs index 540d4ea..22dd286 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -8,6 +8,20 @@ pub struct NetworkTuple { pub dst: SocketAddr, pub tcp: bool, } + +impl NetworkTuple { + pub fn new(src: SocketAddr, dst: SocketAddr, tcp: bool) -> Self { + NetworkTuple { src, dst, tcp } + } +} + +impl std::fmt::Display for NetworkTuple { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let tcp = if self.tcp { "TCP" } else { "UDP" }; + write!(f, "{} {} -> {}", tcp, self.src, self.dst) + } +} + pub mod tcp_flags { pub const CWR: u8 = 0b10000000; pub const ECE: u8 = 0b01000000; @@ -17,14 +31,6 @@ pub mod tcp_flags { pub const RST: u8 = 0b00000100; pub const SYN: u8 = 0b00000010; pub const FIN: u8 = 0b00000001; - pub const NON: u8 = 0b00000000; -} - -#[derive(Debug, Clone)] -pub(crate) enum IpStackPacketProtocol { - Tcp(TcpHeaderWrapper), - Unknown, - Udp, } #[derive(Debug, Clone)] @@ -53,39 +59,21 @@ impl NetworkPacket { let ip = p.net.ok_or(IpStackError::InvalidPacket)?; let (ip, ip_payload) = match ip { - NetSlice::Ipv4(ip) => ( - IpHeader::Ipv4(ip.header().to_header()), - ip.payload().payload, - ), - NetSlice::Ipv6(ip) => ( - IpHeader::Ipv6(ip.header().to_header()), - ip.payload().payload, - ), + NetSlice::Ipv4(ip) => (IpHeader::Ipv4(ip.header().to_header()), ip.payload().payload), + NetSlice::Ipv6(ip) => (IpHeader::Ipv6(ip.header().to_header()), ip.payload().payload), NetSlice::Arp(_) => return Err(IpStackError::UnsupportedTransportProtocol), }; let (transport, payload) = match p.transport { - Some(etherparse::TransportSlice::Tcp(h)) => { - (TransportHeader::Tcp(h.to_header()), h.payload()) - } - Some(etherparse::TransportSlice::Udp(u)) => { - (TransportHeader::Udp(u.to_header()), u.payload()) - } + Some(etherparse::TransportSlice::Tcp(h)) => (TransportHeader::Tcp(h.to_header()), h.payload()), + Some(etherparse::TransportSlice::Udp(u)) => (TransportHeader::Udp(u.to_header()), u.payload()), _ => (TransportHeader::Unknown, ip_payload), }; let payload = payload.to_vec(); - Ok(NetworkPacket { - ip, - transport, - payload, - }) + Ok(NetworkPacket { ip, transport, payload }) } - pub(crate) fn transport_protocol(&self) -> IpStackPacketProtocol { - match self.transport { - TransportHeader::Udp(_) => IpStackPacketProtocol::Udp, - TransportHeader::Tcp(ref h) => IpStackPacketProtocol::Tcp(h.into()), - _ => IpStackPacketProtocol::Unknown, - } + pub(crate) fn transport_header(&self) -> &TransportHeader { + &self.transport } pub fn src_addr(&self) -> SocketAddr { let port = match &self.transport { @@ -145,53 +133,69 @@ impl NetworkPacket { } } -#[derive(Debug, Clone)] -pub(super) struct TcpHeaderWrapper { - header: TcpHeader, -} - -impl TcpHeaderWrapper { - pub fn inner(&self) -> &TcpHeader { - &self.header +pub fn tcp_header_fmt(network_tuple: NetworkTuple, header: &TcpHeader) -> String { + let mut flags = String::new(); + if header.cwr { + flags.push_str("CWR "); } - pub fn flags(&self) -> u8 { - let inner = self.inner(); - let mut flags = 0; - if inner.cwr { - flags |= tcp_flags::CWR; - } - if inner.ece { - flags |= tcp_flags::ECE; - } - if inner.urg { - flags |= tcp_flags::URG; - } - if inner.ack { - flags |= tcp_flags::ACK; - } - if inner.psh { - flags |= tcp_flags::PSH; - } - if inner.rst { - flags |= tcp_flags::RST; - } - if inner.syn { - flags |= tcp_flags::SYN; - } - if inner.fin { - flags |= tcp_flags::FIN; - } - - flags + if header.ece { + flags.push_str("ECE "); + } + if header.urg { + flags.push_str("URG "); + } + if header.ack { + flags.push_str("ACK "); + } + if header.psh { + flags.push_str("PSH "); + } + if header.rst { + flags.push_str("RST "); + } + if header.syn { + flags.push_str("SYN "); } + if header.fin { + flags.push_str("FIN "); + } + format!( + "{} TcpHeader {{ seq: {}, ack: {}, flags: {} }}", + network_tuple, + header.sequence_number, + header.acknowledgment_number, + flags.trim() + ) } -impl From<&TcpHeader> for TcpHeaderWrapper { - fn from(header: &TcpHeader) -> Self { - TcpHeaderWrapper { - header: header.clone(), - } +pub fn tcp_header_flags(inner: &TcpHeader) -> u8 { + let mut flags = 0; + if inner.cwr { + flags |= tcp_flags::CWR; + } + if inner.ece { + flags |= tcp_flags::ECE; + } + if inner.urg { + flags |= tcp_flags::URG; + } + if inner.ack { + flags |= tcp_flags::ACK; } + if inner.psh { + flags |= tcp_flags::PSH; + } + if inner.rst { + flags |= tcp_flags::RST; + } + if inner.syn { + flags |= tcp_flags::SYN; + } + if inner.fin { + flags |= tcp_flags::FIN; + } + + flags } // pub struct UdpPacket { diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 42632f4..7207e99 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -1,12 +1,12 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; -pub use self::tcp_wrapper::IpStackTcpStream; +pub use self::tcp::IpStackTcpStream; pub use self::udp::IpStackUdpStream; pub use self::unknown::IpStackUnknownTransport; +mod seqnum; mod tcb; mod tcp; -mod tcp_wrapper; mod udp; mod unknown; @@ -22,9 +22,7 @@ impl IpStackStream { match self { IpStackStream::Tcp(tcp) => tcp.local_addr(), IpStackStream::Udp(udp) => udp.local_addr(), - IpStackStream::UnknownNetwork(_) => { - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)) - } + IpStackStream::UnknownNetwork(_) => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)), IpStackStream::UnknownTransport(unknown) => match unknown.src_addr() { IpAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr, 0)), IpAddr::V6(addr) => SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)), @@ -35,9 +33,7 @@ impl IpStackStream { match self { IpStackStream::Tcp(tcp) => tcp.peer_addr(), IpStackStream::Udp(udp) => udp.peer_addr(), - IpStackStream::UnknownNetwork(_) => { - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)) - } + IpStackStream::UnknownNetwork(_) => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)), IpStackStream::UnknownTransport(unknown) => match unknown.dst_addr() { IpAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr, 0)), IpAddr::V6(addr) => SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)), diff --git a/src/stream/seqnum.rs b/src/stream/seqnum.rs new file mode 100644 index 0000000..33d4762 --- /dev/null +++ b/src/stream/seqnum.rs @@ -0,0 +1,180 @@ +use std::ops::{Add, AddAssign, Sub, SubAssign}; + +const MAX_DIFF: u32 = u32::MAX / 2; + +/// A TCP sequence number that persents a 32-bit unsigned integer, suppport overflow comparison and arithmetic. +#[derive(Eq, PartialEq, Debug, Copy, Clone, Hash, Default)] +#[repr(transparent)] +pub struct SeqNum(pub u32); + +impl std::fmt::Display for SeqNum { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for SeqNum { + fn from(value: u32) -> Self { + Self(value) + } +} + +impl From for u32 { + fn from(value: SeqNum) -> Self { + value.0 + } +} + +impl From for usize { + fn from(value: SeqNum) -> Self { + value.0 as usize + } +} + +impl TryFrom for SeqNum { + type Error = std::io::Error; + fn try_from(value: usize) -> Result { + if value > u32::MAX as usize { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!("value 0x{:X} is too large to convert to SeqNum", value), + )); + } + Ok(Self(value as u32)) + } +} + +impl PartialEq for SeqNum { + fn eq(&self, other: &u32) -> bool { + self.0 == *other + } +} + +impl PartialOrd for SeqNum { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl PartialOrd for SeqNum { + fn partial_cmp(&self, other: &u32) -> Option { + Some(self.cmp(&SeqNum(*other))) + } +} + +impl Ord for SeqNum { + #[inline] + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + let diff = self.0.wrapping_sub(other.0); + if diff == 0 { + std::cmp::Ordering::Equal + } else if diff < MAX_DIFF { + std::cmp::Ordering::Greater + } else { + std::cmp::Ordering::Less + } + } +} + +impl Add for SeqNum { + type Output = SeqNum; + #[inline] + fn add(self, rhs: Self) -> Self::Output { + SeqNum(self.0.wrapping_add(rhs.0)) + } +} + +impl Add for SeqNum { + type Output = SeqNum; + #[inline] + fn add(self, rhs: u32) -> Self::Output { + SeqNum(self.0.wrapping_add(rhs)) + } +} + +impl AddAssign for SeqNum { + fn add_assign(&mut self, rhs: Self) { + self.0 = self.0.wrapping_add(rhs.0) + } +} + +impl AddAssign for SeqNum { + fn add_assign(&mut self, rhs: u32) { + self.0 = self.0.wrapping_add(rhs) + } +} + +impl Sub for SeqNum { + type Output = SeqNum; + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + SeqNum(self.0.wrapping_sub(rhs.0)) + } +} + +impl Sub for SeqNum { + type Output = SeqNum; + #[inline] + fn sub(self, rhs: u32) -> Self::Output { + SeqNum(self.0.wrapping_sub(rhs)) + } +} + +impl SubAssign for SeqNum { + fn sub_assign(&mut self, rhs: Self) { + self.0 = self.0.wrapping_sub(rhs.0) + } +} + +impl SubAssign for SeqNum { + fn sub_assign(&mut self, rhs: u32) { + self.0 = self.0.wrapping_sub(rhs) + } +} + +impl SeqNum { + pub fn distance(&self, other: Self) -> u32 { + let diff = self.0.wrapping_sub(other.0); + if diff <= MAX_DIFF { + diff + } else { + u32::MAX - diff + 1 + } + } +} + +#[test] +fn test_seq_num_near_overflow() { + let a: SeqNum = (u32::MAX - 3).into(); + let b = a + 8; + + assert_eq!(a, SeqNum(4294967292)); + assert_eq!(b, SeqNum(4)); + + assert!(a < b); + assert!(b > a); + assert!(a <= b); + assert!(b >= a); + assert!(a != b); + + assert_eq!(a.distance(b), 8); + assert_eq!(b.distance(a), 8); +} + +#[test] +fn test_seq_num_near_max_diff() { + let a = SeqNum(MAX_DIFF - 1); + let mut b = SeqNum(MAX_DIFF + 1); + + assert!(a < b); + assert!(b > a); + assert_eq!(a.distance(b), 2); + + b += 3; + assert_eq!(b.distance(a), 5); + + b -= 10; + assert_eq!(b.distance(a), 5); + + assert_eq!(b, SeqNum(MAX_DIFF - 6)); +} diff --git a/src/stream/tcb.rs b/src/stream/tcb.rs index ec0a671..f904bf2 100644 --- a/src/stream/tcb.rs +++ b/src/stream/tcb.rs @@ -1,16 +1,22 @@ -use crate::packet::TcpHeaderWrapper; -use std::{collections::BTreeMap, pin::Pin, time::Duration}; -use tokio::time::Sleep; +use super::seqnum::SeqNum; +use etherparse::TcpHeader; +use std::collections::BTreeMap; const MAX_UNACK: u32 = 1024 * 16; // 16KB const READ_BUFFER_SIZE: usize = 1024 * 16; // 16KB #[derive(Debug, PartialEq, Clone, Copy)] -pub enum TcpState { - SynReceived(bool), // bool means if syn/ack is sent +pub(crate) enum TcpState { + // Init, /* Since we always act as a server, it starts from `Listen`, so we don't use states Init & SynSent. */ + // SynSent, + Listen, + SynReceived, Established, - FinWait1(bool), - FinWait2(bool), // bool means waiting for ack + FinWait1, // act as a client, actively send a farewell packet to the other side, followed with FinWait2, TimeWait, Closed + FinWait2, + TimeWait, + CloseWait, // act as a server, followed with LastAck, Closed + LastAck, Closed, } @@ -26,60 +32,45 @@ pub(super) enum PacketStatus { #[derive(Debug)] pub(super) struct Tcb { - seq: u32, - pub(super) retransmission: Option, - ack: u32, - last_ack: u32, - pub(super) timeout: Pin>, - tcp_timeout: Duration, + seq: SeqNum, + ack: SeqNum, + last_ack: SeqNum, recv_window: u16, send_window: u16, state: TcpState, avg_send_window: (u64, u64), // (avg, count) - pub(super) inflight_packets: Vec, - unordered_packets: BTreeMap, + inflight_packets: Vec, + unordered_packets: BTreeMap, } impl Tcb { - pub(super) fn new(ack: u32, tcp_timeout: Duration) -> Tcb { + pub(super) fn new(ack: SeqNum) -> Tcb { #[cfg(debug_assertions)] let seq = 100; #[cfg(not(debug_assertions))] let seq = rand::random::(); - let deadline = tokio::time::Instant::now() + tcp_timeout; Tcb { - seq, - retransmission: None, + seq: seq.into(), ack, - last_ack: seq, - tcp_timeout, - timeout: Box::pin(tokio::time::sleep_until(deadline)), + last_ack: seq.into(), send_window: u16::MAX, recv_window: 0, - state: TcpState::SynReceived(false), + state: TcpState::Listen, avg_send_window: (1, 1), inflight_packets: Vec::new(), unordered_packets: BTreeMap::new(), } } - pub(super) fn add_inflight_packet(&mut self, seq: u32, buf: Vec) { - let buf_len = buf.len() as u32; - self.inflight_packets.push(InflightPacket::new(seq, buf)); - self.seq = self.seq.wrapping_add(buf_len); - } - pub(super) fn add_unordered_packet(&mut self, seq: u32, buf: Vec) { + + pub(super) fn add_unordered_packet(&mut self, seq: SeqNum, buf: Vec) { if seq < self.ack { + log::debug!("Received packet with seq < ack: seq = {}, ack = {}", seq, self.ack); return; } - self.unordered_packets - .insert(seq, UnorderedPacket::new(buf)); + self.unordered_packets.insert(seq, UnorderedPacket::new(buf)); } pub(super) fn get_available_read_buffer_size(&self) -> usize { - READ_BUFFER_SIZE.saturating_sub( - self.unordered_packets - .iter() - .fold(0, |acc, (_, p)| acc + p.payload.len()), - ) + READ_BUFFER_SIZE.saturating_sub(self.unordered_packets.iter().fold(0, |acc, (_, p)| acc + p.payload.len())) } pub(super) fn get_unordered_packets(&mut self) -> Option> { // dbg!(self.ack); @@ -89,18 +80,18 @@ impl Tcb { self.unordered_packets.remove(&self.ack).map(|p| p.payload) } pub(super) fn add_seq_one(&mut self) { - self.seq = self.seq.wrapping_add(1); + self.seq += 1; } - pub(super) fn get_seq(&self) -> u32 { + pub(super) fn get_seq(&self) -> SeqNum { self.seq } - pub(super) fn add_ack(&mut self, add: u32) { - self.ack = self.ack.wrapping_add(add); + pub(super) fn add_ack(&mut self, add: SeqNum) { + self.ack += add; } - pub(super) fn get_ack(&self) -> u32 { + pub(super) fn get_ack(&self) -> SeqNum { self.ack } - pub(super) fn get_last_ack(&self) -> u32 { + pub(super) fn get_last_ack(&self) -> SeqNum { self.last_ack } pub(super) fn change_state(&mut self, state: TcpState) { @@ -110,8 +101,7 @@ impl Tcb { self.state } pub(super) fn change_send_window(&mut self, window: u16) { - let avg_send_window = ((self.avg_send_window.0 * self.avg_send_window.1) + window as u64) - / (self.avg_send_window.1 + 1); + let avg_send_window = ((self.avg_send_window.0 * self.avg_send_window.1) + window as u64) / (self.avg_send_window.1 + 1); self.avg_send_window.0 = avg_send_window; self.avg_send_window.1 += 1; self.send_window = window; @@ -141,27 +131,24 @@ impl Tcb { // } // } - pub(super) fn check_pkt_type(&self, header: &TcpHeaderWrapper, p: &[u8]) -> PacketStatus { - let tcp_header = header.inner(); - let received_ack_distance = self.seq.wrapping_sub(tcp_header.acknowledgment_number); + pub(super) fn check_pkt_type(&self, tcp_header: &TcpHeader, p: &[u8]) -> PacketStatus { + let received_ack = SeqNum(tcp_header.acknowledgment_number); + let received_ack_distance = self.seq - received_ack; - let current_ack_distance = self.seq.wrapping_sub(self.last_ack); - if received_ack_distance > current_ack_distance - || (tcp_header.acknowledgment_number != self.seq - && self.seq.saturating_sub(tcp_header.acknowledgment_number) == 0) - { + let current_ack_distance = self.seq - self.last_ack; + if received_ack_distance > current_ack_distance || (self.seq != received_ack && self.seq.0.saturating_sub(received_ack.0) == 0) { PacketStatus::Invalid - } else if self.last_ack == tcp_header.acknowledgment_number { + } else if self.last_ack == received_ack { if !p.is_empty() { PacketStatus::NewPacket } else if self.send_window == tcp_header.window_size && self.seq != self.last_ack { PacketStatus::RetransmissionRequest - } else if self.ack.wrapping_sub(1) == tcp_header.sequence_number { + } else if self.ack - 1 == tcp_header.sequence_number { PacketStatus::KeepAlive } else { PacketStatus::WindowUpdate } - } else if self.last_ack < tcp_header.acknowledgment_number { + } else if self.last_ack < received_ack { if !p.is_empty() { PacketStatus::NewPacket } else { @@ -171,14 +158,20 @@ impl Tcb { PacketStatus::Invalid } } - pub(super) fn change_last_ack(&mut self, ack: u32) { - let distance = ack.wrapping_sub(self.last_ack); - self.last_ack = self.last_ack.wrapping_add(distance); + + pub(super) fn add_inflight_packet(&mut self, buf: Vec) { + let buf_len = buf.len() as u32; + self.inflight_packets.push(InflightPacket::new(self.seq, buf)); + self.seq += buf_len; + } + + pub(super) fn change_last_ack(&mut self, ack: SeqNum) { + self.last_ack = ack; if self.state == TcpState::Established { - if let Some(i) = self.inflight_packets.iter().position(|p| p.contains(ack)) { + if let Some(i) = self.inflight_packets.iter().position(|p| p.contains_seq_num(ack - 1)) { let mut inflight_packet = self.inflight_packets.remove(i); - let distance = ack.wrapping_sub(inflight_packet.seq) as usize; + let distance = ack.distance(inflight_packet.seq) as usize; if distance < inflight_packet.payload.len() { inflight_packet.payload.drain(0..distance); inflight_packet.seq = ack; @@ -186,41 +179,59 @@ impl Tcb { } } self.inflight_packets.retain(|p| { - let last_byte = p.seq.wrapping_add(p.payload.len() as u32); - last_byte.saturating_sub(self.last_ack) > 0 + let last_byte = p.seq + (p.payload.len() as u32); + last_byte > self.last_ack }); } } - pub fn is_send_buffer_full(&self) -> bool { - self.seq.wrapping_sub(self.last_ack) >= MAX_UNACK + + pub(crate) fn find_inflight_packet(&self, seq: SeqNum) -> Option<&InflightPacket> { + self.inflight_packets.iter().find(|p| p.seq == seq) + } + + #[allow(dead_code)] + pub(crate) fn get_all_inflight_packets(&self) -> &Vec { + &self.inflight_packets } - pub(crate) fn reset_timeout(&mut self) { - let deadline = tokio::time::Instant::now() + self.tcp_timeout; - self.timeout.as_mut().reset(deadline); + pub fn is_send_buffer_full(&self) -> bool { + (self.seq - self.last_ack).0 >= MAX_UNACK } } #[derive(Debug)] pub struct InflightPacket { - pub seq: u32, + pub seq: SeqNum, pub payload: Vec, // pub send_time: SystemTime, // todo } impl InflightPacket { - fn new(seq: u32, payload: Vec) -> Self { + fn new(seq: SeqNum, payload: Vec) -> Self { Self { seq, payload, // send_time: SystemTime::now(), // todo } } - pub(crate) fn contains(&self, seq: u32) -> bool { - self.seq < seq && seq <= self.seq + self.payload.len() as u32 + pub(crate) fn contains_seq_num(&self, seq: SeqNum) -> bool { + self.seq <= seq && seq < self.seq + self.payload.len() as u32 } } +#[test] +fn test_in_flight_packet() { + let p = InflightPacket::new((u32::MAX - 1).into(), vec![10, 20, 30, 40, 50]); + + assert!(p.contains_seq_num((u32::MAX - 1).into())); + assert!(p.contains_seq_num(u32::MAX.into())); + assert!(p.contains_seq_num(0.into())); + assert!(p.contains_seq_num(1.into())); + assert!(p.contains_seq_num(2.into())); + + assert!(!p.contains_seq_num(3.into())); +} + #[derive(Debug)] struct UnorderedPacket { payload: Vec, diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index 77e2547..a8cb91d 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -1,14 +1,14 @@ +use super::seqnum::SeqNum; use crate::{ error::IpStackError, packet::{ - tcp_flags::{ACK, FIN, NON, PSH, RST, SYN}, - IpHeader, IpStackPacketProtocol, NetworkPacket, TcpHeaderWrapper, TransportHeader, + tcp_flags::{ACK, FIN, PSH, RST, SYN}, + tcp_header_flags, tcp_header_fmt, IpHeader, NetworkPacket, NetworkTuple, TransportHeader, }, stream::tcb::{PacketStatus, Tcb, TcpState}, - PacketReceiver, PacketSender, DROP_TTL, TTL, + PacketReceiver, PacketSender, TTL, }; -use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel}; -use log::{error, trace, warn}; +use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel, TcpHeader}; use std::{ cmp, future::Future, @@ -41,73 +41,99 @@ impl Shutdown { } #[derive(Debug)] -pub(crate) struct IpStackTcpStream { +pub struct IpStackTcpStream { src_addr: SocketAddr, dst_addr: SocketAddr, + stream_sender: PacketSender, stream_receiver: PacketReceiver, - packet_sender: PacketSender, - packet_to_send: Option, + up_packet_sender: PacketSender, tcb: Tcb, mtu: u16, shutdown: Shutdown, write_notify: Option, + destroy_messenger: Option>, + timeout: Pin>, + timeout_interval: Duration, } impl IpStackTcpStream { pub(crate) fn new( src_addr: SocketAddr, dst_addr: SocketAddr, - tcp: TcpHeaderWrapper, - packet_sender: PacketSender, - stream_receiver: PacketReceiver, + tcp: TcpHeader, + up_packet_sender: PacketSender, mtu: u16, - tcp_timeout: Duration, + timeout_interval: Duration, ) -> Result { + let (stream_sender, stream_receiver) = tokio::sync::mpsc::unbounded_channel::(); + let deadline = tokio::time::Instant::now() + timeout_interval; let stream = IpStackTcpStream { src_addr, dst_addr, + stream_sender, stream_receiver, - packet_sender, - packet_to_send: None, - tcb: Tcb::new(tcp.inner().sequence_number + 1, tcp_timeout), + up_packet_sender, + tcb: Tcb::new(SeqNum(tcp.sequence_number) + 1), mtu, shutdown: Shutdown::None, write_notify: None, + destroy_messenger: None, + timeout: Box::pin(tokio::time::sleep_until(deadline)), + timeout_interval, }; - if tcp.inner().syn { + if tcp.syn { return Ok(stream); } - if !tcp.inner().rst { + if !tcp.rst { let pkt = stream.create_rev_packet(RST | ACK, TTL, None, Vec::new())?; - if let Err(err) = stream.packet_sender.send(pkt) { - warn!("Error sending RST/ACK packet: {:?}", err); + if let Err(err) = stream.up_packet_sender.send(pkt) { + log::warn!("Error sending RST/ACK packet: {:?}", err); } } - Err(IpStackError::InvalidTcpPacket) + let info = format!("Invalid TCP packet: {}", tcp_header_fmt(stream.network_tuple(), &tcp)); + Err(IpStackError::IoError(Error::new(ErrorKind::ConnectionRefused, info))) } - fn calculate_payload_len(&self, ip_header_size: u16, tcp_header_size: u16) -> u16 { + fn reset_timeout(&mut self, final_reset: bool) { + let two_msl = Duration::from_secs(2); + let deadline = tokio::time::Instant::now() + if final_reset { two_msl } else { self.timeout_interval }; + self.timeout.as_mut().reset(deadline); + } + + pub(crate) fn network_tuple(&self) -> NetworkTuple { + NetworkTuple::new(self.src_addr, self.dst_addr, true) + } + + pub fn local_addr(&self) -> SocketAddr { + self.src_addr + } + pub fn peer_addr(&self) -> SocketAddr { + self.dst_addr + } + pub fn stream_sender(&self) -> PacketSender { + self.stream_sender.clone() + } + + pub(crate) fn set_destroy_messenger(&mut self, messenger: tokio::sync::oneshot::Sender<()>) { + self.destroy_messenger = Some(messenger); + } + + fn calculate_payload_max_len(&self, ip_header_size: u16, tcp_header_size: u16) -> u16 { cmp::min( self.tcb.get_send_window(), self.mtu.saturating_sub(ip_header_size + tcp_header_size), ) } - fn create_rev_packet( - &self, - flags: u8, - ttl: u8, - seq: impl Into>, - mut payload: Vec, - ) -> Result { + fn create_rev_packet(&self, flags: u8, ttl: u8, seq: impl Into>, mut payload: Vec) -> Result { let mut tcp_header = etherparse::TcpHeader::new( self.dst_addr.port(), self.src_addr.port(), - seq.into().unwrap_or(self.tcb.get_seq()), + seq.into().unwrap_or(self.tcb.get_seq()).0, self.tcb.get_recv_window(), ); - tcp_header.acknowledgment_number = self.tcb.get_ack(); + tcp_header.acknowledgment_number = self.tcb.get_ack().0; tcp_header.syn = flags & SYN != 0; tcp_header.ack = flags & ACK != 0; tcp_header.rst = flags & RST != 0; @@ -117,14 +143,11 @@ impl IpStackTcpStream { let ip_header = match (self.dst_addr.ip(), self.src_addr.ip()) { (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => { let mut ip_h = Ipv4Header::new(0, ttl, IpNumber::TCP, dst.octets(), src.octets()) - .map_err(IpStackError::from)?; - let payload_len = self.calculate_payload_len( - ip_h.header_len() as u16, - tcp_header.header_len() as u16, - ); + .map_err(|e| Error::new(ErrorKind::InvalidInput, e))?; + let payload_len = self.calculate_payload_max_len(ip_h.header_len() as u16, tcp_header.header_len() as u16); payload.truncate(payload_len as usize); ip_h.set_payload_len(payload.len() + tcp_header.header_len()) - .map_err(IpStackError::from)?; + .map_err(|e| Error::new(ErrorKind::InvalidInput, e))?; ip_h.dont_fragment = true; IpHeader::Ipv4(ip_h) } @@ -138,13 +161,10 @@ impl IpStackTcpStream { source: dst.octets(), destination: src.octets(), }; - let payload_len = self.calculate_payload_len( - ip_h.header_len() as u16, - tcp_header.header_len() as u16, - ); + let payload_len = self.calculate_payload_max_len(ip_h.header_len() as u16, tcp_header.header_len() as u16); payload.truncate(payload_len as usize); let len = payload.len() + tcp_header.header_len(); - ip_h.set_payload_length(len).map_err(IpStackError::from)?; + ip_h.set_payload_length(len).map_err(|e| Error::new(ErrorKind::InvalidInput, e))?; IpHeader::Ipv6(ip_h) } @@ -155,12 +175,12 @@ impl IpStackTcpStream { IpHeader::Ipv4(ref ip_header) => { tcp_header.checksum = tcp_header .calc_checksum_ipv4(ip_header, &payload) - .or(Err(ErrorKind::InvalidInput))?; + .map_err(|e| Error::new(ErrorKind::InvalidInput, e))?; } IpHeader::Ipv6(ref ip_header) => { tcp_header.checksum = tcp_header .calc_checksum_ipv6(ip_header, &payload) - .or(Err(ErrorKind::InvalidInput))?; + .map_err(|e| Error::new(ErrorKind::InvalidInput, e))?; } } Ok(NetworkPacket { @@ -172,321 +192,251 @@ impl IpStackTcpStream { } impl AsyncRead for IpStackTcpStream { - fn poll_read( - mut self: std::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll> { + fn poll_read(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll> { loop { - if self.tcb.retransmission.is_some() { - self.write_notify = Some(cx.waker().clone()); - if matches!(self.as_mut().poll_flush(cx), Poll::Pending) { - return Poll::Pending; - } - } - - if let Some(packet) = self.packet_to_send.take() { - self.packet_sender - .send(packet) - .or(Err(ErrorKind::UnexpectedEof))?; - } if self.tcb.get_state() == TcpState::Closed { self.shutdown.ready(); return Poll::Ready(Ok(())); } - if self.tcb.get_state() == TcpState::FinWait2(false) { - self.packet_to_send = - Some(self.create_rev_packet(NON, DROP_TTL, None, Vec::new())?); - self.tcb.change_state(TcpState::Closed); - self.shutdown.ready(); - return Poll::Ready(Err(Error::from(ErrorKind::ConnectionAborted))); - } - let min = self.tcb.get_available_read_buffer_size() as u16; self.tcb.change_recv_window(min); - if matches!(Pin::new(&mut self.tcb.timeout).poll(cx), Poll::Ready(_)) { - trace!("timeout reached for {:?}", self.dst_addr); - self.packet_sender - .send(self.create_rev_packet(RST | ACK, TTL, None, Vec::new())?) - .or(Err(ErrorKind::UnexpectedEof))?; + let final_reset = self.tcb.get_state() == TcpState::TimeWait; + if matches!(Pin::new(&mut self.timeout).poll(cx), Poll::Ready(_)) { + if !final_reset { + log::warn!("timeout reached for {}", self.network_tuple()); + } + let packet = self.create_rev_packet(RST | ACK, TTL, None, Vec::new())?; + self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; self.tcb.change_state(TcpState::Closed); self.shutdown.ready(); return Poll::Ready(Err(Error::from(ErrorKind::TimedOut))); } - self.tcb.reset_timeout(); + self.reset_timeout(final_reset); - if self.tcb.get_state() == TcpState::SynReceived(false) { - self.packet_to_send = - Some(self.create_rev_packet(SYN | ACK, TTL, None, Vec::new())?); + if self.tcb.get_state() == TcpState::Listen { + let packet = self.create_rev_packet(SYN | ACK, TTL, None, Vec::new())?; self.tcb.add_seq_one(); - self.tcb.change_state(TcpState::SynReceived(true)); + self.tcb.change_state(TcpState::SynReceived); + self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; continue; } - if let Some(b) = self - .tcb - .get_unordered_packets() - .filter(|_| matches!(self.shutdown, Shutdown::None)) - { - self.tcb.add_ack(b.len() as u32); + if let Some(b) = self.tcb.get_unordered_packets().filter(|_| matches!(self.shutdown, Shutdown::None)) { + self.tcb.add_ack(b.len().try_into()?); buf.put_slice(&b); - self.packet_sender - .send(self.create_rev_packet(ACK, TTL, None, Vec::new())?) - .or(Err(ErrorKind::UnexpectedEof))?; + let packet = self.create_rev_packet(ACK, TTL, None, Vec::new())?; + self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; return Poll::Ready(Ok(())); } - if self.tcb.get_state() == TcpState::FinWait1(true) { - self.packet_to_send = - Some(self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?); + if self.tcb.get_state() == TcpState::CloseWait { + let packet = self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?; self.tcb.add_seq_one(); - self.tcb.add_ack(1); - self.tcb.change_state(TcpState::FinWait2(true)); + self.tcb.add_ack(1.into()); + self.tcb.change_state(TcpState::LastAck); + self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; continue; } else if matches!(self.shutdown, Shutdown::Pending(_)) && self.tcb.get_state() == TcpState::Established && self.tcb.get_last_ack() == self.tcb.get_seq() { - self.packet_to_send = - Some(self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?); + // Act as a client, actively send a farewell packet to the other side. + let packet = self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?; self.tcb.add_seq_one(); - self.tcb.change_state(TcpState::FinWait1(false)); + self.tcb.change_state(TcpState::FinWait1); + self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; continue; } match self.stream_receiver.poll_recv(cx) { - Poll::Ready(Some(p)) => { - let IpStackPacketProtocol::Tcp(t) = p.transport_protocol() else { + Poll::Ready(None) => return Poll::Ready(Ok(())), + Poll::Pending => return Poll::Pending, + Poll::Ready(Some(network_packet)) => { + let TransportHeader::Tcp(tcp_header) = network_packet.transport_header() else { unreachable!() }; - if t.flags() & RST != 0 { - self.packet_to_send = - Some(self.create_rev_packet(NON, DROP_TTL, None, Vec::new())?); + let payload = &network_packet.payload; + let flags = tcp_header_flags(tcp_header); + let incoming_ack: SeqNum = tcp_header.acknowledgment_number.into(); + let incoming_seq: SeqNum = tcp_header.sequence_number.into(); + let window_size = tcp_header.window_size; + if flags & RST != 0 { self.tcb.change_state(TcpState::Closed); self.shutdown.ready(); return Poll::Ready(Err(Error::from(ErrorKind::ConnectionReset))); } - if self.tcb.check_pkt_type(&t, &p.payload) == PacketStatus::Invalid { + if self.tcb.check_pkt_type(tcp_header, payload) == PacketStatus::Invalid { continue; } - if self.tcb.get_state() == TcpState::SynReceived(true) { - if t.flags() == ACK { - self.tcb.change_last_ack(t.inner().acknowledgment_number); - self.tcb.change_send_window(t.inner().window_size); + if self.tcb.get_state() == TcpState::SynReceived { + if flags == ACK { + self.tcb.change_last_ack(incoming_ack); + self.tcb.change_send_window(window_size); self.tcb.change_state(TcpState::Established); } } else if self.tcb.get_state() == TcpState::Established { - if t.flags() == ACK { - match self.tcb.check_pkt_type(&t, &p.payload) { + if flags == ACK { + match self.tcb.check_pkt_type(tcp_header, payload) { PacketStatus::WindowUpdate => { - self.tcb.change_send_window(t.inner().window_size); - if let Some(ref n) = self.write_notify { - n.wake_by_ref(); - self.write_notify = None; - }; + self.tcb.change_send_window(window_size); + if let Some(waker) = self.write_notify.take() { + waker.wake_by_ref(); + } continue; } PacketStatus::Invalid => continue, PacketStatus::KeepAlive => { - self.tcb.change_last_ack(t.inner().acknowledgment_number); - self.tcb.change_send_window(t.inner().window_size); - self.packet_to_send = - Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); + self.tcb.change_last_ack(incoming_ack); + self.tcb.change_send_window(window_size); + let packet = self.create_rev_packet(ACK, TTL, None, Vec::new())?; + self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; continue; } PacketStatus::RetransmissionRequest => { - self.tcb.change_send_window(t.inner().window_size); - self.tcb.retransmission = Some(t.inner().acknowledgment_number); - if matches!(self.as_mut().poll_flush(cx), Poll::Pending) { - return Poll::Pending; + log::debug!("Retransmission request {}", tcp_header_fmt(self.network_tuple(), tcp_header)); + self.tcb.change_send_window(window_size); + if let Some(packet) = self.tcb.find_inflight_packet(incoming_ack) { + let rev_packet = self.create_rev_packet(PSH | ACK, TTL, packet.seq, packet.payload.clone())?; + self.up_packet_sender.send(rev_packet).or(Err(ErrorKind::UnexpectedEof))?; + } else { + log::error!("Packet {} not found in inflight_packets", incoming_ack); + log::error!("seq: {}", self.tcb.get_seq()); + log::error!("last_ack: {}", self.tcb.get_last_ack()); + log::error!("ack: {}", self.tcb.get_ack()); + log::error!("inflight_packets:"); + for p in self.tcb.get_all_inflight_packets().iter() { + log::error!("seq: {}", p.seq); + log::error!("payload len: {}", p.payload.len()); + } + panic!("Please report these values at: https://github.com/narrowlink/ipstack/"); } continue; } PacketStatus::NewPacket => { - // if t.inner().sequence_number != self.tcb.get_ack() { - // dbg!(t.inner().sequence_number); - // self.packet_to_send = Some(self.create_rev_packet( - // ACK, - // TTL, - // None, - // Vec::new(), - // )?); + // if incoming_seq != self.tcb.get_ack() { + // dbg!(incoming_seq); + // let packet = self.create_rev_packet(ACK, TTL, None, Vec::new())?; + // self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; // continue; // } - self.tcb.change_last_ack(t.inner().acknowledgment_number); - self.tcb - .add_unordered_packet(t.inner().sequence_number, p.payload); + self.tcb.change_last_ack(incoming_ack); + self.tcb.add_unordered_packet(incoming_seq, payload.clone()); - self.tcb.change_send_window(t.inner().window_size); - if let Some(ref n) = self.write_notify { - n.wake_by_ref(); - self.write_notify = None; - }; + self.tcb.change_send_window(window_size); + if let Some(waker) = self.write_notify.take() { + waker.wake_by_ref(); + } continue; } PacketStatus::Ack => { - self.tcb.change_last_ack(t.inner().acknowledgment_number); - self.tcb.change_send_window(t.inner().window_size); - if let Some(ref n) = self.write_notify { - n.wake_by_ref(); - self.write_notify = None; - }; + self.tcb.change_last_ack(incoming_ack); + self.tcb.change_send_window(window_size); + if let Some(waker) = self.write_notify.take() { + waker.wake_by_ref(); + } continue; } }; } - if t.flags() == (FIN | ACK) { - self.tcb.add_ack(1); - self.packet_to_send = - Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); - self.tcb.change_state(TcpState::FinWait1(true)); + if flags == (FIN | ACK) { + self.tcb.add_ack(1.into()); + let packet = self.create_rev_packet(ACK, TTL, None, Vec::new())?; + self.tcb.change_state(TcpState::CloseWait); + self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; continue; } - if t.flags() == (PSH | ACK) { - if !matches!( - self.tcb.check_pkt_type(&t, &p.payload), - PacketStatus::NewPacket - ) { + if flags == (PSH | ACK) { + if !matches!(self.tcb.check_pkt_type(tcp_header, payload), PacketStatus::NewPacket) { continue; } - self.tcb.change_last_ack(t.inner().acknowledgment_number); + self.tcb.change_last_ack(incoming_ack); - if p.payload.is_empty() - || self.tcb.get_ack() != t.inner().sequence_number - { + if payload.is_empty() || self.tcb.get_ack() != incoming_seq { continue; } - self.tcb.change_send_window(t.inner().window_size); + self.tcb.change_send_window(window_size); - self.tcb - .add_unordered_packet(t.inner().sequence_number, p.payload); + self.tcb.add_unordered_packet(incoming_seq, payload.clone()); continue; } - } else if self.tcb.get_state() == TcpState::FinWait1(false) { - if t.flags() == ACK { - self.tcb.change_last_ack(t.inner().acknowledgment_number); - self.tcb.add_ack(1); - self.tcb.change_state(TcpState::FinWait2(true)); + } else if self.tcb.get_state() == TcpState::FinWait1 { + if flags == ACK { + self.tcb.change_last_ack(incoming_ack); + self.tcb.add_ack(1.into()); + self.tcb.change_state(TcpState::FinWait2); continue; - } else if t.flags() == (FIN | ACK) { - self.tcb.add_ack(1); - self.packet_to_send = - Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); - self.tcb.change_send_window(t.inner().window_size); - self.tcb.change_state(TcpState::FinWait2(true)); + } + } else if self.tcb.get_state() == TcpState::FinWait2 { + if flags == (FIN | ACK) { + self.tcb.add_ack(1.into()); + let packet = self.create_rev_packet(ACK, TTL, None, Vec::new())?; + self.tcb.change_send_window(window_size); + self.tcb.change_state(TcpState::TimeWait); + self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; continue; } - } else if self.tcb.get_state() == TcpState::FinWait2(true) { - if t.flags() == ACK { - self.tcb.change_state(TcpState::FinWait2(false)); - } else if t.flags() == (FIN | ACK) { - self.packet_to_send = - Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); - self.tcb.change_state(TcpState::FinWait2(false)); + } else if self.tcb.get_state() == TcpState::LastAck { + if flags == ACK { + self.tcb.change_state(TcpState::Closed); } + } else if self.tcb.get_state() == TcpState::TimeWait && flags == (FIN | ACK) { + let packet = self.create_rev_packet(ACK, TTL, None, Vec::new())?; + // wait to timeout, can't change state here + // self.tcb.change_state(TcpState::Closed); + self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; + // now we need to wait for the timeout to reach... } } - Poll::Ready(None) => return Poll::Ready(Ok(())), - Poll::Pending => return Poll::Pending, } } } } impl AsyncWrite for IpStackTcpStream { - fn poll_write( - mut self: std::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { + fn poll_write(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { if self.tcb.get_state() != TcpState::Established { return Poll::Ready(Err(Error::from(ErrorKind::NotConnected))); } - self.tcb.reset_timeout(); + self.reset_timeout(false); - if (self.tcb.get_send_window() as u64) < self.tcb.get_avg_send_window() / 2 - || self.tcb.is_send_buffer_full() - { + if (self.tcb.get_send_window() as u64) < self.tcb.get_avg_send_window() / 2 || self.tcb.is_send_buffer_full() { self.write_notify = Some(cx.waker().clone()); return Poll::Pending; } - if self.tcb.retransmission.is_some() { - self.write_notify = Some(cx.waker().clone()); - if matches!(self.as_mut().poll_flush(cx), Poll::Pending) { - return Poll::Pending; - } - } - let packet = self.create_rev_packet(PSH | ACK, TTL, None, buf.to_vec())?; - let seq = self.tcb.get_seq(); let payload_len = packet.payload.len(); let payload = packet.payload.clone(); - self.packet_sender - .send(packet) - .or(Err(ErrorKind::UnexpectedEof))?; - self.tcb.add_inflight_packet(seq, payload); + self.up_packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; + self.tcb.add_inflight_packet(payload); Poll::Ready(Ok(payload_len)) } - fn poll_flush( - mut self: std::pin::Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { + fn poll_flush(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { if self.tcb.get_state() != TcpState::Established { return Poll::Ready(Err(Error::from(ErrorKind::NotConnected))); } - - if let Some(s) = self.tcb.retransmission.take() { - if let Some(packet) = self.tcb.inflight_packets.iter().find(|p| p.seq == s) { - let rev_packet = - self.create_rev_packet(PSH | ACK, TTL, packet.seq, packet.payload.clone())?; - - self.packet_sender - .send(rev_packet) - .or(Err(ErrorKind::UnexpectedEof))?; - } else { - error!("Packet {} not found in inflight_packets", s); - error!("seq: {}", self.tcb.get_seq()); - error!("last_ack: {}", self.tcb.get_last_ack()); - error!("ack: {}", self.tcb.get_ack()); - error!("inflight_packets:"); - for p in self.tcb.inflight_packets.iter() { - error!("seq: {}", p.seq); - error!("payload len: {}", p.payload.len()); - } - panic!("Please report these values at: https://github.com/narrowlink/ipstack/"); - } - } Poll::Ready(Ok(())) } - fn poll_shutdown( - mut self: std::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_shutdown(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if matches!(self.shutdown, Shutdown::Ready) { return Poll::Ready(Ok(())); } else if matches!(self.shutdown, Shutdown::None) { self.shutdown.pending(cx.waker().clone()); } - self.poll_read( - cx, - &mut tokio::io::ReadBuf::uninit(&mut [MaybeUninit::::uninit()]), - ) + self.poll_read(cx, &mut tokio::io::ReadBuf::uninit(&mut [MaybeUninit::::uninit()])) } } impl Drop for IpStackTcpStream { fn drop(&mut self) { - if let Ok(p) = self.create_rev_packet(NON, DROP_TTL, None, Vec::new()) { - if let Err(err) = self.packet_sender.send(p) { - trace!("Error sending NON packet: {:?}", err); - } + if let Some(messenger) = self.destroy_messenger.take() { + let _ = messenger.send(()); } } } diff --git a/src/stream/tcp_wrapper.rs b/src/stream/tcp_wrapper.rs deleted file mode 100644 index e6653b9..0000000 --- a/src/stream/tcp_wrapper.rs +++ /dev/null @@ -1,115 +0,0 @@ -use super::tcp::IpStackTcpStream as IpStackTcpStreamInner; -use crate::{ - packet::{NetworkPacket, TcpHeaderWrapper}, - IpStackError, PacketSender, -}; -use std::{net::SocketAddr, pin::Pin, time::Duration}; -use tokio::{io::AsyncWriteExt, sync::mpsc, time::timeout}; - -pub struct IpStackTcpStream { - inner: Option>, - peer_addr: SocketAddr, - local_addr: SocketAddr, - stream_sender: PacketSender, -} - -impl IpStackTcpStream { - pub(crate) fn new( - local_addr: SocketAddr, - peer_addr: SocketAddr, - tcp: TcpHeaderWrapper, - pkt_sender: PacketSender, - mtu: u16, - tcp_timeout: Duration, - ) -> Result { - let (stream_sender, stream_receiver) = mpsc::unbounded_channel::(); - IpStackTcpStreamInner::new( - local_addr, - peer_addr, - tcp, - pkt_sender, - stream_receiver, - mtu, - tcp_timeout, - ) - .map(|inner| IpStackTcpStream { - inner: Some(Box::new(inner)), - peer_addr, - local_addr, - stream_sender, - }) - } - pub fn local_addr(&self) -> SocketAddr { - self.local_addr - } - pub fn peer_addr(&self) -> SocketAddr { - self.peer_addr - } - pub fn stream_sender(&self) -> PacketSender { - self.stream_sender.clone() - } -} - -impl tokio::io::AsyncRead for IpStackTcpStream { - fn poll_read( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> std::task::Poll> { - match self.inner.as_mut() { - Some(mut inner) => Pin::new(&mut inner).poll_read(cx, buf), - None => { - std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))) - } - } - } -} - -impl tokio::io::AsyncWrite for IpStackTcpStream { - fn poll_write( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> std::task::Poll> { - match self.inner.as_mut() { - Some(mut inner) => Pin::new(&mut inner).poll_write(cx, buf), - None => { - std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))) - } - } - } - fn poll_flush( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - match self.inner.as_mut() { - Some(mut inner) => Pin::new(&mut inner).poll_flush(cx), - None => { - std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))) - } - } - } - fn poll_shutdown( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - match self.inner.as_mut() { - Some(mut inner) => Pin::new(&mut inner).poll_shutdown(cx), - None => { - std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))) - } - } - } -} - -impl Drop for IpStackTcpStream { - fn drop(&mut self) { - if let Some(mut inner) = self.inner.take() { - tokio::spawn(async move { - if let Err(err) = timeout(Duration::from_secs(2), inner.shutdown()).await { - log::warn!("Error while dropping IpStackTcpStream: {:?}", err); - } - }); - } - } -} diff --git a/src/stream/udp.rs b/src/stream/udp.rs index ad8086c..e615e7e 100644 --- a/src/stream/udp.rs +++ b/src/stream/udp.rs @@ -6,7 +6,7 @@ use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel, Ipv6Header, UdpHeader}; use std::{future::Future, net::SocketAddr, pin::Pin, time::Duration}; use tokio::{ io::{AsyncRead, AsyncWrite}, - sync::mpsc, + sync::{mpsc, oneshot}, time::Sleep, }; @@ -16,11 +16,12 @@ pub struct IpStackUdpStream { dst_addr: SocketAddr, stream_sender: PacketSender, stream_receiver: PacketReceiver, - pkt_sender: PacketSender, + up_pkt_sender: PacketSender, first_payload: Option>, timeout: Pin>, - udp_timeout: Duration, + timeout_interval: Duration, mtu: u16, + destroy_messenger: Option>, } impl IpStackUdpStream { @@ -28,25 +29,30 @@ impl IpStackUdpStream { src_addr: SocketAddr, dst_addr: SocketAddr, payload: Vec, - pkt_sender: PacketSender, + up_pkt_sender: PacketSender, mtu: u16, - udp_timeout: Duration, + timeout_interval: Duration, ) -> Self { let (stream_sender, stream_receiver) = mpsc::unbounded_channel::(); - let deadline = tokio::time::Instant::now() + udp_timeout; + let deadline = tokio::time::Instant::now() + timeout_interval; IpStackUdpStream { src_addr, dst_addr, stream_sender, stream_receiver, - pkt_sender, + up_pkt_sender, first_payload: Some(payload), timeout: Box::pin(tokio::time::sleep_until(deadline)), - udp_timeout, + timeout_interval, mtu, + destroy_messenger: None, } } + pub(crate) fn set_destroy_messenger(&mut self, messenger: oneshot::Sender<()>) { + self.destroy_messenger = Some(messenger); + } + pub(crate) fn stream_sender(&self) -> PacketSender { self.stream_sender.clone() } @@ -55,19 +61,12 @@ impl IpStackUdpStream { const UHS: usize = 8; // udp header size is 8 match (self.dst_addr.ip(), self.src_addr.ip()) { (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => { - let mut ip_h = Ipv4Header::new(0, ttl, IpNumber::UDP, dst.octets(), src.octets()) - .map_err(IpStackError::from)?; + let mut ip_h = Ipv4Header::new(0, ttl, IpNumber::UDP, dst.octets(), src.octets()).map_err(IpStackError::from)?; let line_buffer = self.mtu.saturating_sub((ip_h.header_len() + UHS) as u16); payload.truncate(line_buffer as usize); - ip_h.set_payload_len(payload.len() + UHS) + ip_h.set_payload_len(payload.len() + UHS).map_err(IpStackError::from)?; + let udp_header = UdpHeader::with_ipv4_checksum(self.dst_addr.port(), self.src_addr.port(), &ip_h, &payload) .map_err(IpStackError::from)?; - let udp_header = UdpHeader::with_ipv4_checksum( - self.dst_addr.port(), - self.src_addr.port(), - &ip_h, - &payload, - ) - .map_err(IpStackError::from)?; Ok(NetworkPacket { ip: IpHeader::Ipv4(ip_h), transport: TransportHeader::Udp(udp_header), @@ -89,13 +88,8 @@ impl IpStackUdpStream { payload.truncate(line_buffer as usize); ip_h.payload_length = (payload.len() + UHS) as u16; - let udp_header = UdpHeader::with_ipv6_checksum( - self.dst_addr.port(), - self.src_addr.port(), - &ip_h, - &payload, - ) - .map_err(IpStackError::from)?; + let udp_header = UdpHeader::with_ipv6_checksum(self.dst_addr.port(), self.src_addr.port(), &ip_h, &payload) + .map_err(IpStackError::from)?; Ok(NetworkPacket { ip: IpHeader::Ipv6(ip_h), transport: TransportHeader::Udp(udp_header), @@ -115,7 +109,7 @@ impl IpStackUdpStream { } fn reset_timeout(&mut self) { - let deadline = tokio::time::Instant::now() + self.udp_timeout; + let deadline = tokio::time::Instant::now() + self.timeout_interval; self.timeout.as_mut().reset(deadline); } } @@ -148,31 +142,27 @@ impl AsyncRead for IpStackUdpStream { } impl AsyncWrite for IpStackUdpStream { - fn poll_write( - mut self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> std::task::Poll> { + fn poll_write(mut self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>, buf: &[u8]) -> std::task::Poll> { self.reset_timeout(); let packet = self.create_rev_packet(TTL, buf.to_vec())?; let payload_len = packet.payload.len(); - self.pkt_sender - .send(packet) - .or(Err(std::io::ErrorKind::UnexpectedEof))?; + self.up_pkt_sender.send(packet).or(Err(std::io::ErrorKind::UnexpectedEof))?; std::task::Poll::Ready(Ok(payload_len)) } - fn poll_flush( - self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + fn poll_flush(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) } - fn poll_shutdown( - self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) } } + +impl Drop for IpStackUdpStream { + fn drop(&mut self) { + if let Some(messenger) = self.destroy_messenger.take() { + let _ = messenger.send(()); + } + } +} diff --git a/src/stream/unknown.rs b/src/stream/unknown.rs index 838d93f..173dfce 100644 --- a/src/stream/unknown.rs +++ b/src/stream/unknown.rs @@ -1,9 +1,9 @@ use crate::{ packet::{IpHeader, NetworkPacket, TransportHeader}, - PacketSender, TTL, + IpStackError, PacketSender, TTL, }; use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel, Ipv6Header}; -use std::{io::Error, mem, net::IpAddr}; +use std::net::IpAddr; pub struct IpStackUnknownTransport { src_addr: IpAddr, @@ -15,14 +15,7 @@ pub struct IpStackUnknownTransport { } impl IpStackUnknownTransport { - pub(crate) fn new( - src_addr: IpAddr, - dst_addr: IpAddr, - payload: Vec, - ip: &IpHeader, - mtu: u16, - packet_sender: PacketSender, - ) -> Self { + pub(crate) fn new(src_addr: IpAddr, dst_addr: IpAddr, payload: Vec, ip: &IpHeader, mtu: u16, packet_sender: PacketSender) -> Self { let protocol = match ip { IpHeader::Ipv4(ip) => ip.protocol, IpHeader::Ipv6(ip) => ip.next_header, @@ -48,32 +41,30 @@ impl IpStackUnknownTransport { pub fn ip_protocol(&self) -> IpNumber { self.protocol } - pub fn send(&self, mut payload: Vec) -> Result<(), Error> { + pub fn send(&self, mut payload: Vec) -> std::io::Result<()> { loop { let packet = self.create_rev_packet(&mut payload)?; self.packet_sender .send(packet) - .map_err(|_| Error::new(std::io::ErrorKind::Other, "send error"))?; + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("send error: {}", e)))?; if payload.is_empty() { return Ok(()); } } } - pub fn create_rev_packet(&self, payload: &mut Vec) -> Result { + pub fn create_rev_packet(&self, payload: &mut Vec) -> std::io::Result { match (self.dst_addr, self.src_addr) { (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => { - let mut ip_h = Ipv4Header::new(0, TTL, self.protocol, dst.octets(), src.octets()) - .map_err(crate::IpStackError::from)?; + let mut ip_h = Ipv4Header::new(0, TTL, self.protocol, dst.octets(), src.octets()).map_err(IpStackError::from)?; let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16); let p = if payload.len() > line_buffer as usize { payload.drain(0..line_buffer as usize).collect::>() } else { - mem::take(payload) + std::mem::take(payload) }; - ip_h.set_payload_len(p.len()) - .map_err(crate::IpStackError::from)?; + ip_h.set_payload_len(p.len()).map_err(IpStackError::from)?; Ok(NetworkPacket { ip: IpHeader::Ipv4(ip_h), transport: TransportHeader::Unknown, @@ -91,13 +82,12 @@ impl IpStackUnknownTransport { destination: src.octets(), }; let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16); - payload.truncate(line_buffer as usize); - ip_h.payload_length = payload.len() as u16; let p = if payload.len() > line_buffer as usize { payload.drain(0..line_buffer as usize).collect::>() } else { - mem::take(payload) + std::mem::take(payload) }; + ip_h.set_payload_length(p.len()).map_err(IpStackError::from)?; Ok(NetworkPacket { ip: IpHeader::Ipv6(ip_h), transport: TransportHeader::Unknown,