diff --git a/src/ser.rs b/src/ser.rs index 673a41e..0012235 100644 --- a/src/ser.rs +++ b/src/ser.rs @@ -1,3 +1,4 @@ +use half::{bf16, f16}; use serde::{ ser::{ Error as _, Impossible, SerializeMap, SerializeSeq, SerializeStruct, @@ -7,7 +8,44 @@ use serde::{ }; use serde_json::error::Error; -use crate::{array::TryCollect, DestructuredRef, IArray, INumber, IObject, IString, IValue}; +use crate::{ + array::{ArraySliceRef, TryCollect}, + DestructuredRef, IArray, INumber, IObject, IString, IValue, +}; + +/// Finds an f64 value that, when formatted by ryu's f64 algorithm, produces +/// the shortest decimal string that still round-trips through the target +/// half-precision type (f16 or bf16). +/// +/// ryu only supports f32/f64, and serde has no `serialize_f16`. Since f16/bf16 +/// have far fewer distinct values than f32, there exist shorter representations +/// that uniquely identify the half value. For example, f16(0.3) = 0.300048828125, +/// and "0.3" parsed as f16 gives back the same bits — so "0.3" is valid. +/// +/// The approach: try increasing significant digits (in scientific notation) until +/// the formatted string round-trips through the half type. Then return the f64 +/// value of that string, so that `serialize_f64` (via ryu) reproduces it. +fn find_shortest_roundtrip_f64(f64_val: f64, roundtrips: impl Fn(f64) -> bool) -> f64 { + if !f64_val.is_finite() || f64_val.fract() == 0.0 { + return f64_val; + } + // What we do here: + // Looking for the shortest string where f16::from_f64(str.parse::()) == original_f16 + // With our usage(F16/BF16), the loop will need only ~4 iterations, since max significant digits needed is ~4 + // Example: f16(3.14159) stores 3.140625 + // sig_digits=1 → "3e0" → 3.0 → f16(3.0)=3.0 ≠ 3.140625 ❌ + // sig_digits=2 → "3.1e0"→ 3.1 → f16(3.1)=3.099.. ≠ 3.140625 ❌ + // sig_digits=3 → "3.14e0"→3.14 → f16(3.14)=3.140625 ✅ → returns 3.14 + for sig_digits in 1..=5 { + let s = format!("{:.prec$e}", f64_val, prec = sig_digits - 1); + if let Ok(parsed) = s.parse::() { + if roundtrips(parsed) { + return parsed; + } + } + } + f64_val +} impl Serialize for IValue { #[inline] @@ -55,11 +93,50 @@ impl Serialize for IArray { where S: Serializer, { - let mut s = serializer.serialize_seq(Some(self.len()))?; - for v in self { - s.serialize_element(&v)?; + match self.as_slice() { + // Serialize typed float arrays with the shortest representation that + // round-trips through the stored precision. Without this, all floats + // would be promoted to f64 via INumber, and ryu's f64 algorithm would + // emit unnecessarily long strings (e.g. "0.3" stored as f32 would + // serialize as "0.30000001192092896" instead of "0.3"). + // + // F32: serialize directly as f32 so ryu uses its f32 algorithm. + // F16/BF16: ryu has no f16 mode and serde has no serialize_f16, so we + // find the shortest decimal that round-trips through the half type and + // pass the corresponding f64 value to serialize_f64. + ArraySliceRef::F32(slice) => { + let mut s = serializer.serialize_seq(Some(slice.len()))?; + for &v in slice { + s.serialize_element(&v)?; + } + s.end() + } + ArraySliceRef::F16(slice) => { + let mut s = serializer.serialize_seq(Some(slice.len()))?; + for &v in slice { + let f64_val = f64::from(v); + let shortest = find_shortest_roundtrip_f64(f64_val, |p| f16::from_f64(p) == v); + s.serialize_element(&shortest)?; + } + s.end() + } + ArraySliceRef::BF16(slice) => { + let mut s = serializer.serialize_seq(Some(slice.len()))?; + for &v in slice { + let f64_val = f64::from(v); + let shortest = find_shortest_roundtrip_f64(f64_val, |p| bf16::from_f64(p) == v); + s.serialize_element(&shortest)?; + } + s.end() + } + _ => { + let mut s = serializer.serialize_seq(Some(self.len()))?; + for v in self { + s.serialize_element(&v)?; + } + s.end() + } } - s.end() } } @@ -635,3 +712,152 @@ where { value.serialize(ValueSerializer) } + +#[cfg(test)] +mod tests { + use crate::array::{ArraySliceRef, FloatType}; + use crate::{FPHAConfig, IArray, IValue, IValueDeserSeed}; + + #[test] + fn test_f32_array_serialization_preserves_short_representation() { + let mut arr = IArray::new(); + arr.push_with_fp_type(IValue::from(0.3), FloatType::F32) + .unwrap(); + assert!(matches!(arr.as_slice(), ArraySliceRef::F32(_))); + + let json = serde_json::to_string(&arr).unwrap(); + assert_eq!( + json, "[0.3]", + "F32 array should serialize 0.3 as '0.3', not with extra f64 precision digits" + ); + } + + #[test] + fn test_f64_array_serialization_preserves_short_representation() { + let mut arr = IArray::new(); + arr.push_with_fp_type(IValue::from(0.3), FloatType::F64) + .unwrap(); + assert!(matches!(arr.as_slice(), ArraySliceRef::F64(_))); + + let json = serde_json::to_string(&arr).unwrap(); + assert_eq!(json, "[0.3]"); + } + + #[test] + fn test_f16_array_serialization_preserves_short_representation() { + let mut arr = IArray::new(); + arr.push_with_fp_type(IValue::from(1.5), FloatType::F16) + .unwrap(); + assert_eq!(serde_json::to_string(&arr).unwrap(), "[1.5]"); + + let mut arr2 = IArray::new(); + arr2.push_with_fp_type(IValue::from(0.3), FloatType::F16) + .unwrap(); + assert_eq!( + serde_json::to_string(&arr2).unwrap(), + "[0.3]", + "F16 array should serialize 0.3 as '0.3', not '0.30004883' or '0.300048828125'" + ); + } + + #[test] + fn test_bf16_array_serialization_preserves_short_representation() { + let mut arr = IArray::new(); + arr.push_with_fp_type(IValue::from(1.5), FloatType::BF16) + .unwrap(); + assert_eq!(serde_json::to_string(&arr).unwrap(), "[1.5]"); + + let mut arr2 = IArray::new(); + arr2.push_with_fp_type(IValue::from(0.3), FloatType::BF16) + .unwrap(); + assert_eq!( + serde_json::to_string(&arr2).unwrap(), + "[0.3]", + "BF16 array should serialize 0.3 as '0.3'" + ); + } + + #[test] + fn test_typed_float_array_serialization_roundtrip() { + let input = "[0.3,0.1,0.7,1.0,2.5,100.0]"; + let fp_types = [ + FloatType::F16, + FloatType::BF16, + FloatType::F32, + FloatType::F64, + ]; + + let jsons: Vec = fp_types + .iter() + .map(|&fp_type| { + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(fp_type))); + let mut de = serde_json::Deserializer::from_str(input); + let arr = serde::de::DeserializeSeed::deserialize(seed, &mut de) + .unwrap() + .into_array() + .unwrap(); + let json_out = serde_json::to_string(&arr).unwrap(); + assert_eq!( + json_out, input, + "{fp_type} round-trip should preserve the original JSON string" + ); + json_out + }) + .collect(); + + for pair in jsons.windows(2) { + assert_eq!( + pair[0], pair[1], + "all float types should produce identical JSON" + ); + } + } + + #[test] + fn test_f16_precision_loss_produces_different_but_short_representation() { + // Values with more significant digits than f16 can represent (~3.3 digits). + // The stored f16 value differs from the original, so the serialized string + // must differ too — but it should still be the shortest string that + // round-trips through f16. + let cases: &[(&str, &str)] = &[ + ("3.14159", "3.14"), // pi truncated: f16 stores 3.140625 + ("42.42", "42.4"), // f16 stores 42.40625 + ("12.345", "12.34"), // f16 stores 12.34375 + ("0.5678", "0.568"), // f16 stores 0.56787109375 + ]; + + for &(input, expected_f16) in cases { + let json_input = format!("[{input}]"); + + let f16_arr: IArray = { + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F16))); + let mut de = serde_json::Deserializer::from_str(&json_input); + serde::de::DeserializeSeed::deserialize(seed, &mut de) + .unwrap() + .into_array() + .unwrap() + }; + let f16_json = serde_json::to_string(&f16_arr).unwrap(); + assert_eq!( + f16_json, + format!("[{expected_f16}]"), + "F16 of {input}: should serialize as shortest f16 representation" + ); + + // Same values through F32 should preserve the original (enough precision) + let f32_arr: IArray = { + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F32))); + let mut de = serde_json::Deserializer::from_str(&json_input); + serde::de::DeserializeSeed::deserialize(seed, &mut de) + .unwrap() + .into_array() + .unwrap() + }; + let f32_json = serde_json::to_string(&f32_arr).unwrap(); + assert_eq!( + f32_json, json_input, + "F32 of {input}: should preserve the original representation" + ); + } + } +}