Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 231 additions & 5 deletions src/ser.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use half::{bf16, f16};
use serde::{
ser::{
Error as _, Impossible, SerializeMap, SerializeSeq, SerializeStruct,
Expand All @@ -7,7 +8,44 @@ use serde::{
};
use serde_json::error::Error;

use crate::{array::TryCollect, DestructuredRef, IArray, INumber, IObject, IString, IValue};
use crate::{
array::{ArraySliceRef, TryCollect},
DestructuredRef, IArray, INumber, IObject, IString, IValue,
};

/// Finds an f64 value that, when formatted by ryu's f64 algorithm, produces
/// the shortest decimal string that still round-trips through the target
/// half-precision type (f16 or bf16).
///
/// ryu only supports f32/f64, and serde has no `serialize_f16`. Since f16/bf16
/// have far fewer distinct values than f32, there exist shorter representations
/// that uniquely identify the half value. For example, f16(0.3) = 0.300048828125,
/// and "0.3" parsed as f16 gives back the same bits — so "0.3" is valid.
///
/// The approach: try increasing significant digits (in scientific notation) until
/// the formatted string round-trips through the half type. Then return the f64
/// value of that string, so that `serialize_f64` (via ryu) reproduces it.
fn find_shortest_roundtrip_f64(f64_val: f64, roundtrips: impl Fn(f64) -> bool) -> f64 {
if !f64_val.is_finite() || f64_val.fract() == 0.0 {
return f64_val;
}
// What we do here:
// Looking for the shortest string where f16::from_f64(str.parse::<f64>()) == original_f16
// With our usage(F16/BF16), the loop will need only ~4 iterations, since max significant digits needed is ~4
// Example: f16(3.14159) stores 3.140625
// sig_digits=1 → "3e0" → 3.0 → f16(3.0)=3.0 ≠ 3.140625 ❌
// sig_digits=2 → "3.1e0"→ 3.1 → f16(3.1)=3.099.. ≠ 3.140625 ❌
// sig_digits=3 → "3.14e0"→3.14 → f16(3.14)=3.140625 ✅ → returns 3.14
for sig_digits in 1..=5 {
let s = format!("{:.prec$e}", f64_val, prec = sig_digits - 1);
if let Ok(parsed) = s.parse::<f64>() {
if roundtrips(parsed) {
return parsed;
}
}
}
f64_val
}

impl Serialize for IValue {
#[inline]
Expand Down Expand Up @@ -55,11 +93,50 @@ impl Serialize for IArray {
where
S: Serializer,
{
let mut s = serializer.serialize_seq(Some(self.len()))?;
for v in self {
s.serialize_element(&v)?;
match self.as_slice() {
// Serialize typed float arrays with the shortest representation that
// round-trips through the stored precision. Without this, all floats
// would be promoted to f64 via INumber, and ryu's f64 algorithm would
// emit unnecessarily long strings (e.g. "0.3" stored as f32 would
// serialize as "0.30000001192092896" instead of "0.3").
//
// F32: serialize directly as f32 so ryu uses its f32 algorithm.
// F16/BF16: ryu has no f16 mode and serde has no serialize_f16, so we
// find the shortest decimal that round-trips through the half type and
// pass the corresponding f64 value to serialize_f64.
ArraySliceRef::F32(slice) => {
let mut s = serializer.serialize_seq(Some(slice.len()))?;
for &v in slice {
s.serialize_element(&v)?;
}
s.end()
}
ArraySliceRef::F16(slice) => {
let mut s = serializer.serialize_seq(Some(slice.len()))?;
for &v in slice {
let f64_val = f64::from(v);
let shortest = find_shortest_roundtrip_f64(f64_val, |p| f16::from_f64(p) == v);
s.serialize_element(&shortest)?;
}
s.end()
}
ArraySliceRef::BF16(slice) => {
let mut s = serializer.serialize_seq(Some(slice.len()))?;
for &v in slice {
let f64_val = f64::from(v);
let shortest = find_shortest_roundtrip_f64(f64_val, |p| bf16::from_f64(p) == v);
s.serialize_element(&shortest)?;
}
s.end()
}
_ => {
let mut s = serializer.serialize_seq(Some(self.len()))?;
for v in self {
s.serialize_element(&v)?;
}
s.end()
}
}
s.end()
}
}

Expand Down Expand Up @@ -635,3 +712,152 @@ where
{
value.serialize(ValueSerializer)
}

#[cfg(test)]
mod tests {
use crate::array::{ArraySliceRef, FloatType};
use crate::{FPHAConfig, IArray, IValue, IValueDeserSeed};

#[test]
fn test_f32_array_serialization_preserves_short_representation() {
let mut arr = IArray::new();
arr.push_with_fp_type(IValue::from(0.3), FloatType::F32)
.unwrap();
assert!(matches!(arr.as_slice(), ArraySliceRef::F32(_)));

let json = serde_json::to_string(&arr).unwrap();
assert_eq!(
json, "[0.3]",
"F32 array should serialize 0.3 as '0.3', not with extra f64 precision digits"
);
}

#[test]
fn test_f64_array_serialization_preserves_short_representation() {
let mut arr = IArray::new();
arr.push_with_fp_type(IValue::from(0.3), FloatType::F64)
.unwrap();
assert!(matches!(arr.as_slice(), ArraySliceRef::F64(_)));

let json = serde_json::to_string(&arr).unwrap();
assert_eq!(json, "[0.3]");
}

#[test]
fn test_f16_array_serialization_preserves_short_representation() {
let mut arr = IArray::new();
arr.push_with_fp_type(IValue::from(1.5), FloatType::F16)
.unwrap();
assert_eq!(serde_json::to_string(&arr).unwrap(), "[1.5]");

let mut arr2 = IArray::new();
arr2.push_with_fp_type(IValue::from(0.3), FloatType::F16)
.unwrap();
assert_eq!(
serde_json::to_string(&arr2).unwrap(),
"[0.3]",
"F16 array should serialize 0.3 as '0.3', not '0.30004883' or '0.300048828125'"
);
}

#[test]
fn test_bf16_array_serialization_preserves_short_representation() {
let mut arr = IArray::new();
arr.push_with_fp_type(IValue::from(1.5), FloatType::BF16)
.unwrap();
assert_eq!(serde_json::to_string(&arr).unwrap(), "[1.5]");

let mut arr2 = IArray::new();
arr2.push_with_fp_type(IValue::from(0.3), FloatType::BF16)
.unwrap();
assert_eq!(
serde_json::to_string(&arr2).unwrap(),
"[0.3]",
"BF16 array should serialize 0.3 as '0.3'"
);
}

#[test]
fn test_typed_float_array_serialization_roundtrip() {
let input = "[0.3,0.1,0.7,1.0,2.5,100.0]";
let fp_types = [
FloatType::F16,
FloatType::BF16,
FloatType::F32,
FloatType::F64,
];

let jsons: Vec<String> = fp_types
.iter()
.map(|&fp_type| {
let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(fp_type)));
let mut de = serde_json::Deserializer::from_str(input);
let arr = serde::de::DeserializeSeed::deserialize(seed, &mut de)
.unwrap()
.into_array()
.unwrap();
let json_out = serde_json::to_string(&arr).unwrap();
assert_eq!(
json_out, input,
"{fp_type} round-trip should preserve the original JSON string"
);
json_out
})
.collect();

for pair in jsons.windows(2) {
assert_eq!(
pair[0], pair[1],
"all float types should produce identical JSON"
);
}
}

#[test]
fn test_f16_precision_loss_produces_different_but_short_representation() {
// Values with more significant digits than f16 can represent (~3.3 digits).
// The stored f16 value differs from the original, so the serialized string
// must differ too — but it should still be the shortest string that
// round-trips through f16.
let cases: &[(&str, &str)] = &[
("3.14159", "3.14"), // pi truncated: f16 stores 3.140625
("42.42", "42.4"), // f16 stores 42.40625
("12.345", "12.34"), // f16 stores 12.34375
("0.5678", "0.568"), // f16 stores 0.56787109375
];

for &(input, expected_f16) in cases {
let json_input = format!("[{input}]");

let f16_arr: IArray = {
let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F16)));
let mut de = serde_json::Deserializer::from_str(&json_input);
serde::de::DeserializeSeed::deserialize(seed, &mut de)
.unwrap()
.into_array()
.unwrap()
};
let f16_json = serde_json::to_string(&f16_arr).unwrap();
assert_eq!(
f16_json,
format!("[{expected_f16}]"),
"F16 of {input}: should serialize as shortest f16 representation"
);

// Same values through F32 should preserve the original (enough precision)
let f32_arr: IArray = {
let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F32)));
let mut de = serde_json::Deserializer::from_str(&json_input);
serde::de::DeserializeSeed::deserialize(seed, &mut de)
.unwrap()
.into_array()
.unwrap()
};
let f32_json = serde_json::to_string(&f32_arr).unwrap();
assert_eq!(
f32_json, json_input,
"F32 of {input}: should preserve the original representation"
);
}
}
}
Loading