diff --git a/go/fory/fory.go b/go/fory/fory.go index 09a0e3c6d2..8c97ee2e58 100644 --- a/go/fory/fory.go +++ b/go/fory/fory.go @@ -50,18 +50,22 @@ const ( // Config holds configuration options for Fory instances type Config struct { - TrackRef bool - MaxDepth int - IsXlang bool - Compatible bool // Schema evolution compatibility mode + TrackRef bool + MaxDepth int + IsXlang bool + Compatible bool // Schema evolution compatibility mode + MaxBinarySize int // Maximum byte length for a single deserialized binary payload (0 = no limit) + MaxCollectionSize int // Maximum element count for a single deserialized collection or map (0 = no limit) } // defaultConfig returns the default configuration func defaultConfig() Config { return Config{ - TrackRef: false, // Match Java's default: reference tracking disabled - MaxDepth: 20, - IsXlang: false, + TrackRef: false, + MaxDepth: 20, + IsXlang: false, + MaxBinarySize: 64 * 1024 * 1024, // 64 MB + MaxCollectionSize: 1_000_000, } } @@ -101,6 +105,20 @@ func WithCompatible(enabled bool) Option { } } +// WithMaxBinarySize sets the maximum byte length for a single deserialized binary payload. +func WithMaxBinarySize(n int) Option { + return func(f *Fory) { + f.config.MaxBinarySize = n + } +} + +// WithMaxCollectionSize sets the maximum element count for a single deserialized collection or map. +func WithMaxCollectionSize(n int) Option { + return func(f *Fory) { + f.config.MaxCollectionSize = n + } +} + // ============================================================================ // Fory - Main serialization instance // ============================================================================ @@ -156,6 +174,7 @@ func New(opts ...Option) *Fory { f.readCtx.refResolver = f.refResolver f.readCtx.compatible = f.config.Compatible f.readCtx.xlang = f.config.IsXlang + f.readCtx.maxCollectionSize = f.config.MaxCollectionSize return f } diff --git a/go/fory/limits_test.go b/go/fory/limits_test.go new file mode 100644 index 0000000000..1a705cce5d --- /dev/null +++ b/go/fory/limits_test.go @@ -0,0 +1,125 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package fory + +import ( + "strings" + "testing" +) + +// TestSizeGuardrails_SliceExceedsLimit verifies that deserializing a slice +// whose element count exceeds MaxCollectionSize returns an error. +func TestSizeGuardrails_SliceExceedsLimit(t *testing.T) { + // Serialize a string slice with 5 elements using no limit + f1 := New(WithXlang(true), WithMaxCollectionSize(0)) + data := []string{"a", "b", "c", "d", "e"} + bytes, err := f1.Marshal(data) + if err != nil { + t.Fatalf("serialize failed: %v", err) + } + + // Deserialize with a limit of 3 — should fail + f2 := New(WithXlang(true), WithMaxCollectionSize(3)) + var result any + err2 := f2.Unmarshal(bytes, &result) + if err2 == nil { + t.Fatal("expected error when collection size exceeds limit, got nil") + } + if !strings.Contains(err2.Error(), "exceeds limit") { + t.Fatalf("expected 'exceeds limit' error, got: %v", err2) + } +} + +// TestSizeGuardrails_SliceWithinLimit verifies that a slice within limits +// deserializes successfully. +func TestSizeGuardrails_SliceWithinLimit(t *testing.T) { + f := New(WithXlang(true), WithMaxCollectionSize(100)) + data := []int32{1, 2, 3} + bytes, err := f.Marshal(data) + if err != nil { + t.Fatalf("serialize failed: %v", err) + } + var result any + err = f.Unmarshal(bytes, &result) + if err != nil { + t.Fatalf("deserialize should succeed within limit: %v", err) + } +} + +// TestSizeGuardrails_MapExceedsLimit verifies that deserializing a map +// whose entry count exceeds MaxCollectionSize returns an error. +func TestSizeGuardrails_MapExceedsLimit(t *testing.T) { + f1 := New(WithXlang(true), WithMaxCollectionSize(0)) + m := map[string]string{"a": "1", "b": "2", "c": "3", "d": "4", "e": "5"} + bytes, err := f1.Marshal(m) + if err != nil { + t.Fatalf("serialize failed: %v", err) + } + + f2 := New(WithXlang(true), WithMaxCollectionSize(2)) + var result any + err2 := f2.Unmarshal(bytes, &result) + if err2 == nil { + t.Fatal("expected error when map size exceeds limit, got nil") + } + if !strings.Contains(err2.Error(), "exceeds limit") { + t.Fatalf("expected 'exceeds limit' error, got: %v", err2) + } +} + +// TestSizeGuardrails_MapWithinLimit verifies that a map within limits +// deserializes successfully. +func TestSizeGuardrails_MapWithinLimit(t *testing.T) { + f := New(WithXlang(true), WithMaxCollectionSize(100)) + m := map[string]string{"a": "1", "b": "2"} + bytes, err := f.Marshal(m) + if err != nil { + t.Fatalf("serialize failed: %v", err) + } + var result any + err = f.Unmarshal(bytes, &result) + if err != nil { + t.Fatalf("deserialize should succeed within limit: %v", err) + } +} + +// TestSizeGuardrails_DefaultConfig verifies that default limits are set. +func TestSizeGuardrails_DefaultConfig(t *testing.T) { + f := New() + if f.config.MaxBinarySize != 64*1024*1024 { + t.Fatalf("expected default MaxBinarySize=64MB, got %d", f.config.MaxBinarySize) + } + if f.config.MaxCollectionSize != 1_000_000 { + t.Fatalf("expected default MaxCollectionSize=1000000, got %d", f.config.MaxCollectionSize) + } +} + +// TestSizeGuardrails_NoLimitWhenZero verifies that 0 means unlimited. +func TestSizeGuardrails_NoLimitWhenZero(t *testing.T) { + f := New(WithXlang(true), WithMaxCollectionSize(0)) + data := []int32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + bytes, err := f.Marshal(data) + if err != nil { + t.Fatalf("serialize failed: %v", err) + } + var result any + err = f.Unmarshal(bytes, &result) + if err != nil { + t.Fatalf("deserialize with no limit should succeed: %v", err) + } +} diff --git a/go/fory/map.go b/go/fory/map.go index f2489601f3..cf299d5bee 100644 --- a/go/fory/map.go +++ b/go/fory/map.go @@ -306,6 +306,7 @@ func (s mapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { refResolver.Reference(value) size := int(buf.ReadVarUint32(ctxErr)) + ctx.checkCollectionSize(size) if size == 0 || ctx.HasError() { return } diff --git a/go/fory/map_primitive.go b/go/fory/map_primitive.go index 21a4bd7b5d..b3c301fe54 100644 --- a/go/fory/map_primitive.go +++ b/go/fory/map_primitive.go @@ -69,8 +69,13 @@ func writeMapStringString(buf *ByteBuffer, m map[string]string, hasGenerics bool } // readMapStringString reads map[string]string using chunk protocol -func readMapStringString(buf *ByteBuffer, err *Error) map[string]string { +func readMapStringString(buf *ByteBuffer, err *Error, maxCollectionSize int) map[string]string { size := int(buf.ReadVarUint32(err)) + if maxCollectionSize > 0 && size > maxCollectionSize { + err.SetError(DeserializationErrorf( + "fory: map size %d exceeds limit %d", size, maxCollectionSize)) + return nil + } result := make(map[string]string, size) if size == 0 { return result @@ -172,8 +177,13 @@ func writeMapStringInt64(buf *ByteBuffer, m map[string]int64, hasGenerics bool) } // readMapStringInt64 reads map[string]int64 using chunk protocol -func readMapStringInt64(buf *ByteBuffer, err *Error) map[string]int64 { +func readMapStringInt64(buf *ByteBuffer, err *Error, maxCollectionSize int) map[string]int64 { size := int(buf.ReadVarUint32(err)) + if maxCollectionSize > 0 && size > maxCollectionSize { + err.SetError(DeserializationErrorf( + "fory: map size %d exceeds limit %d", size, maxCollectionSize)) + return nil + } result := make(map[string]int64, size) if size == 0 { return result @@ -246,8 +256,13 @@ func writeMapStringInt32(buf *ByteBuffer, m map[string]int32, hasGenerics bool) } // readMapStringInt32 reads map[string]int32 using chunk protocol -func readMapStringInt32(buf *ByteBuffer, err *Error) map[string]int32 { +func readMapStringInt32(buf *ByteBuffer, err *Error, maxCollectionSize int) map[string]int32 { size := int(buf.ReadVarUint32(err)) + if maxCollectionSize > 0 && size > maxCollectionSize { + err.SetError(DeserializationErrorf( + "fory: map size %d exceeds limit %d", size, maxCollectionSize)) + return nil + } result := make(map[string]int32, size) if size == 0 { return result @@ -320,8 +335,13 @@ func writeMapStringInt(buf *ByteBuffer, m map[string]int, hasGenerics bool) { } // readMapStringInt reads map[string]int using chunk protocol -func readMapStringInt(buf *ByteBuffer, err *Error) map[string]int { +func readMapStringInt(buf *ByteBuffer, err *Error, maxCollectionSize int) map[string]int { size := int(buf.ReadVarUint32(err)) + if maxCollectionSize > 0 && size > maxCollectionSize { + err.SetError(DeserializationErrorf( + "fory: map size %d exceeds limit %d", size, maxCollectionSize)) + return nil + } result := make(map[string]int, size) if size == 0 { return result @@ -394,8 +414,13 @@ func writeMapStringFloat64(buf *ByteBuffer, m map[string]float64, hasGenerics bo } // readMapStringFloat64 reads map[string]float64 using chunk protocol -func readMapStringFloat64(buf *ByteBuffer, err *Error) map[string]float64 { +func readMapStringFloat64(buf *ByteBuffer, err *Error, maxCollectionSize int) map[string]float64 { size := int(buf.ReadVarUint32(err)) + if maxCollectionSize > 0 && size > maxCollectionSize { + err.SetError(DeserializationErrorf( + "fory: map size %d exceeds limit %d", size, maxCollectionSize)) + return nil + } result := make(map[string]float64, size) if size == 0 { return result @@ -468,8 +493,13 @@ func writeMapStringBool(buf *ByteBuffer, m map[string]bool, hasGenerics bool) { } // readMapStringBool reads map[string]bool using chunk protocol -func readMapStringBool(buf *ByteBuffer, err *Error) map[string]bool { +func readMapStringBool(buf *ByteBuffer, err *Error, maxCollectionSize int) map[string]bool { size := int(buf.ReadVarUint32(err)) + if maxCollectionSize > 0 && size > maxCollectionSize { + err.SetError(DeserializationErrorf( + "fory: map size %d exceeds limit %d", size, maxCollectionSize)) + return nil + } result := make(map[string]bool, size) if size == 0 { return result @@ -547,8 +577,13 @@ func writeMapInt32Int32(buf *ByteBuffer, m map[int32]int32, hasGenerics bool) { } // readMapInt32Int32 reads map[int32]int32 using chunk protocol -func readMapInt32Int32(buf *ByteBuffer, err *Error) map[int32]int32 { +func readMapInt32Int32(buf *ByteBuffer, err *Error, maxCollectionSize int) map[int32]int32 { size := int(buf.ReadVarUint32(err)) + if maxCollectionSize > 0 && size > maxCollectionSize { + err.SetError(DeserializationErrorf( + "fory: map size %d exceeds limit %d", size, maxCollectionSize)) + return nil + } result := make(map[int32]int32, size) if size == 0 { return result @@ -621,8 +656,13 @@ func writeMapInt64Int64(buf *ByteBuffer, m map[int64]int64, hasGenerics bool) { } // readMapInt64Int64 reads map[int64]int64 using chunk protocol -func readMapInt64Int64(buf *ByteBuffer, err *Error) map[int64]int64 { +func readMapInt64Int64(buf *ByteBuffer, err *Error, maxCollectionSize int) map[int64]int64 { size := int(buf.ReadVarUint32(err)) + if maxCollectionSize > 0 && size > maxCollectionSize { + err.SetError(DeserializationErrorf( + "fory: map size %d exceeds limit %d", size, maxCollectionSize)) + return nil + } result := make(map[int64]int64, size) if size == 0 { return result @@ -695,8 +735,13 @@ func writeMapIntInt(buf *ByteBuffer, m map[int]int, hasGenerics bool) { } // readMapIntInt reads map[int]int using chunk protocol -func readMapIntInt(buf *ByteBuffer, err *Error) map[int]int { +func readMapIntInt(buf *ByteBuffer, err *Error, maxCollectionSize int) map[int]int { size := int(buf.ReadVarUint32(err)) + if maxCollectionSize > 0 && size > maxCollectionSize { + err.SetError(DeserializationErrorf( + "fory: map size %d exceeds limit %d", size, maxCollectionSize)) + return nil + } result := make(map[int]int, size) if size == 0 { return result @@ -752,7 +797,7 @@ func (s stringStringMapSerializer) ReadData(ctx *ReadContext, value reflect.Valu value.Set(reflect.MakeMap(value.Type())) } ctx.RefResolver().Reference(value) - result := readMapStringString(ctx.buffer, ctx.Err()) + result := readMapStringString(ctx.buffer, ctx.Err(), ctx.maxCollectionSize) value.Set(reflect.ValueOf(result)) } @@ -787,7 +832,7 @@ func (s stringInt64MapSerializer) ReadData(ctx *ReadContext, value reflect.Value value.Set(reflect.MakeMap(value.Type())) } ctx.RefResolver().Reference(value) - result := readMapStringInt64(ctx.buffer, ctx.Err()) + result := readMapStringInt64(ctx.buffer, ctx.Err(), ctx.maxCollectionSize) value.Set(reflect.ValueOf(result)) } @@ -822,7 +867,7 @@ func (s stringIntMapSerializer) ReadData(ctx *ReadContext, value reflect.Value) value.Set(reflect.MakeMap(value.Type())) } ctx.RefResolver().Reference(value) - result := readMapStringInt(ctx.buffer, ctx.Err()) + result := readMapStringInt(ctx.buffer, ctx.Err(), ctx.maxCollectionSize) value.Set(reflect.ValueOf(result)) } @@ -857,7 +902,7 @@ func (s stringFloat64MapSerializer) ReadData(ctx *ReadContext, value reflect.Val value.Set(reflect.MakeMap(value.Type())) } ctx.RefResolver().Reference(value) - result := readMapStringFloat64(ctx.buffer, ctx.Err()) + result := readMapStringFloat64(ctx.buffer, ctx.Err(), ctx.maxCollectionSize) value.Set(reflect.ValueOf(result)) } @@ -892,7 +937,7 @@ func (s stringBoolMapSerializer) ReadData(ctx *ReadContext, value reflect.Value) value.Set(reflect.MakeMap(value.Type())) } ctx.RefResolver().Reference(value) - result := readMapStringBool(ctx.buffer, ctx.Err()) + result := readMapStringBool(ctx.buffer, ctx.Err(), ctx.maxCollectionSize) value.Set(reflect.ValueOf(result)) } @@ -927,7 +972,7 @@ func (s int32Int32MapSerializer) ReadData(ctx *ReadContext, value reflect.Value) value.Set(reflect.MakeMap(value.Type())) } ctx.RefResolver().Reference(value) - result := readMapInt32Int32(ctx.buffer, ctx.Err()) + result := readMapInt32Int32(ctx.buffer, ctx.Err(), ctx.maxCollectionSize) value.Set(reflect.ValueOf(result)) } @@ -962,7 +1007,7 @@ func (s int64Int64MapSerializer) ReadData(ctx *ReadContext, value reflect.Value) value.Set(reflect.MakeMap(value.Type())) } ctx.RefResolver().Reference(value) - result := readMapInt64Int64(ctx.buffer, ctx.Err()) + result := readMapInt64Int64(ctx.buffer, ctx.Err(), ctx.maxCollectionSize) value.Set(reflect.ValueOf(result)) } @@ -997,7 +1042,7 @@ func (s intIntMapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { value.Set(reflect.MakeMap(value.Type())) } ctx.RefResolver().Reference(value) - result := readMapIntInt(ctx.buffer, ctx.Err()) + result := readMapIntInt(ctx.buffer, ctx.Err(), ctx.maxCollectionSize) value.Set(reflect.ValueOf(result)) } diff --git a/go/fory/reader.go b/go/fory/reader.go index e7a1df1710..b3d10b1142 100644 --- a/go/fory/reader.go +++ b/go/fory/reader.go @@ -43,6 +43,8 @@ type ReadContext struct { err Error // Accumulated error state for deferred checking lastTypePtr uintptr lastTypeInfo *TypeInfo + + maxCollectionSize int // Maximum element count for a single collection or map (0 = no limit) } // IsXlang returns whether cross-language serialization mode is enabled @@ -148,6 +150,15 @@ func (c *ReadContext) CheckError() error { return nil } +// checkCollectionSize validates that a collection/map element count does not exceed the configured limit. +func (c *ReadContext) checkCollectionSize(size int) { + if c.maxCollectionSize > 0 && size > c.maxCollectionSize { + c.SetError(DeserializationErrorf( + "fory: collection/map size %d exceeds limit %d", size, c.maxCollectionSize)) + } +} + + // Inline primitive reads func (c *ReadContext) RawBool() bool { return c.buffer.ReadBool(c.Err()) } func (c *ReadContext) RawInt8() int8 { return int8(c.buffer.ReadByte(c.Err())) } @@ -462,7 +473,7 @@ func (c *ReadContext) ReadStringStringMap(refMode RefMode, readType bool) map[st if readType { _ = c.buffer.ReadUint8(err) } - return readMapStringString(c.buffer, err) + return readMapStringString(c.buffer, err, c.maxCollectionSize) } // ReadStringInt64Map reads map[string]int64 with optional ref/type info @@ -476,7 +487,7 @@ func (c *ReadContext) ReadStringInt64Map(refMode RefMode, readType bool) map[str if readType { _ = c.buffer.ReadUint8(err) } - return readMapStringInt64(c.buffer, err) + return readMapStringInt64(c.buffer, err, c.maxCollectionSize) } // ReadStringInt32Map reads map[string]int32 with optional ref/type info @@ -490,7 +501,7 @@ func (c *ReadContext) ReadStringInt32Map(refMode RefMode, readType bool) map[str if readType { _ = c.buffer.ReadUint8(err) } - return readMapStringInt32(c.buffer, err) + return readMapStringInt32(c.buffer, err, c.maxCollectionSize) } // ReadStringIntMap reads map[string]int with optional ref/type info @@ -504,7 +515,7 @@ func (c *ReadContext) ReadStringIntMap(refMode RefMode, readType bool) map[strin if readType { _ = c.buffer.ReadUint8(err) } - return readMapStringInt(c.buffer, err) + return readMapStringInt(c.buffer, err, c.maxCollectionSize) } // ReadStringFloat64Map reads map[string]float64 with optional ref/type info @@ -518,7 +529,7 @@ func (c *ReadContext) ReadStringFloat64Map(refMode RefMode, readType bool) map[s if readType { _ = c.buffer.ReadUint8(err) } - return readMapStringFloat64(c.buffer, err) + return readMapStringFloat64(c.buffer, err, c.maxCollectionSize) } // ReadStringBoolMap reads map[string]bool with optional ref/type info @@ -532,7 +543,7 @@ func (c *ReadContext) ReadStringBoolMap(refMode RefMode, readType bool) map[stri if readType { _ = c.buffer.ReadUint8(err) } - return readMapStringBool(c.buffer, err) + return readMapStringBool(c.buffer, err, c.maxCollectionSize) } // ReadInt32Int32Map reads map[int32]int32 with optional ref/type info @@ -546,7 +557,7 @@ func (c *ReadContext) ReadInt32Int32Map(refMode RefMode, readType bool) map[int3 if readType { _ = c.buffer.ReadUint8(err) } - return readMapInt32Int32(c.buffer, err) + return readMapInt32Int32(c.buffer, err, c.maxCollectionSize) } // ReadInt64Int64Map reads map[int64]int64 with optional ref/type info @@ -560,7 +571,7 @@ func (c *ReadContext) ReadInt64Int64Map(refMode RefMode, readType bool) map[int6 if readType { _ = c.buffer.ReadUint8(err) } - return readMapInt64Int64(c.buffer, err) + return readMapInt64Int64(c.buffer, err, c.maxCollectionSize) } // ReadIntIntMap reads map[int]int with optional ref/type info @@ -574,7 +585,7 @@ func (c *ReadContext) ReadIntIntMap(refMode RefMode, readType bool) map[int]int if readType { _ = c.buffer.ReadUint8(err) } - return readMapIntInt(c.buffer, err) + return readMapIntInt(c.buffer, err, c.maxCollectionSize) } // ReadBufferObject reads a buffer object diff --git a/go/fory/slice.go b/go/fory/slice.go index bd3a9aa7ee..b3960a5c4c 100644 --- a/go/fory/slice.go +++ b/go/fory/slice.go @@ -265,6 +265,10 @@ func (s *sliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() ctxErr := ctx.Err() length := int(buf.ReadVarUint32(ctxErr)) + ctx.checkCollectionSize(length) + if ctx.HasError() { + return + } isArrayType := value.Type().Kind() == reflect.Array if length == 0 { diff --git a/go/fory/slice_dyn.go b/go/fory/slice_dyn.go index 3393d4b22b..efdae2ce4a 100644 --- a/go/fory/slice_dyn.go +++ b/go/fory/slice_dyn.go @@ -262,6 +262,10 @@ func (s sliceDynSerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() ctxErr := ctx.Err() length := int(buf.ReadVarUint32(ctxErr)) + ctx.checkCollectionSize(length) + if ctx.HasError() { + return + } sliceType := value.Type() value.Set(reflect.MakeSlice(sliceType, length, length)) if length == 0 { diff --git a/go/fory/slice_primitive.go b/go/fory/slice_primitive.go index e4daf990be..ce0c924fe6 100644 --- a/go/fory/slice_primitive.go +++ b/go/fory/slice_primitive.go @@ -75,6 +75,10 @@ func (s byteSliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() ctxErr := ctx.Err() length := buf.ReadLength(ctxErr) + ctx.checkCollectionSize(length) + if ctx.HasError() { + return + } ptr := (*[]byte)(value.Addr().UnsafePointer()) if length == 0 { *ptr = make([]byte, 0) @@ -643,6 +647,10 @@ func (s stringSliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() ctxErr := ctx.Err() length := int(buf.ReadVarUint32(ctxErr)) + ctx.checkCollectionSize(length) + if ctx.HasError() { + return + } ptr := (*[]string)(value.Addr().UnsafePointer()) if length == 0 { *ptr = make([]string, 0)