diff --git a/Cargo.lock b/Cargo.lock index a8330328a..b6b03ba4a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -848,7 +848,7 @@ version = "3.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fde0e0ec90c9dfb3b4b1a0891a7dcd0e2bffde2f7efed5fe7c9bb00e5bfb915e" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.59.0", ] [[package]] @@ -1806,7 +1806,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -2681,7 +2681,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2 0.5.10", + "socket2 0.6.0", "system-configuration", "tokio", "tower-layer", @@ -3257,9 +3257,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.176" +version = "0.2.180" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58f929b4d672ea937a23a1ab494143d968337a5f47e56d0815df1e0890ddf174" +checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc" [[package]] name = "libdlpi-sys" @@ -3858,6 +3858,18 @@ dependencies = [ "libc", ] +[[package]] +name = "nix" +version = "0.31.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225e7cfe711e0ba79a68baeddb2982723e4235247aefce1482f2f16c27865b66" +dependencies = [ + "bitflags 2.9.4", + "cfg-if", + "cfg_aliases 0.2.1", + "libc", +] + [[package]] name = "nom" version = "5.1.3" @@ -5350,11 +5362,13 @@ dependencies = [ "dlpi 0.2.0 (git+https://github.com/oxidecomputer/dlpi-sys?branch=main)", "erased-serde 0.4.5", "futures", + "iddqd", "ispf", "lazy_static", "libc", "libloading 0.7.4", "nexus-client", + "nix 0.31.1", "oximeter", "p9ds", "paste", @@ -6141,7 +6155,7 @@ dependencies = [ "errno 0.3.14", "libc", "linux-raw-sys 0.11.0", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -7286,7 +7300,7 @@ dependencies = [ "getrandom 0.3.2", "once_cell", "rustix 1.1.2", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -7295,7 +7309,7 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2111ef44dae28680ae9752bb89409e7310ca33a8c621ebe7b106cf5c928b3ac0" dependencies = [ - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index a3c16a5ed..1bc01b342 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -132,6 +132,7 @@ hex = "0.4.3" http = "1.1.0" hyper = "1.0" linkme = "0.3.33" +iddqd = "0.3" itertools = "0.13.0" kstat-rs = "0.2.4" lazy_static = "1.4" @@ -139,6 +140,7 @@ libc = "0.2" mockall = "0.12" newtype_derive = "0.1.6" newtype-uuid = { version = "1.0.1", features = [ "v4" ] } +nix = { version = "0.31", features = [ "poll" ] } owo-colors = "4" oxide-tokio-rt = "0.1.2" paste = "1.0.15" diff --git a/bin/propolis-standalone/src/config.rs b/bin/propolis-standalone/src/config.rs index 58b485fd1..3aac56a73 100644 --- a/bin/propolis-standalone/src/config.rs +++ b/bin/propolis-standalone/src/config.rs @@ -10,6 +10,7 @@ use std::sync::Arc; use anyhow::Context; use cpuid_utils::CpuidSet; +use propolis::vsock::proxy::VsockPortMapping; use propolis_types::CpuidIdent; use propolis_types::CpuidValues; use propolis_types::CpuidVendor; @@ -170,6 +171,20 @@ impl VionaDeviceParams { } } +#[derive(Deserialize)] +pub struct VsockDevice { + pub guest_cid: u32, + pub port_mappings: Vec, +} + +impl VsockDevice { + pub fn from_opts( + opts: &BTreeMap, + ) -> Result { + opt_deser(opts) + } +} + // Try to turn unmatched flattened options into a config struct fn opt_deser<'de, T: Deserialize<'de>>( value: &BTreeMap, diff --git a/bin/propolis-standalone/src/main.rs b/bin/propolis-standalone/src/main.rs index 6c5e7e9a8..90135d80e 100644 --- a/bin/propolis-standalone/src/main.rs +++ b/bin/propolis-standalone/src/main.rs @@ -1319,6 +1319,18 @@ fn setup_instance( guard.inventory.register(&pvpanic); } } + "pci-virtio-vsock" => { + let config = config::VsockDevice::from_opts(&dev.options)?; + let bdf = bdf.unwrap(); + let vsock = hw::virtio::PciVirtioSock::new( + 512, + config.guest_cid, + log.new(slog::o!("dev" => "vsock")), + config.port_mappings, + ); + guard.inventory.register(&vsock); + chipset_pci_attach(bdf, vsock); + } _ => { slog::error!(log, "unrecognized driver {driver}"; "name" => name); return Err(Error::new( diff --git a/lib/propolis/Cargo.toml b/lib/propolis/Cargo.toml index 0139a804c..5f409dd09 100644 --- a/lib/propolis/Cargo.toml +++ b/lib/propolis/Cargo.toml @@ -39,6 +39,8 @@ crucible = { workspace = true, optional = true } oximeter = { workspace = true, optional = true } nexus-client = { workspace = true, optional = true } async-trait.workspace = true +iddqd.workspace = true +nix.workspace = true # falcon libloading = { workspace = true, optional = true } diff --git a/lib/propolis/src/hw/virtio/mod.rs b/lib/propolis/src/hw/virtio/mod.rs index 94172740a..0e8e789dc 100644 --- a/lib/propolis/src/hw/virtio/mod.rs +++ b/lib/propolis/src/hw/virtio/mod.rs @@ -21,6 +21,10 @@ mod queue; #[cfg(feature = "falcon")] pub mod softnpu; pub mod viona; +pub mod vsock; + +#[cfg(test)] +pub mod testutil; use crate::common::RWOp; use crate::hw::pci as pci_hw; @@ -29,6 +33,7 @@ use queue::VirtQueue; pub use block::PciVirtioBlock; pub use viona::PciVirtioViona; +pub use vsock::PciVirtioSock; bitflags! { pub struct LegacyFeatures: u64 { @@ -165,6 +170,7 @@ impl DeviceId { match self { Self::Network => Ok(pci_hw::bits::CLASS_NETWORK), Self::Block | Self::NineP => Ok(pci_hw::bits::CLASS_STORAGE), + Self::Socket => Ok(pci_hw::bits::CLASS_UNCLASSIFIED), _ => Err(self), } } @@ -228,6 +234,7 @@ pub trait VirtioIntr: Send + 'static { fn read(&self) -> VqIntr; } +#[derive(Debug)] pub enum VqChange { /// Underlying virtio device has been reset Reset, diff --git a/lib/propolis/src/hw/virtio/queue.rs b/lib/propolis/src/hw/virtio/queue.rs index c6eeaa739..68270795d 100644 --- a/lib/propolis/src/hw/virtio/queue.rs +++ b/lib/propolis/src/hw/virtio/queue.rs @@ -94,7 +94,7 @@ impl VqAvail { } if let Some(idx) = mem.read::(self.gpa_idx) { let ndesc = Wrapping(*idx) - self.cur_avail_idx; - if ndesc.0 != 0 && ndesc.0 < rsize { + if ndesc.0 != 0 && ndesc.0 <= rsize { let avail_idx = self.cur_avail_idx.0 & (rsize - 1); self.cur_avail_idx += Wrapping(1); diff --git a/lib/propolis/src/hw/virtio/testutil.rs b/lib/propolis/src/hw/virtio/testutil.rs new file mode 100644 index 000000000..42c8821e1 --- /dev/null +++ b/lib/propolis/src/hw/virtio/testutil.rs @@ -0,0 +1,629 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +//! Test utilities for constructing fake virtqueues backed by real guest memory. +//! +//! This module provides [`TestVirtQueue`] for single-queue tests and +//! [`TestVirtQueues`] for multi-queue devices. Both allocate guest memory via +//! a tempfile-backed [`PhysMap`], lay out virtio ring structures, and provide +//! helpers to enqueue descriptor chains — simulating a guest driver writing +//! to the available ring. + +use std::sync::Arc; + +use zerocopy::FromBytes; + +use crate::accessors::MemAccessor; +use crate::common::GuestAddr; +use crate::vmm::mem::PhysMap; +use crate::vmm::MemCtx; + +// Re-export queue types so tests outside this module can access them +// without requiring `queue` to be pub(crate). +pub use super::queue::{Chain, DescFlag, VirtQueue, VirtQueues, VqSize}; + +/// Page size for alignment (4 KiB). +const PAGE_SIZE: u64 = 0x1000; + +/// Size in bytes of a virtio descriptor (addr: u64, len: u32, flags: u16, next: u16). +const DESC_SIZE: u64 = 16; + +/// Size in bytes of a used ring element (id: u32, len: u32). +const USED_ELEM_SIZE: u64 = 8; + +/// Size in bytes of an available ring entry (descriptor index: u16). +const AVAIL_ELEM_SIZE: u64 = 2; + +/// Size in bytes of the ring header (flags: u16, idx: u16). +const RING_HEADER_SIZE: u64 = 4; + +/// Number of pages to allocate for the data area in tests. +const DATA_AREA_PAGES: u64 = 64; + +/// Align `val` up to the next multiple of `align` (must be power of 2). +pub const fn align_up(val: u64, align: u64) -> u64 { + (val + align - 1) & !(align - 1) +} + +/// 16-byte virtio descriptor, matching the on-wire/in-memory layout. +#[repr(C)] +#[derive(Copy, Clone, Default, FromBytes)] +pub struct RawDesc { + pub addr: u64, + pub len: u32, + pub flags: u16, + pub next: u16, +} + +/// 8-byte used ring element. +#[repr(C)] +#[derive(Copy, Clone, Default, FromBytes)] +pub struct RawUsedElem { + pub id: u32, + pub len: u32, +} + +/// Guest physical address layout for a single virtqueue's ring structures. +#[derive(Copy, Clone, Debug)] +pub struct QueueLayout { + pub desc_base: u64, + pub avail_base: u64, + pub used_base: u64, + /// First GPA after this queue's structures. + pub end: u64, +} + +impl QueueLayout { + /// Compute the ring layout for a queue of `size` entries starting at + /// `base`. + /// + /// Layout follows the virtio 1.0 split virtqueue format: + /// - Descriptor table: `size * DESC_SIZE` bytes + /// - Available ring: header (4 bytes) + `size * 2` bytes for entries + /// - Used ring: page-aligned, header (4 bytes) + `size * 8` bytes + pub fn new(base: u64, size: u16) -> Self { + let qsz = size as u64; + let desc_base = base; + let avail_base = desc_base + DESC_SIZE * qsz; + let used_base = align_up( + avail_base + RING_HEADER_SIZE + AVAIL_ELEM_SIZE * qsz, + PAGE_SIZE, + ); + let end = align_up( + used_base + RING_HEADER_SIZE + USED_ELEM_SIZE * qsz, + PAGE_SIZE, + ); + Self { desc_base, avail_base, used_base, end } + } +} + +/// Per-queue writer for injecting descriptors into a virtqueue's rings. +pub struct QueueWriter { + layout: QueueLayout, + size: u16, + /// Next free descriptor index. + next_desc: u16, + /// Start of data area for this queue. + data_start: u64, + /// Next free data area offset (GPA). + data_cursor: u64, + /// Avail ring index we've published up to. + avail_idx: u16, +} + +impl QueueWriter { + /// Create a new QueueWriter for a queue with the given layout. + pub fn new(layout: QueueLayout, size: u16, data_start: u64) -> Self { + Self { + layout, + size, + next_desc: 0, + data_start, + data_cursor: data_start, + avail_idx: 0, + } + } + + /// Reset descriptor and data cursors to allow reusing slots. + pub fn reset_cursors(&mut self) { + self.next_desc = 0; + self.data_cursor = self.data_start; + } + + /// Write a descriptor and return its index. + pub fn write_desc( + &mut self, + mem_acc: &MemAccessor, + addr: u64, + len: u32, + flags: u16, + next: u16, + ) -> u16 { + let idx = self.next_desc; + assert!(idx < self.size, "descriptor table exhausted"); + self.next_desc += 1; + + let desc = RawDesc { addr, len, flags, next }; + let gpa = self.layout.desc_base + u64::from(idx) * DESC_SIZE; + let mem = mem_acc.access().unwrap(); + mem.write(GuestAddr(gpa), &desc); + idx + } + + /// Allocate data space and write bytes into it. Returns the GPA. + pub fn write_data(&mut self, mem_acc: &MemAccessor, data: &[u8]) -> u64 { + let gpa = self.data_cursor; + self.data_cursor += data.len() as u64; + let mem = mem_acc.access().unwrap(); + mem.write_from(GuestAddr(gpa), data, data.len()); + gpa + } + + /// Allocate data space without writing. Returns the GPA. + pub fn alloc_data(&mut self, len: u32) -> u64 { + let gpa = self.data_cursor; + self.data_cursor += u64::from(len); + gpa + } + + /// Add a readable descriptor with the given data. + pub fn add_readable(&mut self, mem_acc: &MemAccessor, data: &[u8]) -> u16 { + let gpa = self.write_data(mem_acc, data); + self.write_desc(mem_acc, gpa, data.len() as u32, 0, 0) + } + + /// Add a writable descriptor of the given size. + pub fn add_writable(&mut self, mem_acc: &MemAccessor, len: u32) -> u16 { + let gpa = self.alloc_data(len); + self.write_desc(mem_acc, gpa, len, DescFlag::WRITE.bits(), 0) + } + + /// Chain two descriptors together via NEXT flag. + pub fn chain(&self, mem_acc: &MemAccessor, from: u16, to: u16) { + let gpa = self.layout.desc_base + u64::from(from) * DESC_SIZE; + let mem = mem_acc.access().unwrap(); + let mut raw: RawDesc = *mem.read(GuestAddr(gpa)).unwrap(); + raw.flags |= DescFlag::NEXT.bits(); + raw.next = to; + mem.write(GuestAddr(gpa), &raw); + } + + /// Publish a descriptor chain head on the available ring. + pub fn publish_avail(&mut self, mem_acc: &MemAccessor, head: u16) { + // Available ring layout: + // flags (u16) | idx (u16) | ring[size] (u16 each) + let slot = self.layout.avail_base + + RING_HEADER_SIZE + + u64::from(self.avail_idx % self.size) * AVAIL_ELEM_SIZE; + self.avail_idx += 1; + let new_idx = self.avail_idx; + let mem = mem_acc.access().unwrap(); + mem.write(GuestAddr(slot), &head); + // Write new index at offset 2 (after flags u16) + mem.write(GuestAddr(self.layout.avail_base + 2), &new_idx); + } + + /// Read the used ring index. + pub fn used_idx(&self, mem_acc: &MemAccessor) -> u16 { + let mem = mem_acc.access().unwrap(); + // Used ring idx is at offset 2 (after flags u16) + *mem.read(GuestAddr(self.layout.used_base + 2)).unwrap() + } + + /// Read a used ring entry by index, returning (desc_id, len). + pub fn read_used_elem( + &self, + mem_acc: &MemAccessor, + used_index: u16, + ) -> RawUsedElem { + let mem = mem_acc.access().unwrap(); + // Used ring layout: + // flags (u16) | idx (u16) | ring[size] (RawUsedElem each) + let entry_gpa = self.layout.used_base + + RING_HEADER_SIZE + + u64::from(used_index % self.size) * USED_ELEM_SIZE; + *mem.read(GuestAddr(entry_gpa)).unwrap() + } + + /// Read raw bytes from the buffer of a descriptor. + pub fn read_desc_data( + &self, + mem_acc: &MemAccessor, + desc_id: u16, + len: usize, + ) -> Vec { + let mem = mem_acc.access().unwrap(); + let desc_gpa = self.layout.desc_base + u64::from(desc_id) * DESC_SIZE; + let raw_desc: RawDesc = *mem.read(GuestAddr(desc_gpa)).unwrap(); + + let mut data = vec![0u8; len]; + mem.read_into( + GuestAddr(raw_desc.addr), + &mut crate::common::GuestData::from(data.as_mut_slice()), + len, + ); + data + } +} + +/// Multi-queue test harness for virtio devices that use multiple queues. +pub struct TestVirtQueues { + /// Must stay alive to keep memory mappings valid. + _phys: PhysMap, + mem_acc: MemAccessor, + queues: VirtQueues, + layouts: Vec, + sizes: Vec, + /// Start of data area (after all queue structures). + data_start: u64, +} + +impl TestVirtQueues { + /// Create a new multi-queue test harness. + /// + /// `sizes` specifies the size of each queue (must be powers of 2). + pub fn new(sizes: &[VqSize]) -> Self { + // Compute layouts for all queues sequentially + let mut layouts = Vec::with_capacity(sizes.len()); + let mut size_vals = Vec::with_capacity(sizes.len()); + let mut offset = 0u64; + for &size in sizes { + let size_u16: u16 = size.into(); + let layout = QueueLayout::new(offset, size_u16); + offset = layout.end; + layouts.push(layout); + size_vals.push(size_u16); + } + + // Data area after all rings + let data_start = offset; + let data_area_size = PAGE_SIZE * DATA_AREA_PAGES; + let total_size = + align_up(data_start + data_area_size, PAGE_SIZE) as usize; + + let mut phys = PhysMap::new_test(total_size); + phys.add_test_mem("test-vqs".to_string(), 0, total_size) + .expect("add test mem"); + let mem_acc = phys.finalize(); + + // Create VirtQueues + let queues = VirtQueues::new(sizes); + + // Initialize each queue + for (i, layout) in layouts.iter().enumerate() { + let vq = queues.get(i as u16).unwrap(); + mem_acc.adopt(&vq.acc_mem, Some(format!("test-vq-{i}"))); + vq.map_virtqueue( + layout.desc_base, + layout.avail_base, + layout.used_base, + ); + vq.live.store(true, std::sync::atomic::Ordering::Release); + vq.enabled.store(true, std::sync::atomic::Ordering::Release); + + // Zero out avail and used ring headers + let mem = mem_acc.access().unwrap(); + mem.write(GuestAddr(layout.avail_base), &0u16); + mem.write(GuestAddr(layout.avail_base + 2), &0u16); + mem.write(GuestAddr(layout.used_base), &0u16); + mem.write(GuestAddr(layout.used_base + 2), &0u16); + } + + Self { + _phys: phys, + mem_acc, + queues, + layouts, + sizes: size_vals, + data_start, + } + } + + /// Get the memory accessor. + pub fn mem_acc(&self) -> &MemAccessor { + &self.mem_acc + } + + /// Get the underlying VirtQueues. + pub fn queues(&self) -> &VirtQueues { + &self.queues + } + + /// Get the VirtQueue at the given index. + pub fn vq(&self, idx: u16) -> &Arc { + self.queues.get(idx).unwrap() + } + + /// Create a QueueWriter for the given queue index. + /// + /// `data_offset` is an offset from the shared data area start, + /// allowing different queues to use different regions. + pub fn writer(&self, queue_idx: usize, data_offset: u64) -> QueueWriter { + let layout = self.layouts[queue_idx]; + let size = self.sizes[queue_idx]; + QueueWriter::new(layout, size, self.data_start + data_offset) + } + + /// Get the layout for a queue. + pub fn layout(&self, queue_idx: usize) -> QueueLayout { + self.layouts[queue_idx] + } +} + +/// A test harness wrapping guest memory and a single virtqueue. +/// +/// For multi-queue tests, use [`TestVirtQueues`] instead. +pub struct TestVirtQueue { + inner: TestVirtQueues, + writer: QueueWriter, +} + +impl TestVirtQueue { + /// Create a new test virtqueue. + /// + /// `queue_size` must be a power of 2. + pub fn new(queue_size: u16) -> Self { + let inner = TestVirtQueues::new(&[VqSize::new(queue_size)]); + let writer = inner.writer(0, 0); + Self { inner, writer } + } + + /// Get the underlying `VirtQueue`. + pub fn vq(&self) -> &Arc { + self.inner.vq(0) + } + + /// Get a `MemCtx` guard for directly reading/writing guest memory. + pub fn mem(&self) -> impl std::ops::Deref + '_ { + self.inner.mem_acc().access().expect("test mem accessible") + } + + /// Add a readable descriptor containing `data`. + /// + /// Returns the descriptor index. + pub fn add_readable(&mut self, data: &[u8]) -> u16 { + self.writer.add_readable(self.inner.mem_acc(), data) + } + + /// Add a writable descriptor of `len` bytes. + /// + /// Returns the descriptor index. + pub fn add_writable(&mut self, len: u32) -> u16 { + self.writer.add_writable(self.inner.mem_acc(), len) + } + + /// Link descriptors into a chain by setting NEXT flags. + /// + /// `descs` should be in order: `[head, ..., tail]`. + pub fn chain_descriptors(&mut self, descs: &[u16]) { + for i in 0..descs.len().saturating_sub(1) { + self.writer.chain(self.inner.mem_acc(), descs[i], descs[i + 1]); + } + } + + /// Publish a descriptor chain head on the available ring. + pub fn publish_avail(&mut self, head: u16) { + self.writer.publish_avail(self.inner.mem_acc(), head); + } + + /// Read all entries from the used ring. + /// + /// Returns `(descriptor_id, bytes_written)` pairs. + pub fn read_used(&self) -> Vec<(u32, u32)> { + let used_idx = self.writer.used_idx(self.inner.mem_acc()); + (0..used_idx) + .map(|i| { + let elem = self.writer.read_used_elem(self.inner.mem_acc(), i); + (elem.id, elem.len) + }) + .collect() + } + + /// Pop a chain from the available ring and return it. + pub fn pop_chain(&self) -> Option<(Chain, u16, u32)> { + let mem = self.inner.mem_acc().access()?; + let mut chain = Chain::with_capacity(64); + let (avail_idx, len) = self.vq().pop_avail(&mut chain, &mem)?; + Some((chain, avail_idx, len)) + } + + /// Push a chain back to the used ring. + pub fn push_used(&self, chain: &mut Chain) { + let mem = self.inner.mem_acc().access().unwrap(); + self.vq().push_used(chain, &mem); + } + + /// Get the GPA of a descriptor's buffer. + pub fn desc_addr(&self, idx: u16) -> u64 { + let mem = self.inner.mem_acc().access().unwrap(); + let desc_gpa = + self.inner.layout(0).desc_base + u64::from(idx) * DESC_SIZE; + let raw: RawDesc = *mem.read(GuestAddr(desc_gpa)).unwrap(); + raw.addr + } + + /// Read raw bytes from guest memory at a given GPA. + pub fn read_guest_mem(&self, addr: u64, len: usize) -> Vec { + let mem = self.inner.mem_acc().access().unwrap(); + let mut buf = vec![0u8; len]; + let mut guest_buf = crate::common::GuestData::from(buf.as_mut_slice()); + mem.read_into(GuestAddr(addr), &mut guest_buf, len); + buf + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn smoke_pop_avail_readable() { + let mut tvq = TestVirtQueue::new(16); + + let data = b"hello virtqueue"; + let d0 = tvq.add_readable(data); + tvq.publish_avail(d0); + + let (mut chain, _avail_idx, total_len) = tvq.pop_chain().unwrap(); + assert_eq!(total_len, data.len() as u32); + + let mem = tvq.mem(); + let mut buf = [0u8; 15]; + assert!(chain.read(&mut buf, &mem)); + assert_eq!(&buf, data); + } + + #[test] + fn smoke_pop_avail_writable() { + let mut tvq = TestVirtQueue::new(16); + + let d0 = tvq.add_writable(64); + tvq.publish_avail(d0); + + let (mut chain, _avail_idx, total_len) = tvq.pop_chain().unwrap(); + assert_eq!(total_len, 64); + + let mem = tvq.mem(); + let payload = b"written by device"; + assert!(chain.write(payload, &mem)); + drop(mem); + + tvq.push_used(&mut chain); + + let used = tvq.read_used(); + assert_eq!(used.len(), 1); + assert_eq!(used[0].0, d0 as u32); + assert_eq!(used[0].1, payload.len() as u32); + + let addr = tvq.desc_addr(d0); + let read_back = tvq.read_guest_mem(addr, payload.len()); + assert_eq!(read_back, payload); + } + + #[test] + fn smoke_chained_descriptors() { + let mut tvq = TestVirtQueue::new(16); + + let header_data = [0xAA; 8]; + let body_data = [0xBB; 32]; + let d0 = tvq.add_readable(&header_data); + let d1 = tvq.add_readable(&body_data); + tvq.chain_descriptors(&[d0, d1]); + tvq.publish_avail(d0); + + let (mut chain, _avail_idx, total_len) = tvq.pop_chain().unwrap(); + assert_eq!(total_len, 40); + + let mem = tvq.mem(); + let mut hdr = [0u8; 8]; + assert!(chain.read(&mut hdr, &mem)); + assert_eq!(hdr, header_data); + + let mut body = [0u8; 32]; + assert!(chain.read(&mut body, &mem)); + assert_eq!(body, body_data); + } + + #[test] + fn smoke_mixed_chain() { + let mut tvq = TestVirtQueue::new(16); + + let req_data = [0x01, 0x02, 0x03, 0x04]; + let d0 = tvq.add_readable(&req_data); + let d1 = tvq.add_writable(128); + tvq.chain_descriptors(&[d0, d1]); + tvq.publish_avail(d0); + + let (mut chain, _, total_len) = tvq.pop_chain().unwrap(); + assert_eq!(total_len, 4 + 128); + + let mem = tvq.mem(); + + let mut req = [0u8; 4]; + assert!(chain.read(&mut req, &mem)); + assert_eq!(req, req_data); + + let resp = [0xFF; 16]; + assert!(chain.write(&resp, &mem)); + drop(mem); + + tvq.push_used(&mut chain); + + let addr = tvq.desc_addr(d1); + let read_back = tvq.read_guest_mem(addr, 16); + assert_eq!(read_back, &resp); + } + + #[test] + fn empty_avail_ring_returns_none() { + let tvq = TestVirtQueue::new(16); + assert!(tvq.pop_chain().is_none()); + } + + #[test] + fn multiple_chains() { + let mut tvq = TestVirtQueue::new(16); + + let d0 = tvq.add_readable(b"first"); + tvq.publish_avail(d0); + + let d1 = tvq.add_readable(b"second"); + tvq.publish_avail(d1); + + let (chain0, _, _) = tvq.pop_chain().unwrap(); + let (chain1, _, _) = tvq.pop_chain().unwrap(); + assert!(tvq.pop_chain().is_none()); + + assert_ne!(chain0.remain_read_bytes(), chain1.remain_read_bytes()); + } + + #[test] + fn multi_queue_smoke() { + let tvqs = TestVirtQueues::new(&[ + VqSize::new(64), + VqSize::new(64), + VqSize::new(1), + ]); + + let mut writer0 = tvqs.writer(0, 0); + let mut writer1 = tvqs.writer(1, PAGE_SIZE); + + let d0 = writer0.add_readable(tvqs.mem_acc(), b"queue0"); + writer0.publish_avail(tvqs.mem_acc(), d0); + + let d1 = writer1.add_readable(tvqs.mem_acc(), b"queue1"); + writer1.publish_avail(tvqs.mem_acc(), d1); + + // Pop from each queue + let mem = tvqs.mem_acc().access().unwrap(); + let mut chain0 = Chain::with_capacity(64); + let mut chain1 = Chain::with_capacity(64); + + assert!(tvqs.vq(0).pop_avail(&mut chain0, &mem).is_some()); + assert!(tvqs.vq(1).pop_avail(&mut chain1, &mem).is_some()); + + assert_eq!(chain0.remain_read_bytes(), 6); + assert_eq!(chain1.remain_read_bytes(), 6); + } + + #[test] + fn queue_writer_reset_cursors() { + let tvqs = TestVirtQueues::new(&[VqSize::new(16)]); + let mut writer = tvqs.writer(0, 0); + + // Add some descriptors + let d0 = writer.add_readable(tvqs.mem_acc(), b"first"); + writer.publish_avail(tvqs.mem_acc(), d0); + + // Reset and reuse + writer.reset_cursors(); + + let d1 = writer.add_readable(tvqs.mem_acc(), b"second"); + assert_eq!(d1, 0, "descriptor index should reset to 0"); + writer.publish_avail(tvqs.mem_acc(), d1); + + // Both publishes should have worked + assert_eq!(writer.used_idx(tvqs.mem_acc()), 0); // Nothing consumed yet + } +} diff --git a/lib/propolis/src/hw/virtio/vsock.rs b/lib/propolis/src/hw/virtio/vsock.rs new file mode 100644 index 000000000..6aad9d598 --- /dev/null +++ b/lib/propolis/src/hw/virtio/vsock.rs @@ -0,0 +1,366 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use lazy_static::lazy_static; +use slog::Logger; +use std::sync::Arc; + +use crate::accessors::MemAccessor; +use crate::common::*; +use crate::hw::pci; +use crate::hw::virtio; +use crate::hw::virtio::queue::Chain; +use crate::hw::virtio::queue::VirtQueue; +use crate::hw::virtio::queue::VqSize; +use crate::migrate::*; +use crate::util::regmap::RegMap; +use crate::vmm::MemCtx; +use crate::vsock::packet::VsockPacket; +use crate::vsock::packet::VsockPacketError; +use crate::vsock::packet::VsockPacketHeader; +use crate::vsock::proxy::VsockPortMapping; +use crate::vsock::VsockBackend; +use crate::vsock::VsockProxy; + +use super::pci::PciVirtio; +use super::pci::PciVirtioState; +use super::queue::VirtQueues; +use super::VirtioDevice; + +// virtio queue index numbers for virtio socket devices +pub const VSOCK_RX_QUEUE: u16 = 0x0; +pub const VSOCK_TX_QUEUE: u16 = 0x1; +pub const VSOCK_EVENT_QUEUE: u16 = 0x2; + +/// A permit representing a reserved rx queue descriptor chain. +/// +/// This guarantees we have space to send a packet to the guest before reading +/// data from a host socket, preventing data loss if the queue is full. +/// +/// The permit holds a mutable reference to `VsockVq`, ensuring only one permit +/// can exist at a time (enforced at compile time). If dropped without calling +/// `write`, the chain is retained in `VsockVq` for reuse. +pub struct RxPermit<'a> { + vq: &'a mut VsockVq, +} + +impl RxPermit<'_> { + /// Returns the maximum data payload that can fit in this descriptor chain. + pub fn available_data_space(&self) -> usize { + let header_size = std::mem::size_of::(); + self.vq + .rx_chain + .as_ref() + .expect("has chain") + .remain_write_bytes() + .saturating_sub(header_size) + } + + pub fn write(self, header: &VsockPacketHeader, data: &[u8]) { + // TODO: cannot access memory? + let mem = self.vq.acc_mem.access().expect("mem access for write"); + let queue = + self.vq.queues.get(VSOCK_RX_QUEUE as usize).expect("rx queue"); + + // SAFETY: `RxPermit` should only be created if the owning `VsockVq` + // actually has a `Some(Chain)`. Unfortuantely there doesn't seem to be + // a way to enforce this at compile time. + let mut chain = self.vq.rx_chain.take().expect("has chain"); + chain.write(header, &mem); + + if !data.is_empty() { + let mut done = 0; + chain.for_remaining_type(false, |addr, len| { + let to_write = &data[done..]; + if let Some(copied) = mem.write_from(addr, to_write, len) { + let need_more = copied != to_write.len(); + done += copied; + (copied, need_more) + } else { + (0, false) + } + }); + } + + queue.push_used(&mut chain, &mem); + } +} + +pub struct VsockVq { + queues: Vec>, + acc_mem: MemAccessor, + /// Cached rx chain for permit reuse when dropped without write + rx_chain: Option, +} + +impl VsockVq { + pub(crate) fn new( + queues: Vec>, + acc_mem: MemAccessor, + ) -> Self { + Self { queues, acc_mem, rx_chain: None } + } + + /// Try to acquire a permit for sending a packet to the guest. + /// + /// Returns `Some(RxPermit)` if a descriptor chain is available, + /// `None` if the rx queue is full. + pub fn try_rx_permit(&mut self) -> Option> { + // Reuse cached chain or pop a new one + if self.rx_chain.is_none() { + // TODO: cannot access memory? + let mem = self.acc_mem.access().expect("mem access for write"); + let vq = self.queues.get(VSOCK_RX_QUEUE as usize)?; + let mut chain = Chain::with_capacity(10); + if let Some(_) = vq.pop_avail(&mut chain, &mem) { + self.rx_chain = Some(chain); + } + } + + // We only return a permit iff we know that we are holding onto a valid + // descriptor chain that can be used by the borrowing `RxPermit` + match self.rx_chain { + Some(_) => Some(RxPermit { vq: self }), + None => None, + } + } + + /// Receive all available packets from the TX queue. + /// + /// Returns a Vec of parsed packets. In the future this may be refactored + /// to return an iterator over GuestRegions to avoid copying packet data. + pub fn recv_packet(&self) -> Option> { + // TODO: cannot access memory? + let mem = self.acc_mem.access().expect("mem access for read"); + let vq = self + .queues + .get(VSOCK_TX_QUEUE as usize) + .expect("vsock has tx queue"); + + let mut chain = Chain::with_capacity(10); + let Some((_idx, _clen)) = vq.pop_avail(&mut chain, &mem) else { + return None; + }; + + let packet = VsockPacket::parse(&mut chain, &mem); + vq.push_used(&mut chain, &mem); + + Some(packet) + } +} + +pub struct PciVirtioSock { + cid: u32, + backend: VsockProxy, + virtio_state: PciVirtioState, + pci_state: pci::DeviceState, +} + +impl PciVirtioSock { + pub fn new( + queue_size: u16, + cid: u32, + log: Logger, + port_mappings: Vec, + ) -> Arc { + let queues = VirtQueues::new(&[ + // VSOCK_RX_QUEUE + VqSize::new(queue_size), + // VSOCK_TX_QUEUE + VqSize::new(queue_size), + // VSOCK_EVENT_QUEUE + VqSize::new(1), + ]); + + // One for rx, tx, event + let msix_count = Some(3); + let (virtio_state, pci_state) = PciVirtioState::new( + virtio::Mode::Transitional, + queues, + msix_count, + virtio::DeviceId::Socket, + VIRTIO_VSOCK_CFG_SIZE, + ); + + let vvq = VsockVq::new( + virtio_state.queues.iter().map(Clone::clone).collect(), + pci_state.acc_mem.child(Some("vsock rx queue".to_string())), + ); + let port_mappings = port_mappings.into_iter().collect(); + + let backend = VsockProxy::new(cid, vvq, log, port_mappings); + + Arc::new(Self { cid, backend, virtio_state, pci_state }) + } +} + +impl VirtioDevice for PciVirtioSock { + fn rw_dev_config(&self, mut rwo: crate::common::RWOp) { + VSOCK_DEV_REGS.process(&mut rwo, |id, rwo| match rwo { + RWOp::Read(ro) => match id { + VsockReg::GuestCid => { + ro.write_u32(self.cid); + // The upper 32 bits are reserved and zeroed. + ro.fill(0); + } + }, + RWOp::Write(_) => {} + }) + } + + fn features(&self) -> u64 { + // We support VIRTIO_VSOCK_F_STREAM + // + // virtio spec 1.3: + // The device SHOULD offer the VIRTIO_VSOCK_F_NO_IMPLIED_STREAM feature. + (VsockFeatures::NO_IMPLIED_STREAM | VsockFeatures::STREAM).bits() + } + + fn set_features(&self, feat: u64) -> Result<(), ()> { + // We only care about the vsock specific bits so grab just those + match VsockFeatures::from_bits_truncate(feat) { + // If no feature bit has been negotiated, the device SHOULD act as + // if VIRTIO_VSOCK_F_STREAM has been negotiated. + f if f.is_empty() => Ok(()), + f if f == VsockFeatures::STREAM => Ok(()), + // We have not advertised SEQPACKET so we don't expect it to show up + // here. + _ => Err(()), + } + } + + fn mode(&self) -> virtio::Mode { + virtio::Mode::Transitional + } + + fn queue_notify(&self, vq: &VirtQueue) { + let _ = self.backend.queue_notify(vq.id); + } +} + +impl PciVirtio for PciVirtioSock { + fn virtio_state(&self) -> &PciVirtioState { + &self.virtio_state + } + fn pci_state(&self) -> &pci::DeviceState { + &self.pci_state + } +} + +impl Lifecycle for PciVirtioSock { + fn type_name(&self) -> &'static str { + "pci-virtio-vsock" + } + fn reset(&self) { + self.virtio_state.reset(self); + } + fn migrate(&'_ self) -> Migrator<'_> { + Migrator::NonMigratable + } +} + +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +enum VsockReg { + GuestCid, +} + +lazy_static! { + static ref VSOCK_DEV_REGS: RegMap = { + let layout = [(VsockReg::GuestCid, 8)]; + RegMap::create_packed(VIRTIO_VSOCK_CFG_SIZE, &layout, None) + }; +} + +mod bits { + pub const VIRTIO_VSOCK_CFG_SIZE: usize = 0x8; + + bitflags! { + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + pub struct VsockFeatures: u64 { + const STREAM = 1 << 0; + const SEQPACKET = 1 << 1; + const NO_IMPLIED_STREAM = 1 << 2; + } + } + + #[allow(unused)] + pub const VIRTIO_VSOCK_EVENT_TRANSPORT_RESET: u32 = 0; +} +use bits::*; + +impl VsockPacket { + // TODO: We may want to consider operating on `Vec` to avoid + // double copying the packet contents. For now we are reading all of the + // packet data at once because it's convenient. + fn parse( + chain: &mut Chain, + mem: &MemCtx, + ) -> Result { + let mut packet = VsockPacket::default(); + + // Attempt to read the vsock packet header from the descriptor chain + // before we can process the full packet. + if !chain.read(&mut packet.header, mem) { + return Err(VsockPacketError::ChainHeaderRead); + } + + // If the packet header indicates there is no data in this packet, then + // there's no point in attempting to continue reading from the chain. + if packet.header.len() == 0 { + return Ok(packet); + } + + let hdr_len = usize::try_from(packet.header.len()) + .expect("running on a 64bit platform"); + let chain_len = chain.remain_read_bytes(); + + // Ensure that the vsock packet header length matches the reality of + // the desc chain. + if hdr_len > chain_len { + return Err(VsockPacketError::InvalidPacketLen { + hdr_len, + chain_len, + }); + } + let mut data = vec![0; hdr_len]; + + // While we are here we should validate that packets cid fields do no + // contain reserved bits + if packet.header.src_cid() >> 32 != 0 { + return Err(VsockPacketError::InvalidSrcCid { + src_cid: packet.header.src_cid(), + }); + } + if packet.header.dst_cid() >> 32 != 0 { + return Err(VsockPacketError::InvalidDstCid { + dst_cid: packet.header.dst_cid(), + }); + } + + let mut done = 0; + let copied = chain.for_remaining_type(true, |addr, len| { + let mut remain = GuestData::from(&mut data[done..]); + if let Some(copied) = mem.read_into(addr, &mut remain, len) { + let need_more = copied != remain.len(); + done += copied; + (copied, need_more) + } else { + (0, false) + } + }); + + // If we fail to copy the correct amount of bytes from the desc chain + // something is clearly wrong. + if copied != hdr_len { + return Err(VsockPacketError::InsufficientBytes { + expected: hdr_len, + remaining: copied, + }); + } + + packet.data = data.into(); + + Ok(packet) + } +} diff --git a/lib/propolis/src/lib.rs b/lib/propolis/src/lib.rs index c608a816c..5f07fa1bd 100644 --- a/lib/propolis/src/lib.rs +++ b/lib/propolis/src/lib.rs @@ -34,6 +34,7 @@ pub mod tasks; pub mod util; pub mod vcpu; pub mod vmm; +pub mod vsock; pub use exits::{VmEntry, VmExit}; pub use vmm::Machine; diff --git a/lib/propolis/src/vsock/buffer.rs b/lib/propolis/src/vsock/buffer.rs new file mode 100644 index 000000000..6d0fd8d27 --- /dev/null +++ b/lib/propolis/src/vsock/buffer.rs @@ -0,0 +1,215 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use std::num::NonZeroUsize; +use std::num::Wrapping; + +#[derive(Debug, thiserror::Error)] +pub enum VsockBufError { + #[error( + "VsockBuf has {remaining} bytes available but tried to push {pushed}" + )] + InsufficientSpace { pushed: usize, remaining: usize }, +} + +/// A ringbuffer used to store guest -> host data +pub struct VsockBuf { + buf: Box<[u8]>, + head: Wrapping, + tail: Wrapping, +} + +impl std::fmt::Debug for VsockBuf { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("VsockBuf") + .field("capacity", &self.capacity()) + .field("head", &self.head) + .field("tail", &self.tail) + .field("in_use", &self.len()) + .field("free", &self.free()) + .finish() + } +} + +impl VsockBuf { + /// Create a new `VsockBuf` + pub fn new(capacity: NonZeroUsize) -> Self { + let capacity = capacity.get(); + Self { + buf: vec![0; capacity].into_boxed_slice(), + head: Wrapping(0), + tail: Wrapping(0), + } + } + + pub fn capacity(&self) -> usize { + self.buf.len() + } + + pub fn len(&self) -> usize { + (self.head - self.tail).0 + } + + fn free(&self) -> usize { + self.capacity() - self.len() + } + + pub fn is_empty(&self) -> bool { + self.head == self.tail + } + + pub fn push( + &mut self, + data: impl AsRef<[u8]>, + ) -> Result<(), VsockBufError> { + let data = data.as_ref(); + + if data.len() > self.free() { + return Err(VsockBufError::InsufficientSpace { + pushed: data.len(), + remaining: self.free(), + }); + } + + let head_offset = self.head.0 % self.buf.len(); + let available_len = self.buf.len() - head_offset; + + // If the data can fit in the remaining space of the ring buffer, copy + // it in one go. + if data.len() <= available_len { + self.buf[head_offset..head_offset + data.len()] + .copy_from_slice(&data); + // Otherwise, split it and write the remaining data to the front. + } else { + let (fits, wrapped) = data.split_at(available_len); + self.buf[head_offset..].copy_from_slice(fits); + self.buf[..wrapped.len()].copy_from_slice(wrapped); + } + + self.head += Wrapping(data.len()); + Ok(()) + } + + pub fn write_to( + &mut self, + writer: &mut W, + ) -> std::io::Result { + // If we have no data to write, bail early + if self.is_empty() { + return Ok(0); + } + + let tail_offset = self.tail.0 % self.buf.len(); + let head_offset = self.head.0 % self.buf.len(); + + // If the data is contiguous, write it in one go + let nwritten = if tail_offset < head_offset { + writer.write(&self.buf[tail_offset..head_offset])? + } else { + // Data wraps around, so try to write it in batches + let available_len = self.buf.len() - tail_offset; + let nwritten = writer.write(&self.buf[tail_offset..])?; + + // If we failed to write the entire first segment, return early + if nwritten < available_len { + self.tail += Wrapping(nwritten); + return Ok(nwritten); + } + + // If we were successful, attempt to continue writing the wrapped + // around segment + let second_nwritten = writer.write(&self.buf[..head_offset])?; + nwritten + second_nwritten + }; + + self.tail += Wrapping(nwritten); + Ok(nwritten) + } +} + +#[cfg(test)] +mod test { + use std::{io::Cursor, num::NonZeroUsize}; + + use crate::vsock::buffer::VsockBuf; + + #[test] + fn test_capacity_and_len() { + let mut vb = VsockBuf::new(NonZeroUsize::new(10).unwrap()); + assert_eq!(vb.capacity(), 10); + assert!(vb.is_empty()); + + let data = vec![1; 8]; + let data_len = data.len(); + assert!(vb.push(data).is_ok()); + assert!(!vb.is_empty()); + assert_eq!(vb.capacity(), 10); + assert_eq!(vb.len(), data_len); + } + + #[test] + fn test_push_less_than_capacity() { + let mut vb = VsockBuf::new(NonZeroUsize::new(10).unwrap()); + let data = vec![1; 8]; + assert!(vb.push(data).is_ok()); + } + + #[test] + fn test_push_more_than_capacity() { + let mut vb = VsockBuf::new(NonZeroUsize::new(10).unwrap()); + let data = vec![1; 8]; + assert!(vb.push(data).is_ok()); + + let data = vec![1; 8]; + assert!(vb.push(data).is_err()); + } + + #[test] + fn test_write_to() { + let mut vb = VsockBuf::new(NonZeroUsize::new(10).unwrap()); + let data = vec![1; 10]; + assert!(vb.push(data).is_ok()); + + let mut some_socket = [1; 10]; + let mut cursor = Cursor::new(&mut some_socket[..]); + assert!(vb.write_to(&mut cursor).is_ok_and(|n| n == 10)); + } + + #[test] + fn test_partial_write_to() { + let mut vb = VsockBuf::new(NonZeroUsize::new(10).unwrap()); + let data = vec![1; 10]; + assert!(vb.push(data).is_ok()); + + let mut some_socket = [1; 5]; + let mut cursor = Cursor::new(&mut some_socket[..]); + assert!(vb.write_to(&mut cursor).is_ok_and(|n| n == 5)); + assert_eq!(vb.len(), 5, "5 bytes remain"); + + // reset the cursor and read another chunk + cursor.set_position(0); + assert!(vb.write_to(&mut cursor).is_ok_and(|n| n == 5)); + assert!(vb.is_empty()); + } + + #[test] + fn test_wrap_around() { + let mut vb = VsockBuf::new(NonZeroUsize::new(10).unwrap()); + let data = vec![1; 8]; + assert!(vb.push(data).is_ok()); + + let mut some_socket = [1; 4]; + let mut cursor = Cursor::new(&mut some_socket[..]); + assert!(vb.write_to(&mut cursor).is_ok_and(|n| n == 4)); + assert_eq!(some_socket, [1u8; 4]); + + let data = vec![2; 4]; + assert!(vb.push(data).is_ok()); + + let mut some_socket = [1; 8]; + let mut cursor = Cursor::new(&mut some_socket[..]); + assert!(vb.write_to(&mut cursor).is_ok_and(|n| n == 8)); + assert_eq!(some_socket, [1, 1, 1, 1, 2, 2, 2, 2]); + } +} diff --git a/lib/propolis/src/vsock/mod.rs b/lib/propolis/src/vsock/mod.rs new file mode 100644 index 000000000..709aad407 --- /dev/null +++ b/lib/propolis/src/vsock/mod.rs @@ -0,0 +1,29 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +pub mod buffer; +pub mod packet; + +#[cfg(target_os = "illumos")] +pub mod poller; + +#[cfg(not(target_os = "illumos"))] +#[path = "poller_stub.rs"] +pub mod poller; + +pub mod proxy; +pub use proxy::VsockProxy; + +/// Well-known CID for the host +pub(crate) const VSOCK_HOST_CID: u64 = 2; + +#[derive(Debug, thiserror::Error)] +pub enum VsockError { + #[error("failed to send virt queue notification for queue {}", queue)] + QueueNotify { queue: u16 }, +} + +pub trait VsockBackend: Send + Sync + 'static { + fn queue_notify(&self, queue_id: u16) -> Result<(), VsockError>; +} diff --git a/lib/propolis/src/vsock/packet.rs b/lib/propolis/src/vsock/packet.rs new file mode 100644 index 000000000..45c2de7a7 --- /dev/null +++ b/lib/propolis/src/vsock/packet.rs @@ -0,0 +1,290 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use strum::FromRepr; +use zerocopy::byteorder::little_endian::{U16, U32, U64}; +use zerocopy::{FromBytes, Immutable, IntoBytes}; + +use crate::vsock::proxy::CONN_TX_BUF_SIZE; +use crate::vsock::VSOCK_HOST_CID; + +bitflags! { + /// Shutdown flags for VIRTIO_VSOCK_OP_SHUTDOWN + #[derive(Clone, Copy, Debug, PartialEq, Eq)] + #[repr(transparent)] + pub struct VsockPacketFlags: u32 { + const VIRTIO_VSOCK_SHUTDOWN_F_RECEIVE = 1; + const VIRTIO_VSOCK_SHUTDOWN_F_SEND = 2; + } +} + +#[derive(Debug, Clone, Copy, FromRepr, PartialEq, Eq)] +#[repr(u16)] +pub enum VsockSocketType { + Stream = 1, + SeqPacket = 2, + #[cfg(test)] + InvalidTestValue = 0x1de, +} + +#[derive(thiserror::Error, Debug)] +pub enum VsockPacketError { + #[error("failed to read packet header from descriptor chain")] + ChainHeaderRead, + + #[error("vsock packet header reported {hdr_len} bytes but the descriptor chain cont ains {chain_len}")] + InvalidPacketLen { hdr_len: usize, chain_len: usize }, + + #[error("descriptor chain only yielded {remaining} bytes out of {expected} bytes")] + InsufficientBytes { expected: usize, remaining: usize }, + + #[error("src_cid {src_cid} contains reserved bits")] + InvalidSrcCid { src_cid: u64 }, + + #[error("dst_cid {dst_cid} contains reserved bits")] + InvalidDstCid { dst_cid: u64 }, +} + +#[derive(Clone, Copy, Debug, FromRepr, Eq, PartialEq)] +#[repr(u16)] +pub enum VsockPacketOp { + Request = 1, + Response = 2, + Reset = 3, + Shutdown = 4, + ReadWrite = 5, + CreditUpdate = 6, + CreditRequest = 7, +} + +#[repr(C, packed)] +#[derive(Copy, Clone, Default, Debug, FromBytes, IntoBytes, Immutable)] +pub struct VsockPacketHeader { + src_cid: U64, + dst_cid: U64, + src_port: U32, + dst_port: U32, + len: U32, + // Note this is "type" in the spec + socket_type: U16, + op: U16, + flags: U32, + buf_alloc: U32, + fwd_cnt: U32, +} + +impl VsockPacketHeader { + pub fn src_cid(&self) -> u64 { + self.src_cid.get() + } + + pub fn dst_cid(&self) -> u64 { + self.dst_cid.get() + } + + pub fn src_port(&self) -> u32 { + self.src_port.get() + } + + pub fn dst_port(&self) -> u32 { + self.dst_port.get() + } + + pub fn len(&self) -> u32 { + self.len.get() + } + + pub fn socket_type(&self) -> Option { + VsockSocketType::from_repr(self.socket_type.get()) + } + + pub fn op(&self) -> Option { + VsockPacketOp::from_repr(self.op.get()) + } + + pub fn flags(&self) -> VsockPacketFlags { + VsockPacketFlags::from_bits_retain(self.flags.get()) + } + + pub fn buf_alloc(&self) -> u32 { + self.buf_alloc.get() + } + + pub fn fwd_cnt(&self) -> u32 { + self.fwd_cnt.get() + } + + pub const fn new() -> Self { + Self { + src_cid: U64::new(0), + dst_cid: U64::new(0), + src_port: U32::new(0), + dst_port: U32::new(0), + len: U32::new(0), + socket_type: U16::new(VsockSocketType::Stream as u16), + op: U16::new(0), + flags: U32::new(0), + buf_alloc: U32::new(CONN_TX_BUF_SIZE as u32), + fwd_cnt: U32::new(0), + } + } + + pub const fn set_src_cid(&mut self, cid: u32) -> &mut Self { + // The spec states: + // + // The upper 32 bits of src_cid and dst_cid are reserved and zeroed. + self.src_cid = U64::new(cid as u64); + self + } + + pub const fn set_dst_cid(&mut self, cid: u32) -> &mut Self { + // The spec states: + // + // The upper 32 bits of src_cid and dst_cid are reserved and zeroed. + self.dst_cid = U64::new(cid as u64); + self + } + + pub const fn set_src_port(&mut self, port: u32) -> &mut Self { + self.src_port = U32::new(port); + self + } + + pub const fn set_dst_port(&mut self, port: u32) -> &mut Self { + self.dst_port = U32::new(port); + self + } + + pub const fn set_len(&mut self, len: u32) -> &mut Self { + self.len = U32::new(len); + self + } + + pub const fn set_socket_type( + &mut self, + socket_type: VsockSocketType, + ) -> &mut Self { + self.socket_type = U16::new(socket_type as u16); + self + } + + pub const fn set_op(&mut self, op: VsockPacketOp) -> &mut Self { + self.op = U16::new(op as u16); + self + } + + pub const fn set_flags(&mut self, flags: VsockPacketFlags) -> &mut Self { + self.flags = U32::new(flags.bits()); + self + } + + pub const fn set_buf_alloc(&mut self, buf_alloc: u32) -> &mut Self { + self.buf_alloc = U32::new(buf_alloc); + self + } + + pub const fn set_fwd_cnt(&mut self, fwd_cnt: u32) -> &mut Self { + self.fwd_cnt = U32::new(fwd_cnt); + self + } +} + +#[derive(Default)] +pub struct VsockPacket { + pub(crate) header: VsockPacketHeader, + pub(crate) data: Box<[u8]>, +} + +impl std::fmt::Debug for VsockPacket { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("VsockPacket") + .field("header", &self.header) + .field("data_len", &self.data.len()) + .finish() + } +} + +impl VsockPacket { + fn new( + guest_cid: u32, + src_port: u32, + dst_port: u32, + op: VsockPacketOp, + ) -> Self { + let mut header = VsockPacketHeader::new(); + header + .set_src_cid(VSOCK_HOST_CID as u32) + .set_dst_cid(guest_cid) + .set_src_port(src_port) + .set_dst_port(dst_port) + .set_op(op); + + Self { header, data: [].into() } + } + + pub fn new_reset(guest_cid: u32, src_port: u32, dst_port: u32) -> Self { + Self::new(guest_cid, src_port, dst_port, VsockPacketOp::Reset) + } + + pub fn new_response(guest_cid: u32, src_port: u32, dst_port: u32) -> Self { + let packet = + Self::new(guest_cid, src_port, dst_port, VsockPacketOp::Response); + packet + } + + /// Create a new RW packet that sets the len field to the size of the data. + /// + /// Panics if the supplied data value is greater than u32::MAX as anything + /// larger would not fit within the peers buf_alloc which is defined as u32. + pub fn new_rw( + guest_cid: u32, + src_port: u32, + dst_port: u32, + fwd_cnt: u32, + data: impl Into>, + ) -> Self { + let data = data.into(); + let len = data.len(); + assert!( + len < u32::MAX as usize, + "vsock packets should not exceed u32::MAX" + ); + let mut packet = + Self::new(guest_cid, src_port, dst_port, VsockPacketOp::ReadWrite); + packet.header.set_len(len as u32); + packet.header.set_fwd_cnt(fwd_cnt); + packet.data = data; + packet + } + + pub fn new_credit_update( + guest_cid: u32, + src_port: u32, + dst_port: u32, + fwd_cnt: u32, + ) -> Self { + let mut packet = Self::new( + guest_cid, + src_port, + dst_port, + VsockPacketOp::CreditUpdate, + ); + packet.header.set_fwd_cnt(fwd_cnt); + packet + } + + pub fn new_shutdown( + guest_cid: u32, + src_port: u32, + dst_port: u32, + flags: VsockPacketFlags, + fwd_cnt: u32, + ) -> Self { + let mut packet = + Self::new(guest_cid, src_port, dst_port, VsockPacketOp::Shutdown); + packet.header.set_fwd_cnt(fwd_cnt); + packet.header.set_flags(flags); + packet + } +} diff --git a/lib/propolis/src/vsock/poller.rs b/lib/propolis/src/vsock/poller.rs new file mode 100644 index 000000000..998e97417 --- /dev/null +++ b/lib/propolis/src/vsock/poller.rs @@ -0,0 +1,1736 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::collections::VecDeque; +use std::ffi::c_void; +use std::io::ErrorKind; +use std::io::Read; +use std::os::fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, OwnedFd, RawFd}; +use std::sync::Arc; +use std::thread::JoinHandle; + +use iddqd::IdHashMap; +use nix::poll::PollFlags; +use slog::{debug, error, info, warn, Logger}; + +use crate::hw::virtio::vsock::VsockVq; +use crate::hw::virtio::vsock::VSOCK_RX_QUEUE; +use crate::hw::virtio::vsock::VSOCK_TX_QUEUE; +use crate::vsock::packet::VsockPacket; +use crate::vsock::packet::VsockPacketFlags; +use crate::vsock::packet::VsockSocketType; +use crate::vsock::proxy::ConnKey; +use crate::vsock::proxy::VsockPortMapping; +use crate::vsock::proxy::VsockProxyConn; +use crate::vsock::VSOCK_HOST_CID; + +use super::packet::VsockPacketOp; + +#[repr(usize)] +enum VsockEvent { + TxQueue = 0, + RxQueue, + Shutdown, +} + +pub struct VsockPollerNotify { + port_fd: Arc, +} + +impl VsockPollerNotify { + fn port_fd(&self) -> BorrowedFd<'_> { + self.port_fd.as_fd() + } + + fn port_send(&self, event: VsockEvent) -> std::io::Result<()> { + let ret = unsafe { + libc::port_send(self.port_fd().as_raw_fd(), 0, event as usize as _) + }; + + if ret == 0 { + Ok(()) + } else { + Err(std::io::Error::last_os_error()) + } + } + + pub fn queue_notify(&self, id: u16) -> std::io::Result<()> { + match id { + VSOCK_RX_QUEUE => self.port_send(VsockEvent::RxQueue), + VSOCK_TX_QUEUE => self.port_send(VsockEvent::TxQueue), + _ => Ok(()), + } + } + + pub fn shutdown(&self) -> std::io::Result<()> { + self.port_send(VsockEvent::Shutdown) + } +} + +/// Set of `PollFlags` that signifies a readable event. +const fn is_readable(flags: PollFlags) -> bool { + const READABLE: PollFlags = PollFlags::from_bits_truncate( + PollFlags::POLLIN.bits() + | PollFlags::POLLHUP.bits() + | PollFlags::POLLERR.bits() + | PollFlags::POLLPRI.bits(), + ); + READABLE.intersects(flags) +} + +/// Set of `PollFlags` that signifies a writable event. +const fn is_writable(flags: PollFlags) -> bool { + const WRITABLE: PollFlags = PollFlags::from_bits_truncate( + PollFlags::POLLOUT.bits() + | PollFlags::POLLHUP.bits() + | PollFlags::POLLERR.bits(), + ); + WRITABLE.intersects(flags) +} + +#[derive(Debug)] +enum RxEvent { + /// Vsock RST packet + Reset(ConnKey), + /// Vsock RESPONSE packet + NewConnection(ConnKey), + /// Vsock CREDIT_UPDATE packet + CreditUpdate(ConnKey), +} + +pub struct VsockPoller { + log: Logger, + /// The guest context id + guest_cid: u32, + /// Port mappings we are proxying packets to and from + port_mappings: IdHashMap, + /// The event port fd. + port_fd: Arc, + /// The virtqueues associated with the vsock device + queues: VsockVq, + /// The connection map of guest connected streams + connections: HashMap, + /// Queue of vsock packets that need to be sent to the guest + rx: VecDeque, + /// Connections blocked waiting for rx queue descriptors + rx_blocked: Vec, +} + +impl VsockPoller { + /// Create a new `VsockPoller`. + /// + /// This poller is responsible for driving virtio-socket connections between + /// the guest VM and host sockets. + pub fn new( + cid: u32, + queues: VsockVq, + log: Logger, + port_mappings: IdHashMap, + ) -> std::io::Result { + let port_fd = unsafe { + let fd = match libc::port_create() { + -1 => return Err(std::io::Error::last_os_error()), + fd => fd, + }; + + // Set CLOEXEC on the event port fd + if libc::fcntl( + fd, + libc::F_SETFD, + libc::fcntl(fd, libc::F_GETFD) | libc::FD_CLOEXEC, + ) < 0 + { + return Err(std::io::Error::last_os_error()); + }; + + fd + }; + + info!( + &log, + "vsock poller configured with"; + "mappings" => ?port_mappings, + ); + + Ok(Self { + log, + guest_cid: cid, + port_mappings, + port_fd: Arc::new(unsafe { OwnedFd::from_raw_fd(port_fd) }), + queues, + connections: Default::default(), + rx: Default::default(), + rx_blocked: Default::default(), + }) + } + + /// Get a handle to a `VsockPollerNotify`. + pub fn notify_handle(&self) -> VsockPollerNotify { + VsockPollerNotify { port_fd: Arc::clone(&self.port_fd) } + } + + /// Start the event loop. + pub fn run(mut self) -> JoinHandle<()> { + std::thread::Builder::new() + .name("vsock-event-loop".to_string()) + .spawn(move || self.handle_events()) + .expect("failed to spawn vsock event loop") + } + + /// Handle the guest's VIRTIO_VSOCK_OP_REQUEST packet. + fn handle_connection_request(&mut self, key: ConnKey, packet: VsockPacket) { + if self.connections.contains_key(&key) { + // Connection already exists + self.send_conn_rst(key); + return; + } + + let Some(mapping) = self.port_mappings.get(&packet.header.dst_port()) + else { + // Drop the unknown connection so that it times out in the guest. + debug!( + &self.log, + "dropping connect request to unknown mapping"; + "packet" => ?packet, + ); + return; + }; + + match VsockProxyConn::new(mapping.addr()) { + Ok(mut conn) => { + conn.update_peer_credit(&packet.header); + self.connections.insert(key, conn); + self.rx.push_back(RxEvent::NewConnection(key)); + } + Err(e) => { + self.send_conn_rst(key); + error!(self.log, "{e}"); + } + }; + } + + /// Handle the guest's VIRTIO_VSOCK_OP_SHUTDOWN packet. + fn handle_shutdown(&mut self, key: ConnKey, flags: VsockPacketFlags) { + if let Entry::Occupied(mut entry) = self.connections.entry(key) { + let conn = entry.get_mut(); + + // Guest won't receive more data + if flags.contains(VsockPacketFlags::VIRTIO_VSOCK_SHUTDOWN_F_RECEIVE) + { + if let Err(e) = conn.shutdown_guest_read() { + error!( + &self.log, + "cannot transition vsock connection state: {e}"; + "conn" => ?conn, + ); + entry.remove(); + self.send_conn_rst(key); + return; + }; + } + // Guest won't send more data + if flags.contains(VsockPacketFlags::VIRTIO_VSOCK_SHUTDOWN_F_SEND) { + if let Err(e) = conn.shutdown_guest_write() { + error!( + &self.log, + "cannot transition vsock connection state: {e}"; + "conn" => ?conn, + ); + entry.remove(); + self.send_conn_rst(key); + return; + }; + } + // XXX how do we register this for future cleanup if there is data + // we have not synced locally yet? We need a cleanup loop... + if conn.should_close() { + if !conn.has_buffered_data() { + self.connections.remove(&key); + // virtio spec states: + // + // Clean disconnect is achieved by one or more + // VIRTIO_VSOCK_OP_SHUTDOWN packets that indicate no + // more data will be sent and received, followed by a + // VIRTIO_VSOCK_OP_RST response from the peer. + self.send_conn_rst(key); + } + } + } + } + + /// Handle the guest's VIRTIO_VSOCK_OP_RW packet. + fn handle_rw_packet(&mut self, key: ConnKey, packet: VsockPacket) { + if let Entry::Occupied(mut entry) = self.connections.entry(key) { + let conn = entry.get_mut(); + + // If we have a valid connection attempt to consume the guest's + // packet. + if let Err(e) = conn.recv_packet(packet) { + error!( + &self.log, + "failed to push vsock packet data into the conn vbuf: {e}"; + "conn" => ?conn, + ); + + entry.remove(); + self.send_conn_rst(key); + return; + } + + if let Some(interests) = conn.poll_interests() { + let fd = conn.get_fd(); + self.associate_fd(key, fd, interests); + } + }; + } + + /// Handle the guest's tx virtqueue. + fn handle_tx_queue_event(&mut self) { + loop { + let packet = match self.queues.recv_packet().transpose() { + Ok(Some(packet)) => packet, + // No more packets on the guests tx queue + Ok(None) => break, + Err(e) => { + warn!(&self.log, "dropping invalid vsock packet: {e}"); + continue; + } + }; + + // If the packet is not destined for the host drop it. + if packet.header.dst_cid() != VSOCK_HOST_CID { + debug!( + &self.log, + "droppping vsock packet not destined for the host"; + "packet" => ?packet, + ); + continue; + } + + // If the packet is not coming from our guest drop it. + if packet.header.src_cid() != u64::from(self.guest_cid) { + // Note that we could send a RST here but technically we should + // not know how to address this guest cid as it's not the one + // we assigned to our guest. + debug!( + &self.log, + "droppping vsock packet not arriving from our guest cid"; + "packet" => ?packet, + ); + continue; + } + + let key = ConnKey { + host_port: packet.header.dst_port(), + guest_port: packet.header.src_port(), + }; + + // We only support stream connections + let Some(VsockSocketType::Stream) = packet.header.socket_type() + else { + self.send_conn_rst(key); + warn!(&self.log, + "received invalid vsock packet"; + "packet" => ?packet, + ); + continue; + }; + + let Some(packet_op) = packet.header.op() else { + warn!( + &self.log, + "received vsock packet with unknown op code"; + "packet" => ?packet, + ); + return; + }; + + if let Some(conn) = self.connections.get_mut(&key) { + // Regardless of the vsock operation we need to record the peers + // credit info + conn.update_peer_credit(&packet.header); + match packet_op { + VsockPacketOp::Reset => { + self.connections.remove(&key); + } + VsockPacketOp::Shutdown => { + self.handle_shutdown(key, packet.header.flags()); + } + VsockPacketOp::CreditUpdate => continue, + VsockPacketOp::CreditRequest => { + if self.connections.contains_key(&key) { + self.rx.push_back(RxEvent::CreditUpdate(key)); + } + } + VsockPacketOp::ReadWrite => { + self.handle_rw_packet(key, packet); + } + // We are operating on an existing connection either of + // these should not be received + // + // XXX: send a RST, but what about our orignal connection? + VsockPacketOp::Request | VsockPacketOp::Response => (), + } + } else { + match packet_op { + VsockPacketOp::Request => { + self.handle_connection_request(key, packet) + } + VsockPacketOp::Reset => {} + _ => { + warn!( + &self.log, + "received a vsock packet for an unknown connection \ + that was not a REQUEST or RST"; + "packet" => ?packet, + ); + } + } + } + } + } + + /// Process the rx virtqueue (host -> guest). + fn handle_rx_queue_event(&mut self) { + // Now that more descriptors have become available for sending vsock + // packets attempt to drain pending packets + self.process_pending_rx(); + + // Re-register connections that were blocked waiting for rx queue space. + // It would be nice if we had a hint of how many descriptors became + // available but that's not the case today. + for key in std::mem::take(&mut self.rx_blocked).drain(..) { + if let Some(conn) = self.connections.get(&key) { + if let Some(interests) = conn.poll_interests() { + let fd = conn.get_fd(); + self.associate_fd(key, fd, interests); + } + } + } + } + + // Attempt to send any queued rx packets destined for the guest. + fn process_pending_rx(&mut self) { + while let Some(permit) = self.queues.try_rx_permit() { + let Some(rx_event) = self.rx.pop_front() else { + break; + }; + + match rx_event { + RxEvent::Reset(key) => { + let packet = VsockPacket::new_reset( + self.guest_cid, + key.host_port, + key.guest_port, + ); + permit.write(&packet.header, &packet.data); + } + RxEvent::NewConnection(key) => { + let packet = VsockPacket::new_response( + self.guest_cid, + key.host_port, + key.guest_port, + ); + permit.write(&packet.header, &packet.data); + + if let Entry::Occupied(mut entry) = + self.connections.entry(key) + { + let conn = entry.get_mut(); + if let Err(e) = conn.set_established() { + error!( + &self.log, + "cannot transition vsock connection state: {e}"; + "conn" => ?conn, + ); + entry.remove(); + self.send_conn_rst(key); + continue; + }; + + if let Some(interests) = conn.poll_interests() { + let fd = conn.get_fd(); + self.associate_fd(key, fd, interests); + } + } + } + RxEvent::CreditUpdate(key) => { + if let Some(conn) = self.connections.get_mut(&key) { + let packet = VsockPacket::new_credit_update( + self.guest_cid, + key.host_port, + key.guest_port, + conn.fwd_cnt(), + ); + permit.write(&packet.header, &packet.data); + conn.mark_credit_sent(); + } + } + } + } + } + + /// Handle a user event. Returns `true` if the event loop should shut down. + fn handle_user_event(&mut self, event: PortEvent) -> bool { + match event.user { + val if val == VsockEvent::TxQueue as usize => { + self.handle_tx_queue_event() + } + val if val == VsockEvent::RxQueue as usize => { + self.handle_rx_queue_event() + } + val if val == VsockEvent::Shutdown as usize => return true, + _ => (), + } + false + } + + /// Handle an fd event by flushing data to the underlying socket from the + /// connections [`VsockBuf`], and by reading data from the socket and + /// sending it to the guest as a `VIRTIO_VSOCK_OP_RW` packet. + fn handle_fd_event(&mut self, event: PortEvent, read_buf: &mut [u8]) { + let key = ConnKey::from_portev_user(event.user); + let events = PollFlags::from_bits_retain(event.events as i16); + + if is_writable(events) { + self.handle_writable_fd(key); + } + + if is_readable(events) { + self.handle_readable_fd(key, read_buf); + } + } + + /// When an fd is writable, drain buffered guest data to the host socket. + fn handle_writable_fd(&mut self, key: ConnKey) { + let Some(conn) = self.connections.get_mut(&key) else { + return; + }; + + loop { + match conn.flush() { + Ok(0) => break, + Ok(nbytes) => { + conn.update_fwd_cnt(nbytes as u32); + if conn.needs_credit_update() { + self.rx.push_back(RxEvent::CreditUpdate(key)); + } + } + Err(e) if e.kind() == ErrorKind::WouldBlock => break, + Err(e) => { + eprintln!("error writing to socket: {e}"); + break; + } + } + } + + // We have finished draining our buffered data to the host, so check if + // we should remove ourselves from the active connections. + if conn.should_close() && !conn.has_buffered_data() { + self.connections.remove(&key); + self.send_conn_rst(key); + return; + } + + if let Some(interests) = conn.poll_interests() { + let fd = conn.get_fd(); + self.associate_fd(key, fd, interests); + } + } + + /// When an fd is readable, read from host socket and send to guest. + fn handle_readable_fd(&mut self, key: ConnKey, read_buf: &mut [u8]) { + let VsockPoller { queues, connections, guest_cid, rx_blocked, .. } = + self; + + let Some(conn) = connections.get_mut(&key) else { + return; + }; + + // The guest is no longer expecting any data + if !conn.guest_can_read() { + return; + } + + loop { + let Some(permit) = queues.try_rx_permit() else { + rx_blocked.push(key); + break; + }; + + let credit = conn.peer_credit(); + if credit == 0 { + // TODO: when this happens under sufficient load there's the + // possibility we wake up the event loop repeatedly and we + // should defer associating this fd again until there's enough + // credit. This is similar to the `rx_blocked` queue but + // slightly different. + break; + } + + let max_read = std::cmp::min( + permit.available_data_space(), + std::cmp::min(credit as usize, read_buf.len()), + ); + + match conn.socket.read(&mut read_buf[..max_read]) { + Ok(0) => { + // TODO the guest is supposed to send us a RST to finalize + // the shutdown. We need to put this on a quiesce queue so + // that we don't leave a half open connection laying around + // in our connection map. + let packet = VsockPacket::new_shutdown( + *guest_cid, + key.host_port, + key.guest_port, + VsockPacketFlags::VIRTIO_VSOCK_SHUTDOWN_F_SEND + | VsockPacketFlags::VIRTIO_VSOCK_SHUTDOWN_F_RECEIVE, + conn.fwd_cnt(), + ); + permit.write(&packet.header, &packet.data); + return; + } + Ok(nbytes) => { + let read_u32: u32 = nbytes + .try_into() + .expect("max_read is <=u32::MAX by min() above"); + conn.update_tx_cnt(read_u32); + let VsockPacket { header, data } = VsockPacket::new_rw( + *guest_cid, + key.host_port, + key.guest_port, + conn.fwd_cnt(), + &read_buf[..nbytes], + ); + permit.write(&header, &data); + } + Err(e) if e.kind() == ErrorKind::WouldBlock => break, + Err(e) => { + error!( + &self.log, + "vsock backend socket read faild: {e}"; + "key" => ?key, + "conn" => ?conn, + ); + + connections.remove(&key); + let packet = VsockPacket::new_reset( + *guest_cid, + key.host_port, + key.guest_port, + ); + permit.write(&packet.header, &packet.data); + return; + } + } + } + + if let Some(interests) = conn.poll_interests() { + let fd = conn.get_fd(); + self.associate_fd(key, fd, interests); + } + } + + /// Associate a connections underlying socket fd with our port fd. + fn associate_fd(&mut self, key: ConnKey, fd: RawFd, interests: PollFlags) { + let ret = unsafe { + libc::port_associate( + self.port_fd.as_raw_fd(), + libc::PORT_SOURCE_FD, + fd as usize, + interests.bits() as i32, + key.to_portev_user() as *mut c_void, + ) + }; + + if ret < 0 { + let err = std::io::Error::last_os_error(); + if let Some(conn) = self.connections.remove(&key) { + error!( + &self.log, + "vsock port_assocaite failed: {err}"; + "key" => ?key, + "conn" => ?conn, + ); + self.send_conn_rst(key); + } + } + } + + /// Enqueue a RST packet for the provided [`ConnKey`] + fn send_conn_rst(&mut self, key: ConnKey) { + self.rx.push_back(RxEvent::Reset(key)); + } + + /// This is the vsock event-loop. It's responsible for handling vsock + /// packets to and from the guest. + fn handle_events(&mut self) { + const MAX_EVENTS: u32 = 32; + + let mut events = [const { unsafe { std::mem::zeroed::() } }; + MAX_EVENTS as usize]; + let mut read_buf: Box<[u8]> = vec![0u8; 1024 * 64].into(); + + loop { + let mut nget = 1; + + let ret = unsafe { + libc::port_getn( + self.port_fd.as_raw_fd(), + events.as_mut_ptr(), + MAX_EVENTS, + &mut nget, + // TODO currently we are not supplying a timeout because + // there is no other work to do unless we are woken up. In + // the near future we will likely periodically wake up to + // service the shutdown quiesce queue. + std::ptr::null_mut(), + ) + }; + + if ret < 0 { + let err = std::io::Error::last_os_error(); + // SAFETY: The docs state that `raw_os_error` will always return + // a `Some` variant when obtained via `last_os_error`. + match err.raw_os_error().unwrap() { + // A signal was caught so process the loop again + libc::EINTR => continue, + libc::EBADF | libc::EBADFD => { + // This means our event loop is effectively no + // longer servicable and the vsock device is useless. + error!( + &self.log, + "vsock port fd is no longer valid: {err}" + ); + return; + } + _ => { + error!(&self.log, "vsock port_getn returned: {err}"); + continue; + } + } + } + + assert!( + nget as usize <= events.len(), + "event port returned what we asked it for" + ); + let events = unsafe { + std::slice::from_raw_parts(events.as_ptr(), nget as usize) + }; + for event in events { + let event = PortEvent::from_raw(*event); + + match event.source { + EventSource::User => { + let should_shutdown = self.handle_user_event(event); + if should_shutdown { + return; + } + } + EventSource::Fd => { + self.handle_fd_event(event, &mut read_buf); + } + _ => {} + }; + } + + // Process any pending rx events + self.process_pending_rx(); + } + } +} + +/// The source of a port event. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum EventSource { + /// User event i.e. `port_send(3C)` + User, + /// File descriptor event + Fd, + /// Unknown source for the vsock backend + Unknown(u16), +} + +impl EventSource { + fn from_raw(source: u16) -> Self { + match source as i32 { + libc::PORT_SOURCE_USER => EventSource::User, + libc::PORT_SOURCE_FD => EventSource::Fd, + _ => EventSource::Unknown(source), + } + } +} + +/// A port event retrieved from an event port. +/// +/// This represents an event from one of the various event sources (file +/// descriptors, timers, user events, etc.). +#[derive(Debug, Clone)] +struct PortEvent { + /// The events that occurred (source-specific) + events: i32, + /// The source of the event + source: EventSource, + /// The object associated with the event (interpretation depends on source) + #[allow(dead_code)] + object: usize, + /// User-defined data provided during association + user: usize, +} + +impl PortEvent { + fn from_raw(event: libc::port_event) -> Self { + PortEvent { + events: event.portev_events, + source: EventSource::from_raw(event.portev_source), + object: event.portev_object, + user: event.portev_user as usize, + } + } +} + +#[cfg(test)] +mod tests { + use std::io::{Read, Write}; + use std::net::TcpListener; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + use std::time::Duration; + + use iddqd::IdHashMap; + + use zerocopy::{FromBytes, IntoBytes}; + + use crate::hw::virtio::testutil::{QueueWriter, TestVirtQueues, VqSize}; + use crate::hw::virtio::vsock::{VsockVq, VSOCK_RX_QUEUE, VSOCK_TX_QUEUE}; + use crate::vsock::packet::{ + VsockPacketFlags, VsockPacketHeader, VsockPacketOp, VsockSocketType, + }; + use crate::vsock::proxy::{VsockPortMapping, CONN_TX_BUF_SIZE}; + use crate::vsock::VSOCK_HOST_CID; + + use super::VsockPoller; + + fn test_logger() -> slog::Logger { + use slog::Drain; + let decorator = slog_term::TermDecorator::new().stderr().build(); + let drain = slog_term::FullFormat::new(decorator).build().fuse(); + let drain = slog_async::Async::new(drain).build().fuse(); + slog::Logger::root(drain, slog::o!("component" => "vsock-test")) + } + + const QUEUE_SIZE: u16 = 64; + const PAGE_SIZE: u64 = 0x1000; + + /// Bind a TCP listener on an ephemeral port and return it along with an + /// `IdHashMap` that maps `vsock_port` to the listener's + /// actual address. + fn bind_test_backend( + vsock_port: u32, + ) -> (TcpListener, IdHashMap) { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + let mut backends = IdHashMap::new(); + backends.insert_overwrite(VsockPortMapping::new(vsock_port, addr)); + (listener, backends) + } + + /// Test harness for vsock poller tests using shared testutil infrastructure. + struct VsockTestHarness { + tvqs: TestVirtQueues, + rx_writer: QueueWriter, + tx_writer: QueueWriter, + } + + impl VsockTestHarness { + fn new() -> Self { + let tvqs = TestVirtQueues::new(&[ + VqSize::new(QUEUE_SIZE), // RX + VqSize::new(QUEUE_SIZE), // TX + VqSize::new(1), // Event + ]); + + // RX and TX use separate data regions + let rx_writer = tvqs.writer(VSOCK_RX_QUEUE as usize, 0); + let tx_writer = + tvqs.writer(VSOCK_TX_QUEUE as usize, PAGE_SIZE * 16); + + Self { tvqs, rx_writer, tx_writer } + } + + fn make_vsock_vq(&self) -> VsockVq { + let queues: Vec<_> = + self.tvqs.queues().iter().map(|q| q.clone()).collect(); + let acc = self.tvqs.mem_acc().child(Some("vsock-vq".to_string())); + VsockVq::new(queues, acc) + } + + /// Add a writable descriptor to the RX queue and publish it. + fn add_rx_writable(&mut self, len: u32) -> u16 { + let d = self.rx_writer.add_writable(self.tvqs.mem_acc(), len); + self.rx_writer.publish_avail(self.tvqs.mem_acc(), d); + d + } + + /// Add a readable descriptor to the TX queue. + fn add_tx_readable(&mut self, data: &[u8]) -> u16 { + self.tx_writer.add_readable(self.tvqs.mem_acc(), data) + } + + /// Publish a descriptor on the TX queue. + fn publish_tx(&mut self, head: u16) { + self.tx_writer.publish_avail(self.tvqs.mem_acc(), head); + } + + /// Chain two TX descriptors together. + fn chain_tx(&mut self, from: u16, to: u16) { + self.tx_writer.chain(self.tvqs.mem_acc(), from, to); + } + + /// Reset TX writer cursors for reuse. + fn reset_tx_cursors(&mut self) { + self.tx_writer.reset_cursors(); + } + + /// Reset RX writer cursors for reuse. + fn reset_rx_cursors(&mut self) { + self.rx_writer.reset_cursors(); + } + + /// Read a vsock packet header and data from a used ring entry. + fn read_vsock_packet( + &self, + used_index: u16, + ) -> (VsockPacketHeader, Vec) { + let mem_acc = self.tvqs.mem_acc(); + let elem = self.rx_writer.read_used_elem(mem_acc, used_index); + let desc_id = elem.id as u16; + let total_len = elem.len as usize; + + // Read the entire buffer (header + data) + let buf = + self.rx_writer.read_desc_data(mem_acc, desc_id, total_len); + + // Parse header from the first bytes + let hdr_size = std::mem::size_of::(); + let (hdr, data) = buf.split_at(hdr_size); + let hdr = VsockPacketHeader::read_from_bytes(hdr) + .expect("buffer should contain valid header"); + + (hdr, data.to_vec()) + } + + fn rx_used_idx(&self) -> u16 { + self.rx_writer.used_idx(self.tvqs.mem_acc()) + } + + fn tx_used_idx(&self) -> u16 { + self.tx_writer.used_idx(self.tvqs.mem_acc()) + } + } + + /// Helper: serialize a VsockPacketHeader to bytes. + fn hdr_as_bytes(hdr: &VsockPacketHeader) -> &[u8] { + hdr.as_bytes() + } + + /// Spin until a condition is met, with a timeout. + fn wait_for_condition(mut f: F, timeout_ms: u64) + where + F: FnMut() -> bool, + { + let start = std::time::Instant::now(); + let timeout = Duration::from_millis(timeout_ms); + while !f() { + if start.elapsed() > timeout { + panic!("timed out waiting for condition"); + } + std::thread::sleep(Duration::from_millis(1)); + } + } + + #[test] + fn request_receives_response() { + let vsock_port = 3000; + let guest_port = 1234; + let guest_cid: u32 = 50; + let (_listener, backends) = bind_test_backend(vsock_port); + + let mut harness = VsockTestHarness::new(); + let vq = harness.make_vsock_vq(); + let log = test_logger(); + let poller = VsockPoller::new(guest_cid, vq, log, backends).unwrap(); + + harness.add_rx_writable(256); + + let notify = poller.notify_handle(); + let handle = poller.run(); + + let mut hdr = VsockPacketHeader::new(); + hdr.set_src_cid(guest_cid) + .set_dst_cid(VSOCK_HOST_CID as u32) + .set_src_port(guest_port) + .set_dst_port(vsock_port) + .set_len(0) + .set_socket_type(VsockSocketType::Stream) + .set_op(crate::vsock::packet::VsockPacketOp::Request) + .set_buf_alloc(65536) + .set_fwd_cnt(0); + + let d_tx = harness.add_tx_readable(hdr_as_bytes(&hdr)); + harness.publish_tx(d_tx); + notify.queue_notify(VSOCK_TX_QUEUE).unwrap(); + + wait_for_condition(|| harness.rx_used_idx() >= 1, 5000); + + let (resp_hdr, _) = harness.read_vsock_packet(0); + assert_eq!(resp_hdr.op(), Some(VsockPacketOp::Response)); + assert_eq!(resp_hdr.src_cid(), VSOCK_HOST_CID); + assert_eq!(resp_hdr.dst_cid(), guest_cid as u64); + assert_eq!(resp_hdr.src_port(), vsock_port); + assert_eq!(resp_hdr.dst_port(), guest_port); + assert_eq!(resp_hdr.socket_type(), Some(VsockSocketType::Stream)); + + notify.shutdown().unwrap(); + handle.join().unwrap(); + } + + #[test] + fn rw_with_invalid_socket_type_receives_rst() { + let guest_cid: u32 = 50; + + let mut harness = VsockTestHarness::new(); + let vq = harness.make_vsock_vq(); + let log = test_logger(); + let poller = + VsockPoller::new(guest_cid, vq, log, IdHashMap::new()).unwrap(); + + harness.add_rx_writable(256); + + let notify = poller.notify_handle(); + let handle = poller.run(); + + let mut hdr = VsockPacketHeader::new(); + hdr.set_src_cid(guest_cid) + .set_dst_cid(VSOCK_HOST_CID as u32) + .set_src_port(5555) + .set_dst_port(8080) + .set_len(0) + .set_socket_type(VsockSocketType::InvalidTestValue) + .set_op(VsockPacketOp::ReadWrite) + .set_buf_alloc(65536) + .set_fwd_cnt(0); + + let d_tx = harness.add_tx_readable(hdr_as_bytes(&hdr)); + harness.publish_tx(d_tx); + notify.queue_notify(VSOCK_TX_QUEUE).unwrap(); + + wait_for_condition(|| harness.rx_used_idx() >= 1, 5000); + + let (resp_hdr, _) = harness.read_vsock_packet(0); + assert_eq!(resp_hdr.op(), Some(VsockPacketOp::Reset)); + assert_eq!(resp_hdr.src_cid(), VSOCK_HOST_CID); + assert_eq!(resp_hdr.dst_cid(), guest_cid as u64); + assert_eq!(resp_hdr.src_port(), 8080); + assert_eq!(resp_hdr.dst_port(), 5555); + + notify.shutdown().unwrap(); + handle.join().unwrap(); + } + + #[test] + fn request_then_rw_delivers_data() { + let vsock_port = 3000; + let guest_port = 1234; + let guest_cid: u32 = 50; + let (listener, backends) = bind_test_backend(vsock_port); + listener.set_nonblocking(false).unwrap(); + + let mut harness = VsockTestHarness::new(); + let vq = harness.make_vsock_vq(); + let log = test_logger(); + let poller = VsockPoller::new(guest_cid, vq, log, backends).unwrap(); + + for _ in 0..4 { + harness.add_rx_writable(4096); + } + + let notify = poller.notify_handle(); + let handle = poller.run(); + + // Send REQUEST + let mut req_hdr = VsockPacketHeader::new(); + req_hdr + .set_src_cid(guest_cid) + .set_dst_cid(VSOCK_HOST_CID as u32) + .set_src_port(guest_port) + .set_dst_port(vsock_port) + .set_len(0) + .set_socket_type(VsockSocketType::Stream) + .set_op(VsockPacketOp::Request) + .set_buf_alloc(65536) + .set_fwd_cnt(0); + + let d_tx = harness.add_tx_readable(hdr_as_bytes(&req_hdr)); + harness.publish_tx(d_tx); + notify.queue_notify(VSOCK_TX_QUEUE).unwrap(); + + // Accept TCP connection and wait for RESPONSE + let mut accepted = listener.accept().unwrap().0; + accepted.set_nonblocking(false).unwrap(); + accepted.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + wait_for_condition(|| harness.rx_used_idx() >= 1, 5000); + + // Send RW packet with data payload + let payload = b"hello from guest via vsock!"; + let mut rw_hdr = VsockPacketHeader::new(); + rw_hdr + .set_src_cid(guest_cid) + .set_dst_cid(VSOCK_HOST_CID as u32) + .set_src_port(guest_port) + .set_dst_port(vsock_port) + .set_len(payload.len() as u32) + .set_socket_type(VsockSocketType::Stream) + .set_op(VsockPacketOp::ReadWrite) + .set_buf_alloc(65536) + .set_fwd_cnt(0); + + let d_hdr = harness.add_tx_readable(hdr_as_bytes(&rw_hdr)); + let d_body = harness.add_tx_readable(payload); + harness.chain_tx(d_hdr, d_body); + harness.publish_tx(d_hdr); + notify.queue_notify(VSOCK_TX_QUEUE).unwrap(); + + // Read from accepted TCP stream and verify + let mut buf = vec![0u8; payload.len()]; + accepted.read_exact(&mut buf).unwrap(); + assert_eq!(&buf, payload); + + notify.shutdown().unwrap(); + handle.join().unwrap(); + } + + #[test] + fn credit_update_sent_after_flushing_half_buffer() { + let vsock_port = 4000; + let guest_port = 2000; + let guest_cid: u32 = 50; + let (listener, backends) = bind_test_backend(vsock_port); + listener.set_nonblocking(false).unwrap(); + + let mut harness = VsockTestHarness::new(); + let vq = harness.make_vsock_vq(); + let log = test_logger(); + let poller = VsockPoller::new(guest_cid, vq, log, backends).unwrap(); + + // Provide plenty of RX descriptors for RESPONSE + credit updates + for _ in 0..16 { + harness.add_rx_writable(4096); + } + + let notify = poller.notify_handle(); + let handle = poller.run(); + + // Establish connection + let mut req_hdr = VsockPacketHeader::new(); + req_hdr + .set_src_cid(guest_cid) + .set_dst_cid(VSOCK_HOST_CID as u32) + .set_src_port(guest_port) + .set_dst_port(vsock_port) + .set_len(0) + .set_socket_type(VsockSocketType::Stream) + .set_op(VsockPacketOp::Request) + .set_buf_alloc(65536) + .set_fwd_cnt(0); + + let d_tx = harness.add_tx_readable(hdr_as_bytes(&req_hdr)); + harness.publish_tx(d_tx); + notify.queue_notify(VSOCK_TX_QUEUE).unwrap(); + + let mut accepted = listener.accept().unwrap().0; + accepted.set_nonblocking(false).unwrap(); + accepted.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + wait_for_condition(|| harness.rx_used_idx() >= 1, 5000); + + // Send enough data to exceed half the buffer capacity (64KB). + let chunk_size = 8192; + let num_chunks = (CONN_TX_BUF_SIZE / 2) / chunk_size + 1; + let payload = vec![0xAB_u8; chunk_size]; + let total_sent = num_chunks * chunk_size; + let mut tx_consumed = 1u16; // REQUEST was consumed + + for _ in 0..num_chunks { + // Reuse descriptor slots each iteration + harness.reset_tx_cursors(); + + let mut rw_hdr = VsockPacketHeader::new(); + rw_hdr + .set_src_cid(guest_cid) + .set_dst_cid(VSOCK_HOST_CID as u32) + .set_src_port(guest_port) + .set_dst_port(vsock_port) + .set_len(payload.len() as u32) + .set_socket_type(VsockSocketType::Stream) + .set_op(VsockPacketOp::ReadWrite) + .set_buf_alloc(65536) + .set_fwd_cnt(0); + + let d_hdr = harness.add_tx_readable(hdr_as_bytes(&rw_hdr)); + let d_body = harness.add_tx_readable(&payload); + harness.chain_tx(d_hdr, d_body); + harness.publish_tx(d_hdr); + notify.queue_notify(VSOCK_TX_QUEUE).unwrap(); + + tx_consumed += 1; + wait_for_condition(|| harness.tx_used_idx() >= tx_consumed, 5000); + } + + // Drain the data from the accepted socket to confirm it arrived + let mut buf = vec![0u8; total_sent]; + accepted.read_exact(&mut buf).unwrap(); + assert!(buf.iter().all(|&b| b == 0xAB)); + + // Look for a CREDIT_UPDATE in the RX used entries + let rx_used = harness.rx_used_idx(); + assert!(rx_used >= 2, "expected at least RESPONSE + CREDIT_UPDATE"); + + let mut found_credit_update = false; + for i in 1..rx_used { + let (hdr, _) = harness.read_vsock_packet(i); + if hdr.op() == Some(VsockPacketOp::CreditUpdate) { + assert_eq!(hdr.src_cid(), VSOCK_HOST_CID); + assert_eq!(hdr.dst_cid(), guest_cid as u64); + assert_eq!(hdr.src_port(), vsock_port); + assert_eq!(hdr.dst_port(), guest_port); + assert_eq!(hdr.buf_alloc(), CONN_TX_BUF_SIZE as u32); + found_credit_update = true; + break; + } + } + assert!(found_credit_update, "expected a CREDIT_UPDATE on RX queue"); + + notify.shutdown().unwrap(); + handle.join().unwrap(); + } + + #[test] + fn rst_removes_established_connection() { + let vsock_port = 5000; + let guest_port = 3000; + let guest_cid: u32 = 50; + let (listener, backends) = bind_test_backend(vsock_port); + listener.set_nonblocking(false).unwrap(); + + let mut harness = VsockTestHarness::new(); + let vq = harness.make_vsock_vq(); + let log = test_logger(); + let poller = VsockPoller::new(guest_cid, vq, log, backends).unwrap(); + + for _ in 0..4 { + harness.add_rx_writable(4096); + } + + let notify = poller.notify_handle(); + let handle = poller.run(); + + // Send REQUEST + let mut req_hdr = VsockPacketHeader::new(); + req_hdr + .set_src_cid(guest_cid) + .set_dst_cid(VSOCK_HOST_CID as u32) + .set_src_port(guest_port) + .set_dst_port(vsock_port) + .set_len(0) + .set_socket_type(VsockSocketType::Stream) + .set_op(VsockPacketOp::Request) + .set_buf_alloc(65536) + .set_fwd_cnt(0); + + let d_tx = harness.add_tx_readable(hdr_as_bytes(&req_hdr)); + harness.publish_tx(d_tx); + notify.queue_notify(VSOCK_TX_QUEUE).unwrap(); + + let mut accepted = listener.accept().unwrap().0; + accepted.set_nonblocking(false).unwrap(); + accepted.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + wait_for_condition(|| harness.rx_used_idx() >= 1, 5000); + + // Send RST + let mut rst_hdr = VsockPacketHeader::new(); + rst_hdr + .set_src_cid(guest_cid) + .set_dst_cid(VSOCK_HOST_CID as u32) + .set_src_port(guest_port) + .set_dst_port(vsock_port) + .set_len(0) + .set_socket_type(VsockSocketType::Stream) + .set_op(VsockPacketOp::Reset) + .set_buf_alloc(0) + .set_fwd_cnt(0); + + let d_rst = harness.add_tx_readable(hdr_as_bytes(&rst_hdr)); + harness.publish_tx(d_rst); + notify.queue_notify(VSOCK_TX_QUEUE).unwrap(); + + // Wait for the RST to be consumed + wait_for_condition(|| harness.tx_used_idx() >= 2, 5000); + + // Verify the TCP connection was closed by reading from the + // accepted stream. + let mut buf = [0u8; 1]; + let result = accepted.read(&mut buf); + match result { + Ok(0) => {} + Err(_) => {} + Ok(n) => panic!("expected EOF or error, got {n} bytes"), + } + + notify.shutdown().unwrap(); + handle.join().unwrap(); + } + + #[test] + fn end_to_end_guest_to_host() { + let vsock_port = 7000; + let guest_port = 5000; + let guest_cid: u32 = 50; + let (listener, backends) = bind_test_backend(vsock_port); + listener.set_nonblocking(false).unwrap(); + + let mut harness = VsockTestHarness::new(); + let vq = harness.make_vsock_vq(); + let log = test_logger(); + let poller = VsockPoller::new(guest_cid, vq, log, backends).unwrap(); + + // Pre-populate RX queue with writable descriptors for RESPONSE + data + for _ in 0..8 { + harness.add_rx_writable(4096); + } + + let notify = poller.notify_handle(); + let handle = poller.run(); + + // Write REQUEST packet into TX queue + let mut req_hdr = VsockPacketHeader::new(); + req_hdr + .set_src_cid(guest_cid) + .set_dst_cid(VSOCK_HOST_CID as u32) + .set_src_port(guest_port) + .set_dst_port(vsock_port) + .set_len(0) + .set_socket_type(VsockSocketType::Stream) + .set_op(VsockPacketOp::Request) + .set_buf_alloc(65536) + .set_fwd_cnt(0); + + let d_tx = harness.add_tx_readable(hdr_as_bytes(&req_hdr)); + harness.publish_tx(d_tx); + notify.queue_notify(VSOCK_TX_QUEUE).unwrap(); + + // Accept the TCP connection (blocks until poller connects) + let mut accepted = listener.accept().unwrap().0; + accepted.set_nonblocking(false).unwrap(); + accepted.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + + // Wait for RESPONSE on RX queue + wait_for_condition(|| harness.rx_used_idx() >= 1, 5000); + + // Guest->Host: send RW packet with payload + let payload = b"hello from guest via vsock end-to-end!"; + let mut rw_hdr = VsockPacketHeader::new(); + rw_hdr + .set_src_cid(guest_cid) + .set_dst_cid(VSOCK_HOST_CID as u32) + .set_src_port(guest_port) + .set_dst_port(vsock_port) + .set_len(payload.len() as u32) + .set_socket_type(VsockSocketType::Stream) + .set_op(VsockPacketOp::ReadWrite) + .set_buf_alloc(65536) + .set_fwd_cnt(0); + + let d_hdr = harness.add_tx_readable(hdr_as_bytes(&rw_hdr)); + let d_body = harness.add_tx_readable(payload); + harness.chain_tx(d_hdr, d_body); + harness.publish_tx(d_hdr); + notify.queue_notify(VSOCK_TX_QUEUE).unwrap(); + + // Read from accepted TCP stream, and verify guest->host data + let mut buf = vec![0u8; payload.len()]; + accepted.read_exact(&mut buf).unwrap(); + assert_eq!(&buf, payload, "guest->host data mismatch"); + + // Host->Guest: write data into accepted TCP stream + let host_payload = b"reply from host via vsock!"; + accepted.write_all(host_payload).unwrap(); + accepted.flush().unwrap(); + + // Wait for RW packet on RX queue (RESPONSE was 1, now expect 2+) + wait_for_condition(|| harness.rx_used_idx() >= 2, 5000); + + // Read back the RW packet from RX used ring entry 1 + let (resp_hdr, host_buf) = harness.read_vsock_packet(1); + + assert_eq!(resp_hdr.op(), Some(VsockPacketOp::ReadWrite)); + assert_eq!(resp_hdr.src_port(), vsock_port); + assert_eq!(resp_hdr.dst_port(), guest_port); + assert_eq!(&host_buf, host_payload, "host->guest data mismatch"); + + notify.shutdown().unwrap(); + handle.join().unwrap(); + } + + #[test] + fn rx_blocked_resumes_when_descriptors_available() { + let vsock_port = 6000; + let guest_port = 4000; + let guest_cid: u32 = 50; + let (listener, backends) = bind_test_backend(vsock_port); + listener.set_nonblocking(false).unwrap(); + + let mut harness = VsockTestHarness::new(); + let vq = harness.make_vsock_vq(); + let log = test_logger(); + let poller = VsockPoller::new(guest_cid, vq, log, backends).unwrap(); + + // Provide only one RX descriptor, just enough for the RESPONSE. + harness.add_rx_writable(4096); + + let notify = poller.notify_handle(); + let handle = poller.run(); + + // Send REQUEST + let mut req_hdr = VsockPacketHeader::new(); + req_hdr + .set_src_cid(guest_cid) + .set_dst_cid(VSOCK_HOST_CID as u32) + .set_src_port(guest_port) + .set_dst_port(vsock_port) + .set_len(0) + .set_socket_type(VsockSocketType::Stream) + .set_op(VsockPacketOp::Request) + .set_buf_alloc(65536) + .set_fwd_cnt(0); + + let d_tx = harness.add_tx_readable(hdr_as_bytes(&req_hdr)); + harness.publish_tx(d_tx); + notify.queue_notify(VSOCK_TX_QUEUE).unwrap(); + + let mut accepted = listener.accept().unwrap().0; + accepted.set_nonblocking(false).unwrap(); + wait_for_condition(|| harness.rx_used_idx() >= 1, 5000); + + // The RESPONSE consumed the only RX descriptor. Write data from + // the host side. + let host_data = b"data from the host side"; + accepted.write_all(host_data).unwrap(); + accepted.flush().unwrap(); + + // Give the poller time to attempt delivery (and get blocked) + std::thread::sleep(Duration::from_millis(100)); + + // Verify no new used entries appeared (still just the RESPONSE) + assert_eq!(harness.rx_used_idx(), 1); + + // Add new RX descriptors and notify + harness.reset_rx_cursors(); + harness.add_rx_writable(4096); + notify.queue_notify(VSOCK_RX_QUEUE).unwrap(); + + // Wait for the data to be delivered + wait_for_condition(|| harness.rx_used_idx() >= 2, 5000); + + let (rw_hdr, payload) = harness.read_vsock_packet(1); + assert_eq!(rw_hdr.op(), Some(VsockPacketOp::ReadWrite)); + assert_eq!(rw_hdr.src_port(), vsock_port); + assert_eq!(rw_hdr.dst_port(), guest_port); + assert_eq!(&payload, host_data); + + notify.shutdown().unwrap(); + handle.join().unwrap(); + } + + /// End-to-end test with large data transfers in both directions, + /// exercising rx_blocked, credit updates, and descriptor replenishment + /// across many batches of reused descriptor slots. + #[test] + fn end_to_end_large_data() { + let total_bytes: usize = 10 * 1024 * 1024; + + let vsock_port = 8000; + let guest_port = 6000; + let guest_cid: u32 = 50; + let (listener, backends) = bind_test_backend(vsock_port); + listener.set_nonblocking(false).unwrap(); + + let mut harness = VsockTestHarness::new(); + let vq = harness.make_vsock_vq(); + let log = test_logger(); + let poller = VsockPoller::new(guest_cid, vq, log, backends).unwrap(); + + // Provide initial RX descriptors for RESPONSE + credit updates + for _ in 0..8 { + harness.add_rx_writable(4096); + } + + let notify = poller.notify_handle(); + let handle = poller.run(); + + // Establish connection + // Use a large buf_alloc so host->guest credit doesn't run out + // before we've transferred all the data. + let buf_alloc = total_bytes as u32 * 2; + + let mut req_hdr = VsockPacketHeader::new(); + req_hdr + .set_src_cid(guest_cid) + .set_dst_cid(VSOCK_HOST_CID as u32) + .set_src_port(guest_port) + .set_dst_port(vsock_port) + .set_len(0) + .set_socket_type(VsockSocketType::Stream) + .set_op(VsockPacketOp::Request) + .set_buf_alloc(buf_alloc) + .set_fwd_cnt(0); + + let d_tx = harness.add_tx_readable(hdr_as_bytes(&req_hdr)); + harness.publish_tx(d_tx); + notify.queue_notify(VSOCK_TX_QUEUE).unwrap(); + + let accepted = listener.accept().unwrap().0; + accepted.set_nonblocking(false).unwrap(); + + wait_for_condition(|| harness.rx_used_idx() >= 1, 5000); + + // A reader thread drains the TCP socket while the main thread + // injects RW packets in batches, reusing descriptor slots and + // guest memory between batches. + let guest_data: Vec = + (0..total_bytes).map(|i| (i % 251) as u8).collect(); + + // Track how many bytes the reader has consumed so we can apply + // backpressure and avoid overflowing the poller's VsockBuf. + let bytes_read = Arc::new(AtomicUsize::new(0)); + let tcp_reader = { + let mut stream = accepted.try_clone().unwrap(); + let len = total_bytes; + let progress = Arc::clone(&bytes_read); + std::thread::spawn(move || { + let mut result = Vec::with_capacity(len); + let mut chunk = vec![0u8; 65536]; + let mut total = 0; + while total < len { + let n = stream.read(&mut chunk).unwrap(); + assert!(n > 0, "unexpected EOF after {total}/{len}"); + result.extend_from_slice(&chunk[..n]); + total += n; + progress.store(total, Ordering::Release); + } + result + }) + }; + + let chunk_size = 4096; + let batch_packets = 8; // 8 packets × 2 descs = 16 descs per batch + let mut guest_sent = 0usize; + // TX used_idx starts at 1 (the REQUEST was consumed) + let mut tx_consumed = 1u16; + + while guest_sent < total_bytes { + let remaining = (total_bytes - guest_sent).div_ceil(chunk_size); + let this_batch = std::cmp::min(batch_packets, remaining); + // Backpressure: don't let in-flight data exceed VsockBuf + // capacity. The poller buffers TX data in VsockBuf (128KB) + // and flushes via POLLOUT. If we push faster than the + // flush rate, the buffer overflows and panics. + let after_send = guest_sent + this_batch * chunk_size; + loop { + let read = bytes_read.load(Ordering::Acquire); + if after_send <= read + CONN_TX_BUF_SIZE { + break; + } + std::thread::sleep(Duration::from_millis(1)); + } + + // Reuse the same descriptor slots and data region each batch. + // Safe because we wait for the previous batch to be fully + // consumed before overwriting. + harness.reset_tx_cursors(); + + for i in 0..this_batch { + let offset = guest_sent + i * chunk_size; + let end = std::cmp::min(offset + chunk_size, total_bytes); + let payload = &guest_data[offset..end]; + + let mut rw_hdr = VsockPacketHeader::new(); + rw_hdr + .set_src_cid(guest_cid) + .set_dst_cid(VSOCK_HOST_CID as u32) + .set_src_port(guest_port) + .set_dst_port(vsock_port) + .set_len(payload.len() as u32) + .set_socket_type(VsockSocketType::Stream) + .set_op(VsockPacketOp::ReadWrite) + .set_buf_alloc(buf_alloc) + .set_fwd_cnt(0); + + let d_hdr = harness.add_tx_readable(hdr_as_bytes(&rw_hdr)); + let d_body = harness.add_tx_readable(payload); + harness.chain_tx(d_hdr, d_body); + harness.publish_tx(d_hdr); + } + + notify.queue_notify(VSOCK_TX_QUEUE).unwrap(); + + // Wait for the poller to consume this entire batch before + // we overwrite the descriptor slots in the next iteration. + tx_consumed += this_batch as u16; + wait_for_condition(|| harness.tx_used_idx() >= tx_consumed, 10000); + + guest_sent += this_batch * chunk_size; + if guest_sent > total_bytes { + guest_sent = total_bytes; + } + } + + let received = tcp_reader.join().unwrap(); + assert_eq!(received.len(), total_bytes); + assert!(received == guest_data, "guest->host data mismatch"); + + // A writer thread pushes data into the TCP socket while the + // main thread replenishes RX descriptors in batches, reads + // completed used entries, and reuses descriptor slots once + // the entire batch has been consumed. + let host_data: Vec = + (0..total_bytes).map(|i| ((i + 7) % 251) as u8).collect(); + + let tcp_writer = { + let mut stream = accepted.try_clone().unwrap(); + let data = host_data.clone(); + std::thread::spawn(move || { + stream.write_all(&data).unwrap(); + }) + }; + + let mut host_to_guest = Vec::with_capacity(total_bytes); + + // Skip all used entries produced before this phase (RESPONSE + + // any credit updates from Phase 1). + let mut rx_next_used = harness.rx_used_idx(); + let rx_batch = 16u16; + let mut descs_outstanding = 0u16; + + while host_to_guest.len() < total_bytes { + // When all outstanding descriptors have been consumed we can + // safely reuse the descriptor slots and data region. + if descs_outstanding == 0 { + harness.reset_rx_cursors(); + + for _ in 0..rx_batch { + harness.add_rx_writable(4096); + descs_outstanding += 1; + } + notify.queue_notify(VSOCK_RX_QUEUE).unwrap(); + } + + // Wait for at least one new used entry. + wait_for_condition(|| harness.rx_used_idx() > rx_next_used, 10000); + + // Drain all currently available used entries. + let current_used = harness.rx_used_idx(); + while rx_next_used < current_used { + let (hdr, data) = harness.read_vsock_packet(rx_next_used); + rx_next_used += 1; + descs_outstanding -= 1; + + if hdr.op() == Some(VsockPacketOp::ReadWrite) { + host_to_guest.extend_from_slice(&data); + } + // Credit updates and other control packets are + // silently consumed — they're expected here. + } + } + + tcp_writer.join().unwrap(); + assert_eq!(host_to_guest.len(), total_bytes); + assert!(host_to_guest == host_data, "host->guest data mismatch"); + + notify.shutdown().unwrap(); + handle.join().unwrap(); + } + + /// Closing the host-side TCP socket should cause the poller to send + /// a VIRTIO_VSOCK_OP_SHUTDOWN packet with VIRTIO_VSOCK_SHUTDOWN_F_SEND + /// to the guest, indicating the host will no longer send data. + #[test] + fn host_socket_eof_sends_shutdown() { + let vsock_port = 9000; + let guest_port = 7000; + let guest_cid: u32 = 50; + let (listener, backends) = bind_test_backend(vsock_port); + listener.set_nonblocking(false).unwrap(); + + let mut harness = VsockTestHarness::new(); + let vq = harness.make_vsock_vq(); + let log = test_logger(); + let poller = VsockPoller::new(guest_cid, vq, log, backends).unwrap(); + + // Provide RX descriptors for RESPONSE + SHUTDOWN + for _ in 0..4 { + harness.add_rx_writable(4096); + } + + let notify = poller.notify_handle(); + let handle = poller.run(); + + // Establish connection + let mut req_hdr = VsockPacketHeader::new(); + req_hdr + .set_src_cid(guest_cid) + .set_dst_cid(VSOCK_HOST_CID as u32) + .set_src_port(guest_port) + .set_dst_port(vsock_port) + .set_len(0) + .set_socket_type(VsockSocketType::Stream) + .set_op(VsockPacketOp::Request) + .set_buf_alloc(65536) + .set_fwd_cnt(0); + + let d_tx = harness.add_tx_readable(hdr_as_bytes(&req_hdr)); + harness.publish_tx(d_tx); + notify.queue_notify(VSOCK_TX_QUEUE).unwrap(); + + // Accept the connection, wait for RESPONSE + let accepted = listener.accept().unwrap().0; + wait_for_condition(|| harness.rx_used_idx() >= 1, 5000); + + // Close the host-side socket to produce EOF + drop(accepted); + + // The poller should detect EOF on the next POLLIN and send + // a SHUTDOWN packet to the guest. + wait_for_condition(|| harness.rx_used_idx() >= 2, 5000); + + // Read back the packet from RX used ring entry 1 + let (hdr, _data) = harness.read_vsock_packet(1); + + assert_eq!(hdr.op(), Some(VsockPacketOp::Shutdown)); + assert_eq!(hdr.src_cid(), VSOCK_HOST_CID); + assert_eq!(hdr.dst_cid(), guest_cid as u64); + assert_eq!(hdr.src_port(), vsock_port); + assert_eq!(hdr.dst_port(), guest_port); + assert_eq!( + hdr.flags(), + VsockPacketFlags::VIRTIO_VSOCK_SHUTDOWN_F_SEND + | VsockPacketFlags::VIRTIO_VSOCK_SHUTDOWN_F_RECEIVE + ); + + notify.shutdown().unwrap(); + handle.join().unwrap(); + } +} diff --git a/lib/propolis/src/vsock/poller_stub.rs b/lib/propolis/src/vsock/poller_stub.rs new file mode 100644 index 000000000..71bcdd1ee --- /dev/null +++ b/lib/propolis/src/vsock/poller_stub.rs @@ -0,0 +1,54 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use std::thread::JoinHandle; + +use iddqd::IdHashMap; +use slog::Logger; + +use crate::hw::virtio::vsock::VsockVq; +use crate::vsock::proxy::VsockPortMapping; + +bitflags! { + pub struct PollEvents: i32 { + const IN = libc::POLLIN as i32; + const OUT = libc::POLLOUT as i32; + } +} + +pub struct VsockPollerNotify; + +impl VsockPollerNotify { + pub fn queue_notify(&self, _id: u16) -> std::io::Result<()> { + return Err(std::io::Error::other( + "not available on non-illumos systems", + )); + } +} + +pub struct VsockPoller; + +impl VsockPoller { + pub fn new( + _cid: u32, + _queues: VsockVq, + _log: Logger, + _port_mappings: IdHashMap, + ) -> std::io::Result { + return Err(std::io::Error::other( + "VsockPoller is not available on non-illumos systems", + )); + } + + pub fn notify_handle(&self) -> VsockPollerNotify { + VsockPollerNotify {} + } + + pub fn run(self) -> JoinHandle<()> { + std::thread::Builder::new() + .name("vsock-event-loop".to_string()) + .spawn(move || {}) + .expect("failed to spawn vsock event loop") + } +} diff --git a/lib/propolis/src/vsock/proxy.rs b/lib/propolis/src/vsock/proxy.rs new file mode 100644 index 000000000..e9e87459b --- /dev/null +++ b/lib/propolis/src/vsock/proxy.rs @@ -0,0 +1,362 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use std::net::SocketAddr; +use std::net::TcpStream; +use std::num::NonZeroUsize; +use std::num::Wrapping; +use std::os::fd::AsRawFd; +use std::os::fd::RawFd; +use std::thread::JoinHandle; +use std::time::Duration; + +use iddqd::IdHashItem; +use iddqd::IdHashMap; +use nix::poll::PollFlags; +use serde::Deserialize; +use slog::error; +use slog::Logger; + +use crate::hw::virtio::vsock::VsockVq; +use crate::vsock::buffer::VsockBuf; +use crate::vsock::buffer::VsockBufError; +use crate::vsock::packet::VsockPacket; +use crate::vsock::packet::VsockPacketHeader; +use crate::vsock::poller::VsockPoller; +use crate::vsock::poller::VsockPollerNotify; +use crate::vsock::VsockBackend; +use crate::vsock::VsockError; + +/// Default buffer size for guest->host data. +pub const CONN_TX_BUF_SIZE: usize = 1024 * 128; + +/// Connection lifecycle state for a vsock connection. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ConnState { + // The guest has sent us a VIRTIO_VSOCK_OP_REQUEST + Init, + /// We have sent VIRTIO_VSOCK_OP_RESPONSE - connection can send/recv data + Established, + /// The connection is in the process of closing - read and write halves are + /// tracked seperately + Closing { + read: bool, + write: bool, + }, +} + +#[derive(Debug, Clone, Copy, Eq, Hash, PartialEq)] +pub struct ConnKey { + /// The port the guest is transmitting to. + pub(crate) host_port: u32, + /// The port the guest is transmitting from. + pub(crate) guest_port: u32, +} + +// This impl allows us to convert to and from a portev_user object (see +// port_associate3C). The conversion to and from a usize allows us to encode +// the key in the pointer value itself rather than allocating memory. +// +// NB: This object is defined as a `*mut c_void` and therefore will not be +// 64bits on all platforms, but we currently only support x86_64 hardware, +// therefore we are leaving a static assertion behind as a future hint to +// ourselves. +impl ConnKey { + /// Pack the host + port into a usize + pub fn to_portev_user(self) -> usize { + static_assertions::assert_eq_size!(u64, usize); + ((self.host_port as usize) << 32) | (self.guest_port as usize) + } + + /// Unpack the host + port from a usize + pub fn from_portev_user(val: usize) -> Self { + Self { host_port: (val >> 32) as u32, guest_port: val as u32 } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum ProxyConnError { + #[error("Failed to connect to vsock backend {backend}: {source}")] + Socket { + backend: SocketAddr, + #[source] + source: std::io::Error, + }, + #[error("Failed to put socket into nonblocking mode: {0}")] + NonBlocking(#[source] std::io::Error), + #[error("Cannot transition connection from {from:?} to {to:?}")] + InvalidStateTransition { from: ConnState, to: ConnState }, +} + +/// An established guest<=>host connection. +/// +/// Note that the internal state of the proxy connection uses `Wrapping` +/// because the virtio spec uses the following calculation to determine how much +/// buffer space a guest has: +/// +/// /* tx_cnt is the sender's free-running bytes transmitted counter */ +/// u32 peer_free = peer_buf_alloc - (tx_cnt - peer_fwd_cnt); +/// +/// The lifetime of a connection can exceed u32::MAX bytes transmitted, so we +/// rely on wrapping semantics to determine the difference. +#[derive(Debug)] +pub struct VsockProxyConn { + pub(crate) socket: TcpStream, + /// Current connection state. + state: ConnState, + /// Ring buffer used to receive packets from the guest tx virt queue. + vbuf: VsockBuf, + /// Bytes we've consumed from vbuf (forwarded to socket). + fwd_cnt: Wrapping, + /// The fwd_cnt value we last sent to the guest in a credit update. + last_fwd_cnt_sent: Wrapping, + /// Bytes we've sent to the guest from the socket. + tx_cnt: Wrapping, + /// Guest's buffer allocation. + peer_buf_alloc: u32, + /// Bytes the guest has consumed from their buffer. + peer_fwd_cnt: Wrapping, +} + +impl VsockProxyConn { + /// Create a new `VsockProxyConn` connected to an underlying host socket. + pub fn new(addr: &SocketAddr) -> Result { + let socket = + TcpStream::connect_timeout(addr, Duration::from_millis(100)) + .map_err(|e| ProxyConnError::Socket { + backend: *addr, + source: e, + })?; + socket.set_nonblocking(true).map_err(ProxyConnError::NonBlocking)?; + + Ok(Self { + socket, + state: ConnState::Init, + vbuf: VsockBuf::new(NonZeroUsize::new(CONN_TX_BUF_SIZE).unwrap()), + fwd_cnt: Wrapping(0), + last_fwd_cnt_sent: Wrapping(0), + tx_cnt: Wrapping(0), + peer_buf_alloc: 0, + peer_fwd_cnt: Wrapping(0), + }) + } + + /// Set of `PollEvents` that this connection is interested in. + pub fn poll_interests(&self) -> Option { + let mut interests = PollFlags::empty(); + interests.set(PollFlags::POLLOUT, self.has_buffered_data()); + interests.set(PollFlags::POLLIN, self.guest_can_read()); + + match interests.is_empty() { + true => None, + false => Some(interests), + } + } + + /// Returns `true` if the connection has data pending in its ring buffer + /// that needs to be flushed to the underlying socket. + pub fn has_buffered_data(&self) -> bool { + !self.vbuf.is_empty() + } + + /// Set the connection to established. + pub fn set_established(&mut self) -> Result<(), ProxyConnError> { + match self.state { + ConnState::Init => self.state = ConnState::Established, + current => { + return Err(ProxyConnError::InvalidStateTransition { + from: current, + to: ConnState::Established, + }) + } + } + + Ok(()) + } + + /// Check if the connection can read from the host socket. + pub fn guest_can_read(&self) -> bool { + matches!( + self.state, + ConnState::Established | ConnState::Closing { read: false, .. } + ) + } + + pub fn shutdown_guest_read(&mut self) -> Result<(), ProxyConnError> { + self.state = match self.state { + ConnState::Established => { + ConnState::Closing { read: true, write: false } + } + ConnState::Closing { write, .. } => { + ConnState::Closing { read: true, write: write } + } + current => { + return Err(ProxyConnError::InvalidStateTransition { + from: current, + to: ConnState::Closing { read: true, write: false }, + }) + } + }; + + Ok(()) + } + + pub fn shutdown_guest_write(&mut self) -> Result<(), ProxyConnError> { + self.state = match self.state { + ConnState::Established => { + ConnState::Closing { read: false, write: true } + } + ConnState::Closing { read, .. } => { + ConnState::Closing { read, write: true } + } + current => { + return Err(ProxyConnError::InvalidStateTransition { + from: current, + to: ConnState::Closing { read: true, write: false }, + }) + } + }; + + Ok(()) + } + + /// Check if the connection should be removed. + pub fn should_close(&self) -> bool { + matches!(self.state, ConnState::Closing { read: true, write: true }) + } + + /// Update peer credit info from a packet header. + pub fn update_peer_credit(&mut self, header: &VsockPacketHeader) { + self.peer_buf_alloc = header.buf_alloc(); + self.peer_fwd_cnt = Wrapping(header.fwd_cnt()); + } + + /// Process a packet received from the guest tx queue. + pub fn recv_packet( + &mut self, + packet: VsockPacket, + ) -> Result<(), VsockBufError> { + self.vbuf.push(packet.data) + } + + pub fn flush(&mut self) -> std::io::Result { + self.vbuf.write_to(&mut self.socket) + } + + /// Calculate how much data we can send to the guest based on their credit. + pub fn peer_credit(&self) -> u32 { + let in_flight = (self.tx_cnt - self.peer_fwd_cnt).0; + self.peer_buf_alloc.saturating_sub(in_flight) + } + + /// Update fwd_cnt after consuming data from vbuf. + pub fn update_fwd_cnt(&mut self, bytes: u32) { + self.fwd_cnt += Wrapping(bytes); + } + + /// Update tx_cnt after sending data to guest. + pub fn update_tx_cnt(&mut self, bytes: u32) { + self.tx_cnt += Wrapping(bytes); + } + + /// Get our current fwd_cnt to report to the guest. + pub fn fwd_cnt(&self) -> u32 { + self.fwd_cnt.0 + } + + /// Get our buffer allocation to report to the guest. + pub fn buf_alloc(&self) -> u32 { + self.vbuf.capacity() as u32 + } + + /// Check if we should send a credit update to the guest. + /// + /// Returns true if we've consumed more than half of our buffer capacity + /// since the last credit update was sent. + pub fn needs_credit_update(&self) -> bool { + let bytes_consumed_since_update = + (self.fwd_cnt - self.last_fwd_cnt_sent).0; + bytes_consumed_since_update > (self.vbuf.capacity() / 2) as u32 + } + + /// Mark that we've sent a credit update with the current fwd_cnt. + pub fn mark_credit_sent(&mut self) { + self.last_fwd_cnt_sent = self.fwd_cnt; + } + + pub fn get_fd(&self) -> RawFd { + self.socket.as_raw_fd() + } +} + +#[derive(Deserialize, Debug, Clone, Copy)] +pub struct VsockPortMapping { + port: u32, + // TODO this could be extended to support Unix sockets as well. + addr: SocketAddr, +} + +impl VsockPortMapping { + pub fn new(port: u32, addr: SocketAddr) -> Self { + Self { port, addr } + } + + pub fn addr(&self) -> &SocketAddr { + &self.addr + } +} + +impl IdHashItem for VsockPortMapping { + type Key<'a> = u32; + + fn key(&self) -> Self::Key<'_> { + self.port + } + + iddqd::id_upcast!(); +} + +/// virtio-socket backend that proxies between a guest and a host UDS. +pub struct VsockProxy { + log: Logger, + poller: VsockPollerNotify, + _evloop_handle: JoinHandle<()>, +} + +impl VsockProxy { + pub fn new( + cid: u32, + queues: VsockVq, + log: Logger, + port_mappings: IdHashMap, + ) -> Self { + let evloop = + VsockPoller::new(cid, queues, log.clone(), port_mappings).unwrap(); + let poller = evloop.notify_handle(); + let jh = evloop.run(); + + Self { log, poller, _evloop_handle: jh } + } + + /// Notification from the vsock device that one of the queues has had an + /// event. + fn queue_notify(&self, vq_id: u16) -> std::io::Result<()> { + self.poller.queue_notify(vq_id) + } +} + +impl VsockBackend for VsockProxy { + fn queue_notify(&self, queue_id: u16) -> Result<(), VsockError> { + self.queue_notify(queue_id) + // Log the raw error in additon to returning the top level + // `VsockError` + .inspect_err(|_e| { + error!(&self.log, + "failed to send virtqueue notification"; + "queue" => %queue_id, + ) + }) + .map_err(|_| VsockError::QueueNotify { queue: queue_id }) + } +}