diff --git a/Cargo.lock b/Cargo.lock index 61c277226dd..0c7b686cdf5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10067,7 +10067,9 @@ dependencies = [ "fastlanes", "mimalloc", "parquet 58.0.0", + "paste", "rand 0.10.0", + "rand_distr 0.6.0", "serde_json", "tokio", "tracing", @@ -10258,6 +10260,7 @@ dependencies = [ "vortex-runend", "vortex-sequence", "vortex-sparse", + "vortex-tensor", "vortex-utils", "vortex-zigzag", "vortex-zstd", @@ -10960,14 +10963,20 @@ dependencies = [ name = "vortex-tensor" version = "0.1.0" dependencies = [ + "half", "itertools 0.14.0", "num-traits", "prost 0.14.3", + "rand 0.10.0", + "rand_distr 0.6.0", "rstest", "vortex-array", "vortex-buffer", + "vortex-compressor", "vortex-error", + "vortex-fastlanes", "vortex-session", + "vortex-utils", ] [[package]] diff --git a/_typos.toml b/_typos.toml index e9cf23d68b7..62c3b0d6358 100644 --- a/_typos.toml +++ b/_typos.toml @@ -1,5 +1,5 @@ [default] -extend-ignore-identifiers-re = ["ffor", "FFOR", "FoR", "typ", "ratatui"] +extend-ignore-identifiers-re = ["ffor", "FFOR", "FoR", "typ", "ratatui", "wht", "WHT"] # We support a few common special comments to tell the checker to ignore sections of code extend-ignore-re = [ "(#|//)\\s*spellchecker:ignore-next-line\\n.*", # Ignore the next line diff --git a/vortex-btrblocks/Cargo.toml b/vortex-btrblocks/Cargo.toml index 9bbd2430f09..8906fd24d2e 100644 --- a/vortex-btrblocks/Cargo.toml +++ b/vortex-btrblocks/Cargo.toml @@ -35,6 +35,7 @@ vortex-pco = { workspace = true, optional = true } vortex-runend = { workspace = true } vortex-sequence = { workspace = true } vortex-sparse = { workspace = true } +vortex-tensor = { workspace = true, optional = true } vortex-utils = { workspace = true } vortex-zigzag = { workspace = true } vortex-zstd = { workspace = true, optional = true } @@ -47,7 +48,7 @@ vortex-array = { workspace = true, features = ["_test-harness"] } [features] # This feature enabled unstable encodings for which we don't guarantee stability. -unstable_encodings = ["vortex-zstd?/unstable_encodings"] +unstable_encodings = ["dep:vortex-tensor", "vortex-zstd?/unstable_encodings"] pco = ["dep:pco", "dep:vortex-pco"] zstd = ["dep:vortex-zstd"] diff --git a/vortex-btrblocks/src/builder.rs b/vortex-btrblocks/src/builder.rs index 3ff8e872b19..ee1707c8961 100644 --- a/vortex-btrblocks/src/builder.rs +++ b/vortex-btrblocks/src/builder.rs @@ -138,6 +138,22 @@ impl BtrBlocksCompressorBuilder { builder } + /// Adds the TurboQuant lossy vector quantization scheme. + /// + /// When enabled, [`Vector`] extension arrays are compressed using the TurboQuant algorithm + /// with MSE-optimal scalar quantization. + /// + /// # Panics + /// + /// Panics if the TurboQuant scheme is already present. + /// + /// [`Vector`]: vortex_tensor::vector::Vector + #[cfg(feature = "unstable_encodings")] + pub fn with_turboquant(self) -> Self { + use vortex_tensor::encodings::turboquant::TurboQuantScheme; + self.with_new_scheme(&TurboQuantScheme) + } + /// Excludes schemes without CUDA kernel support and adds Zstd for string compression. /// /// With the `unstable_encodings` feature, buffer-level Zstd compression is used which diff --git a/vortex-file/src/strategy.rs b/vortex-file/src/strategy.rs index 197efd9583f..9af6c1e9402 100644 --- a/vortex-file/src/strategy.rs +++ b/vortex-file/src/strategy.rs @@ -56,6 +56,8 @@ use vortex_pco::Pco; use vortex_runend::RunEnd; use vortex_sequence::Sequence; use vortex_sparse::Sparse; +#[cfg(feature = "unstable_encodings")] +use vortex_tensor::encodings::turboquant::TurboQuant; use vortex_utils::aliases::hash_map::HashMap; use vortex_zigzag::ZigZag; #[cfg(feature = "zstd")] @@ -104,6 +106,8 @@ pub static ALLOWED_ENCODINGS: LazyLock = LazyLock::new(|| { session.register(RunEnd); session.register(Sequence); session.register(Sparse); + #[cfg(feature = "unstable_encodings")] + session.register(TurboQuant); session.register(ZigZag); #[cfg(feature = "zstd")] diff --git a/vortex-tensor/Cargo.toml b/vortex-tensor/Cargo.toml index f0b6670cc51..9f94a0c2d3d 100644 --- a/vortex-tensor/Cargo.toml +++ b/vortex-tensor/Cargo.toml @@ -19,13 +19,18 @@ workspace = true [dependencies] vortex-array = { workspace = true } vortex-buffer = { workspace = true } +vortex-compressor = { workspace = true } vortex-error = { workspace = true } +vortex-fastlanes = { workspace = true } vortex-session = { workspace = true } +vortex-utils = { workspace = true } +half = { workspace = true } itertools = { workspace = true } num-traits = { workspace = true } prost = { workspace = true } +rand = { workspace = true } [dev-dependencies] +rand_distr = { workspace = true } rstest = { workspace = true } -vortex-buffer = { workspace = true } diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index cea02b69e38..5ab6d5c7e0d 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -2,6 +2,194 @@ pub mod vortex_tensor pub mod vortex_tensor::encodings +pub mod vortex_tensor::encodings::turboquant + +pub struct vortex_tensor::encodings::turboquant::TurboQuant + +impl vortex_tensor::encodings::turboquant::TurboQuant + +pub const vortex_tensor::encodings::turboquant::TurboQuant::ID: vortex_array::array::ArrayId + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::validate_dtype(dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<&vortex_array::dtype::extension::erased::ExtDTypeRef> + +impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuant + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::clone(&self) -> vortex_tensor::encodings::turboquant::TurboQuant + +impl core::fmt::Debug for vortex_tensor::encodings::turboquant::TurboQuant + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl vortex_array::array::vtable::VTable for vortex_tensor::encodings::turboquant::TurboQuant + +pub type vortex_tensor::encodings::turboquant::TurboQuant::ArrayData = vortex_tensor::encodings::turboquant::TurboQuantData + +pub type vortex_tensor::encodings::turboquant::TurboQuant::Metadata = vortex_tensor::encodings::turboquant::TurboQuantMetadata + +pub type vortex_tensor::encodings::turboquant::TurboQuant::OperationsVTable = vortex_tensor::encodings::turboquant::TurboQuant + +pub type vortex_tensor::encodings::turboquant::TurboQuant::ValidityVTable = vortex_array::array::vtable::validity::ValidityVTableFromChild + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::array_eq(array: &vortex_tensor::encodings::turboquant::TurboQuantData, other: &vortex_tensor::encodings::turboquant::TurboQuantData, precision: vortex_array::hash::Precision) -> bool + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::array_hash(array: &vortex_tensor::encodings::turboquant::TurboQuantData, state: &mut H, precision: vortex_array::hash::Precision) + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::buffer(_array: vortex_array::array::view::ArrayView<'_, Self>, idx: usize) -> vortex_array::buffer::BufferHandle + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::buffer_name(_array: vortex_array::array::view::ArrayView<'_, Self>, _idx: usize) -> core::option::Option + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::build(dtype: &vortex_array::dtype::DType, len: usize, metadata: &Self::Metadata, _buffers: &[vortex_array::buffer::BufferHandle], children: &dyn vortex_array::serde::ArrayChildren) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::deserialize(bytes: &[u8], _dtype: &vortex_array::dtype::DType, _len: usize, _buffers: &[vortex_array::buffer::BufferHandle], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::dtype(array: &vortex_tensor::encodings::turboquant::TurboQuantData) -> &vortex_array::dtype::DType + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::execute(array: vortex_array::array::typed::Array, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::execute_parent(array: vortex_array::array::view::ArrayView<'_, Self>, parent: &vortex_array::array::erased::ArrayRef, child_idx: usize, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::id(&self) -> vortex_array::array::ArrayId + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::len(array: &vortex_tensor::encodings::turboquant::TurboQuantData) -> usize + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::metadata(array: vortex_array::array::view::ArrayView<'_, Self>) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::nbuffers(_array: vortex_array::array::view::ArrayView<'_, Self>) -> usize + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::reduce_parent(array: vortex_array::array::view::ArrayView<'_, Self>, parent: &vortex_array::array::erased::ArrayRef, child_idx: usize) -> vortex_error::VortexResult> + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::serialize(metadata: Self::Metadata) -> vortex_error::VortexResult>> + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::slot_name(_array: vortex_array::array::view::ArrayView<'_, Self>, idx: usize) -> alloc::string::String + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::slots(array: vortex_array::array::view::ArrayView<'_, Self>) -> &[core::option::Option] + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::stats(array: &vortex_tensor::encodings::turboquant::TurboQuantData) -> &vortex_array::stats::array::ArrayStats + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::vtable(_array: &Self::ArrayData) -> &Self + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::with_slots(array: &mut vortex_tensor::encodings::turboquant::TurboQuantData, slots: alloc::vec::Vec>) -> vortex_error::VortexResult<()> + +impl vortex_array::array::vtable::operations::OperationsVTable for vortex_tensor::encodings::turboquant::TurboQuant + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::scalar_at(array: vortex_array::array::view::ArrayView<'_, vortex_tensor::encodings::turboquant::TurboQuant>, index: usize, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + +impl vortex_array::array::vtable::validity::ValidityChild for vortex_tensor::encodings::turboquant::TurboQuant + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::validity_child(array: &vortex_tensor::encodings::turboquant::TurboQuantData) -> &vortex_array::array::erased::ArrayRef + +impl vortex_array::arrays::dict::take::TakeExecute for vortex_tensor::encodings::turboquant::TurboQuant + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::take(array: vortex_array::array::view::ArrayView<'_, vortex_tensor::encodings::turboquant::TurboQuant>, indices: &vortex_array::array::erased::ArrayRef, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> + +impl vortex_array::arrays::slice::SliceReduce for vortex_tensor::encodings::turboquant::TurboQuant + +pub fn vortex_tensor::encodings::turboquant::TurboQuant::slice(array: vortex_array::array::view::ArrayView<'_, vortex_tensor::encodings::turboquant::TurboQuant>, range: core::ops::range::Range) -> vortex_error::VortexResult> + +pub struct vortex_tensor::encodings::turboquant::TurboQuantConfig + +pub vortex_tensor::encodings::turboquant::TurboQuantConfig::bit_width: u8 + +pub vortex_tensor::encodings::turboquant::TurboQuantConfig::seed: core::option::Option + +impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuantConfig + +pub fn vortex_tensor::encodings::turboquant::TurboQuantConfig::clone(&self) -> vortex_tensor::encodings::turboquant::TurboQuantConfig + +impl core::default::Default for vortex_tensor::encodings::turboquant::TurboQuantConfig + +pub fn vortex_tensor::encodings::turboquant::TurboQuantConfig::default() -> Self + +impl core::fmt::Debug for vortex_tensor::encodings::turboquant::TurboQuantConfig + +pub fn vortex_tensor::encodings::turboquant::TurboQuantConfig::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +pub struct vortex_tensor::encodings::turboquant::TurboQuantData + +impl vortex_tensor::encodings::turboquant::TurboQuantData + +pub fn vortex_tensor::encodings::turboquant::TurboQuantData::bit_width(&self) -> u8 + +pub fn vortex_tensor::encodings::turboquant::TurboQuantData::centroids(&self) -> &vortex_array::array::erased::ArrayRef + +pub fn vortex_tensor::encodings::turboquant::TurboQuantData::codes(&self) -> &vortex_array::array::erased::ArrayRef + +pub fn vortex_tensor::encodings::turboquant::TurboQuantData::dimension(&self) -> u32 + +pub unsafe fn vortex_tensor::encodings::turboquant::TurboQuantData::new_unchecked(dtype: vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef) -> Self + +pub fn vortex_tensor::encodings::turboquant::TurboQuantData::norms(&self) -> &vortex_array::array::erased::ArrayRef + +pub fn vortex_tensor::encodings::turboquant::TurboQuantData::padded_dim(&self) -> u32 + +pub fn vortex_tensor::encodings::turboquant::TurboQuantData::rotation_signs(&self) -> &vortex_array::array::erased::ArrayRef + +pub fn vortex_tensor::encodings::turboquant::TurboQuantData::try_new(dtype: vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::turboquant::TurboQuantData::validate(dtype: &vortex_array::dtype::DType, codes: &vortex_array::array::erased::ArrayRef, norms: &vortex_array::array::erased::ArrayRef, centroids: &vortex_array::array::erased::ArrayRef, rotation_signs: &vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult<()> + +impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuantData + +pub fn vortex_tensor::encodings::turboquant::TurboQuantData::clone(&self) -> vortex_tensor::encodings::turboquant::TurboQuantData + +impl core::convert::From for vortex_array::array::erased::ArrayRef + +pub fn vortex_array::array::erased::ArrayRef::from(value: vortex_tensor::encodings::turboquant::TurboQuantData) -> vortex_array::array::erased::ArrayRef + +impl core::fmt::Debug for vortex_tensor::encodings::turboquant::TurboQuantData + +pub fn vortex_tensor::encodings::turboquant::TurboQuantData::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl vortex_array::array::IntoArray for vortex_tensor::encodings::turboquant::TurboQuantData + +pub fn vortex_tensor::encodings::turboquant::TurboQuantData::into_array(self) -> vortex_array::array::erased::ArrayRef + +pub struct vortex_tensor::encodings::turboquant::TurboQuantMetadata + +pub vortex_tensor::encodings::turboquant::TurboQuantMetadata::bit_width: u8 + +impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuantMetadata + +pub fn vortex_tensor::encodings::turboquant::TurboQuantMetadata::clone(&self) -> vortex_tensor::encodings::turboquant::TurboQuantMetadata + +impl core::fmt::Debug for vortex_tensor::encodings::turboquant::TurboQuantMetadata + +pub fn vortex_tensor::encodings::turboquant::TurboQuantMetadata::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +pub struct vortex_tensor::encodings::turboquant::TurboQuantScheme + +impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuantScheme + +pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::clone(&self) -> vortex_tensor::encodings::turboquant::TurboQuantScheme + +impl core::cmp::Eq for vortex_tensor::encodings::turboquant::TurboQuantScheme + +impl core::cmp::PartialEq for vortex_tensor::encodings::turboquant::TurboQuantScheme + +pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::eq(&self, other: &vortex_tensor::encodings::turboquant::TurboQuantScheme) -> bool + +impl core::fmt::Debug for vortex_tensor::encodings::turboquant::TurboQuantScheme + +pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::marker::Copy for vortex_tensor::encodings::turboquant::TurboQuantScheme + +impl core::marker::StructuralPartialEq for vortex_tensor::encodings::turboquant::TurboQuantScheme + +impl vortex_compressor::scheme::Scheme for vortex_tensor::encodings::turboquant::TurboQuantScheme + +pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::compress(&self, compressor: &vortex_compressor::compressor::CascadingCompressor, data: &mut vortex_compressor::stats::cache::ArrayAndStats, _ctx: vortex_compressor::ctx::CompressorContext) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::expected_compression_ratio(&self, _compressor: &vortex_compressor::compressor::CascadingCompressor, data: &mut vortex_compressor::stats::cache::ArrayAndStats, _ctx: vortex_compressor::ctx::CompressorContext) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::matches(&self, canonical: &vortex_array::canonical::Canonical) -> bool + +pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::scheme_name(&self) -> &'static str + +pub fn vortex_tensor::encodings::turboquant::initialize(session: &mut vortex_session::VortexSession) + +pub fn vortex_tensor::encodings::turboquant::turboquant_encode(ext: &vortex_array::arrays::extension::vtable::ExtensionArray, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + pub mod vortex_tensor::fixed_shape pub struct vortex_tensor::fixed_shape::FixedShapeTensor @@ -180,7 +368,7 @@ pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::arity(&self, _opt pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName -pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result @@ -236,6 +424,12 @@ pub vortex_tensor::scalar_fns::ApproxOptions::Approximate pub vortex_tensor::scalar_fns::ApproxOptions::Exact +impl vortex_tensor::scalar_fns::ApproxOptions + +pub fn vortex_tensor::scalar_fns::ApproxOptions::is_approx(&self) -> bool + +pub fn vortex_tensor::scalar_fns::ApproxOptions::is_exact(&self) -> bool + impl core::clone::Clone for vortex_tensor::scalar_fns::ApproxOptions pub fn vortex_tensor::scalar_fns::ApproxOptions::clone(&self) -> vortex_tensor::scalar_fns::ApproxOptions diff --git a/vortex-tensor/src/encodings/mod.rs b/vortex-tensor/src/encodings/mod.rs index 090151e9226..7c75269b632 100644 --- a/vortex-tensor/src/encodings/mod.rs +++ b/vortex-tensor/src/encodings/mod.rs @@ -7,5 +7,4 @@ // pub mod norm; // Unit-normalized vectors. // pub mod spherical; // Spherical transform on unit-normalized vectors. -// TODO(will): -// pub mod turboquant; +pub mod turboquant; diff --git a/vortex-tensor/src/encodings/turboquant/array/centroids.rs b/vortex-tensor/src/encodings/turboquant/array/centroids.rs new file mode 100644 index 00000000000..85ea39fcc9e --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/array/centroids.rs @@ -0,0 +1,311 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Max-Lloyd centroid computation for TurboQuant scalar quantizers. +//! +//! Pre-computes optimal scalar quantizer centroids for the marginal distribution of coordinates +//! after random rotation of a unit-norm vector. In high dimensions, each coordinate of a randomly +//! rotated unit vector follows a distribution proportional to `(1 - x^2)^((d-3)/2)` on `[-1, 1]`, +//! which converges to `N(0, 1/d)`. The Max-Lloyd algorithm finds optimal quantization centroids +//! that minimize MSE for this distribution. + +use std::sync::LazyLock; + +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_utils::aliases::dash_map::DashMap; + +/// Number of numerical integration points for computing conditional expectations. +const INTEGRATION_POINTS: usize = 1000; + +/// Max-Lloyd convergence threshold. +const CONVERGENCE_EPSILON: f64 = 1e-12; + +/// Maximum iterations for Max-Lloyd algorithm. +const MAX_ITERATIONS: usize = 200; + +/// Global centroid cache keyed by (dimension, bit_width). +static CENTROID_CACHE: LazyLock>> = LazyLock::new(DashMap::default); + +/// Get or compute cached centroids for the given dimension and bit width. +/// +/// Returns `2^bit_width` centroids sorted in ascending order, representing +/// optimal scalar quantization levels for the coordinate distribution after +/// random rotation in `dimension`-dimensional space. +pub fn get_centroids(dimension: u32, bit_width: u8) -> VortexResult> { + if !(1..=8).contains(&bit_width) { + vortex_bail!("TurboQuant bit_width must be 1-8, got {bit_width}"); + } + if dimension < 3 { + vortex_bail!("TurboQuant dimension must be >= 3, got {dimension}"); + } + + if let Some(centroids) = CENTROID_CACHE.get(&(dimension, bit_width)) { + return Ok(centroids.clone()); + } + + let centroids = max_lloyd_centroids(dimension, bit_width); + CENTROID_CACHE.insert((dimension, bit_width), centroids.clone()); + Ok(centroids) +} + +/// Half-integer exponent: represents `int_part + (if has_half { 0.5 } else { 0.0 })`. +/// +/// The marginal distribution exponent `(d-3)/2` is always an integer (when `d` is odd) +/// or a half-integer (when `d` is even). This type makes that invariant explicit and +/// avoids floating-point comparison in the hot path. +#[derive(Clone, Copy, Debug)] +struct HalfIntExponent { + int_part: i32, + has_half: bool, +} + +impl HalfIntExponent { + /// Compute `(numerator) / 2` as a half-integer exponent. + /// + /// `numerator` is `d - 3` where `d` is the dimension (>= 2), so it can be negative. + fn from_numerator(numerator: i32) -> Self { + // Integer division truncates toward zero; for negative odd numerators + // (e.g., d=2 → num=-1) this gives int_part=0, has_half=true, + // representing -0.5 = 0 + (-0.5). The sign is handled by adjusting + // int_part: -1/2 = 0 with has_half, but we need the floor division. + // Rust's `/` truncates toward zero, so -1/2 = 0. We want floor: -1. + // Use divmod that rounds toward negative infinity. + let int_part = numerator.div_euclid(2); + let has_half = numerator.rem_euclid(2) != 0; + Self { int_part, has_half } + } +} + +/// Compute optimal centroids via the Max-Lloyd (Lloyd-Max) algorithm. +/// +/// Operates on the marginal distribution of a single coordinate of a randomly +/// rotated unit vector in d dimensions. The PDF is: +/// `f(x) = C_d * (1 - x^2)^((d-3)/2)` on `[-1, 1]` +/// where `C_d` is the normalizing constant. +#[allow(clippy::cast_possible_truncation)] // f64→f32 centroid values are intentional +fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Vec { + let num_centroids = 1usize << bit_width; + + // For the marginal distribution on [-1, 1], we use the exponent (d-3)/2. + let exponent = HalfIntExponent::from_numerator(dimension as i32 - 3); + + // Initialize centroids uniformly on [-1, 1]. + let mut centroids: Vec = (0..num_centroids) + .map(|idx| -1.0 + (2.0 * (idx as f64) + 1.0) / (num_centroids as f64)) + .collect(); + + let mut boundaries: Vec = vec![0.0; num_centroids + 1]; + for _ in 0..MAX_ITERATIONS { + // Compute decision boundaries (midpoints between adjacent centroids). + boundaries[0] = -1.0; + for idx in 0..num_centroids - 1 { + boundaries[idx + 1] = (centroids[idx] + centroids[idx + 1]) / 2.0; + } + boundaries[num_centroids] = 1.0; + + // Update each centroid to the conditional mean within its Voronoi cell. + let mut max_change = 0.0f64; + for idx in 0..num_centroids { + let lo = boundaries[idx]; + let hi = boundaries[idx + 1]; + let new_centroid = conditional_mean(lo, hi, exponent); + max_change = max_change.max((new_centroid - centroids[idx]).abs()); + centroids[idx] = new_centroid; + } + + if max_change < CONVERGENCE_EPSILON { + break; + } + } + + centroids.into_iter().map(|val| val as f32).collect() +} + +/// Compute the conditional mean of the coordinate distribution on interval [lo, hi]. +/// +/// Returns `E[X | lo <= X <= hi]` where X has PDF proportional to `(1 - x^2)^exponent` +/// on [-1, 1]. +fn conditional_mean(lo: f64, hi: f64, exponent: HalfIntExponent) -> f64 { + if (hi - lo).abs() < 1e-15 { + return (lo + hi) / 2.0; + } + + let dx = (hi - lo) / INTEGRATION_POINTS as f64; + + let mut numerator = 0.0; + let mut denominator = 0.0; + + for step in 0..=INTEGRATION_POINTS { + let x_val = lo + (step as f64) * dx; + let weight = pdf_unnormalized(x_val, exponent); + + let trap_weight = if step == 0 || step == INTEGRATION_POINTS { + 0.5 + } else { + 1.0 + }; + + numerator += trap_weight * x_val * weight; + denominator += trap_weight * weight; + } + + if denominator.abs() < 1e-30 { + (lo + hi) / 2.0 + } else { + numerator / denominator + } +} + +/// Unnormalized PDF of the coordinate distribution: `(1 - x^2)^exponent`. +/// +/// Uses `powi` + `sqrt` instead of `powf` for the half-integer exponents +/// that arise from `(d-3)/2`. This is significantly faster than the general +/// `powf` which goes through `exp(exponent * ln(base))`. +#[inline] +fn pdf_unnormalized(x_val: f64, exponent: HalfIntExponent) -> f64 { + let base = (1.0 - x_val * x_val).max(0.0); + + if exponent.has_half { + // Half-integer exponent: base^(int_part) * sqrt(base). + base.powi(exponent.int_part) * base.sqrt() + } else { + // Integer exponent: use powi directly. + base.powi(exponent.int_part) + } +} + +/// Precompute decision boundaries (midpoints between adjacent centroids). +/// +/// For `k` centroids, returns `k-1` boundaries. A value below `boundaries[0]` maps +/// to centroid 0, a value in `[boundaries[i-1], boundaries[i])` maps to centroid `i`, +/// and a value >= `boundaries[k-2]` maps to centroid `k-1`. +pub fn compute_boundaries(centroids: &[f32]) -> Vec { + centroids.windows(2).map(|w| (w[0] + w[1]) * 0.5).collect() +} + +/// Find the index of the nearest centroid using precomputed decision boundaries. +/// +/// `boundaries` must be the output of [`compute_boundaries`] for the corresponding +/// centroids. Uses binary search on the midpoints, avoiding distance comparisons +/// in the inner loop. +#[inline] +#[allow(clippy::cast_possible_truncation)] // bounded by num_centroids <= 256 +pub fn find_nearest_centroid(value: f32, boundaries: &[f32]) -> u8 { + debug_assert!( + boundaries.windows(2).all(|w| w[0] <= w[1]), + "boundaries must be sorted" + ); + + boundaries.partition_point(|&b| b < value) as u8 +} + +#[cfg(test)] +#[allow(clippy::cast_possible_truncation)] +mod tests { + use rstest::rstest; + use vortex_error::VortexResult; + + use super::*; + + #[rstest] + #[case(128, 1, 2)] + #[case(128, 2, 4)] + #[case(128, 3, 8)] + #[case(128, 4, 16)] + #[case(768, 2, 4)] + #[case(1536, 3, 8)] + fn centroids_have_correct_count( + #[case] dim: u32, + #[case] bits: u8, + #[case] expected: usize, + ) -> VortexResult<()> { + let centroids = get_centroids(dim, bits)?; + assert_eq!(centroids.len(), expected); + Ok(()) + } + + #[rstest] + #[case(128, 1)] + #[case(128, 2)] + #[case(128, 3)] + #[case(128, 4)] + #[case(768, 2)] + fn centroids_are_sorted(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { + let centroids = get_centroids(dim, bits)?; + for window in centroids.windows(2) { + assert!( + window[0] < window[1], + "centroids not sorted: {:?}", + centroids + ); + } + Ok(()) + } + + #[rstest] + #[case(128, 1)] + #[case(128, 2)] + #[case(256, 2)] + #[case(768, 2)] + fn centroids_are_symmetric(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { + let centroids = get_centroids(dim, bits)?; + let count = centroids.len(); + for idx in 0..count / 2 { + let diff = (centroids[idx] + centroids[count - 1 - idx]).abs(); + assert!( + diff < 1e-5, + "centroids not symmetric: c[{idx}]={}, c[{}]={}", + centroids[idx], + count - 1 - idx, + centroids[count - 1 - idx] + ); + } + Ok(()) + } + + #[rstest] + #[case(128, 1)] + #[case(128, 4)] + fn centroids_within_bounds(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { + let centroids = get_centroids(dim, bits)?; + for &val in ¢roids { + assert!( + (-1.0..=1.0).contains(&val), + "centroid out of [-1, 1]: {val}", + ); + } + Ok(()) + } + + #[test] + fn centroids_cached() -> VortexResult<()> { + let c1 = get_centroids(128, 2)?; + let c2 = get_centroids(128, 2)?; + assert_eq!(c1, c2); + Ok(()) + } + + #[test] + fn find_nearest_basic() -> VortexResult<()> { + let centroids = get_centroids(128, 2)?; + let boundaries = compute_boundaries(¢roids); + assert_eq!(find_nearest_centroid(-1.0, &boundaries), 0); + + let last_idx = (centroids.len() - 1) as u8; + assert_eq!(find_nearest_centroid(1.0, &boundaries), last_idx); + for (idx, &cv) in centroids.iter().enumerate() { + let expected = idx as u8; + assert_eq!(find_nearest_centroid(cv, &boundaries), expected); + } + Ok(()) + } + + #[test] + fn rejects_invalid_params() { + assert!(get_centroids(128, 0).is_err()); + assert!(get_centroids(128, 9).is_err()); + assert!(get_centroids(1, 2).is_err()); + assert!(get_centroids(2, 2).is_err()); + } +} diff --git a/vortex-tensor/src/encodings/turboquant/array/data.rs b/vortex-tensor/src/encodings/turboquant/array/data.rs new file mode 100644 index 00000000000..43d84777c07 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/array/data.rs @@ -0,0 +1,287 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::sync::Arc; + +use vortex_array::ArrayRef; +use vortex_array::dtype::DType; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_error::vortex_ensure_eq; + +use crate::encodings::turboquant::array::slots::Slot; +use crate::encodings::turboquant::vtable::TurboQuant; +use crate::utils::extension_element_ptype; +use crate::utils::extension_list_size; + +/// TurboQuant array data. +/// +/// TurboQuant is a lossy vector quantization encoding for [`Vector`](crate::vector::Vector) +/// extension arrays. It stores quantized coordinate codes and per-vector norms, along with shared +/// codebook centroids and SRHT rotation signs. +/// +/// See the [module docs](crate::encodings::turboquant) for algorithmic details. +/// +/// A degenerate TurboQuant array has zero rows and `bit_width == 0`, with all slots empty. +#[derive(Clone, Debug)] +pub struct TurboQuantData { + /// Child arrays stored as slots. See [`Slot`] for positions: + /// + /// - [`Codes`](Slot::Codes): Non-nullable `FixedSizeListArray` with + /// `list_size == padded_dim`. Each row holds one u8 centroid index per padded coordinate. + /// Null vectors are represented by all-zero codes. The cascade compressor handles packing + /// to the actual `bit_width` on disk. + /// + /// - [`Norms`](Slot::Norms): Per-vector L2 norms, one per row. The dtype matches the element + /// type of the Vector (e.g., f64 norms for f64 vectors) and carries the nullability of the + /// parent dtype. Null vectors have null norms. This child determines the validity of the + /// entire TurboQuant array, enabling O(1) L2 norm readthrough without decompression. + /// + /// - [`Centroids`](Slot::Centroids): `PrimitiveArray` codebook with `2^bit_width` entries + /// that is shared across all rows. We always store these as f32 regardless of the input + /// element type because quantization itself introduces far more error than f32 precision + /// loss, and f16 inputs can be upcast to f32 before quantization. + /// + /// - [`RotationSigns`](Slot::RotationSigns): `BitPackedArray` of `3 * padded_dim` 1-bit sign + /// values for the 3-round SRHT rotation, stored in inverse application order, and shared + /// across all rows. + pub(crate) slots: Vec>, + + /// The vector dimension `d`, cached from the `FixedSizeList` storage dtype's list size. + /// + /// Stored as a convenience field to avoid repeatedly extracting it from `dtype`. + pub(crate) dimension: u32, + + /// The number of bits per coordinate (1-8), derived from `log2(centroids.len())`. + /// + /// This is 0 for degenerate empty arrays. + pub(crate) bit_width: u8, +} + +impl TurboQuantData { + /// Build a TurboQuant array with validation. + /// + /// The `dimension` and `bit_width` are derived from the inputs: + /// - `dimension` from the `dtype`'s `FixedSizeList` storage list size. + /// - `bit_width` from `log2(centroids.len())` (0 for degenerate empty arrays). + /// + /// # Errors + /// + /// Returns an error if the provided components do not satisfy the invariants documented + /// in [`new_unchecked`](Self::new_unchecked). + pub fn try_new( + dtype: &DType, + codes: ArrayRef, + norms: ArrayRef, + centroids: ArrayRef, + rotation_signs: ArrayRef, + ) -> VortexResult { + Self::validate(dtype, &codes, &norms, ¢roids, &rotation_signs)?; + + // SAFETY: we validate that the inputs are valid above. + Ok(unsafe { Self::new_unchecked(dtype, codes, norms, centroids, rotation_signs) }) + } + + /// Build a TurboQuant array without validation. + /// + /// # Safety + /// + /// The caller must ensure: + /// + /// - `dtype` is a [`Vector`](crate::vector::Vector) extension type whose storage list size + /// is >= 3. + /// - `codes` is a non-nullable `FixedSizeListArray` with `list_size == padded_dim` and + /// `codes.len() == norms.len()`. Null vectors are represented by all-zero codes. + /// - `norms` is a primitive array whose ptype matches the element type of the Vector's storage + /// dtype. The nullability must match `dtype.nullability()`. Norms carry the validity of the + /// entire array, since null vectors have null norms. + /// - `centroids` is a non-nullable `PrimitiveArray` whose length is a power of 2 in + /// `[2, 256]` (i.e., `2^bit_width` for bit_width 1-8), or empty for degenerate arrays. + /// - `rotation_signs` has `3 * padded_dim` elements, or is empty for degenerate arrays. + /// - For degenerate (empty) arrays: all children must be empty. + /// + /// Violating these invariants may produce incorrect results during decompression. + pub unsafe fn new_unchecked( + dtype: &DType, + codes: ArrayRef, + norms: ArrayRef, + centroids: ArrayRef, + rotation_signs: ArrayRef, + ) -> Self { + #[cfg(debug_assertions)] + Self::validate(dtype, &codes, &norms, ¢roids, &rotation_signs) + .vortex_expect("[Debug Assertion]: Invalid TurboQuantData parameters"); + + let dimension = dtype + .as_extension_opt() + .and_then(|ext| extension_list_size(ext).ok()) + .vortex_expect("dtype must be a Vector extension type with FixedSizeList storage"); + + let bit_width = if centroids.is_empty() { + 0 + } else { + // Guaranteed to be 1-8 by validate(). + #[expect(clippy::cast_possible_truncation)] + { + centroids.len().trailing_zeros() as u8 + } + }; + + let mut slots = vec![None; Slot::COUNT]; + slots[Slot::Codes as usize] = Some(codes); + slots[Slot::Norms as usize] = Some(norms); + slots[Slot::Centroids as usize] = Some(centroids); + slots[Slot::RotationSigns as usize] = Some(rotation_signs); + + Self { + slots, + dimension, + bit_width, + } + } + + /// Validates the components that would be used to create a `TurboQuantData`. + /// + /// This function checks all the invariants required by [`new_unchecked`](Self::new_unchecked). + pub fn validate( + dtype: &DType, + codes: &ArrayRef, + norms: &ArrayRef, + centroids: &ArrayRef, + rotation_signs: &ArrayRef, + ) -> VortexResult<()> { + let ext = TurboQuant::validate_dtype(dtype)?; + let dimension = extension_list_size(ext)?; + let padded_dim = dimension.next_power_of_two(); + + // Codes must be a non-nullable FixedSizeList with list_size == padded_dim. + // Null vectors are represented by all-zero codes since validity lives in the norms array. + let expected_codes_dtype = DType::FixedSizeList( + Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)), + padded_dim, + Nullability::NonNullable, + ); + vortex_ensure_eq!( + *codes.dtype(), + expected_codes_dtype, + "codes dtype does not match expected {expected_codes_dtype}", + ); + + let num_rows = codes.len(); + vortex_ensure_eq!( + norms.len(), + num_rows, + "norms length must match codes length", + ); + + // Degenerate (empty) case: all children must be empty, and bit_width is 0. + if num_rows == 0 { + vortex_ensure!( + centroids.is_empty(), + "degenerate TurboQuant must have empty centroids, got length {}", + centroids.len() + ); + vortex_ensure!( + rotation_signs.is_empty(), + "degenerate TurboQuant must have empty rotation_signs, got length {}", + rotation_signs.len() + ); + return Ok(()); + } + + // Non-degenerate: derive and validate bit_width from centroids. + let num_centroids = centroids.len(); + vortex_ensure!( + num_centroids.is_power_of_two() && (2..=256).contains(&num_centroids), + "centroids length must be a power of 2 in [2, 256], got {num_centroids}" + ); + + // Guaranteed to be 1-8 by the preceding power-of-2 and range checks. + #[expect(clippy::cast_possible_truncation)] + let bit_width = num_centroids.trailing_zeros() as u8; + vortex_ensure!( + (1..=8).contains(&bit_width), + "derived bit_width must be 1-8, got {bit_width}" + ); + + // Norms dtype must match the element ptype of the Vector, with the parent's nullability. + // Norms carry the validity of the entire TurboQuant array. + let element_ptype = extension_element_ptype(ext)?; + let expected_norms_dtype = DType::Primitive(element_ptype, dtype.nullability()); + vortex_ensure_eq!( + *norms.dtype(), + expected_norms_dtype, + "norms dtype does not match expected {expected_norms_dtype}", + ); + + // Centroids are always f32 regardless of element type. + let centroids_dtype = DType::Primitive(PType::F32, Nullability::NonNullable); + vortex_ensure_eq!( + *centroids.dtype(), + centroids_dtype, + "centroids dtype must be non-nullable f32", + ); + + // Rotation signs count must be 3 * padded_dim. + vortex_ensure_eq!( + rotation_signs.len(), + 3 * padded_dim as usize, + "rotation_signs length does not match expected 3 * {padded_dim}", + ); + + Ok(()) + } + + /// The vector dimension `d`, as stored in the [`Vector`](crate::vector::Vector) extension + /// dtype's `FixedSizeList` storage. + pub fn dimension(&self) -> u32 { + self.dimension + } + + /// MSE bits per coordinate (1-8 for non-empty arrays, 0 for degenerate empty arrays). + pub fn bit_width(&self) -> u8 { + self.bit_width + } + + /// Padded dimension (next power of 2 >= [`dimension`](Self::dimension)). + /// + /// The SRHT rotation requires power-of-2 input, so non-power-of-2 dimensions are + /// zero-padded to this value. + pub fn padded_dim(&self) -> u32 { + self.dimension.next_power_of_two() + } + + /// The quantized codes child (`FixedSizeListArray`, one row per vector). + pub fn codes(&self) -> &ArrayRef { + self.slot(Slot::Codes as usize) + } + + /// Per-vector L2 norms. The dtype matches the Vector's element type (f16, f32, or f64). + pub fn norms(&self) -> &ArrayRef { + self.slot(Slot::Norms as usize) + } + + /// The codebook centroids (`PrimitiveArray`, length `2^bit_width`). + /// + /// Always f32 regardless of input element type: quantization noise dominates f32 + /// precision loss, and f16 inputs are upcast before quantization anyway. + pub fn centroids(&self) -> &ArrayRef { + self.slot(Slot::Centroids as usize) + } + + /// The SRHT rotation signs (`BitPackedArray`, `3 * padded_dim` 1-bit values). + /// + /// Stored in inverse application order for efficient decode. + pub fn rotation_signs(&self) -> &ArrayRef { + self.slot(Slot::RotationSigns as usize) + } + + fn slot(&self, idx: usize) -> &ArrayRef { + self.slots[idx] + .as_ref() + .vortex_expect("required slot is None") + } +} diff --git a/vortex-tensor/src/encodings/turboquant/array/mod.rs b/vortex-tensor/src/encodings/turboquant/array/mod.rs new file mode 100644 index 00000000000..0c98c974203 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/array/mod.rs @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant array definition: stores quantized coordinate codes, norms, centroids (codebook), +//! and rotation signs. + +pub(crate) mod data; +pub(crate) mod slots; + +pub(crate) mod centroids; +pub(crate) mod rotation; + +pub(crate) mod scheme; diff --git a/vortex-tensor/src/encodings/turboquant/array/rotation.rs b/vortex-tensor/src/encodings/turboquant/array/rotation.rs new file mode 100644 index 00000000000..2f654349778 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/array/rotation.rs @@ -0,0 +1,379 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Deterministic random rotation for TurboQuant. +//! +//! Uses a Structured Random Hadamard Transform (SRHT) for O(d log d) rotation +//! instead of a full d×d matrix multiply. The SRHT applies the sequence +//! D₃ · H · D₂ · H · D₁ where H is the Walsh-Hadamard Transform (WHT) and Dₖ are +//! random diagonal ±1 sign matrices. Three rounds of HD provide sufficient +//! randomness for near-uniform distribution on the sphere. +//! +//! For dimensions that are not powers of 2, the input is zero-padded to the +//! next power of 2 before the transform and truncated afterward. +//! +//! # Sign representation +//! +//! Signs are stored internally as `u32` XOR masks: `0x00000000` for +1 (no-op) +//! and `0x80000000` for -1 (flip IEEE 754 sign bit). The sign application +//! function uses integer XOR instead of floating-point multiply, which avoids +//! FP dependency chains and auto-vectorizes into `vpxor`/`veor`. + +use rand::RngExt; +use rand::SeedableRng; +use rand::rngs::StdRng; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; + +/// IEEE 754 sign bit mask for f32. +const F32_SIGN_BIT: u32 = 0x8000_0000; + +/// A structured random Hadamard transform for O(d log d) pseudo-random rotation. +pub struct RotationMatrix { + /// XOR masks for each of the 3 diagonal matrices, each of length `padded_dim`. + /// `0x00000000` = multiply by +1 (no-op), `0x80000000` = multiply by -1 (flip sign bit). + sign_masks: [Vec; 3], + /// The padded dimension (next power of 2 >= dimension). + padded_dim: usize, + /// Normalization factor: 1/(padded_dim * sqrt(padded_dim)), applied once at the end. + norm_factor: f32, +} + +impl RotationMatrix { + /// Create a new SRHT rotation from a deterministic seed. + pub fn try_new(seed: u64, dimension: usize) -> VortexResult { + let padded_dim = dimension.next_power_of_two(); + let mut rng = StdRng::seed_from_u64(seed); + + let sign_masks = std::array::from_fn(|_| gen_random_sign_masks(&mut rng, padded_dim)); + let norm_factor = 1.0 / (padded_dim as f32 * (padded_dim as f32).sqrt()); + + Ok(Self { + sign_masks, + padded_dim, + norm_factor, + }) + } + + /// Apply forward rotation: `output = SRHT(input)`. + /// + /// Both `input` and `output` must have length `padded_dim()`. The caller + /// is responsible for zero-padding input beyond `dim` positions. + pub fn rotate(&self, input: &[f32], output: &mut [f32]) { + debug_assert_eq!(input.len(), self.padded_dim); + debug_assert_eq!(output.len(), self.padded_dim); + + output.copy_from_slice(input); + self.apply_srht(output); + } + + /// Apply inverse rotation: `output = SRHT⁻¹(input)`. + /// + /// Both `input` and `output` must have length `padded_dim()`. + pub fn inverse_rotate(&self, input: &[f32], output: &mut [f32]) { + debug_assert_eq!(input.len(), self.padded_dim); + debug_assert_eq!(output.len(), self.padded_dim); + + output.copy_from_slice(input); + self.apply_inverse_srht(output); + } + + /// Returns the padded dimension (next power of 2 >= dim). + /// + /// All rotate/inverse_rotate buffers must be this length. + pub fn padded_dim(&self) -> usize { + self.padded_dim + } + + /// Apply the SRHT: D₃ · H · D₂ · H · D₁ · x, with normalization. + fn apply_srht(&self, buf: &mut [f32]) { + apply_signs_xor(buf, &self.sign_masks[0]); + walsh_hadamard_transform(buf); + + apply_signs_xor(buf, &self.sign_masks[1]); + walsh_hadamard_transform(buf); + + apply_signs_xor(buf, &self.sign_masks[2]); + walsh_hadamard_transform(buf); + + let norm = self.norm_factor; + buf.iter_mut().for_each(|val| *val *= norm); + } + + /// Apply the inverse SRHT. + /// + /// Forward is: norm · H · D₃ · H · D₂ · H · D₁ + /// Inverse is: norm · D₁ · H · D₂ · H · D₃ · H + fn apply_inverse_srht(&self, buf: &mut [f32]) { + walsh_hadamard_transform(buf); + apply_signs_xor(buf, &self.sign_masks[2]); + + walsh_hadamard_transform(buf); + apply_signs_xor(buf, &self.sign_masks[1]); + + walsh_hadamard_transform(buf); + apply_signs_xor(buf, &self.sign_masks[0]); + + let norm = self.norm_factor; + buf.iter_mut().for_each(|val| *val *= norm); + } + + /// Export the 3 sign vectors as a flat `Vec` of 0/1 values in inverse + /// application order `[D₃ | D₂ | D₁]`. + /// + /// Convention: `1` = positive (+1), `0` = negative (-1). + /// The output has length `3 * padded_dim` and is suitable for bitpacking + /// via FastLanes `bitpack_encode(..., 1, None)`. + pub fn export_inverse_signs_u8(&self) -> Vec { + let total = 3 * self.padded_dim; + let mut out = Vec::with_capacity(total); + + // Store in inverse order: sign_masks[2] (D₃), sign_masks[1] (D₂), sign_masks[0] (D₁) + for sign_idx in [2, 1, 0] { + for &mask in &self.sign_masks[sign_idx] { + out.push(if mask == 0 { 1u8 } else { 0u8 }); + } + } + out + } + + /// Reconstruct a `RotationMatrix` from unpacked `u8` 0/1 values. + /// + /// The input must have length `3 * padded_dim` with signs in inverse + /// application order `[D₃ | D₂ | D₁]` (as produced by [`export_inverse_signs_u8`]). + /// Convention: `1` = positive, `0` = negative. + /// + /// This is the decode-time reconstruction path: FastLanes SIMD-unpacks the + /// stored `BitPackedArray` into `&[u8]`, which is passed here. + pub fn from_u8_slice(signs_u8: &[u8], dimension: usize) -> VortexResult { + let padded_dim = dimension.next_power_of_two(); + vortex_ensure!( + signs_u8.len() == 3 * padded_dim, + "Expected {} sign bytes, got {}", + 3 * padded_dim, + signs_u8.len() + ); + + // Reconstruct in storage order (inverse): [D₃, D₂, D₁] → sign_masks[2], [1], [0] + let mut sign_masks: [Vec; 3] = std::array::from_fn(|_| Vec::with_capacity(padded_dim)); + + for (round, sign_idx) in [2, 1, 0].into_iter().enumerate() { + let offset = round * padded_dim; + sign_masks[sign_idx] = signs_u8[offset..offset + padded_dim] + .iter() + .map(|&v| if v != 0 { 0u32 } else { F32_SIGN_BIT }) + .collect(); + } + + let norm_factor = 1.0 / (padded_dim as f32 * (padded_dim as f32).sqrt()); + + Ok(Self { + sign_masks, + padded_dim, + norm_factor, + }) + } +} + +/// Generate a vector of random XOR sign masks. +fn gen_random_sign_masks(rng: &mut StdRng, len: usize) -> Vec { + (0..len) + .map(|_| { + if rng.random_bool(0.5) { + 0u32 // +1: no-op + } else { + F32_SIGN_BIT // -1: flip sign bit + } + }) + .collect() +} + +/// Apply sign masks via XOR on the IEEE 754 sign bit. +/// +/// This is branchless and auto-vectorizes into `vpxor` (x86) / `veor` (ARM). +/// Equivalent to multiplying each element by ±1.0, but avoids FP dependency chains. +#[inline] +fn apply_signs_xor(buf: &mut [f32], masks: &[u32]) { + for (val, &mask) in buf.iter_mut().zip(masks.iter()) { + *val = f32::from_bits(val.to_bits() ^ mask); + } +} + +/// In-place Walsh-Hadamard Transform (unnormalized, iterative). +/// +/// Input length must be a power of 2. Runs in O(n log n). +/// +/// Uses a fixed-size chunk strategy: for each stage, the buffer is processed +/// in `CHUNK`-element blocks with a compile-time-known butterfly function. +/// This lets LLVM unroll and auto-vectorize the butterfly into NEON/AVX SIMD. +fn walsh_hadamard_transform(buf: &mut [f32]) { + let len = buf.len(); + debug_assert!(len.is_power_of_two()); + + let mut half = 1; + while half < len { + let stride = half * 2; + // Process in chunks of `stride` elements. Within each chunk, + // split into non-overlapping (lo, hi) halves for the butterfly. + for chunk in buf.chunks_exact_mut(stride) { + let (lo, hi) = chunk.split_at_mut(half); + butterfly(lo, hi); + } + half *= 2; + } +} + +/// Butterfly: `lo[i], hi[i] = lo[i] + hi[i], lo[i] - hi[i]`. +/// +/// Separate function so LLVM can see the slice lengths match and auto-vectorize. +#[inline(always)] +fn butterfly(lo: &mut [f32], hi: &mut [f32]) { + debug_assert_eq!(lo.len(), hi.len()); + for (a, b) in lo.iter_mut().zip(hi.iter_mut()) { + let sum = *a + *b; + let diff = *a - *b; + *a = sum; + *b = diff; + } +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + use vortex_error::VortexResult; + + use super::*; + + #[test] + fn deterministic_from_seed() -> VortexResult<()> { + let r1 = RotationMatrix::try_new(42, 64)?; + let r2 = RotationMatrix::try_new(42, 64)?; + let pd = r1.padded_dim(); + + let mut input = vec![0.0f32; pd]; + for i in 0..64 { + input[i] = i as f32; + } + let mut out1 = vec![0.0f32; pd]; + let mut out2 = vec![0.0f32; pd]; + + r1.rotate(&input, &mut out1); + r2.rotate(&input, &mut out2); + + assert_eq!(out1, out2); + Ok(()) + } + + /// Verify roundtrip is exact to f32 precision across many dimensions, + /// including non-power-of-two dimensions that require padding. + #[rstest] + #[case(32)] + #[case(64)] + #[case(100)] + #[case(128)] + #[case(256)] + #[case(512)] + #[case(768)] + #[case(1024)] + fn roundtrip_exact(#[case] dim: usize) -> VortexResult<()> { + let rot = RotationMatrix::try_new(42, dim)?; + let padded_dim = rot.padded_dim(); + + let mut input = vec![0.0f32; padded_dim]; + for i in 0..dim { + input[i] = (i as f32 + 1.0) * 0.01; + } + let mut rotated = vec![0.0f32; padded_dim]; + let mut recovered = vec![0.0f32; padded_dim]; + + rot.rotate(&input, &mut rotated); + rot.inverse_rotate(&rotated, &mut recovered); + + let max_err: f32 = input + .iter() + .zip(recovered.iter()) + .map(|(a, b)| (a - b).abs()) + .fold(0.0f32, f32::max); + let max_val: f32 = input.iter().map(|x| x.abs()).fold(0.0f32, f32::max); + let rel_err = max_err / max_val; + + // SRHT roundtrip should be exact up to f32 precision (~1e-6). + assert!( + rel_err < 1e-5, + "roundtrip relative error too large for dim={dim}: {rel_err:.2e}" + ); + Ok(()) + } + + /// Verify norm preservation across dimensions. + #[rstest] + #[case(128)] + #[case(768)] + fn preserves_norm(#[case] dim: usize) -> VortexResult<()> { + let rot = RotationMatrix::try_new(7, dim)?; + let padded_dim = rot.padded_dim(); + + let mut input = vec![0.0f32; padded_dim]; + for i in 0..dim { + input[i] = (i as f32) * 0.01; + } + let input_norm: f32 = input.iter().map(|x| x * x).sum::().sqrt(); + + let mut rotated = vec![0.0f32; padded_dim]; + rot.rotate(&input, &mut rotated); + let rotated_norm: f32 = rotated.iter().map(|x| x * x).sum::().sqrt(); + + assert!( + (input_norm - rotated_norm).abs() / input_norm < 1e-5, + "norm not preserved for dim={dim}: {} vs {} (rel err: {:.2e})", + input_norm, + rotated_norm, + (input_norm - rotated_norm).abs() / input_norm + ); + Ok(()) + } + + /// Verify that export → from_u8_slice produces identical rotation output. + #[rstest] + #[case(64)] + #[case(128)] + #[case(768)] + fn sign_export_import_roundtrip(#[case] dim: usize) -> VortexResult<()> { + let rot = RotationMatrix::try_new(42, dim)?; + let padded_dim = rot.padded_dim(); + + let signs_u8 = rot.export_inverse_signs_u8(); + let rot2 = RotationMatrix::from_u8_slice(&signs_u8, dim)?; + + let mut input = vec![0.0f32; padded_dim]; + for i in 0..dim { + input[i] = (i as f32 + 1.0) * 0.01; + } + + let mut out1 = vec![0.0f32; padded_dim]; + let mut out2 = vec![0.0f32; padded_dim]; + rot.rotate(&input, &mut out1); + rot2.rotate(&input, &mut out2); + assert_eq!(out1, out2, "Forward rotation mismatch after export/import"); + + rot.inverse_rotate(&out1, &mut out2); + let mut out3 = vec![0.0f32; padded_dim]; + rot2.inverse_rotate(&out1, &mut out3); + assert_eq!(out2, out3, "Inverse rotation mismatch after export/import"); + + Ok(()) + } + + #[test] + fn wht_basic() { + // WHT of [1, 0, 0, 0] should be [1, 1, 1, 1] + let mut buf = vec![1.0f32, 0.0, 0.0, 0.0]; + walsh_hadamard_transform(&mut buf); + assert_eq!(buf, vec![1.0, 1.0, 1.0, 1.0]); + + // WHT is self-inverse (up to scaling by n) + walsh_hadamard_transform(&mut buf); + // After two WHTs: each element multiplied by n=4 + assert_eq!(buf, vec![4.0, 0.0, 0.0, 0.0]); + } +} diff --git a/vortex-tensor/src/encodings/turboquant/array/scheme.rs b/vortex-tensor/src/encodings/turboquant/array/scheme.rs new file mode 100644 index 00000000000..2b954522676 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/array/scheme.rs @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant compression scheme for the pluggable compressor. + +use vortex_array::ArrayRef; +use vortex_array::Canonical; +use vortex_compressor::CascadingCompressor; +use vortex_compressor::ctx::CompressorContext; +use vortex_compressor::scheme::Scheme; +use vortex_compressor::stats::ArrayAndStats; +use vortex_error::VortexResult; + +use crate::encodings::turboquant::TurboQuant; +use crate::encodings::turboquant::TurboQuantConfig; +use crate::encodings::turboquant::turboquant_encode; +use crate::utils::extension_element_ptype; +use crate::utils::extension_list_size; + +/// TurboQuant compression scheme for [`Vector`] extension types. +/// +/// Applies lossy vector quantization to [`Vector`] extension arrays using the TurboQuant +/// algorithm with MSE-optimal encoding. +/// +/// Register this scheme with the compressor builder via `with_scheme`: +/// ```ignore +/// use vortex_btrblocks::BtrBlocksCompressorBuilder; +/// use vortex_tensor::encodings::turboquant::TurboQuantScheme; +/// +/// let compressor = BtrBlocksCompressorBuilder::default() +/// .with_new_scheme(&TurboQuantScheme) +/// .build(); +/// ``` +/// +/// [`Vector`]: crate::vector::Vector +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct TurboQuantScheme; + +impl Scheme for TurboQuantScheme { + fn scheme_name(&self) -> &'static str { + "vortex.tensor.turboquant" + } + + fn matches(&self, canonical: &Canonical) -> bool { + let Canonical::Extension(ext) = canonical else { + return false; + }; + + TurboQuant::validate_dtype(ext.dtype()).is_ok() + } + + fn expected_compression_ratio( + &self, + _compressor: &CascadingCompressor, + data: &mut ArrayAndStats, + _ctx: CompressorContext, + ) -> VortexResult { + let dtype = data.array().dtype(); + let len = data.array().len(); + + let ext = TurboQuant::validate_dtype(dtype)?; + let element_ptype = extension_element_ptype(ext)?; + let dimension = extension_list_size(ext)?; + + Ok(estimate_compression_ratio( + element_ptype.bit_width(), + dimension, + len, + )) + } + + fn compress( + &self, + compressor: &CascadingCompressor, + data: &mut ArrayAndStats, + _ctx: CompressorContext, + ) -> VortexResult { + // TODO(connor): Fix this once we ensure that the data array is always canonical. + let ext_array = data.array().to_canonical()?.into_extension(); + + let config = TurboQuantConfig::default(); + turboquant_encode(&ext_array, &config, &mut compressor.execution_ctx()) + } +} + +/// Estimate the compression ratio for TurboQuant MSE encoding with the default config. +fn estimate_compression_ratio(bits_per_element: usize, dimensions: u32, num_vectors: usize) -> f64 { + let config = TurboQuantConfig::default(); + let padded_dim = dimensions.next_power_of_two() as usize; + + // Per-vector: MSE codes per padded coordinate, plus one f32 norm. + let compressed_bits_per_vector = 32 // norm is always f32 + + (config.bit_width as usize) * padded_dim; // MSE codes + + // Shared overhead: codebook centroids (2^bit_width f32 values) and + // rotation signs (3 * padded_dim bits). + let num_centroids = 1usize << config.bit_width; + let overhead_bits = num_centroids * 32 // centroids are always f32 + + 3 * padded_dim; // rotation signs, 1 bit each + + let compressed_size_bits = compressed_bits_per_vector * num_vectors + overhead_bits; + let uncompressed_size_bits = bits_per_element * num_vectors * dimensions as usize; + uncompressed_size_bits as f64 / compressed_size_bits as f64 +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + + use super::*; + + /// Verify compression ratio for typical embedding dimensions. + /// + /// f32 input at 768-d (padded to 1024) with 1000 vectors should give ~4-6x. + /// f32 input at 1024-d (no padding) should give higher ratio since no waste. + #[rstest] + #[case::f32_768d(32, 768, 1000, 3.5, 8.0)] + #[case::f32_1024d(32, 1024, 1000, 5.0, 9.0)] + #[case::f32_1536d(32, 1536, 1000, 3.0, 8.0)] + #[case::f32_128d(32, 128, 1000, 4.0, 8.0)] + #[case::f64_768d(64, 768, 1000, 7.0, 16.0)] + #[case::f16_768d(16, 768, 1000, 1.5, 4.5)] + fn compression_ratio_in_expected_range( + #[case] bits_per_element: usize, + #[case] dim: u32, + #[case] num_vectors: usize, + #[case] min_ratio: f64, + #[case] max_ratio: f64, + ) { + let ratio = estimate_compression_ratio(bits_per_element, dim, num_vectors); + assert!( + ratio > min_ratio && ratio < max_ratio, + "ratio {ratio:.2} not in [{min_ratio}, {max_ratio}] for \ + {bits_per_element}-bit elements, dim={dim}, n={num_vectors}" + ); + } + + /// Compression ratio must always be > 1 for reasonable inputs, + /// otherwise TurboQuant makes things bigger and should not be selected. + #[rstest] + #[case(32, 128, 100)] + #[case(32, 768, 10)] + #[case(64, 256, 50)] + fn ratio_always_greater_than_one( + #[case] bits_per_element: usize, + #[case] dim: u32, + #[case] num_vectors: usize, + ) { + let ratio = estimate_compression_ratio(bits_per_element, dim, num_vectors); + assert!( + ratio > 1.0, + "ratio {ratio:.4} <= 1.0 for {bits_per_element}-bit, dim={dim}, n={num_vectors}" + ); + } + + /// Power-of-2 dimensions should have better ratios than their non-power-of-2 + /// predecessors due to no padding waste. + #[test] + fn power_of_two_has_better_ratio() { + let ratio_768 = estimate_compression_ratio(32, 768, 1000); + let ratio_1024 = estimate_compression_ratio(32, 1024, 1000); + assert!( + ratio_1024 > ratio_768, + "1024-d ratio ({ratio_1024:.2}) should exceed 768-d ({ratio_768:.2})" + ); + } +} diff --git a/vortex-tensor/src/encodings/turboquant/array/slots.rs b/vortex-tensor/src/encodings/turboquant/array/slots.rs new file mode 100644 index 00000000000..ff59db447d3 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/array/slots.rs @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +/// Slot positions for TurboQuantArray children. +#[repr(usize)] +#[derive(Clone, Copy, Debug)] +pub(crate) enum Slot { + Codes = 0, + Norms = 1, + Centroids = 2, + RotationSigns = 3, +} + +impl Slot { + pub(crate) const COUNT: usize = 4; + + pub(crate) fn name(self) -> &'static str { + match self { + Self::Codes => "codes", + Self::Norms => "norms", + Self::Centroids => "centroids", + Self::RotationSigns => "rotation_signs", + } + } + + pub(crate) fn from_index(idx: usize) -> Self { + match idx { + 0 => Self::Codes, + 1 => Self::Norms, + 2 => Self::Centroids, + 3 => Self::RotationSigns, + _ => vortex_error::vortex_panic!("invalid slot index {idx}"), + } + } +} diff --git a/vortex-tensor/src/encodings/turboquant/compress.rs b/vortex-tensor/src/encodings/turboquant/compress.rs new file mode 100644 index 00000000000..241c7089b61 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/compress.rs @@ -0,0 +1,273 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant encoding (quantization) logic. + +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::dtype::DType; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_array::match_each_float_ptype; +use vortex_array::validity::Validity; +use vortex_buffer::BufferMut; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_ensure; +use vortex_fastlanes::bitpack_compress::bitpack_encode; + +use crate::encodings::turboquant::TurboQuant; +use crate::encodings::turboquant::array::centroids::compute_boundaries; +use crate::encodings::turboquant::array::centroids::find_nearest_centroid; +use crate::encodings::turboquant::array::centroids::get_centroids; +use crate::encodings::turboquant::array::rotation::RotationMatrix; +use crate::encodings::turboquant::vtable::TurboQuantArray; +use crate::scalar_fns::ApproxOptions; +use crate::scalar_fns::l2_norm::L2Norm; + +/// Configuration for TurboQuant encoding. +#[derive(Clone, Debug)] +pub struct TurboQuantConfig { + /// Bits per coordinate (1-8). + pub bit_width: u8, + /// Optional seed for the rotation matrix. If None, the default seed is used. + pub seed: Option, +} + +impl Default for TurboQuantConfig { + fn default() -> Self { + Self { + bit_width: 4, + seed: Some(42), + } + } +} + +/// Extract elements from a FixedSizeListArray as a flat f32 PrimitiveArray for quantization. +/// +/// All quantization (rotation, centroid lookup) happens in f32. f16 is upcast; f64 is truncated. +#[allow(clippy::cast_possible_truncation)] +fn extract_f32_elements(fsl: &FixedSizeListArray) -> VortexResult { + let elements = fsl.elements(); + let primitive = elements.to_canonical()?.into_primitive(); + let ptype = primitive.ptype(); + + match ptype { + PType::F16 => Ok(primitive + .as_slice::() + .iter() + .map(|&v| f32::from(v)) + .collect()), + PType::F32 => Ok(primitive), + PType::F64 => Ok(primitive + .as_slice::() + .iter() + .map(|&v| v as f32) + .collect()), + _ => vortex_bail!("TurboQuant requires float elements, got {ptype:?}"), + } +} + +/// Shared intermediate results from the quantization loop. +struct QuantizationResult { + rotation: RotationMatrix, + centroids: Vec, + all_indices: BufferMut, + /// Native-precision norms (matching the Vector element type). Carries validity: null vectors + /// have null norms. + norms_array: ArrayRef, + padded_dim: usize, +} + +/// Core quantization: compute norms via [`L2Norm`], extract f32 elements, then +/// normalize/rotate/quantize all rows. +/// +/// Norms are computed in the native element precision via the [`L2Norm`] scalar function. +/// The rotation and centroid lookup happen in f32. Null rows (per the input validity) produce +/// all-zero codes. +#[allow(clippy::cast_possible_truncation)] +fn turboquant_quantize_core( + ext: &ExtensionArray, + fsl: &FixedSizeListArray, + seed: u64, + bit_width: u8, + validity: &Validity, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let dimension = fsl.list_size() as usize; + let num_rows = fsl.len(); + + // Compute native-precision norms via the L2Norm scalar fn. L2Norm propagates validity from + // the input, so null vectors get null norms automatically. + let norms_sfn = L2Norm::try_new_array(&ApproxOptions::Exact, ext.as_ref().clone(), num_rows)?; + let norms_array: ArrayRef = norms_sfn.into_array().execute(ctx)?; + let norms_prim: PrimitiveArray = norms_array.to_canonical()?.into_primitive(); + + // Extract f32 norms for the internal quantization loop. + let f32_norms: Vec = match_each_float_ptype!(norms_prim.ptype(), |T| { + norms_prim + .as_slice::() + .iter() + .map(|&v| num_traits::ToPrimitive::to_f32(&v).unwrap_or(0.0)) + .collect() + }); + + let rotation = RotationMatrix::try_new(seed, dimension)?; + let padded_dim = rotation.padded_dim(); + + let f32_elements = extract_f32_elements(fsl)?; + + let centroids = get_centroids(padded_dim as u32, bit_width)?; + let boundaries = compute_boundaries(¢roids); + + let mut all_indices = BufferMut::::with_capacity(num_rows * padded_dim); + let mut padded = vec![0.0f32; padded_dim]; + let mut rotated = vec![0.0f32; padded_dim]; + + let f32_slice = f32_elements.as_slice::(); + for row in 0..num_rows { + // Null vectors get all-zero codes. + if !validity.is_valid(row)? { + all_indices.extend(std::iter::repeat_n(0u8, padded_dim)); + continue; + } + + let x = &f32_slice[row * dimension..(row + 1) * dimension]; + let norm = f32_norms[row]; + + if norm > 0.0 { + let inv_norm = 1.0 / norm; + for (dst, &src) in padded[..dimension].iter_mut().zip(x.iter()) { + *dst = src * inv_norm; + } + } else { + padded[..dimension].fill(0.0); + } + rotation.rotate(&padded, &mut rotated); + + for j in 0..padded_dim { + all_indices.push(find_nearest_centroid(rotated[j], &boundaries)); + } + } + + Ok(QuantizationResult { + rotation, + centroids, + all_indices, + norms_array, + padded_dim, + }) +} + +/// Build a `TurboQuantArray` from quantization results. +#[allow(clippy::cast_possible_truncation)] +fn build_turboquant( + fsl: &FixedSizeListArray, + core: QuantizationResult, + ext_dtype: DType, +) -> VortexResult { + let num_rows = fsl.len(); + let padded_dim = core.padded_dim; + let codes_elements = + PrimitiveArray::new::(core.all_indices.freeze(), Validity::NonNullable); + let codes = FixedSizeListArray::try_new( + codes_elements.into_array(), + padded_dim as u32, + Validity::NonNullable, + num_rows, + )? + .into_array(); + + // TODO(perf): `get_centroids` returns Vec; could avoid the copy by + // supporting Buffer::from(Vec) or caching as Buffer directly. + let mut centroids_buf = BufferMut::::with_capacity(core.centroids.len()); + centroids_buf.extend_from_slice(&core.centroids); + let centroids_array = + PrimitiveArray::new::(centroids_buf.freeze(), Validity::NonNullable).into_array(); + + let rotation_signs = bitpack_rotation_signs(&core.rotation)?; + + TurboQuant::try_new_array( + ext_dtype, + codes, + core.norms_array, + centroids_array, + rotation_signs, + ) +} + +/// Encode a [`Vector`](crate::vector::Vector) extension array into a `TurboQuantArray`. +/// +/// Nullable inputs are supported: null vectors get all-zero codes and null norms. The validity +/// of the resulting TurboQuant array is carried by the norms child. +pub fn turboquant_encode( + ext: &ExtensionArray, + config: &TurboQuantConfig, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let ext_dtype = ext.dtype().clone(); + let storage = ext.storage_array(); + let fsl = storage.to_canonical()?.into_fixed_size_list(); + + vortex_ensure!( + config.bit_width >= 1 && config.bit_width <= 8, + "bit_width must be 1-8, got {}", + config.bit_width + ); + let dimension = fsl.list_size(); + vortex_ensure!( + dimension >= 3, + "TurboQuant requires dimension >= 3, got {dimension}" + ); + + if fsl.is_empty() { + let padded_dim = dimension.next_power_of_two(); + let empty_codes = FixedSizeListArray::try_new( + PrimitiveArray::empty::(Nullability::NonNullable).into_array(), + padded_dim, + Validity::NonNullable, + 0, + )?; + + // Norms dtype matches the element type and carries the parent's nullability. + let element_ptype = fsl.elements().dtype().as_ptype(); + let norms_nullability = ext_dtype.nullability(); + let empty_norms: ArrayRef = match_each_float_ptype!(element_ptype, |T| { + PrimitiveArray::empty::(norms_nullability).into_array() + }); + + let empty_centroids = PrimitiveArray::empty::(Nullability::NonNullable); + let empty_signs = PrimitiveArray::empty::(Nullability::NonNullable); + return Ok(TurboQuant::try_new_array( + ext_dtype, + empty_codes.into_array(), + empty_norms, + empty_centroids.into_array(), + empty_signs.into_array(), + )? + .into_array()); + } + + let validity = ext.as_ref().validity()?; + let seed = config.seed.unwrap_or(42); + let core = turboquant_quantize_core(ext, &fsl, seed, config.bit_width, &validity, ctx)?; + + Ok(build_turboquant(&fsl, core, ext_dtype)?.into_array()) +} + +/// Export rotation signs as a 1-bit `BitPackedArray` for efficient storage. +/// +/// The rotation matrix's 3 x padded_dim sign values are exported as 0/1 u8 +/// values in inverse application order, then bitpacked to 1 bit per sign. +/// On decode, FastLanes SIMD-unpacks back to `&[u8]` of 0/1 values. +fn bitpack_rotation_signs(rotation: &RotationMatrix) -> VortexResult { + let signs_u8 = rotation.export_inverse_signs_u8(); + let mut buf = BufferMut::::with_capacity(signs_u8.len()); + buf.extend_from_slice(&signs_u8); + let prim = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + Ok(bitpack_encode(&prim, 1, None)?.into_array()) +} diff --git a/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs b/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs new file mode 100644 index 00000000000..c6c6531119b --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Approximate cosine similarity in the quantized domain. +//! +//! Since the SRHT is orthogonal, inner products are preserved in the rotated +//! domain. For two vectors from the same TurboQuant column (same rotation and +//! centroids), we can compute the dot product of their quantized representations +//! without full decompression: +//! +//! ```text +//! cos_approx(a, b) = sum(centroids[code_a[j]] × centroids[code_b[j]]) +//! ``` +//! +//! where `code_a` and `code_b` are the quantized coordinate indices of the +//! unit-norm rotated vectors `â_rot` and `b̂_rot`. +//! +//! # Bias and error bounds +//! +//! This estimate is **biased**. The MSE quantizer minimizes reconstruction error +//! but does not guarantee unbiased inner products; the discrete centroid grid +//! introduces systematic bias in the dot product. +//! +//! The approximation error is bounded by the MSE quantization distortion. For +//! unit-norm vectors quantized at `b` bits, the per-coordinate MSE is bounded by +//! `(√3 · π / 2) / 4^b` (Theorem 1). The inner product error scales with this +//! distortion: at 4 bits the error is typically < 0.1, at 8 bits < 0.001. +//! +//! For approximate nearest neighbor (ANN) search, biased-but-accurate ranking is +//! usually sufficient -- the relative ordering of cosine similarities is preserved +//! even if the absolute values have bounded error. + +use num_traits::FromPrimitive; +use num_traits::Zero; +use vortex_array::ArrayRef; +use vortex_array::ArrayView; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::match_each_float_ptype; +use vortex_buffer::BufferMut; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure_eq; + +use crate::encodings::turboquant::TurboQuant; +use crate::utils::extension_element_ptype; + +/// Convert an f32 value to `T`, returning `T::zero()` if the conversion fails. +/// +/// This helper exists because `half::f16` has an inherent `from_f32` method that shadows +/// the [`FromPrimitive`] trait method, causing compilation errors when used inside +/// [`match_each_float_ptype!`]. +#[inline] +fn f32_to_t(v: f32) -> T { + // TODO(connor): Is this actually correct? How should we handle f64 overflow? + FromPrimitive::from_f32(v).unwrap_or_else(T::zero) +} + +/// Compute the per-row unit-norm dot products in f32 (centroids are always f32). +/// +/// Returns a `Vec` of length `num_rows`. +fn compute_unit_dots( + lhs: &ArrayView, + rhs: &ArrayView, + ctx: &mut ExecutionCtx, +) -> VortexResult> { + let pd = lhs.padded_dim() as usize; + let num_rows = lhs.norms().len(); + + let lhs_codes_fsl: FixedSizeListArray = lhs.codes().clone().execute(ctx)?; + let rhs_codes_fsl: FixedSizeListArray = rhs.codes().clone().execute(ctx)?; + let lhs_codes: PrimitiveArray = lhs_codes_fsl.elements().clone().execute(ctx)?; + let rhs_codes: PrimitiveArray = rhs_codes_fsl.elements().clone().execute(ctx)?; + let ca = lhs_codes.as_slice::(); + let cb = rhs_codes.as_slice::(); + + // Read centroids from both arrays. They may have different codebooks (e.g., different bit + // widths). + let lhs_centroids: PrimitiveArray = lhs.centroids().clone().execute(ctx)?; + let rhs_centroids: PrimitiveArray = rhs.centroids().clone().execute(ctx)?; + let cl = lhs_centroids.as_slice::(); + let cr = rhs_centroids.as_slice::(); + + let mut dots = Vec::with_capacity(num_rows); + for row in 0..num_rows { + let row_ca = &ca[row * pd..(row + 1) * pd]; + let row_cb = &cb[row * pd..(row + 1) * pd]; + let dot: f32 = row_ca + .iter() + .zip(row_cb.iter()) + .map(|(&a, &b)| cl[a as usize] * cr[b as usize]) + .sum(); + dots.push(dot); + } + + Ok(dots) +} + +/// Compute approximate cosine similarity for all rows between two TurboQuant +/// arrays (same rotation matrix and codebook) without full decompression. +/// +/// Since TurboQuant stores unit-normalized rotated vectors, the dot product of the quantized +/// codes directly approximates cosine similarity without needing the stored norms. +/// +/// The output dtype matches the Vector's element type (f16, f32, or f64). +pub fn cosine_similarity_quantized_column( + lhs: ArrayView, + rhs: ArrayView, + ctx: &mut ExecutionCtx, +) -> VortexResult { + vortex_ensure_eq!( + lhs.dimension(), + rhs.dimension(), + "TurboQuant quantized dot product requires matching dimensions", + ); + + let element_ptype = extension_element_ptype(lhs.dtype().as_extension())?; + let validity = lhs.norms().validity()?.and(rhs.norms().validity()?)?; + let dots = compute_unit_dots(&lhs, &rhs, ctx)?; + + // The unit-norm dot product IS the cosine similarity. Cast from f32 to the native type. + match_each_float_ptype!(element_ptype, |T| { + let mut result = BufferMut::::with_capacity(dots.len()); + for &dot in &dots { + // SAFETY: We allocated the correct amount. + unsafe { result.push_unchecked(f32_to_t(dot)) }; + } + + // SAFETY: `result` has the same length as the input arrays, matching `validity`. + Ok(unsafe { PrimitiveArray::new_unchecked(result.freeze(), validity) }.into_array()) + }) +} + +/// Compute approximate dot product for all rows between two TurboQuant +/// arrays (same rotation matrix and codebook) without full decompression. +/// +/// `dot_product(a, b) = ||a|| * ||b|| * sum(c[code_a[j]] * c[code_b[j]])` +/// +/// The output dtype matches the Vector's element type (f16, f32, or f64). +pub fn dot_product_quantized_column( + lhs: ArrayView, + rhs: ArrayView, + ctx: &mut ExecutionCtx, +) -> VortexResult { + vortex_ensure_eq!( + lhs.dimension(), + rhs.dimension(), + "TurboQuant quantized dot product requires matching dimensions", + ); + + let element_ptype = extension_element_ptype(lhs.dtype().as_extension())?; + let validity = lhs.norms().validity()?.and(rhs.norms().validity()?)?; + let dots = compute_unit_dots(&lhs, &rhs, ctx)?; + let num_rows = lhs.norms().len(); + + let lhs_norms: PrimitiveArray = lhs.norms().clone().execute(ctx)?; + let rhs_norms: PrimitiveArray = rhs.norms().clone().execute(ctx)?; + + // Scale the f32 unit-norm dot product by native-precision norms. + match_each_float_ptype!(element_ptype, |T| { + let na = lhs_norms.as_slice::(); + let nb = rhs_norms.as_slice::(); + + let mut result = BufferMut::::with_capacity(num_rows); + for row in 0..num_rows { + let dot_t: T = f32_to_t(dots[row]); + // SAFETY: We allocated the correct amount. + unsafe { result.push_unchecked(na[row] * nb[row] * dot_t) }; + } + + // SAFETY: `result` has the same length as the input arrays, matching `validity`. + Ok(unsafe { PrimitiveArray::new_unchecked(result.freeze(), validity) }.into_array()) + }) +} diff --git a/vortex-tensor/src/encodings/turboquant/compute/mod.rs b/vortex-tensor/src/encodings/turboquant/compute/mod.rs new file mode 100644 index 00000000000..67b4d3efb7f --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/compute/mod.rs @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Compute pushdown implementations for TurboQuant. + +pub(crate) mod cosine_similarity; +mod ops; +pub(crate) mod rules; +mod slice; +mod take; diff --git a/vortex-tensor/src/encodings/turboquant/compute/ops.rs b/vortex-tensor/src/encodings/turboquant/compute/ops.rs new file mode 100644 index 00000000000..4999816319b --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/compute/ops.rs @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::ArrayView; +use vortex_array::ExecutionCtx; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::slice::SliceReduce; +use vortex_array::scalar::Scalar; +use vortex_array::vtable::OperationsVTable; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; + +use crate::encodings::turboquant::TurboQuant; + +impl OperationsVTable for TurboQuant { + fn scalar_at( + array: ArrayView<'_, TurboQuant>, + index: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + // Slice to single row, decompress that one row. + let Some(sliced) = ::slice(array, index..index + 1)? else { + vortex_bail!("slice returned None for index {index}") + }; + let decoded = sliced.execute::(ctx)?; + decoded.scalar_at(0) + } +} diff --git a/vortex-tensor/src/encodings/turboquant/compute/rules.rs b/vortex-tensor/src/encodings/turboquant/compute/rules.rs new file mode 100644 index 00000000000..39919a8c1ec --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/compute/rules.rs @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::arrays::dict::TakeExecuteAdaptor; +use vortex_array::arrays::slice::SliceReduceAdaptor; +use vortex_array::kernel::ParentKernelSet; +use vortex_array::optimizer::rules::ParentRuleSet; + +use crate::encodings::turboquant::TurboQuant; + +pub(crate) static RULES: ParentRuleSet = + ParentRuleSet::new(&[ParentRuleSet::lift(&SliceReduceAdaptor(TurboQuant))]); + +pub(crate) static PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(TurboQuant))]); diff --git a/vortex-tensor/src/encodings/turboquant/compute/slice.rs b/vortex-tensor/src/encodings/turboquant/compute/slice.rs new file mode 100644 index 00000000000..a8daef6466b --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/compute/slice.rs @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::ops::Range; + +use vortex_array::ArrayRef; +use vortex_array::ArrayView; +use vortex_array::IntoArray; +use vortex_array::arrays::slice::SliceReduce; +use vortex_error::VortexResult; + +use crate::encodings::turboquant::TurboQuant; + +impl SliceReduce for TurboQuant { + fn slice( + array: ArrayView<'_, TurboQuant>, + range: Range, + ) -> VortexResult> { + let sliced_codes = array.codes().slice(range.clone())?; + let sliced_norms = array.norms().slice(range)?; + + Ok(Some( + TurboQuant::try_new_array( + array.dtype().clone(), + sliced_codes, + sliced_norms, + array.centroids().clone(), + array.rotation_signs().clone(), + )? + .into_array(), + )) + } +} diff --git a/vortex-tensor/src/encodings/turboquant/compute/take.rs b/vortex-tensor/src/encodings/turboquant/compute/take.rs new file mode 100644 index 00000000000..7614f1577a7 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/compute/take.rs @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::ArrayRef; +use vortex_array::ArrayView; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::dict::TakeExecute; +use vortex_error::VortexResult; + +use crate::encodings::turboquant::TurboQuant; + +impl TakeExecute for TurboQuant { + fn take( + array: ArrayView<'_, TurboQuant>, + indices: &ArrayRef, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + // FSL children handle per-row take natively. + let taken_codes = array.codes().take(indices.clone())?; + let taken_norms = array.norms().take(indices.clone())?; + + Ok(Some( + TurboQuant::try_new_array( + array.dtype().clone(), + taken_codes, + taken_norms, + array.centroids().clone(), + array.rotation_signs().clone(), + )? + .into_array(), + )) + } +} diff --git a/vortex-tensor/src/encodings/turboquant/decompress.rs b/vortex-tensor/src/encodings/turboquant/decompress.rs new file mode 100644 index 00000000000..362207913a3 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/decompress.rs @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant decoding (dequantization) logic. + +use num_traits::FromPrimitive; +use num_traits::Zero; +use vortex_array::Array; +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::dtype::NativePType; +use vortex_array::dtype::Nullability; +use vortex_array::match_each_float_ptype; +use vortex_array::validity::Validity; +use vortex_buffer::BufferMut; +use vortex_error::VortexResult; + +use crate::encodings::turboquant::TurboQuant; +use crate::encodings::turboquant::array::rotation::RotationMatrix; +use crate::utils::extension_element_ptype; + +/// Decompress a `TurboQuantArray` into a [`Vector`] extension array. +/// +/// The returned array is an [`ExtensionArray`] with the original Vector dtype wrapping a +/// `FixedSizeListArray` of f32 elements. +/// +/// [`Vector`]: crate::vector::Vector +pub fn execute_decompress( + array: Array, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let dim = array.dimension() as usize; + let padded_dim = array.padded_dim() as usize; + let num_rows = array.norms().len(); + let ext_dtype = array.dtype().as_extension().clone(); + let element_ptype = extension_element_ptype(&ext_dtype)?; + + if num_rows == 0 { + let fsl_validity = Validity::from(ext_dtype.storage_dtype().nullability()); + + match_each_float_ptype!(element_ptype, |T| { + let elements = PrimitiveArray::empty::(Nullability::NonNullable); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + array.dimension(), + fsl_validity, + 0, + )?; + + return Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()); + }) + } + + // Read stored centroids (always f32). + let centroids_prim = array.centroids().clone().execute::(ctx)?; + let centroids = centroids_prim.as_slice::(); + + // FastLanes SIMD-unpacks the 1-bit bitpacked rotation signs into u8 0/1 values, + // then we expand to u32 XOR masks once (amortized over all rows). This enables + // branchless XOR-based sign application in the per-row SRHT hot loop. + let signs_prim = array + .rotation_signs() + .clone() + .execute::(ctx)?; + let rotation = RotationMatrix::from_u8_slice(signs_prim.as_slice::(), dim)?; + + // Unpack codes from FixedSizeListArray -> flat u8 elements. + let codes_fsl = array.codes().clone().execute::(ctx)?; + let codes_prim = codes_fsl.elements().to_canonical()?.into_primitive(); + let indices = codes_prim.as_slice::(); + + // Read norms in their native precision. Norms carry the validity of the array. + let norms_prim = array.norms().clone().execute::(ctx)?; + let output_validity = array.norms().validity()?; + + // MSE decode: dequantize (f32) -> inverse rotate (f32) -> scale by norm -> cast to T. + // The rotation and centroid lookup always happen in f32. The final output is cast to the + // Vector's element type to match the original storage dtype. + match_each_float_ptype!(element_ptype, |T| { + decompress_typed::( + &norms_prim, + centroids, + &rotation, + indices, + dim, + padded_dim, + num_rows, + ) + .and_then(|elements| { + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + array.dimension(), + output_validity, + num_rows, + )?; + Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + }) + }) +} + +/// Typed decompress: reads norms as `T`, dequantizes in f32, and produces output as `T`. +fn decompress_typed( + norms_prim: &PrimitiveArray, + centroids: &[f32], + rotation: &RotationMatrix, + indices: &[u8], + dim: usize, + padded_dim: usize, + num_rows: usize, +) -> VortexResult { + let norms = norms_prim.as_slice::(); + + let mut output = BufferMut::::with_capacity(num_rows * dim); + let mut dequantized = vec![0.0f32; padded_dim]; + let mut unrotated = vec![0.0f32; padded_dim]; + + for row in 0..num_rows { + let row_indices = &indices[row * padded_dim..(row + 1) * padded_dim]; + let norm = norms[row]; + + for idx in 0..padded_dim { + dequantized[idx] = centroids[row_indices[idx] as usize]; + } + + rotation.inverse_rotate(&dequantized, &mut unrotated); + + for idx in 0..dim { + // Convert f32 dequantized value to T, then scale by the native-precision norm. + let val = T::from_f32(unrotated[idx]).unwrap_or_else(T::zero) * norm; + output.push(val); + } + } + + Ok(PrimitiveArray::new::( + output.freeze(), + Validity::NonNullable, + )) +} diff --git a/vortex-tensor/src/encodings/turboquant/mod.rs b/vortex-tensor/src/encodings/turboquant/mod.rs new file mode 100644 index 00000000000..0bd6efff68b --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/mod.rs @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant vector quantization encoding for Vortex. +//! +//! Implements the TurboQuant algorithm ([arXiv:2504.19874]) for lossy compression of +//! high-dimensional vector data. The encoding operates on [`Vector`] extension arrays, +//! compressing their `FixedSizeList` storage into quantized codes with an SRHT rotation. +//! +//! [arXiv:2504.19874]: https://arxiv.org/abs/2504.19874 +//! [`Vector`]: crate::vector::Vector +//! +//! # Overview +//! +//! TurboQuant minimizes mean-squared reconstruction error (1-8 bits per coordinate) +//! using MSE-optimal scalar quantization with an SRHT rotation for coordinate independence. +//! +//! # Theoretical error bounds +//! +//! For unit-norm vectors quantized at `b` bits per coordinate, the paper's Theorem 1 +//! guarantees normalized MSE distortion: +//! +//! > `E[||x - x_hat||^2 / ||x||^2] <= (sqrt(3) * pi / 2) / 4^b` +//! +//! | Bits | MSE bound | Quality | +//! |------|------------|-------------------| +//! | 1 | 6.80e-01 | Poor | +//! | 2 | 1.70e-01 | Usable for ANN | +//! | 3 | 4.25e-02 | Good | +//! | 4 | 1.06e-02 | Very good | +//! | 5 | 2.66e-03 | Excellent | +//! | 6 | 6.64e-04 | Near-lossless | +//! | 7 | 1.66e-04 | Near-lossless | +//! | 8 | 4.15e-05 | Near-lossless | +//! +//! # Compression ratios +//! +//! Each vector is stored as `padded_dim * bit_width / 8` bytes of quantized codes plus a +//! 4-byte f32 norm. Non-power-of-2 dimensions are padded to the next power of 2 for the +//! Walsh-Hadamard transform, which reduces the effective ratio for those dimensions. +//! +//! | dim | padded | bits | f32 bytes | TQ bytes | ratio | +//! |------|--------|------|-----------|----------|--------| +//! | 768 | 1024 | 2 | 3072 | 260 | 11.8x | +//! | 1024 | 1024 | 2 | 4096 | 260 | 15.8x | +//! | 768 | 1024 | 4 | 3072 | 516 | 6.0x | +//! | 1024 | 1024 | 4 | 4096 | 516 | 7.9x | +//! | 768 | 1024 | 8 | 3072 | 1028 | 3.0x | +//! | 1024 | 1024 | 8 | 4096 | 1028 | 4.0x | +//! +//! # Example +//! +//! ``` +//! use vortex_array::IntoArray; +//! use vortex_array::VortexSessionExecute; +//! use vortex_array::arrays::ExtensionArray; +//! use vortex_array::arrays::FixedSizeListArray; +//! use vortex_array::arrays::PrimitiveArray; +//! use vortex_array::dtype::extension::ExtDType; +//! use vortex_array::extension::EmptyMetadata; +//! use vortex_array::validity::Validity; +//! use vortex_buffer::BufferMut; +//! use vortex_array::session::ArraySession; +//! use vortex_session::VortexSession; +//! use vortex_tensor::encodings::turboquant::{TurboQuantConfig, turboquant_encode}; +//! use vortex_tensor::vector::Vector; +//! +//! // Create a Vector extension array of 100 random 128-d vectors. +//! let num_rows = 100; +//! let dim = 128u32; +//! let mut buf = BufferMut::::with_capacity(num_rows * dim as usize); +//! for i in 0..(num_rows * dim as usize) { +//! buf.push((i as f32 * 0.001).sin()); +//! } +//! let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); +//! let fsl = FixedSizeListArray::try_new( +//! elements.into_array(), dim, Validity::NonNullable, num_rows, +//! ).unwrap(); +//! let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone()) +//! .unwrap().erased(); +//! let ext = ExtensionArray::new(ext_dtype, fsl.into_array()); +//! +//! // Quantize at 2 bits per coordinate. +//! let config = TurboQuantConfig { bit_width: 2, seed: Some(42) }; +//! let session = VortexSession::empty().with::(); +//! let mut ctx = session.create_execution_ctx(); +//! let encoded = turboquant_encode(&ext, &config, &mut ctx).unwrap(); +//! +//! // Verify compression: 100 vectors x 128 dims x 4 bytes = 51200 bytes input. +//! assert!(encoded.nbytes() < 51200); +//! ``` + +mod array; +pub use array::data::TurboQuantData; +pub use array::scheme::TurboQuantScheme; + +pub(crate) mod compute; + +mod vtable; +pub use vtable::TurboQuant; + +mod compress; +pub use compress::TurboQuantConfig; +pub use compress::turboquant_encode; + +mod decompress; + +#[cfg(test)] +mod tests; diff --git a/vortex-tensor/src/encodings/turboquant/tests.rs b/vortex-tensor/src/encodings/turboquant/tests.rs new file mode 100644 index 00000000000..731789bea2c --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/tests.rs @@ -0,0 +1,911 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::sync::LazyLock; + +use rand::SeedableRng; +use rand::rngs::StdRng; +use rand_distr::Distribution; +use rand_distr::Normal; +use rstest::rstest; +use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::dtype::extension::ExtDType; +use vortex_array::extension::EmptyMetadata; +use vortex_array::session::ArraySession; +use vortex_array::validity::Validity; +use vortex_buffer::BufferMut; +use vortex_error::VortexResult; +use vortex_session::VortexSession; + +use crate::encodings::turboquant::TurboQuant; +use crate::encodings::turboquant::TurboQuantConfig; +use crate::encodings::turboquant::array::rotation::RotationMatrix; +use crate::encodings::turboquant::turboquant_encode; +use crate::scalar_fns::ApproxOptions; +use crate::scalar_fns::l2_norm::L2Norm; +use crate::vector::Vector; + +static SESSION: LazyLock = + LazyLock::new(|| VortexSession::empty().with::()); + +/// Create a FixedSizeListArray of random f32 vectors (i.i.d. standard normal) with the given +/// validity. +fn make_fsl_with_validity( + num_rows: usize, + dim: usize, + seed: u64, + validity: Validity, +) -> FixedSizeListArray { + let mut rng = StdRng::seed_from_u64(seed); + let normal = Normal::new(0.0f32, 1.0).unwrap(); + + let mut buf = BufferMut::::with_capacity(num_rows * dim); + for _ in 0..(num_rows * dim) { + buf.push(normal.sample(&mut rng)); + } + + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + FixedSizeListArray::try_new( + elements.into_array(), + dim.try_into() + .expect("somehow got dimension greater than u32::MAX"), + validity, + num_rows, + ) + .unwrap() +} + +/// Create a non-nullable FixedSizeListArray of random f32 vectors (i.i.d. standard normal). +fn make_fsl(num_rows: usize, dim: usize, seed: u64) -> FixedSizeListArray { + let mut rng = StdRng::seed_from_u64(seed); + let normal = Normal::new(0.0f32, 1.0).unwrap(); + + let mut buf = BufferMut::::with_capacity(num_rows * dim); + for _ in 0..(num_rows * dim) { + buf.push(normal.sample(&mut rng)); + } + + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + FixedSizeListArray::try_new( + elements.into_array(), + dim.try_into() + .expect("somehow got dimension greater than u32::MAX"), + Validity::NonNullable, + num_rows, + ) + .unwrap() +} + +/// Wrap a `FixedSizeListArray` in a `Vector` extension array. +fn make_vector_ext(fsl: &FixedSizeListArray) -> ExtensionArray { + let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone()) + .unwrap() + .erased(); + ExtensionArray::new(ext_dtype, fsl.clone().into_array()) +} + +fn theoretical_mse_bound(bit_width: u8) -> f32 { + let sqrt3_pi_over_2 = (3.0f32).sqrt() * std::f32::consts::PI / 2.0; + sqrt3_pi_over_2 / (4.0f32).powi(bit_width as i32) +} + +fn per_vector_normalized_mse( + original: &[f32], + reconstructed: &[f32], + dim: usize, + num_rows: usize, +) -> f32 { + let mut total = 0.0f32; + for row in 0..num_rows { + let orig = &original[row * dim..(row + 1) * dim]; + let recon = &reconstructed[row * dim..(row + 1) * dim]; + let norm_sq: f32 = orig.iter().map(|&v| v * v).sum(); + if norm_sq < 1e-10 { + continue; + } + let err_sq: f32 = orig + .iter() + .zip(recon.iter()) + .map(|(&a, &b)| (a - b) * (a - b)) + .sum(); + total += err_sq / norm_sq; + } + total / num_rows as f32 +} + +/// Encode and decode, returning (original, decoded) flat f32 slices. +fn encode_decode( + fsl: &FixedSizeListArray, + config: &TurboQuantConfig, +) -> VortexResult<(Vec, Vec)> { + let original: Vec = { + let prim = fsl.elements().to_canonical().unwrap().into_primitive(); + prim.as_slice::().to_vec() + }; + let ext = make_vector_ext(fsl); + let config = config.clone(); + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + let decoded_ext = encoded.execute::(&mut ctx)?; + let decoded_fsl = decoded_ext + .storage_array() + .to_canonical() + .unwrap() + .into_fixed_size_list(); + let decoded_elements: Vec = { + let prim = decoded_fsl + .elements() + .to_canonical() + .unwrap() + .into_primitive(); + prim.as_slice::().to_vec() + }; + Ok((original, decoded_elements)) +} + +// ----------------------------------------------------------------------- +// Roundtrip tests +// ----------------------------------------------------------------------- + +#[rstest] +#[case(32, 1)] +#[case(32, 2)] +#[case(32, 3)] +#[case(32, 4)] +#[case(128, 2)] +#[case(128, 4)] +#[case(128, 6)] +#[case(128, 8)] +#[case(256, 2)] +fn roundtrip(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let fsl = make_fsl(10, dim, 42); + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + }; + let (original, decoded) = encode_decode(&fsl, &config)?; + assert_eq!(decoded.len(), original.len()); + Ok(()) +} + +// ----------------------------------------------------------------------- +// MSE quality tests +// ----------------------------------------------------------------------- + +#[rstest] +#[case(128, 1)] +#[case(128, 2)] +#[case(128, 3)] +#[case(128, 4)] +#[case(256, 2)] +#[case(256, 4)] +fn mse_within_theoretical_bound(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let num_rows = 200; + let fsl = make_fsl(num_rows, dim, 42); + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + }; + let (original, decoded) = encode_decode(&fsl, &config)?; + + let normalized_mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + let bound = theoretical_mse_bound(bit_width); + + assert!( + normalized_mse < bound, + "Normalized MSE {normalized_mse:.6} exceeds bound {bound:.6} \ + for dim={dim}, bits={bit_width}", + ); + Ok(()) +} + +#[rstest] +#[case(128, 6)] +#[case(128, 8)] +#[case(256, 6)] +#[case(256, 8)] +fn high_bitwidth_mse_is_small(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let num_rows = 200; + let fsl = make_fsl(num_rows, dim, 42); + + let config_4bit = TurboQuantConfig { + bit_width: 4, + seed: Some(123), + }; + let (original_4, decoded_4) = encode_decode(&fsl, &config_4bit)?; + let mse_4bit = per_vector_normalized_mse(&original_4, &decoded_4, dim, num_rows); + + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + }; + let (original, decoded) = encode_decode(&fsl, &config)?; + let mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + + assert!( + mse < mse_4bit, + "{bit_width}-bit MSE ({mse:.6}) should be < 4-bit MSE ({mse_4bit:.6})" + ); + assert!(mse < 0.01, "{bit_width}-bit MSE ({mse:.6}) should be < 1%"); + Ok(()) +} + +#[test] +fn mse_decreases_with_bits() -> VortexResult<()> { + let dim = 128; + let num_rows = 50; + let fsl = make_fsl(num_rows, dim, 99); + + let mut prev_mse = f32::MAX; + for bit_width in 1..=8u8 { + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + }; + let (original, decoded) = encode_decode(&fsl, &config)?; + let mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + assert!( + mse <= prev_mse * 1.01, + "MSE should decrease: {bit_width}-bit={mse:.6} > prev={prev_mse:.6}" + ); + prev_mse = mse; + } + Ok(()) +} + +// ----------------------------------------------------------------------- +// Edge cases +// ----------------------------------------------------------------------- + +#[rstest] +#[case(0)] +#[case(1)] +fn roundtrip_edge_cases(#[case] num_rows: usize) -> VortexResult<()> { + let fsl = make_fsl(num_rows, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 2, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + let decoded = encoded.execute::(&mut ctx)?; + assert_eq!(decoded.len(), num_rows); + Ok(()) +} + +#[rstest] +#[case(1)] +#[case(2)] +fn rejects_dimension_below_3(#[case] dim: usize) { + let fsl = make_fsl_small(dim); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 2, + seed: Some(0), + }; + let mut ctx = SESSION.create_execution_ctx(); + assert!(turboquant_encode(&ext, &config, &mut ctx).is_err()); +} + +fn make_fsl_small(dim: usize) -> FixedSizeListArray { + let mut buf = BufferMut::::with_capacity(dim); + for i in 0..dim { + buf.push(i as f32 + 1.0); + } + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + FixedSizeListArray::try_new( + elements.into_array(), + dim.try_into() + .expect("somehow got dimension greater than u32::MAX"), + Validity::NonNullable, + 1, + ) + .unwrap() +} + +/// Verify that all-zero vectors roundtrip correctly (norm == 0 branch). +#[test] +fn all_zero_vectors_roundtrip() -> VortexResult<()> { + let num_rows = 10; + let dim = 128; + let buf = BufferMut::::full(0.0f32, num_rows * dim); + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + dim.try_into() + .expect("somehow got dimension greater than u32::MAX"), + Validity::NonNullable, + num_rows, + )?; + + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(42), + }; + let (original, decoded) = encode_decode(&fsl, &config)?; + // All-zero vectors should decode to all-zero (norm=0 -> 0 * anything = 0). + for (i, (&o, &d)) in original.iter().zip(decoded.iter()).enumerate() { + assert_eq!(o, 0.0, "original[{i}] not zero"); + assert_eq!(d, 0.0, "decoded[{i}] not zero for all-zero input"); + } + Ok(()) +} + +/// Verify that f64 input is accepted and encoded (converted to f32 internally). +#[test] +fn f64_input_encodes_successfully() -> VortexResult<()> { + let num_rows = 10; + let dim = 64; + let mut rng = StdRng::seed_from_u64(99); + let normal = Normal::new(0.0f64, 1.0).unwrap(); + + let mut buf = BufferMut::::with_capacity(num_rows * dim); + for _ in 0..(num_rows * dim) { + buf.push(normal.sample(&mut rng)); + } + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + dim.try_into() + .expect("somehow got dimension greater than u32::MAX"), + Validity::NonNullable, + num_rows, + )?; + + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(42), + }; + // Verify encoding succeeds with f64 input (f64->f32 conversion). + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + let encoded = encoded.as_opt::().unwrap(); + assert_eq!(encoded.norms().len(), num_rows); + assert_eq!(encoded.dimension() as usize, dim); + Ok(()) +} + +// ----------------------------------------------------------------------- +// Verification tests for stored metadata +// ----------------------------------------------------------------------- + +/// Verify that the centroids stored in the array match what `get_centroids()` computes. +#[test] +fn stored_centroids_match_computed() -> VortexResult<()> { + let fsl = make_fsl(10, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + let encoded = encoded.as_opt::().unwrap(); + + let mut ctx = SESSION.create_execution_ctx(); + let stored_centroids_prim = encoded + .centroids() + .clone() + .execute::(&mut ctx)?; + let stored = stored_centroids_prim.as_slice::(); + + let padded_dim = encoded.padded_dim(); + let computed = crate::encodings::turboquant::array::centroids::get_centroids(padded_dim, 3)?; + + assert_eq!(stored.len(), computed.len()); + for i in 0..stored.len() { + assert_eq!(stored[i], computed[i], "Centroid mismatch at {i}"); + } + Ok(()) +} + +/// Verify that stored rotation signs produce identical decode to seed-based decode. +#[test] +fn stored_rotation_signs_produce_correct_decode() -> VortexResult<()> { + let fsl = make_fsl(20, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + let encoded = encoded.as_opt::().unwrap(); + + // Decode via the stored-signs path (normal decode). + let mut ctx = SESSION.create_execution_ctx(); + let decoded_ext = encoded + .array() + .clone() + .execute::(&mut ctx)?; + let decoded_fsl = decoded_ext + .storage_array() + .to_canonical()? + .into_fixed_size_list(); + let decoded = decoded_fsl.elements().to_canonical()?.into_primitive(); + let decoded_slice = decoded.as_slice::(); + + // Verify stored signs match seed-derived signs. + let rot_from_seed = RotationMatrix::try_new(123, 128)?; + let expected_u8 = rot_from_seed.export_inverse_signs_u8(); + let stored_signs = encoded + .rotation_signs() + .clone() + .execute::(&mut ctx)?; + let stored_u8 = stored_signs.as_slice::(); + + assert_eq!(expected_u8.len(), stored_u8.len()); + for i in 0..expected_u8.len() { + assert_eq!(expected_u8[i], stored_u8[i], "Sign mismatch at index {i}"); + } + + // Also verify decode output is non-empty and has expected size. + assert_eq!(decoded_slice.len(), 20 * 128); + Ok(()) +} + +// ----------------------------------------------------------------------- +// Compute pushdown tests +// ----------------------------------------------------------------------- + +#[test] +fn slice_preserves_data() -> VortexResult<()> { + let fsl = make_fsl(20, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + + // Full decompress then slice. + let mut ctx = SESSION.create_execution_ctx(); + let full_decoded = encoded.clone().execute::(&mut ctx)?; + let full_fsl = full_decoded + .storage_array() + .to_canonical()? + .into_fixed_size_list(); + let expected = full_fsl.slice(5..10)?; + let expected_prim = expected.to_canonical()?.into_fixed_size_list(); + let expected_elements = expected_prim.elements().to_canonical()?.into_primitive(); + + // Slice then decompress. + let sliced = encoded.slice(5..10)?; + let sliced_decoded = sliced.execute::(&mut ctx)?; + let sliced_fsl = sliced_decoded + .storage_array() + .to_canonical()? + .into_fixed_size_list(); + let actual_elements = sliced_fsl.elements().to_canonical()?.into_primitive(); + + assert_eq!( + expected_elements.as_slice::(), + actual_elements.as_slice::() + ); + Ok(()) +} + +#[test] +fn scalar_at_matches_decompress() -> VortexResult<()> { + let fsl = make_fsl(10, 64, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + + let full_decoded = encoded.clone().execute::(&mut ctx)?; + + for i in [0, 1, 5, 9] { + let expected = full_decoded.scalar_at(i)?; + let actual = encoded.scalar_at(i)?; + assert_eq!(expected, actual, "scalar_at mismatch at index {i}"); + } + Ok(()) +} + +#[test] +fn l2_norm_readthrough() -> VortexResult<()> { + let fsl = make_fsl(10, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + let tq = encoded.as_opt::().unwrap(); + + // Stored norms should match the actual L2 norms of the input. + let norms_prim = tq.norms().to_canonical()?.into_primitive(); + let stored_norms = norms_prim.as_slice::(); + + let input_prim = fsl.elements().to_canonical()?.into_primitive(); + let input_f32 = input_prim.as_slice::(); + for row in 0..10 { + let vec = &input_f32[row * 128..(row + 1) * 128]; + let actual_norm: f32 = vec.iter().map(|&v| v * v).sum::().sqrt(); + assert!( + (stored_norms[row] - actual_norm).abs() < 1e-5, + "norm mismatch at row {row}: stored={}, actual={}", + stored_norms[row], + actual_norm + ); + } + Ok(()) +} + +#[test] +fn cosine_similarity_quantized_accuracy() -> VortexResult<()> { + let fsl = make_fsl(20, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 4, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + let tq = encoded.as_opt::().unwrap(); + + // Compute exact cosine similarity from original data. + let input_prim = fsl.elements().to_canonical()?.into_primitive(); + let input_f32 = input_prim.as_slice::(); + + // Read quantized codes, norms, and centroids for approximate computation. + let mut ctx = SESSION.create_execution_ctx(); + let pd = tq.padded_dim() as usize; + let norms_prim = tq.norms().clone().execute::(&mut ctx)?; + let norms = norms_prim.as_slice::(); + let codes_fsl = tq.codes().clone().execute::(&mut ctx)?; + let codes_prim = codes_fsl.elements().to_canonical()?.into_primitive(); + let all_codes = codes_prim.as_slice::(); + let centroids_prim = tq.centroids().clone().execute::(&mut ctx)?; + let centroid_vals = centroids_prim.as_slice::(); + + for (row_a, row_b) in [(0, 1), (5, 10), (0, 19)] { + let vec_a = &input_f32[row_a * 128..(row_a + 1) * 128]; + let vec_b = &input_f32[row_b * 128..(row_b + 1) * 128]; + + let dot: f32 = vec_a.iter().zip(vec_b.iter()).map(|(&x, &y)| x * y).sum(); + let norm_a: f32 = vec_a.iter().map(|&v| v * v).sum::().sqrt(); + let norm_b: f32 = vec_b.iter().map(|&v| v * v).sum::().sqrt(); + let exact_cos = dot / (norm_a * norm_b); + + // Approximate cosine similarity in quantized domain. + let approx_cos = if norms[row_a] == 0.0 || norms[row_b] == 0.0 { + 0.0 + } else { + let codes_a = &all_codes[row_a * pd..(row_a + 1) * pd]; + let codes_b = &all_codes[row_b * pd..(row_b + 1) * pd]; + codes_a + .iter() + .zip(codes_b.iter()) + .map(|(&ca, &cb)| centroid_vals[ca as usize] * centroid_vals[cb as usize]) + .sum::() + }; + + // 4-bit quantization: expect reasonable accuracy. + let error = (exact_cos - approx_cos).abs(); + assert!( + error < 0.15, + "cosine similarity error too large for ({row_a}, {row_b}): \ + exact={exact_cos:.4}, approx={approx_cos:.4}, error={error:.4}" + ); + } + Ok(()) +} + +/// Verify that the encoded array's dtype is a Vector extension type. +#[test] +fn encoded_dtype_is_vector_extension() -> VortexResult<()> { + let fsl = make_fsl(10, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + + // The encoded TurboQuant array should claim a Vector extension dtype. + assert!( + encoded.dtype().is_extension(), + "TurboQuant dtype should be an extension type, got {}", + encoded.dtype() + ); + assert!( + encoded.dtype().as_extension().is::(), + "TurboQuant dtype should be a Vector extension type" + ); + Ok(()) +} + +// ----------------------------------------------------------------------- +// Nullable vector tests +// ----------------------------------------------------------------------- + +/// Encode a nullable Vector array and verify roundtrip preserves validity and non-null values. +#[test] +fn nullable_vectors_roundtrip() -> VortexResult<()> { + // Rows 2, 5, 7 are null. + let validity = Validity::from_iter([ + true, true, false, true, true, false, true, false, true, true, + ]); + let fsl = make_fsl_with_validity(10, 128, 42, validity); + let ext = make_vector_ext(&fsl); + + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + + assert_eq!(encoded.len(), 10); + assert!(encoded.dtype().is_nullable()); + + // Check validity of the encoded array. + let encoded_validity = encoded.validity()?; + for i in 0..10 { + let expected = ![2, 5, 7].contains(&i); + assert_eq!( + encoded_validity.is_valid(i)?, + expected, + "validity mismatch at row {i}" + ); + } + + // Decode and verify non-null rows have correct data. + let decoded_ext = encoded.execute::(&mut ctx)?; + assert_eq!(decoded_ext.len(), 10); + + let decoded_fsl = decoded_ext + .storage_array() + .to_canonical()? + .into_fixed_size_list(); + let decoded_prim = decoded_fsl.elements().to_canonical()?.into_primitive(); + let decoded_f32 = decoded_prim.as_slice::(); + + // Original f32 elements for non-null row comparison. + let orig_prim = fsl.elements().to_canonical()?.into_primitive(); + let orig_f32 = orig_prim.as_slice::(); + + // Non-null rows should have reasonable reconstruction (within MSE bounds). + for row in [0, 1, 3, 4, 6, 8, 9] { + let orig_vec = &orig_f32[row * 128..(row + 1) * 128]; + let dec_vec = &decoded_f32[row * 128..(row + 1) * 128]; + let norm_sq: f32 = orig_vec.iter().map(|&v| v * v).sum(); + let err_sq: f32 = orig_vec + .iter() + .zip(dec_vec.iter()) + .map(|(&a, &b)| (a - b) * (a - b)) + .sum(); + // 3-bit normalized MSE should be well under the theoretical bound. + assert!( + err_sq / norm_sq < 0.1, + "non-null row {row} has excessive reconstruction error" + ); + } + Ok(()) +} + +/// Verify that norms carry the validity: null vectors have null norms. +#[test] +fn nullable_norms_match_validity() -> VortexResult<()> { + let validity = Validity::from_iter([true, false, true, false, true]); + let fsl = make_fsl_with_validity(5, 64, 42, validity); + let ext = make_vector_ext(&fsl); + + let config = TurboQuantConfig { + bit_width: 2, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + let tq = encoded.as_opt::().unwrap(); + + let norms_validity = tq.norms().validity()?; + for i in 0..5 { + let expected = i % 2 == 0; // rows 0, 2, 4 are valid + assert_eq!( + norms_validity.is_valid(i)?, + expected, + "norms validity mismatch at row {i}" + ); + } + Ok(()) +} + +/// Verify that L2Norm readthrough works correctly on nullable TurboQuant arrays. +#[test] +fn nullable_l2_norm_readthrough() -> VortexResult<()> { + let validity = Validity::from_iter([true, false, true, false, true]); + let fsl = make_fsl_with_validity(5, 64, 42, validity); + let ext = make_vector_ext(&fsl); + + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + + // Compute L2Norm on the encoded array. + let norm_sfn = L2Norm::try_new_array(&ApproxOptions::Exact, encoded, 5)?; + let norms: PrimitiveArray = norm_sfn.into_array().execute(&mut ctx)?; + + // Null rows should have null norms, valid rows should have correct norms. + let orig_prim = fsl.elements().to_canonical()?.into_primitive(); + let orig_f32 = orig_prim.as_slice::(); + for row in 0..5 { + if row % 2 == 0 { + assert!(norms.is_valid(row)?, "row {row} should be valid"); + let expected: f32 = orig_f32[row * 64..(row + 1) * 64] + .iter() + .map(|&v| v * v) + .sum::() + .sqrt(); + let actual = norms.as_slice::()[row]; + assert!( + (actual - expected).abs() < 1e-5, + "norm mismatch at valid row {row}: actual={actual}, expected={expected}" + ); + } else { + assert!(!norms.is_valid(row)?, "row {row} should be null"); + } + } + Ok(()) +} + +/// Verify that slicing a nullable TurboQuant array preserves validity. +#[test] +fn nullable_slice_preserves_validity() -> VortexResult<()> { + // Rows 2, 5, 7 are null. + let validity = Validity::from_iter([ + true, true, false, true, true, false, true, false, true, true, + ]); + let fsl = make_fsl_with_validity(10, 64, 42, validity); + let ext = make_vector_ext(&fsl); + + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + + // Slice rows 1..6 -> [true, false, true, true, false]. + let sliced = encoded.slice(1..6)?; + assert_eq!(sliced.len(), 5); + + let sliced_validity = sliced.validity()?; + let expected = [true, false, true, true, false]; + for (i, &exp) in expected.iter().enumerate() { + assert_eq!( + sliced_validity.is_valid(i)?, + exp, + "sliced validity mismatch at index {i}" + ); + } + Ok(()) +} + +// ----------------------------------------------------------------------- +// Serde roundtrip tests +// ----------------------------------------------------------------------- + +/// Verify that a TurboQuant array survives serialize/deserialize. +#[test] +fn serde_roundtrip() -> VortexResult<()> { + use vortex_array::ArrayContext; + use vortex_array::ArrayEq; + use vortex_array::Precision; + use vortex_array::serde::SerializeOptions; + use vortex_array::serde::SerializedArray; + use vortex_array::session::ArraySessionExt; + use vortex_buffer::ByteBufferMut; + use vortex_fastlanes::BitPacked; + use vortex_session::registry::ReadContext; + + let fsl = make_fsl(20, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + + let dtype = encoded.dtype().clone(); + let len = encoded.len(); + + // Serialize. + let array_ctx = ArrayContext::empty(); + let serialized = encoded.serialize(&array_ctx, &SerializeOptions::default())?; + + let mut concat = ByteBufferMut::empty(); + for buf in serialized { + concat.extend_from_slice(buf.as_ref()); + } + + // Deserialize. The session needs TurboQuant and BitPacked (for rotation signs) registered. + let serde_session = VortexSession::empty().with::(); + serde_session.arrays().register(TurboQuant); + serde_session.arrays().register(BitPacked); + + let parts = SerializedArray::try_from(concat.freeze())?; + let decoded = parts.decode( + &dtype, + len, + &ReadContext::new(array_ctx.to_ids()), + &serde_session, + )?; + + assert!( + decoded.array_eq(&encoded, Precision::Value), + "serde roundtrip did not preserve array equality" + ); + Ok(()) +} + +/// Verify that a degenerate (empty) TurboQuant array survives serialize/deserialize. +#[test] +fn serde_roundtrip_empty() -> VortexResult<()> { + use vortex_array::ArrayContext; + use vortex_array::ArrayEq; + use vortex_array::Precision; + use vortex_array::serde::SerializeOptions; + use vortex_array::serde::SerializedArray; + use vortex_array::session::ArraySessionExt; + use vortex_buffer::ByteBufferMut; + use vortex_fastlanes::BitPacked; + use vortex_session::registry::ReadContext; + + let fsl = make_fsl(0, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 2, + seed: Some(123), + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = turboquant_encode(&ext, &config, &mut ctx)?; + assert_eq!(encoded.len(), 0); + + let dtype = encoded.dtype().clone(); + let len = encoded.len(); + + let array_ctx = ArrayContext::empty(); + let serialized = encoded.serialize(&array_ctx, &SerializeOptions::default())?; + + let mut concat = ByteBufferMut::empty(); + for buf in serialized { + concat.extend_from_slice(buf.as_ref()); + } + + let serde_session = VortexSession::empty().with::(); + serde_session.arrays().register(TurboQuant); + serde_session.arrays().register(BitPacked); + + let parts = SerializedArray::try_from(concat.freeze())?; + let decoded = parts.decode( + &dtype, + len, + &ReadContext::new(array_ctx.to_ids()), + &serde_session, + )?; + + assert!( + decoded.array_eq(&encoded, Precision::Value), + "serde roundtrip did not preserve array equality" + ); + Ok(()) +} diff --git a/vortex-tensor/src/encodings/turboquant/vtable.rs b/vortex-tensor/src/encodings/turboquant/vtable.rs new file mode 100644 index 00000000000..d6f5f998041 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/vtable.rs @@ -0,0 +1,294 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! VTable implementation for TurboQuant encoding. + +use std::hash::Hash; +use std::hash::Hasher; +use std::sync::Arc; + +use vortex_array::Array; +use vortex_array::ArrayEq; +use vortex_array::ArrayHash; +use vortex_array::ArrayId; +use vortex_array::ArrayParts; +use vortex_array::ArrayRef; +use vortex_array::ArrayView; +use vortex_array::ExecutionCtx; +use vortex_array::ExecutionResult; +use vortex_array::Precision; +use vortex_array::buffer::BufferHandle; +use vortex_array::dtype::DType; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_array::dtype::extension::ExtDTypeRef; +use vortex_array::serde::ArrayChildren; +use vortex_array::vtable; +use vortex_array::vtable::VTable; +use vortex_array::vtable::ValidityChild; +use vortex_array::vtable::ValidityVTableFromChild; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_error::vortex_ensure_eq; +use vortex_error::vortex_err; +use vortex_error::vortex_panic; +use vortex_session::VortexSession; + +use crate::encodings::turboquant::TurboQuantData; +use crate::encodings::turboquant::array::slots::Slot; +use crate::encodings::turboquant::compute::rules::PARENT_KERNELS; +use crate::encodings::turboquant::compute::rules::RULES; +use crate::encodings::turboquant::decompress::execute_decompress; +use crate::utils::extension_element_ptype; +use crate::utils::extension_list_size; +use crate::vector::Vector; + +/// Encoding marker type for TurboQuant. +#[derive(Clone, Debug)] +pub struct TurboQuant; + +impl TurboQuant { + pub const ID: ArrayId = ArrayId::new_ref("vortex.turboquant"); + + /// Validates that `dtype` is a [`Vector`](crate::vector::Vector) extension type with + /// dimension >= 3. + /// + /// Returns the validated [`ExtDTypeRef`] on success, which can be used to extract the + /// element ptype and list size. + pub fn validate_dtype(dtype: &DType) -> VortexResult<&ExtDTypeRef> { + let ext = dtype + .as_extension_opt() + .filter(|e| e.is::()) + .ok_or_else(|| { + vortex_err!("TurboQuant dtype must be a Vector extension type, got {dtype}") + })?; + + let dimension = extension_list_size(ext)?; + vortex_ensure!( + dimension >= 3, + "TurboQuant requires dimension >= 3, got {dimension}" + ); + + Ok(ext) + } + + /// Creates a new [`TurboQuantArray`]. + /// + /// Internally calls [`TurboQuantData::try_new`]. + pub fn try_new_array( + dtype: DType, + codes: ArrayRef, + norms: ArrayRef, + centroids: ArrayRef, + rotation_signs: ArrayRef, + ) -> VortexResult { + let data = TurboQuantData::try_new(&dtype, codes, norms, centroids, rotation_signs)?; + + let parts = ArrayParts::new(TurboQuant, dtype, data.norms().len(), data); + + Array::try_from_parts(parts) + } +} + +vtable!(TurboQuant, TurboQuant, TurboQuantData); + +impl VTable for TurboQuant { + type ArrayData = TurboQuantData; + type OperationsVTable = TurboQuant; + type ValidityVTable = ValidityVTableFromChild; + + fn id(&self) -> ArrayId { + Self::ID + } + + fn validate(&self, data: &Self::ArrayData, dtype: &DType, len: usize) -> VortexResult<()> { + let ext = dtype + .as_extension_opt() + .filter(|e| e.is::()) + .ok_or_else(|| { + vortex_err!("TurboQuant dtype must be a Vector extension type, got {dtype}") + })?; + + let dimension = extension_list_size(ext)?; + vortex_ensure!( + dimension >= 3, + "TurboQuant requires dimension >= 3, got {dimension}" + ); + + vortex_ensure_eq!(data.dimension(), dimension); + + // TODO(connor): In the future, we may not need to validate `len` on the array data because + // the child arrays will be located somewhere else. + // bit_width == 0 is only valid for degenerate (empty) arrays. A non-empty array with + // bit_width == 0 would have zero centroids while codes reference centroid indices. + vortex_ensure!( + data.bit_width > 0 || len == 0, + "bit_width == 0 is only valid for empty arrays, got len={len}" + ); + + Ok(()) + } + + fn array_hash(array: &TurboQuantData, state: &mut H, precision: Precision) { + array.dimension.hash(state); + array.bit_width.hash(state); + for slot in &array.slots { + slot.is_some().hash(state); + if let Some(child) = slot { + child.array_hash(state, precision); + } + } + } + + fn array_eq(array: &TurboQuantData, other: &TurboQuantData, precision: Precision) -> bool { + array.dimension == other.dimension + && array.bit_width == other.bit_width + && array.slots.len() == other.slots.len() + && array + .slots + .iter() + .zip(other.slots.iter()) + .all(|(a, b)| match (a, b) { + (Some(a), Some(b)) => a.array_eq(b, precision), + (None, None) => true, + _ => false, + }) + } + + fn nbuffers(_array: ArrayView) -> usize { + 0 + } + + fn buffer(_array: ArrayView, idx: usize) -> BufferHandle { + vortex_panic!("TurboQuantArray buffer index {idx} out of bounds") + } + + fn buffer_name(_array: ArrayView, _idx: usize) -> Option { + None + } + + fn serialize(array: ArrayView<'_, Self>) -> VortexResult>> { + Ok(Some(vec![array.bit_width])) + } + + fn deserialize( + &self, + dtype: &DType, + len: usize, + metadata: &[u8], + _buffers: &[BufferHandle], + children: &dyn ArrayChildren, + _session: &VortexSession, + ) -> VortexResult { + vortex_ensure_eq!( + metadata.len(), + 1, + "TurboQuant metadata must be exactly 1 byte, got {}", + metadata.len() + ); + vortex_ensure!( + metadata[0] <= 8, + "bit_width is expected to be between 0 and 8, got {}", + metadata[0] + ); + + let bit_width = metadata[0]; + + // bit_width == 0 is only valid for degenerate (empty) arrays. A non-empty array with + // bit_width == 0 would have zero centroids while codes reference centroid indices. + vortex_ensure!( + bit_width > 0 || len == 0, + "bit_width == 0 is only valid for empty arrays, got len={len}" + ); + + // Validate and derive dimension and element ptype from the Vector extension dtype. + let ext = TurboQuant::validate_dtype(dtype)?; + let dimension = extension_list_size(ext)?; + let element_ptype = extension_element_ptype(ext)?; + + let padded_dim = dimension.next_power_of_two(); + + // Get the codes array (indices into the codebook). Codes are always non-nullable; + // null vectors are represented by all-zero codes with a null norm. + let codes_ptype = DType::Primitive(PType::U8, Nullability::NonNullable); + let codes_dtype = + DType::FixedSizeList(Arc::new(codes_ptype), padded_dim, Nullability::NonNullable); + let codes_array = children.get(0, &codes_dtype, len)?; + + // Get the L2 norms array. Norms carry the validity of the entire TurboQuant array: + // null vectors have null norms. + let norms_dtype = DType::Primitive(element_ptype, dtype.nullability()); + let norms_array = children.get(1, &norms_dtype, len)?; + + // Get the centroids array (codebook). + let num_centroids = if bit_width == 0 { + 0 // A degenerate TQ array. + } else { + 1usize << bit_width + }; + let centroids_dtype = DType::Primitive(PType::F32, Nullability::NonNullable); + let centroids = children.get(2, ¢roids_dtype, num_centroids)?; + + // Get the rotation array. + let signs_len = if len == 0 { 0 } else { 3 * padded_dim as usize }; + let signs_dtype = DType::Primitive(PType::U8, Nullability::NonNullable); + let rotation_signs = children.get(3, &signs_dtype, signs_len)?; + + Ok(TurboQuantData { + slots: vec![ + Some(codes_array), + Some(norms_array), + Some(centroids), + Some(rotation_signs), + ], + dimension, + bit_width, + }) + } + + fn slots(array: ArrayView<'_, Self>) -> &[Option] { + &array.data().slots + } + + fn slot_name(_array: ArrayView, idx: usize) -> String { + Slot::from_index(idx).name().to_string() + } + + fn with_slots(array: &mut TurboQuantData, slots: Vec>) -> VortexResult<()> { + vortex_ensure!( + slots.len() == Slot::COUNT, + "TurboQuantArray expects {} slots, got {}", + Slot::COUNT, + slots.len() + ); + array.slots = slots; + Ok(()) + } + + fn execute(array: Array, ctx: &mut ExecutionCtx) -> VortexResult { + Ok(ExecutionResult::done(execute_decompress(array, ctx)?)) + } + + fn execute_parent( + array: ArrayView, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } + + fn reduce_parent( + array: ArrayView, + parent: &ArrayRef, + child_idx: usize, + ) -> VortexResult> { + RULES.evaluate(array, parent, child_idx) + } +} + +impl ValidityChild for TurboQuant { + fn validity_child(array: &TurboQuantData) -> &ArrayRef { + array.norms() + } +} diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index 6b55389d8c9..e17ec4c88f0 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -7,8 +7,10 @@ use vortex_array::dtype::session::DTypeSessionExt; use vortex_array::scalar_fn::session::ScalarFnSessionExt; +use vortex_array::session::ArraySessionExt; use vortex_session::VortexSession; +use crate::encodings::turboquant::TurboQuant; use crate::fixed_shape::FixedShapeTensor; use crate::scalar_fns::cosine_similarity::CosineSimilarity; use crate::scalar_fns::inner_product::InnerProduct; @@ -25,10 +27,13 @@ pub mod encodings; mod utils; -/// Registers the tensor extension dtypes and scalar functions with the given session. +/// Initialize the Vortex tensor library with a Vortex session. pub fn initialize(session: &VortexSession) { session.dtypes().register(Vector); session.dtypes().register(FixedShapeTensor); + + session.arrays().register(TurboQuant); + session.scalar_fns().register(CosineSimilarity); session.scalar_fns().register(InnerProduct); session.scalar_fns().register(L2Norm); diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index 22c51189380..f466ab91aff 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -28,6 +28,8 @@ use vortex_error::VortexResult; use vortex_error::vortex_ensure; use vortex_error::vortex_err; +use crate::encodings::turboquant::TurboQuant; +use crate::encodings::turboquant::compute::cosine_similarity; use crate::matcher::AnyTensor; use crate::scalar_fns::ApproxOptions; use crate::scalar_fns::inner_product::InnerProduct; @@ -142,13 +144,29 @@ impl ScalarFnVTable for CosineSimilarity { args: &dyn ExecutionArgs, ctx: &mut ExecutionCtx, ) -> VortexResult { - let lhs = args.get(0)?.execute::(ctx)?.into_array(); - let rhs = args.get(1)?.execute::(ctx)?.into_array(); + let lhs_ref = args.get(0)?; + let rhs_ref = args.get(1)?; let len = args.row_count(); + // TurboQuant approximate path: check encoding before executing. + if options.is_approx() + && let (Some(lhs_tq), Some(rhs_tq)) = ( + lhs_ref.as_opt::(), + rhs_ref.as_opt::(), + ) + { + return cosine_similarity::cosine_similarity_quantized_column(lhs_tq, rhs_tq, ctx); + } + + let lhs = lhs_ref.execute::(ctx)?; + let rhs = rhs_ref.execute::(ctx)?; + // Compute combined validity. - let validity = lhs.validity()?.and(rhs.validity()?)?; + let validity = lhs.as_ref().validity()?.and(rhs.as_ref().validity()?)?; + + let lhs = lhs.into_array(); + let rhs = rhs.into_array(); // Compute inner product and norms as columnar operations, and propagate the options. let norm_lhs_arr = L2Norm::try_new_array(options, lhs.clone(), len)?; diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index d142649600d..4e1a3805d5a 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -29,6 +29,8 @@ use vortex_error::VortexResult; use vortex_error::vortex_ensure; use vortex_error::vortex_err; +use crate::encodings::turboquant::TurboQuant; +use crate::encodings::turboquant::compute::cosine_similarity; use crate::matcher::AnyTensor; use crate::scalar_fns::ApproxOptions; use crate::utils::extension_element_ptype; @@ -137,15 +139,28 @@ impl ScalarFnVTable for InnerProduct { fn execute( &self, - _options: &Self::Options, + options: &Self::Options, args: &dyn ExecutionArgs, ctx: &mut ExecutionCtx, ) -> VortexResult { - let lhs: ExtensionArray = args.get(0)?.execute(ctx)?; - let rhs: ExtensionArray = args.get(1)?.execute(ctx)?; + let lhs_ref = args.get(0)?; + let rhs_ref = args.get(1)?; let row_count = args.row_count(); + // TurboQuant approximate path: check encoding before executing. + if options.is_approx() + && let (Some(lhs_tq), Some(rhs_tq)) = ( + lhs_ref.as_opt::(), + rhs_ref.as_opt::(), + ) + { + return cosine_similarity::dot_product_quantized_column(lhs_tq, rhs_tq, ctx); + } + + let lhs: ExtensionArray = lhs_ref.execute(ctx)?; + let rhs: ExtensionArray = rhs_ref.execute(ctx)?; + // Compute combined validity. let rhs_validity = rhs.as_ref().validity()?; let validity = lhs.as_ref().validity()?.and(rhs_validity)?; diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index ed29cc776b7..f573386768b 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -12,6 +12,7 @@ use vortex_array::IntoArray; use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFnArray; +use vortex_array::builtins::ArrayBuiltins; use vortex_array::dtype::DType; use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; @@ -28,6 +29,7 @@ use vortex_error::VortexResult; use vortex_error::vortex_ensure; use vortex_error::vortex_err; +use crate::encodings::turboquant::TurboQuant; use crate::matcher::AnyTensor; use crate::scalar_fns::ApproxOptions; use crate::utils::extension_element_ptype; @@ -123,12 +125,24 @@ impl ScalarFnVTable for L2Norm { args: &dyn ExecutionArgs, ctx: &mut ExecutionCtx, ) -> VortexResult { - let input: ExtensionArray = args.get(0)?.execute(ctx)?; - + let input_ref = args.get(0)?; let row_count = args.row_count(); + + // TurboQuant stores exact precomputed norms -- no decompression needed. + // Norms are currently stored as f32; cast to the target dtype if needed + // (e.g., if the input extension has f64 elements). + if let Some(tq) = input_ref.as_opt::() { + let ext = input_ref.dtype().as_extension(); + let target_ptype = extension_element_ptype(ext)?; + let norms: PrimitiveArray = tq.norms().clone().execute(ctx)?; + let target_dtype = DType::Primitive(target_ptype, input_ref.dtype().nullability()); + return norms.into_array().cast(target_dtype); + } + + let input: ExtensionArray = input_ref.execute(ctx)?; let validity = input.as_ref().validity()?; - // Get list size (dimensions) from the dtype (validated by `return_dtype`). + // Get element ptype and list size from the dtype (validated by `return_dtype`). let ext = input.dtype().as_extension(); let list_size = extension_list_size(ext)? as usize; diff --git a/vortex-tensor/src/scalar_fns/mod.rs b/vortex-tensor/src/scalar_fns/mod.rs index b10fd335420..8fb1883b706 100644 --- a/vortex-tensor/src/scalar_fns/mod.rs +++ b/vortex-tensor/src/scalar_fns/mod.rs @@ -12,11 +12,25 @@ pub mod l2_norm; /// Options for tensor-related expressions that might have error. #[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] pub enum ApproxOptions { + /// Computes the exact result. #[default] Exact, + /// Allows approximate results. Approximate, } +impl ApproxOptions { + /// Returns `true` if the option is [`Exact`](Self::Exact). + pub fn is_exact(&self) -> bool { + matches!(self, Self::Exact) + } + + /// Returns `true` if the option is [`Approximate`](Self::Approximate). + pub fn is_approx(&self) -> bool { + matches!(self, Self::Approximate) + } +} + impl fmt::Display for ApproxOptions { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { diff --git a/vortex/Cargo.toml b/vortex/Cargo.toml index 896ec139251..f042f568c11 100644 --- a/vortex/Cargo.toml +++ b/vortex/Cargo.toml @@ -56,12 +56,15 @@ divan = { workspace = true } fastlanes = { workspace = true } mimalloc = { workspace = true } parquet = { workspace = true } +paste = { workspace = true } rand = { workspace = true } +rand_distr = { workspace = true } serde_json = { workspace = true } tokio = { workspace = true, features = ["full"] } tracing = { workspace = true } tracing-subscriber = { workspace = true } vortex = { path = ".", features = ["tokio"] } +vortex-tensor = { workspace = true } [features] default = ["files", "zstd"] diff --git a/vortex/benches/single_encoding_throughput.rs b/vortex/benches/single_encoding_throughput.rs index 8a4cfa53f59..b5531beb767 100644 --- a/vortex/benches/single_encoding_throughput.rs +++ b/vortex/benches/single_encoding_throughput.rs @@ -17,7 +17,6 @@ use rand::prelude::IndexedRandom; use rand::rngs::StdRng; use vortex::array::IntoArray; use vortex::array::ToCanonical; -use vortex::array::VortexSessionExecute; use vortex::array::arrays::PrimitiveArray; use vortex::array::arrays::VarBinViewArray; use vortex::array::builders::dict::dict_encode; @@ -39,6 +38,7 @@ use vortex::encodings::sequence::sequence_encode; use vortex::encodings::zigzag::zigzag_encode; use vortex::encodings::zstd::Zstd; use vortex::encodings::zstd::ZstdData; +use vortex_array::VortexSessionExecute; use vortex_sequence::Sequence; use vortex_session::VortexSession; @@ -426,3 +426,113 @@ fn bench_zstd_decompress_string(bencher: Bencher) { .with_inputs(|| &compressed) .bench_refs(|a| a.to_canonical()); } + +// TurboQuant vector quantization benchmarks. +#[cfg(feature = "unstable_encodings")] +mod turboquant_benches { + use divan::Bencher; + use paste::paste; + use rand::SeedableRng; + use rand::rngs::StdRng; + use vortex::array::IntoArray; + use vortex::array::arrays::ExtensionArray; + use vortex::array::arrays::FixedSizeListArray; + use vortex::array::arrays::PrimitiveArray; + use vortex::array::dtype::extension::ExtDType; + use vortex::array::extension::EmptyMetadata; + use vortex::array::validity::Validity; + use vortex_array::VortexSessionExecute; + use vortex_buffer::BufferMut; + use vortex_tensor::encodings::turboquant::TurboQuantConfig; + use vortex_tensor::encodings::turboquant::turboquant_encode; + use vortex_tensor::vector::Vector; + + use super::SESSION; + use super::with_byte_counter; + + const NUM_VECTORS: usize = 1_000; + + /// Generate `num_vectors` random f32 Vector extension arrays of the given dimension + /// using i.i.d. standard normal components. This is a conservative test distribution: + /// real neural network embeddings typically have structure (clustered, anisotropic) + /// that the SRHT exploits for better quantization, so Gaussian i.i.d. is a + /// worst-case baseline for TurboQuant. + fn setup_vector_ext(dim: usize) -> ExtensionArray { + let mut rng = StdRng::seed_from_u64(42); + let normal = rand_distr::Normal::new(0.0f32, 1.0).unwrap(); + + let mut buf = BufferMut::::with_capacity(NUM_VECTORS * dim); + for _ in 0..(NUM_VECTORS * dim) { + buf.push(rand_distr::Distribution::sample(&normal, &mut rng)); + } + + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + dim as u32, + Validity::NonNullable, + NUM_VECTORS, + ) + .unwrap(); + let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone()) + .unwrap() + .erased(); + ExtensionArray::new(ext_dtype, fsl.into_array()) + } + + fn turboquant_config(bit_width: u8) -> TurboQuantConfig { + TurboQuantConfig { + bit_width, + seed: Some(123), + } + } + + macro_rules! turboquant_bench { + (compress, $dim:literal, $bits:literal, $name:ident) => { + paste! { + #[divan::bench(name = concat!("turboquant_compress_dim", stringify!($dim), "_", stringify!($bits), "bit"))] + fn $name(bencher: Bencher) { + let ext = setup_vector_ext($dim); + let config = turboquant_config($bits); + with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64) + .with_inputs(|| &ext) + .bench_refs(|a| { + let mut ctx = SESSION.create_execution_ctx(); + turboquant_encode(a, &config, &mut ctx).unwrap() + }); + } + } + }; + (decompress, $dim:literal, $bits:literal, $name:ident) => { + paste! { + #[divan::bench(name = concat!("turboquant_decompress_dim", stringify!($dim), "_", stringify!($bits), "bit"))] + fn $name(bencher: Bencher) { + let ext = setup_vector_ext($dim); + let config = turboquant_config($bits); + let mut ctx = SESSION.create_execution_ctx(); + let compressed = turboquant_encode(&ext, &config, &mut ctx).unwrap(); + with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64) + .with_inputs(|| &compressed) + .bench_refs(|a| { + let mut ctx = SESSION.create_execution_ctx(); + a.clone() + .into_array() + .execute::(&mut ctx) + .unwrap() + }); + } + } + }; + } + + turboquant_bench!(compress, 128, 4, bench_tq_compress_128_4); + turboquant_bench!(decompress, 128, 4, bench_tq_decompress_128_4); + turboquant_bench!(compress, 768, 4, bench_tq_compress_768_4); + turboquant_bench!(decompress, 768, 4, bench_tq_decompress_768_4); + turboquant_bench!(compress, 1024, 2, bench_tq_compress_1024_2); + turboquant_bench!(decompress, 1024, 2, bench_tq_decompress_1024_2); + turboquant_bench!(compress, 1024, 4, bench_tq_compress_1024_4); + turboquant_bench!(decompress, 1024, 4, bench_tq_decompress_1024_4); + turboquant_bench!(compress, 1024, 8, bench_tq_compress_1024_8); + turboquant_bench!(decompress, 1024, 8, bench_tq_decompress_1024_8); +}