From 19661447872fbf27fd8be70c74dffe54f904b0a5 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Thu, 2 Apr 2026 14:55:19 -0400 Subject: [PATCH] invert vortex tensor dependency Signed-off-by: Connor Tsui --- Cargo.lock | 7 +- vortex-file/Cargo.toml | 2 + vortex-file/src/lib.rs | 3 + vortex-tensor/Cargo.toml | 5 +- vortex-tensor/public-api.lock | 2 + vortex-tensor/src/fixed_shape/metadata.rs | 8 +-- vortex-tensor/src/fixed_shape/proto.rs | 6 +- vortex-tensor/src/fixed_shape/vtable.rs | 22 +++--- vortex-tensor/src/lib.rs | 17 +++++ vortex-tensor/src/matcher.rs | 4 +- .../src/scalar_fns/cosine_similarity.rs | 47 ++++++------- vortex-tensor/src/scalar_fns/l2_norm.rs | 45 +++++++------ vortex-tensor/src/utils.rs | 67 ++++++++++--------- vortex-tensor/src/vector/vtable.rs | 32 ++++----- vortex/Cargo.toml | 2 + 15 files changed, 153 insertions(+), 116 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4a4f4b14629..61c277226dd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10097,6 +10097,7 @@ dependencies = [ "vortex-sequence", "vortex-session", "vortex-sparse", + "vortex-tensor", "vortex-utils", "vortex-zigzag", "vortex-zstd", @@ -10590,6 +10591,7 @@ dependencies = [ "vortex-sequence", "vortex-session", "vortex-sparse", + "vortex-tensor", "vortex-utils", "vortex-zigzag", "vortex-zstd", @@ -10962,7 +10964,10 @@ dependencies = [ "num-traits", "prost 0.14.3", "rstest", - "vortex", + "vortex-array", + "vortex-buffer", + "vortex-error", + "vortex-session", ] [[package]] diff --git a/vortex-file/Cargo.toml b/vortex-file/Cargo.toml index d568328bb52..e812ab228c6 100644 --- a/vortex-file/Cargo.toml +++ b/vortex-file/Cargo.toml @@ -54,6 +54,7 @@ vortex-scan = { workspace = true } vortex-sequence = { workspace = true } vortex-session = { workspace = true } vortex-sparse = { workspace = true } +vortex-tensor = { workspace = true, optional = true } vortex-utils = { workspace = true, features = ["dashmap"] } vortex-zigzag = { workspace = true } vortex-zstd = { workspace = true, optional = true } @@ -79,6 +80,7 @@ tokio = [ zstd = ["dep:vortex-zstd", "vortex-btrblocks/zstd", "vortex-btrblocks/pco"] # This feature enables unstable encodings for which we don't guarantee stability. unstable_encodings = [ + "dep:vortex-tensor", "vortex-zstd?/unstable_encodings", "vortex-btrblocks/unstable_encodings", ] diff --git a/vortex-file/src/lib.rs b/vortex-file/src/lib.rs index c8e1d7740d8..000df4a7bcb 100644 --- a/vortex-file/src/lib.rs +++ b/vortex-file/src/lib.rs @@ -178,4 +178,7 @@ pub fn register_default_encodings(session: &VortexSession) { vortex_fastlanes::initialize(session); vortex_runend::initialize(session); vortex_sequence::initialize(session); + + #[cfg(feature = "unstable_encodings")] + vortex_tensor::initialize(session); } diff --git a/vortex-tensor/Cargo.toml b/vortex-tensor/Cargo.toml index 6f4fe4511af..c80127bc7b3 100644 --- a/vortex-tensor/Cargo.toml +++ b/vortex-tensor/Cargo.toml @@ -17,7 +17,9 @@ version = { workspace = true } workspace = true [dependencies] -vortex = { workspace = true } +vortex-array = { workspace = true } +vortex-error = { workspace = true } +vortex-session = { workspace = true } itertools = { workspace = true } num-traits = { workspace = true } @@ -25,3 +27,4 @@ prost = { workspace = true } [dev-dependencies] rstest = { workspace = true } +vortex-buffer = { workspace = true } diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index 6dd94c857d4..4a3b0d79fa4 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -259,3 +259,5 @@ pub fn vortex_tensor::vector::Vector::serialize_metadata(&self, _metadata: &Self pub fn vortex_tensor::vector::Vector::unpack_native<'a>(_ext_dtype: &'a vortex_array::dtype::extension::typed::ExtDType, storage_value: &'a vortex_array::scalar::scalar_value::ScalarValue) -> vortex_error::VortexResult pub fn vortex_tensor::vector::Vector::validate_dtype(ext_dtype: &vortex_array::dtype::extension::typed::ExtDType) -> vortex_error::VortexResult<()> + +pub fn vortex_tensor::initialize(session: &vortex_session::VortexSession) diff --git a/vortex-tensor/src/fixed_shape/metadata.rs b/vortex-tensor/src/fixed_shape/metadata.rs index fb46c67d213..264d18453c4 100644 --- a/vortex-tensor/src/fixed_shape/metadata.rs +++ b/vortex-tensor/src/fixed_shape/metadata.rs @@ -4,10 +4,10 @@ use std::fmt; use itertools::Either; -use vortex::error::VortexExpect; -use vortex::error::VortexResult; -use vortex::error::vortex_ensure; -use vortex::error::vortex_ensure_eq; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_error::vortex_ensure_eq; /// Metadata for a `FixedShapeTensor` extension type. #[derive(Debug, Clone, PartialEq, Eq, Hash)] diff --git a/vortex-tensor/src/fixed_shape/proto.rs b/vortex-tensor/src/fixed_shape/proto.rs index 06b4f45b726..9d89c56bcec 100644 --- a/vortex-tensor/src/fixed_shape/proto.rs +++ b/vortex-tensor/src/fixed_shape/proto.rs @@ -4,9 +4,9 @@ //! Protobuf serialization for [`FixedShapeTensorMetadata`]. use prost::Message; -use vortex::error::VortexExpect; -use vortex::error::VortexResult; -use vortex::error::vortex_err; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_err; use crate::fixed_shape::FixedShapeTensorMetadata; diff --git a/vortex-tensor/src/fixed_shape/vtable.rs b/vortex-tensor/src/fixed_shape/vtable.rs index 3c0b6512a65..21ab1ef5336 100644 --- a/vortex-tensor/src/fixed_shape/vtable.rs +++ b/vortex-tensor/src/fixed_shape/vtable.rs @@ -1,15 +1,15 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use vortex::dtype::DType; -use vortex::dtype::extension::ExtDType; -use vortex::dtype::extension::ExtId; -use vortex::dtype::extension::ExtVTable; -use vortex::error::VortexResult; -use vortex::error::vortex_bail; -use vortex::error::vortex_ensure; -use vortex::error::vortex_ensure_eq; -use vortex::scalar::ScalarValue; +use vortex_array::dtype::DType; +use vortex_array::dtype::extension::ExtDType; +use vortex_array::dtype::extension::ExtId; +use vortex_array::dtype::extension::ExtVTable; +use vortex_array::scalar::ScalarValue; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_ensure; +use vortex_error::vortex_ensure_eq; use crate::fixed_shape::FixedShapeTensor; use crate::fixed_shape::FixedShapeTensorMetadata; @@ -77,8 +77,8 @@ impl ExtVTable for FixedShapeTensor { #[cfg(test)] mod tests { use rstest::rstest; - use vortex::dtype::extension::ExtVTable; - use vortex::error::VortexResult; + use vortex_array::dtype::extension::ExtVTable; + use vortex_error::VortexResult; use crate::fixed_shape::FixedShapeTensor; use crate::fixed_shape::FixedShapeTensorMetadata; diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index c036b9854b2..515c2f373ca 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -5,6 +5,15 @@ //! including unit vectors, spherical coordinates, and similarity measures such as cosine //! similarity. +use vortex_array::dtype::session::DTypeSessionExt; +use vortex_array::scalar_fn::session::ScalarFnSessionExt; +use vortex_session::VortexSession; + +use crate::fixed_shape::FixedShapeTensor; +use crate::scalar_fns::cosine_similarity::CosineSimilarity; +use crate::scalar_fns::l2_norm::L2Norm; +use crate::vector::Vector; + pub mod matcher; pub mod scalar_fns; @@ -14,3 +23,11 @@ pub mod vector; pub mod encodings; mod utils; + +/// Registers the tensor extension dtypes and scalar functions with the given session. +pub fn initialize(session: &VortexSession) { + session.dtypes().register(Vector); + session.dtypes().register(FixedShapeTensor); + session.scalar_fns().register(CosineSimilarity); + session.scalar_fns().register(L2Norm); +} diff --git a/vortex-tensor/src/matcher.rs b/vortex-tensor/src/matcher.rs index bb79ad7447e..16bca0bb043 100644 --- a/vortex-tensor/src/matcher.rs +++ b/vortex-tensor/src/matcher.rs @@ -3,8 +3,8 @@ //! Matcher for tensor-like extension types. -use vortex::dtype::extension::ExtDTypeRef; -use vortex::dtype::extension::Matcher; +use vortex_array::dtype::extension::ExtDTypeRef; +use vortex_array::dtype::extension::Matcher; use crate::fixed_shape::FixedShapeTensor; use crate::fixed_shape::FixedShapeTensorMetadata; diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index 5155a7c8f08..539d4ca4a7c 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -8,23 +8,24 @@ use std::fmt::Formatter; use num_traits::Float; -use vortex::array::ArrayRef; -use vortex::array::ExecutionCtx; -use vortex::array::IntoArray; -use vortex::array::arrays::PrimitiveArray; -use vortex::array::match_each_float_ptype; -use vortex::dtype::DType; -use vortex::dtype::NativePType; -use vortex::dtype::Nullability; -use vortex::error::VortexResult; -use vortex::error::vortex_ensure; -use vortex::error::vortex_err; -use vortex::expr::Expression; -use vortex::scalar_fn::Arity; -use vortex::scalar_fn::ChildName; -use vortex::scalar_fn::ExecutionArgs; -use vortex::scalar_fn::ScalarFnId; -use vortex::scalar_fn::ScalarFnVTable; +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::dtype::DType; +use vortex_array::dtype::NativePType; +use vortex_array::dtype::Nullability; +use vortex_array::expr::Expression; +use vortex_array::expr::and; +use vortex_array::match_each_float_ptype; +use vortex_array::scalar_fn::Arity; +use vortex_array::scalar_fn::ChildName; +use vortex_array::scalar_fn::ExecutionArgs; +use vortex_array::scalar_fn::ScalarFnId; +use vortex_array::scalar_fn::ScalarFnVTable; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_error::vortex_err; use crate::matcher::AnyTensor; use crate::scalar_fns::ApproxOptions; @@ -156,7 +157,7 @@ impl ScalarFnVTable for CosineSimilarity { let lhs_validity = expression.child(0).validity()?; let rhs_validity = expression.child(1).validity()?; - Ok(Some(vortex::expr::and(lhs_validity, rhs_validity))) + Ok(Some(and(lhs_validity, rhs_validity))) } fn is_null_sensitive(&self, _options: &Self::Options) -> bool { @@ -188,11 +189,11 @@ fn cosine_similarity_row(a: &[T], b: &[T]) -> T { #[cfg(test)] mod tests { use rstest::rstest; - use vortex::array::ArrayRef; - use vortex::array::ToCanonical; - use vortex::array::arrays::ScalarFnArray; - use vortex::error::VortexResult; - use vortex::scalar_fn::ScalarFn; + use vortex_array::ArrayRef; + use vortex_array::ToCanonical; + use vortex_array::arrays::ScalarFnArray; + use vortex_array::scalar_fn::ScalarFn; + use vortex_error::VortexResult; use crate::scalar_fns::ApproxOptions; use crate::scalar_fns::cosine_similarity::CosineSimilarity; diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index b02035d4572..f319a0df0bb 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -8,23 +8,23 @@ use std::fmt::Formatter; use num_traits::Float; -use vortex::array::ArrayRef; -use vortex::array::ExecutionCtx; -use vortex::array::IntoArray; -use vortex::array::arrays::PrimitiveArray; -use vortex::array::match_each_float_ptype; -use vortex::dtype::DType; -use vortex::dtype::NativePType; -use vortex::dtype::Nullability; -use vortex::error::VortexResult; -use vortex::error::vortex_ensure; -use vortex::error::vortex_err; -use vortex::expr::Expression; -use vortex::scalar_fn::Arity; -use vortex::scalar_fn::ChildName; -use vortex::scalar_fn::ExecutionArgs; -use vortex::scalar_fn::ScalarFnId; -use vortex::scalar_fn::ScalarFnVTable; +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::dtype::DType; +use vortex_array::dtype::NativePType; +use vortex_array::dtype::Nullability; +use vortex_array::expr::Expression; +use vortex_array::match_each_float_ptype; +use vortex_array::scalar_fn::Arity; +use vortex_array::scalar_fn::ChildName; +use vortex_array::scalar_fn::ExecutionArgs; +use vortex_array::scalar_fn::ScalarFnId; +use vortex_array::scalar_fn::ScalarFnVTable; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_error::vortex_err; use crate::matcher::AnyTensor; use crate::scalar_fns::ApproxOptions; @@ -156,10 +156,11 @@ fn l2_norm_row(v: &[T]) -> T { #[cfg(test)] mod tests { use rstest::rstest; - use vortex::array::ToCanonical; - use vortex::array::arrays::ScalarFnArray; - use vortex::error::VortexResult; - use vortex::scalar_fn::ScalarFn; + use vortex_array::ArrayRef; + use vortex_array::ToCanonical; + use vortex_array::arrays::ScalarFnArray; + use vortex_array::scalar_fn::ScalarFn; + use vortex_error::VortexResult; use crate::scalar_fns::ApproxOptions; use crate::scalar_fns::l2_norm::L2Norm; @@ -168,7 +169,7 @@ mod tests { use crate::utils::test_helpers::vector_array; /// Evaluates L2 norm on a tensor/vector array and returns the result as `Vec`. - fn eval_l2_norm(input: vortex::array::ArrayRef, len: usize) -> VortexResult> { + fn eval_l2_norm(input: ArrayRef, len: usize) -> VortexResult> { let scalar_fn = ScalarFn::new(L2Norm, ApproxOptions::Exact).erased(); let result = ScalarFnArray::try_new(scalar_fn, vec![input], len)?; let prim = result.as_array().to_primitive(); diff --git a/vortex-tensor/src/utils.rs b/vortex-tensor/src/utils.rs index 82e2d1f5b45..cd21cd65964 100644 --- a/vortex-tensor/src/utils.rs +++ b/vortex-tensor/src/utils.rs @@ -1,22 +1,22 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use vortex::array::ArrayRef; -use vortex::array::ExecutionCtx; -use vortex::array::IntoArray; -use vortex::array::arrays::Constant; -use vortex::array::arrays::ConstantArray; -use vortex::array::arrays::Extension; -use vortex::array::arrays::FixedSizeListArray; -use vortex::array::arrays::PrimitiveArray; -use vortex::dtype::DType; -use vortex::dtype::NativePType; -use vortex::dtype::PType; -use vortex::dtype::extension::ExtDTypeRef; -use vortex::error::VortexResult; -use vortex::error::vortex_bail; -use vortex::error::vortex_ensure; -use vortex::error::vortex_err; +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::Constant; +use vortex_array::arrays::ConstantArray; +use vortex_array::arrays::Extension; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::dtype::DType; +use vortex_array::dtype::NativePType; +use vortex_array::dtype::PType; +use vortex_array::dtype::extension::ExtDTypeRef; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_ensure; +use vortex_error::vortex_err; /// Extracts the list size from a tensor-like extension dtype. /// @@ -123,21 +123,22 @@ pub fn extract_flat_elements( #[cfg(test)] pub mod test_helpers { - use vortex::array::ArrayRef; - use vortex::array::ExecutionCtx; - use vortex::array::IntoArray; - use vortex::array::arrays::ConstantArray; - use vortex::array::arrays::ExtensionArray; - use vortex::array::arrays::FixedSizeListArray; - use vortex::array::validity::Validity; - use vortex::buffer::Buffer; - use vortex::dtype::DType; - use vortex::dtype::Nullability; - use vortex::dtype::extension::ExtDType; - use vortex::error::VortexResult; - use vortex::error::vortex_err; - use vortex::extension::EmptyMetadata; - use vortex::scalar::Scalar; + use vortex_array::ArrayRef; + use vortex_array::ExecutionCtx; + use vortex_array::IntoArray; + use vortex_array::arrays::ConstantArray; + use vortex_array::arrays::ExtensionArray; + use vortex_array::arrays::FixedSizeListArray; + use vortex_array::dtype::DType; + use vortex_array::dtype::Nullability; + use vortex_array::dtype::PType; + use vortex_array::dtype::extension::ExtDType; + use vortex_array::extension::EmptyMetadata; + use vortex_array::scalar::Scalar; + use vortex_array::validity::Validity; + use vortex_buffer::Buffer; + use vortex_error::VortexResult; + use vortex_error::vortex_err; use super::extension_list_size; use super::extension_storage; @@ -183,7 +184,7 @@ pub mod test_helpers { elements: &[f64], len: usize, ) -> VortexResult { - let element_dtype = DType::Primitive(vortex::dtype::PType::F64, Nullability::NonNullable); + let element_dtype = DType::Primitive(PType::F64, Nullability::NonNullable); let children: Vec = elements .iter() @@ -204,7 +205,7 @@ pub mod test_helpers { /// Builds a [`Vector`] extension array whose storage is a [`ConstantArray`], representing a /// single query vector broadcast to `len` rows. pub fn constant_vector_array(elements: &[f64], len: usize) -> VortexResult { - let element_dtype = DType::Primitive(vortex::dtype::PType::F64, Nullability::NonNullable); + let element_dtype = DType::Primitive(PType::F64, Nullability::NonNullable); let children: Vec = elements .iter() diff --git a/vortex-tensor/src/vector/vtable.rs b/vortex-tensor/src/vector/vtable.rs index 61a0f35d9ff..2dda05b7363 100644 --- a/vortex-tensor/src/vector/vtable.rs +++ b/vortex-tensor/src/vector/vtable.rs @@ -1,15 +1,15 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use vortex::dtype::DType; -use vortex::dtype::extension::ExtDType; -use vortex::dtype::extension::ExtId; -use vortex::dtype::extension::ExtVTable; -use vortex::error::VortexResult; -use vortex::error::vortex_bail; -use vortex::error::vortex_ensure; -use vortex::extension::EmptyMetadata; -use vortex::scalar::ScalarValue; +use vortex_array::dtype::DType; +use vortex_array::dtype::extension::ExtDType; +use vortex_array::dtype::extension::ExtId; +use vortex_array::dtype::extension::ExtVTable; +use vortex_array::extension::EmptyMetadata; +use vortex_array::scalar::ScalarValue; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_ensure; use crate::vector::Vector; @@ -62,13 +62,13 @@ mod tests { use std::sync::Arc; use rstest::rstest; - use vortex::dtype::DType; - use vortex::dtype::Nullability; - use vortex::dtype::PType; - use vortex::dtype::extension::ExtDType; - use vortex::dtype::extension::ExtVTable; - use vortex::error::VortexResult; - use vortex::extension::EmptyMetadata; + use vortex_array::dtype::DType; + use vortex_array::dtype::Nullability; + use vortex_array::dtype::PType; + use vortex_array::dtype::extension::ExtDType; + use vortex_array::dtype::extension::ExtVTable; + use vortex_array::extension::EmptyMetadata; + use vortex_error::VortexResult; use crate::vector::Vector; diff --git a/vortex/Cargo.toml b/vortex/Cargo.toml index d8dc89882b0..896ec139251 100644 --- a/vortex/Cargo.toml +++ b/vortex/Cargo.toml @@ -44,6 +44,7 @@ vortex-scan = { workspace = true } vortex-sequence = { workspace = true } vortex-session = { workspace = true } vortex-sparse = { workspace = true } +vortex-tensor = { workspace = true, optional = true } vortex-utils = { workspace = true } vortex-zigzag = { workspace = true } vortex-zstd = { workspace = true, optional = true } @@ -73,6 +74,7 @@ pretty = ["vortex-array/table-display"] serde = ["vortex-array/serde", "vortex-buffer/serde", "vortex-mask/serde"] # This feature enabled unstable encodings for which we don't guarantee stability. unstable_encodings = [ + "dep:vortex-tensor", "vortex-btrblocks/unstable_encodings", "vortex-file?/unstable_encodings", "vortex-zstd?/unstable_encodings",