diff --git a/encodings/parquet-variant/src/kernel.rs b/encodings/parquet-variant/src/kernel.rs index 08698b0f6b0..818aaf3b43f 100644 --- a/encodings/parquet-variant/src/kernel.rs +++ b/encodings/parquet-variant/src/kernel.rs @@ -19,11 +19,13 @@ use vortex_mask::Mask; use crate::ParquetVariant; use crate::array::ParquetVariantArray; +use crate::variant_get::VariantGetExecuteParent; pub(crate) static PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ ParentKernelSet::lift(&FilterExecuteAdaptor(ParquetVariant)), ParentKernelSet::lift(&SliceExecuteAdaptor(ParquetVariant)), ParentKernelSet::lift(&TakeExecuteAdaptor(ParquetVariant)), + ParentKernelSet::lift(&VariantGetExecuteParent), ]); impl SliceKernel for ParquetVariant { diff --git a/encodings/parquet-variant/src/lib.rs b/encodings/parquet-variant/src/lib.rs index 6188f8fe727..4cc3463f808 100644 --- a/encodings/parquet-variant/src/lib.rs +++ b/encodings/parquet-variant/src/lib.rs @@ -28,6 +28,7 @@ mod array; mod kernel; mod operations; mod validity; +mod variant_get; mod vtable; pub use array::ParquetVariantArray; diff --git a/encodings/parquet-variant/src/variant_get/mod.rs b/encodings/parquet-variant/src/variant_get/mod.rs new file mode 100644 index 00000000000..eb64bff3028 --- /dev/null +++ b/encodings/parquet-variant/src/variant_get/mod.rs @@ -0,0 +1,116 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Execute-parent kernel for `variant_get` on `ParquetVariantArray`. +//! +//! Delegates to `parquet_variant_compute::variant_get` after converting to Arrow. + +use std::sync::Arc; + +use parquet_variant::VariantPathElement; +use parquet_variant_compute::GetOptions; +use parquet_variant_compute::VariantArray as ArrowVariantArray; +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::VariantArray; +use vortex_array::arrays::scalar_fn::ExactScalarFn; +use vortex_array::arrays::scalar_fn::ScalarFnArrayView; +use vortex_array::arrow::FromArrowArray; +use vortex_array::dtype::DType; +use vortex_array::dtype::FieldName; +use vortex_array::dtype::Nullability; +use vortex_array::kernel::ExecuteParentKernel; +use vortex_array::scalar_fn::fns::variant_get::VariantGet; +use vortex_array::validity::Validity; +use vortex_buffer::BitBuffer; +use vortex_error::VortexResult; +use vortex_error::vortex_err; + +use crate::ParquetVariant; +use crate::array::ParquetVariantArray; + +#[cfg(test)] +mod tests; + +#[derive(Debug)] +pub(crate) struct VariantGetExecuteParent; + +impl ExecuteParentKernel for VariantGetExecuteParent { + type Parent = ExactScalarFn; + + fn execute_parent( + &self, + array: &ParquetVariantArray, + parent: ScalarFnArrayView<'_, VariantGet>, + _child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + let field_name: &FieldName = parent.options; + variant_get_impl(array, field_name, ctx).map(Some) + } +} + +fn variant_get_impl( + array: &ParquetVariantArray, + field_name: &FieldName, + ctx: &mut ExecutionCtx, +) -> VortexResult { + // Convert to Arrow VariantArray + let arrow_variant = array.to_arrow(ctx)?; + + // Build path for a single field access + let path_element = VariantPathElement::Field { + name: field_name.as_ref().into(), + }; + let options = GetOptions::new_with_path(vec![path_element].into()); + + // Delegate to the parquet-variant-compute kernel. + // With as_type = None, the result is itself a VariantArray. + let inner: Arc = Arc::new(arrow_variant.into_inner()); + let arrow_result = parquet_variant_compute::variant_get(&inner, options) + .map_err(|e| vortex_err!("variant_get failed: {e}"))?; + + // Convert back to Vortex. + let result_variant = ArrowVariantArray::try_new( + arrow_result + .as_any() + .downcast_ref::() + .ok_or_else(|| vortex_err!("variant_get did not return a StructArray"))?, + ) + .map_err(|e| vortex_err!("failed to create VariantArray from result: {e}"))?; + + // Ensure the result is always nullable (matching variant_get's return_dtype). + // Arrow may return a non-nullable result when no nulls are present. + let validity = result_variant + .nulls() + .map(|nulls| { + if nulls.null_count() == nulls.len() { + Validity::AllInvalid + } else { + Validity::from(BitBuffer::from(nulls.inner().clone())) + } + }) + .unwrap_or(Validity::AllValid); + + let metadata = ArrayRef::from_arrow( + result_variant.metadata_field() as &dyn arrow_array::Array, + false, + )?; + let value = result_variant + .value_field() + .map(|v| ArrayRef::from_arrow(v as &dyn arrow_array::Array, true)) + .transpose()?; + let typed_value = result_variant + .typed_value_field() + .map(|tv| ArrayRef::from_arrow(tv.as_ref(), true)) + .transpose()?; + + let pv = ParquetVariantArray::try_new(validity, metadata, value, typed_value)?; + debug_assert_eq!( + pv.dtype, + DType::Variant(Nullability::Nullable), + "variant_get result must be nullable" + ); + Ok(VariantArray::new(pv.into_array()).into_array()) +} diff --git a/encodings/parquet-variant/src/variant_get/tests.rs b/encodings/parquet-variant/src/variant_get/tests.rs new file mode 100644 index 00000000000..1911bf97b91 --- /dev/null +++ b/encodings/parquet-variant/src/variant_get/tests.rs @@ -0,0 +1,350 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::sync::Arc; + +use arrow_array::Array as ArrowArray; +use arrow_array::ArrayRef as ArrowArrayRef; +use arrow_array::StringArray; +use arrow_array::StructArray; +use arrow_buffer::NullBuffer; +use parquet_variant::Variant as PqVariant; +use parquet_variant::VariantBuilderExt; +use parquet_variant::VariantPath; +use parquet_variant_compute::GetOptions; +use parquet_variant_compute::VariantArray as ArrowVariantArray; +use parquet_variant_compute::VariantArrayBuilder; +use parquet_variant_compute::json_to_variant; +use rstest::rstest; +use vortex_array::ArrayRef; +use vortex_array::DynArray; +use vortex_array::LEGACY_SESSION; +use vortex_array::VortexSessionExecute; +use vortex_array::expr::root; +use vortex_array::expr::variant_get; +use vortex_error::VortexResult; + +use crate::ParquetVariant; +use crate::ParquetVariantArray; + +/// Apply variant_get and execute through the full pipeline (including execute_parent). +fn apply_variant_get(arr: &ArrayRef, field: &str) -> VortexResult { + let expr = variant_get(field, root()); + let array = arr.apply(&expr)?; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + array.execute::(&mut ctx) +} + +/// Convert a Vortex result back to an Arrow VariantArray for comparison. +fn vortex_to_arrow_variant(arr: &ArrayRef) -> ArrowVariantArray { + let pv = arr + .as_::() + .child() + .as_::(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + pv.to_arrow(&mut ctx).unwrap() +} + +/// Run variant_get through both Arrow and Vortex on the same input, and assert +/// the per-row results (value + validity) are identical by comparing at the Arrow level. +fn assert_matches_arrow(json_rows: &[&str], field: &str) { + // --- Arrow side --- + let arrow_strings: ArrowArrayRef = Arc::new(StringArray::from( + json_rows.iter().map(|s| Some(*s)).collect::>(), + )); + let arrow_variant = json_to_variant(&arrow_strings).unwrap(); + let path = VariantPath::try_from(field).unwrap(); + let arrow_result = parquet_variant_compute::variant_get( + &arrow_variant.clone().into(), + GetOptions::new_with_path(path), + ) + .unwrap(); + let arrow_result_variant = + ArrowVariantArray::try_new(arrow_result.as_any().downcast_ref::().unwrap()) + .unwrap(); + + // --- Vortex side --- + let vortex_input = ParquetVariantArray::from_arrow_variant(&arrow_variant).unwrap(); + let vortex_result = apply_variant_get(&vortex_input, field).unwrap(); + let vortex_as_arrow = vortex_to_arrow_variant(&vortex_result); + + // --- Compare row-by-row at Arrow Variant level --- + assert_eq!( + vortex_as_arrow.len(), + arrow_result_variant.len(), + "length mismatch" + ); + + for i in 0..arrow_result_variant.len() { + let arrow_is_null = arrow_result_variant.is_null(i); + let vortex_is_null = vortex_as_arrow.is_null(i); + + assert_eq!( + vortex_is_null, arrow_is_null, + "row {i}: null mismatch (vortex={vortex_is_null}, arrow={arrow_is_null})" + ); + + if !arrow_is_null { + let arrow_value = arrow_result_variant.value(i); + let vortex_value = vortex_as_arrow.value(i); + assert_eq!( + vortex_value, arrow_value, + "row {i}: value mismatch\n vortex: {vortex_value:?}\n arrow: {arrow_value:?}" + ); + } + } +} + +/// Run variant_get through both Arrow and Vortex on nullable input (with NullBuffer), +/// and assert the results match. +fn assert_matches_arrow_nullable(json_rows: &[&str], validity: &[bool], field: &str) { + // --- Arrow side --- + let arrow_strings: ArrowArrayRef = Arc::new(StringArray::from( + json_rows.iter().map(|s| Some(*s)).collect::>(), + )); + let base_variant = json_to_variant(&arrow_strings).unwrap(); + let inner = base_variant.into_inner(); + let null_struct = StructArray::try_new( + inner.fields().clone(), + inner.columns().to_vec(), + Some(NullBuffer::from(validity.to_vec())), + ) + .unwrap(); + let arrow_variant = ArrowVariantArray::try_new(&null_struct).unwrap(); + + let path = VariantPath::try_from(field).unwrap(); + let arrow_result = parquet_variant_compute::variant_get( + &ArrowArrayRef::from(arrow_variant.clone()), + GetOptions::new_with_path(path), + ) + .unwrap(); + let arrow_result_variant = + ArrowVariantArray::try_new(arrow_result.as_any().downcast_ref::().unwrap()) + .unwrap(); + + // --- Vortex side --- + let vortex_input = ParquetVariantArray::from_arrow_variant(&arrow_variant).unwrap(); + let vortex_result = apply_variant_get(&vortex_input, field).unwrap(); + let vortex_as_arrow = vortex_to_arrow_variant(&vortex_result); + + // --- Compare --- + assert_eq!( + vortex_as_arrow.len(), + arrow_result_variant.len(), + "length mismatch" + ); + + for i in 0..arrow_result_variant.len() { + let arrow_is_null = arrow_result_variant.is_null(i); + let vortex_is_null = vortex_as_arrow.is_null(i); + + assert_eq!( + vortex_is_null, arrow_is_null, + "row {i}: null mismatch (vortex={vortex_is_null}, arrow={arrow_is_null})" + ); + + if !arrow_is_null { + let arrow_value = arrow_result_variant.value(i); + let vortex_value = vortex_as_arrow.value(i); + assert_eq!( + vortex_value, arrow_value, + "row {i}: value mismatch\n vortex: {vortex_value:?}\n arrow: {arrow_value:?}" + ); + } + } +} + +// --------------------------------------------------------------------------- +// Tests that compare Vortex vs Arrow variant_get +// --------------------------------------------------------------------------- + +#[rstest] +#[case("some_field", &[r#"{"some_field": 1234}"#])] +#[case("a", &[r#"{"a": 1, "b": 2}"#, r#"{"a": "hello"}"#, r#"{"b": 99}"#])] +#[case("nested", &[r#"{"nested": {"x": 1, "y": 2}}"#])] +#[case("missing", &[r#"{"a": 1}"#, r#"{"b": 2}"#])] +#[case("x", &[r#"{"x": true}"#, r#"{"x": false}"#, r#"{"x": null}"#])] +#[case("arr", &[r#"{"arr": [1, 2, 3]}"#])] +#[case("s", &[r#"{"s": "hello world"}"#, r#"{"s": ""}"#])] +#[case("n", &[r#"{"n": 3.14}"#, r#"{"n": -0.0}"#])] +fn test_variant_get_matches_arrow(#[case] field: &str, #[case] json_rows: &[&str]) { + assert_matches_arrow(json_rows, field); +} + +#[test] +fn test_variant_get_matches_arrow_non_object() { + // Primitive variants (not objects) — accessing any field should give null + assert_matches_arrow(&["42", r#""hello""#, "true", "null"], "a"); +} + +#[test] +fn test_variant_get_matches_arrow_mixed_types() { + // Same field name, different value types across rows + assert_matches_arrow( + &[ + r#"{"v": 1}"#, + r#"{"v": "text"}"#, + r#"{"v": true}"#, + r#"{"v": [1,2]}"#, + r#"{"v": {"nested": 1}}"#, + ], + "v", + ); +} + +#[test] +fn test_variant_get_matches_arrow_nullable() { + assert_matches_arrow_nullable( + &[r#"{"a": 10}"#, r#"{"a": 20}"#, r#"{"a": 30}"#], + &[true, false, true], // row 1 is null + "a", + ); +} + +#[test] +fn test_variant_get_matches_arrow_all_null() { + assert_matches_arrow_nullable( + &[r#"{"a": 1}"#, r#"{"a": 2}"#, r#"{"a": 3}"#], + &[false, false, false], + "a", + ); +} + +#[test] +fn test_variant_get_matches_arrow_nested_object_result() { + // The result of variant_get is itself an object + assert_matches_arrow( + &[ + r#"{"outer": {"inner": 42}}"#, + r#"{"outer": {"a": 1, "b": 2}}"#, + ], + "outer", + ); +} + +// --------------------------------------------------------------------------- +// Original standalone tests +// --------------------------------------------------------------------------- + +#[test] +fn test_variant_get_basic() -> VortexResult<()> { + let arr = make_object_array()?; + let result = apply_variant_get(&arr, "a")?; + + assert_eq!(result.len(), 3); + + // Row 0: {"a": 1, ...} → variant(1) + let s0 = result.scalar_at(0)?; + assert!(!s0.is_null()); + let inner0 = s0.as_variant().value().unwrap(); + assert_eq!(*inner0, 1i32.into()); + + // Row 1: {"a": 2, ...} → variant(2) + let s1 = result.scalar_at(1)?; + assert!(!s1.is_null()); + let inner1 = s1.as_variant().value().unwrap(); + assert_eq!(*inner1, 2i32.into()); + + // Row 2: {"b": "y"} → null (field "a" missing) + let s2 = result.scalar_at(2)?; + assert!(s2.is_null()); + + Ok(()) +} + +#[test] +fn test_variant_get_missing_field() -> VortexResult<()> { + let arr = make_object_array()?; + let result = apply_variant_get(&arr, "nonexistent")?; + + assert_eq!(result.len(), 3); + for i in 0..3 { + assert!(result.scalar_at(i)?.is_null(), "row {i} should be null"); + } + + Ok(()) +} + +#[test] +fn test_variant_get_null_input() -> VortexResult<()> { + let arr = make_nullable_object_array()?; + let result = apply_variant_get(&arr, "a")?; + + assert_eq!(result.len(), 3); + assert!(!result.scalar_at(0)?.is_null()); + assert!(result.scalar_at(1)?.is_null()); + assert!(!result.scalar_at(2)?.is_null()); + + Ok(()) +} + +#[test] +fn test_variant_get_non_object() -> VortexResult<()> { + let mut builder = VariantArrayBuilder::new(2); + builder.append_variant(PqVariant::from(42i32)); + builder.append_variant(PqVariant::from("hello")); + let arr = ParquetVariantArray::from_arrow_variant(&builder.build())?; + + let result = apply_variant_get(&arr, "a")?; + + assert_eq!(result.len(), 2); + assert!(result.scalar_at(0)?.is_null()); + assert!(result.scalar_at(1)?.is_null()); + + Ok(()) +} + +#[test] +fn test_variant_get_different_field() -> VortexResult<()> { + let arr = make_object_array()?; + let result = apply_variant_get(&arr, "b")?; + + assert_eq!(result.len(), 3); + assert!(!result.scalar_at(0)?.is_null()); + assert!(result.scalar_at(1)?.is_null()); + assert!(!result.scalar_at(2)?.is_null()); + + Ok(()) +} + +// --------------------------------------------------------------------------- +// Test data helpers +// --------------------------------------------------------------------------- + +fn make_object_array() -> VortexResult { + let mut builder = VariantArrayBuilder::new(3); + + builder + .new_object() + .with_field("a", 1i32) + .with_field("b", "x") + .finish(); + + builder + .new_object() + .with_field("a", 2i32) + .with_field("c", true) + .finish(); + + builder.new_object().with_field("b", "y").finish(); + + ParquetVariantArray::from_arrow_variant(&builder.build()) +} + +fn make_nullable_object_array() -> VortexResult { + let mut builder = VariantArrayBuilder::new(3); + + builder.new_object().with_field("a", 10i32).finish(); + builder.new_object().with_field("a", 20i32).finish(); + builder.new_object().with_field("a", 30i32).finish(); + + let inner = builder.build().into_inner(); + let null_struct = StructArray::try_new( + inner.fields().clone(), + inner.columns().to_vec(), + Some(NullBuffer::from(vec![true, false, true])), + ) + .unwrap(); + let arrow_variant = ArrowVariantArray::try_new(&null_struct).unwrap(); + ParquetVariantArray::from_arrow_variant(&arrow_variant) +} diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index c193f10da92..0f5271e16df 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -14480,6 +14480,8 @@ pub fn vortex_array::expr::select_exclude(fields: impl core::convert::Into alloc::vec::Vec +pub fn vortex_array::expr::variant_get(field: impl core::convert::Into, child: vortex_array::expr::Expression) -> vortex_array::expr::Expression + pub fn vortex_array::expr::zip_expr(mask: vortex_array::expr::Expression, if_true: vortex_array::expr::Expression, if_false: vortex_array::expr::Expression) -> vortex_array::expr::Expression pub type vortex_array::expr::Annotations<'a, A> = vortex_utils::aliases::hash_map::HashMap<&'a vortex_array::expr::Expression, vortex_utils::aliases::hash_set::HashSet> @@ -19098,6 +19100,52 @@ pub fn vortex_array::scalar_fn::fns::select::Select::stat_falsification(&self, o pub fn vortex_array::scalar_fn::fns::select::Select::validity(&self, options: &Self::Options, expression: &vortex_array::expr::Expression) -> vortex_error::VortexResult> +pub mod vortex_array::scalar_fn::fns::variant_get + +pub struct vortex_array::scalar_fn::fns::variant_get::VariantGet + +impl core::clone::Clone for vortex_array::scalar_fn::fns::variant_get::VariantGet + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::clone(&self) -> vortex_array::scalar_fn::fns::variant_get::VariantGet + +impl vortex_array::scalar_fn::ScalarFnVTable for vortex_array::scalar_fn::fns::variant_get::VariantGet + +pub type vortex_array::scalar_fn::fns::variant_get::VariantGet::Options = vortex_array::dtype::FieldName + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::arity(&self, _field_name: &vortex_array::dtype::FieldName) -> vortex_array::scalar_fn::Arity + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::child_name(&self, _instance: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::deserialize(&self, metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::execute(&self, _field_name: &vortex_array::dtype::FieldName, _args: &dyn vortex_array::scalar_fn::ExecutionArgs, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::fmt_sql(&self, field_name: &vortex_array::dtype::FieldName, expr: &vortex_array::expr::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::id(&self) -> vortex_array::scalar_fn::ScalarFnId + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::is_fallible(&self, _field_name: &vortex_array::dtype::FieldName) -> bool + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::is_null_sensitive(&self, _field_name: &vortex_array::dtype::FieldName) -> bool + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::reduce(&self, field_name: &vortex_array::dtype::FieldName, node: &dyn vortex_array::scalar_fn::ReduceNode, ctx: &dyn vortex_array::scalar_fn::ReduceCtx) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::return_dtype(&self, _field_name: &vortex_array::dtype::FieldName, arg_dtypes: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::serialize(&self, instance: &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::simplify(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, ctx: &dyn vortex_array::scalar_fn::SimplifyCtx) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::simplify_untyped(&self, options: &Self::Options, expr: &vortex_array::expr::Expression) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::stat_expression(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, stat: vortex_array::expr::stats::Stat, catalog: &dyn vortex_array::expr::pruning::StatsCatalog) -> core::option::Option + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::stat_falsification(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, catalog: &dyn vortex_array::expr::pruning::StatsCatalog) -> core::option::Option + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::validity(&self, options: &Self::Options, expression: &vortex_array::expr::Expression) -> vortex_error::VortexResult> + pub mod vortex_array::scalar_fn::fns::zip pub struct vortex_array::scalar_fn::fns::zip::Zip @@ -20182,6 +20230,44 @@ pub fn vortex_array::scalar_fn::fns::select::Select::stat_falsification(&self, o pub fn vortex_array::scalar_fn::fns::select::Select::validity(&self, options: &Self::Options, expression: &vortex_array::expr::Expression) -> vortex_error::VortexResult> +impl vortex_array::scalar_fn::ScalarFnVTable for vortex_array::scalar_fn::fns::variant_get::VariantGet + +pub type vortex_array::scalar_fn::fns::variant_get::VariantGet::Options = vortex_array::dtype::FieldName + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::arity(&self, _field_name: &vortex_array::dtype::FieldName) -> vortex_array::scalar_fn::Arity + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::child_name(&self, _instance: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::deserialize(&self, metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::execute(&self, _field_name: &vortex_array::dtype::FieldName, _args: &dyn vortex_array::scalar_fn::ExecutionArgs, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::fmt_sql(&self, field_name: &vortex_array::dtype::FieldName, expr: &vortex_array::expr::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::id(&self) -> vortex_array::scalar_fn::ScalarFnId + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::is_fallible(&self, _field_name: &vortex_array::dtype::FieldName) -> bool + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::is_null_sensitive(&self, _field_name: &vortex_array::dtype::FieldName) -> bool + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::reduce(&self, field_name: &vortex_array::dtype::FieldName, node: &dyn vortex_array::scalar_fn::ReduceNode, ctx: &dyn vortex_array::scalar_fn::ReduceCtx) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::return_dtype(&self, _field_name: &vortex_array::dtype::FieldName, arg_dtypes: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::serialize(&self, instance: &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::simplify(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, ctx: &dyn vortex_array::scalar_fn::SimplifyCtx) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::simplify_untyped(&self, options: &Self::Options, expr: &vortex_array::expr::Expression) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::stat_expression(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, stat: vortex_array::expr::stats::Stat, catalog: &dyn vortex_array::expr::pruning::StatsCatalog) -> core::option::Option + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::stat_falsification(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, catalog: &dyn vortex_array::expr::pruning::StatsCatalog) -> core::option::Option + +pub fn vortex_array::scalar_fn::fns::variant_get::VariantGet::validity(&self, options: &Self::Options, expression: &vortex_array::expr::Expression) -> vortex_error::VortexResult> + impl vortex_array::scalar_fn::ScalarFnVTable for vortex_array::scalar_fn::fns::zip::Zip pub type vortex_array::scalar_fn::fns::zip::Zip::Options = vortex_array::scalar_fn::EmptyOptions diff --git a/vortex-array/src/expr/exprs.rs b/vortex-array/src/expr/exprs.rs index bc30ba86ec4..4de384a01a3 100644 --- a/vortex-array/src/expr/exprs.rs +++ b/vortex-array/src/expr/exprs.rs @@ -45,6 +45,7 @@ use crate::scalar_fn::fns::pack::PackOptions; use crate::scalar_fn::fns::root::Root; use crate::scalar_fn::fns::select::FieldSelection; use crate::scalar_fn::fns::select::Select; +use crate::scalar_fn::fns::variant_get::VariantGet; use crate::scalar_fn::fns::zip::Zip; // ---- Root ---- @@ -676,3 +677,17 @@ pub fn dynamic( pub fn list_contains(list: Expression, value: Expression) -> Expression { ListContains.new_expr(EmptyOptions, [list, value]) } + +// ---- VariantGet ---- + +/// Creates an expression that extracts a field from a variant object by name. +/// +/// Returns a new variant containing the field's value, or null if the field does not exist. +/// +/// ```rust +/// # use vortex_array::expr::{variant_get, root}; +/// let expr = variant_get("field_name", root()); +/// ``` +pub fn variant_get(field: impl Into, child: Expression) -> Expression { + VariantGet.new_expr(field.into(), vec![child]) +} diff --git a/vortex-array/src/scalar_fn/fns/mod.rs b/vortex-array/src/scalar_fn/fns/mod.rs index 94fc8fb0384..b0c7e7547e5 100644 --- a/vortex-array/src/scalar_fn/fns/mod.rs +++ b/vortex-array/src/scalar_fn/fns/mod.rs @@ -19,4 +19,5 @@ pub mod operators; pub mod pack; pub mod root; pub mod select; +pub mod variant_get; pub mod zip; diff --git a/vortex-array/src/scalar_fn/fns/variant_get.rs b/vortex-array/src/scalar_fn/fns/variant_get.rs new file mode 100644 index 00000000000..ee4dc50dca5 --- /dev/null +++ b/vortex-array/src/scalar_fn/fns/variant_get.rs @@ -0,0 +1,144 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::fmt::Formatter; +use std::sync::Arc; + +use prost::Message; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_proto::expr as pb; +use vortex_session::VortexSession; + +use crate::ArrayRef; +use crate::ExecutionCtx; +use crate::arrays::Variant; +use crate::dtype::DType; +use crate::dtype::FieldName; +use crate::dtype::Nullability; +use crate::expr::Expression; +use crate::scalar_fn::Arity; +use crate::scalar_fn::ChildName; +use crate::scalar_fn::ExecutionArgs; +use crate::scalar_fn::ReduceCtx; +use crate::scalar_fn::ReduceNode; +use crate::scalar_fn::ReduceNodeRef; +use crate::scalar_fn::ScalarFnId; +use crate::scalar_fn::ScalarFnVTable; +use crate::scalar_fn::ScalarFnVTableExt; + +/// Extracts a field from a variant object by name, returning a new variant. +/// +/// This is analogous to [`GetItem`](super::get_item::GetItem) for structs, but operates on +/// semi-structured variant data. The result is always `DType::Variant(Nullable)` since the +/// requested field may not exist in every row. +/// +/// Execution is handled by variant encodings (e.g. `ParquetVariantArray`) via `execute_parent`. +/// The canonical `VariantArray` does not support direct execution; a `reduce` rule unwraps +/// the `VariantArray` wrapper to expose the underlying encoding. +#[derive(Clone)] +pub struct VariantGet; + +impl ScalarFnVTable for VariantGet { + type Options = FieldName; + + fn id(&self) -> ScalarFnId { + ScalarFnId::from("vortex.variant_get") + } + + fn serialize(&self, instance: &Self::Options) -> VortexResult>> { + Ok(Some( + pb::VariantGetOpts { + path: instance.to_string(), + } + .encode_to_vec(), + )) + } + + fn deserialize( + &self, + metadata: &[u8], + _session: &VortexSession, + ) -> VortexResult { + let opts = pb::VariantGetOpts::decode(metadata)?; + Ok(FieldName::from(opts.path)) + } + + fn arity(&self, _field_name: &FieldName) -> Arity { + Arity::Exact(1) + } + + fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName { + match child_idx { + 0 => ChildName::from("input"), + _ => unreachable!( + "Invalid child index {} for VariantGet expression", + child_idx + ), + } + } + + fn fmt_sql( + &self, + field_name: &FieldName, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { + write!(f, "variant_get(")?; + expr.children()[0].fmt_sql(f)?; + write!(f, ", '{}')", field_name) + } + + fn return_dtype(&self, _field_name: &FieldName, arg_dtypes: &[DType]) -> VortexResult { + if !matches!(arg_dtypes[0], DType::Variant(_)) { + vortex_bail!( + "variant_get requires a Variant input, got {:?}", + arg_dtypes[0] + ); + } + // Always nullable: the field may not exist in every variant value. + Ok(DType::Variant(Nullability::Nullable)) + } + + fn execute( + &self, + _field_name: &FieldName, + _args: &dyn ExecutionArgs, + _ctx: &mut ExecutionCtx, + ) -> VortexResult { + vortex_bail!( + "variant_get cannot be executed directly; \ + it must be pushed down to a variant encoding via execute_parent" + ) + } + + fn reduce( + &self, + field_name: &FieldName, + node: &dyn ReduceNode, + ctx: &dyn ReduceCtx, + ) -> VortexResult> { + // If the child is a canonical VariantArray wrapper, unwrap it to expose the + // underlying encoding (e.g. ParquetVariantArray) so that execute_parent can + // handle the operation. + let child = node.child(0); + if let Some(child_array) = child.as_any().downcast_ref::() + && child_array.is::() + { + let inner = child_array.as_::().child().clone(); + return Ok(Some(ctx.new_node( + VariantGet.bind(field_name.clone()), + &[Arc::new(inner) as ReduceNodeRef], + )?)); + } + Ok(None) + } + + fn is_null_sensitive(&self, _field_name: &FieldName) -> bool { + true + } + + fn is_fallible(&self, _field_name: &FieldName) -> bool { + false + } +} diff --git a/vortex-array/src/scalar_fn/session.rs b/vortex-array/src/scalar_fn/session.rs index eef759bf8e3..6553cbd29e3 100644 --- a/vortex-array/src/scalar_fn/session.rs +++ b/vortex-array/src/scalar_fn/session.rs @@ -23,6 +23,7 @@ use crate::scalar_fn::fns::not::Not; use crate::scalar_fn::fns::pack::Pack; use crate::scalar_fn::fns::root::Root; use crate::scalar_fn::fns::select::Select; +use crate::scalar_fn::fns::variant_get::VariantGet; /// Registry of scalar function vtables. /// Registry of scalar function vtables. @@ -67,6 +68,7 @@ impl Default for ScalarFnSession { this.register(Pack); this.register(Root); this.register(Select); + this.register(VariantGet); this } diff --git a/vortex-proto/proto/expr.proto b/vortex-proto/proto/expr.proto index 73ba7209a15..3c062dffb79 100644 --- a/vortex-proto/proto/expr.proto +++ b/vortex-proto/proto/expr.proto @@ -87,6 +87,11 @@ message SelectOpts { } } +// Options for `vortex.variant_get` +message VariantGetOpts { + string path = 1; +} + // Options for `vortex.case_when` // Encodes num_when_then_pairs and has_else into a single u32 (num_children). // num_children = num_when_then_pairs * 2 + (has_else ? 1 : 0) diff --git a/vortex-proto/public-api.lock b/vortex-proto/public-api.lock index 045b53d17eb..d7ea01cb7bb 100644 --- a/vortex-proto/public-api.lock +++ b/vortex-proto/public-api.lock @@ -1192,6 +1192,40 @@ pub fn vortex_proto::expr::SelectOpts::clear(&mut self) pub fn vortex_proto::expr::SelectOpts::encoded_len(&self) -> usize +pub struct vortex_proto::expr::VariantGetOpts + +pub vortex_proto::expr::VariantGetOpts::path: alloc::string::String + +impl core::clone::Clone for vortex_proto::expr::VariantGetOpts + +pub fn vortex_proto::expr::VariantGetOpts::clone(&self) -> vortex_proto::expr::VariantGetOpts + +impl core::cmp::Eq for vortex_proto::expr::VariantGetOpts + +impl core::cmp::PartialEq for vortex_proto::expr::VariantGetOpts + +pub fn vortex_proto::expr::VariantGetOpts::eq(&self, other: &vortex_proto::expr::VariantGetOpts) -> bool + +impl core::default::Default for vortex_proto::expr::VariantGetOpts + +pub fn vortex_proto::expr::VariantGetOpts::default() -> Self + +impl core::fmt::Debug for vortex_proto::expr::VariantGetOpts + +pub fn vortex_proto::expr::VariantGetOpts::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::hash::Hash for vortex_proto::expr::VariantGetOpts + +pub fn vortex_proto::expr::VariantGetOpts::hash<__H: core::hash::Hasher>(&self, state: &mut __H) + +impl core::marker::StructuralPartialEq for vortex_proto::expr::VariantGetOpts + +impl prost::message::Message for vortex_proto::expr::VariantGetOpts + +pub fn vortex_proto::expr::VariantGetOpts::clear(&mut self) + +pub fn vortex_proto::expr::VariantGetOpts::encoded_len(&self) -> usize + pub mod vortex_proto::scalar pub mod vortex_proto::scalar::scalar_value diff --git a/vortex-proto/src/generated/vortex.expr.rs b/vortex-proto/src/generated/vortex.expr.rs index 9c7ddb1d90c..607dc848d8a 100644 --- a/vortex-proto/src/generated/vortex.expr.rs +++ b/vortex-proto/src/generated/vortex.expr.rs @@ -153,6 +153,12 @@ pub mod select_opts { Exclude(super::FieldNames), } } +/// Options for `vortex.variant_get` +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct VariantGetOpts { + #[prost(string, tag = "1")] + pub path: ::prost::alloc::string::String, +} /// Options for `vortex.case_when` /// Encodes num_when_then_pairs and has_else into a single u32 (num_children). /// num_children = num_when_then_pairs * 2 + (has_else ? 1 : 0)