Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions go/fory/fory.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,22 @@ const (

// Config holds configuration options for Fory instances
type Config struct {
TrackRef bool
MaxDepth int
IsXlang bool
Compatible bool // Schema evolution compatibility mode
TrackRef bool
MaxDepth int
IsXlang bool
Compatible bool // Schema evolution compatibility mode
MaxBinarySize int // Maximum byte length for a single deserialized binary payload (0 = no limit)
MaxCollectionSize int // Maximum element count for a single deserialized collection or map (0 = no limit)
}

// defaultConfig returns the default configuration
func defaultConfig() Config {
return Config{
TrackRef: false, // Match Java's default: reference tracking disabled
MaxDepth: 20,
IsXlang: false,
TrackRef: false,
MaxDepth: 20,
IsXlang: false,
MaxBinarySize: 64 * 1024 * 1024, // 64 MB
MaxCollectionSize: 1_000_000,
}
}

Expand Down Expand Up @@ -101,6 +105,20 @@ func WithCompatible(enabled bool) Option {
}
}

// WithMaxBinarySize sets the maximum byte length for a single deserialized binary payload.
func WithMaxBinarySize(n int) Option {
return func(f *Fory) {
f.config.MaxBinarySize = n
}
}

// WithMaxCollectionSize sets the maximum element count for a single deserialized collection or map.
func WithMaxCollectionSize(n int) Option {
return func(f *Fory) {
f.config.MaxCollectionSize = n
}
}

// ============================================================================
// Fory - Main serialization instance
// ============================================================================
Expand Down Expand Up @@ -156,6 +174,7 @@ func New(opts ...Option) *Fory {
f.readCtx.refResolver = f.refResolver
f.readCtx.compatible = f.config.Compatible
f.readCtx.xlang = f.config.IsXlang
f.readCtx.maxCollectionSize = f.config.MaxCollectionSize

return f
}
Expand Down
125 changes: 125 additions & 0 deletions go/fory/limits_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package fory

import (
"strings"
"testing"
)

// TestSizeGuardrails_SliceExceedsLimit verifies that deserializing a slice
// whose element count exceeds MaxCollectionSize returns an error.
func TestSizeGuardrails_SliceExceedsLimit(t *testing.T) {
// Serialize a string slice with 5 elements using no limit
f1 := New(WithXlang(true), WithMaxCollectionSize(0))
data := []string{"a", "b", "c", "d", "e"}
bytes, err := f1.Marshal(data)
if err != nil {
t.Fatalf("serialize failed: %v", err)
}

// Deserialize with a limit of 3 — should fail
f2 := New(WithXlang(true), WithMaxCollectionSize(3))
var result any
err2 := f2.Unmarshal(bytes, &result)
if err2 == nil {
t.Fatal("expected error when collection size exceeds limit, got nil")
}
if !strings.Contains(err2.Error(), "exceeds limit") {
t.Fatalf("expected 'exceeds limit' error, got: %v", err2)
}
}

// TestSizeGuardrails_SliceWithinLimit verifies that a slice within limits
// deserializes successfully.
func TestSizeGuardrails_SliceWithinLimit(t *testing.T) {
f := New(WithXlang(true), WithMaxCollectionSize(100))
data := []int32{1, 2, 3}
bytes, err := f.Marshal(data)
if err != nil {
t.Fatalf("serialize failed: %v", err)
}
var result any
err = f.Unmarshal(bytes, &result)
if err != nil {
t.Fatalf("deserialize should succeed within limit: %v", err)
}
}

// TestSizeGuardrails_MapExceedsLimit verifies that deserializing a map
// whose entry count exceeds MaxCollectionSize returns an error.
func TestSizeGuardrails_MapExceedsLimit(t *testing.T) {
f1 := New(WithXlang(true), WithMaxCollectionSize(0))
m := map[string]string{"a": "1", "b": "2", "c": "3", "d": "4", "e": "5"}
bytes, err := f1.Marshal(m)
if err != nil {
t.Fatalf("serialize failed: %v", err)
}

f2 := New(WithXlang(true), WithMaxCollectionSize(2))
var result any
err2 := f2.Unmarshal(bytes, &result)
if err2 == nil {
t.Fatal("expected error when map size exceeds limit, got nil")
}
if !strings.Contains(err2.Error(), "exceeds limit") {
t.Fatalf("expected 'exceeds limit' error, got: %v", err2)
}
}

// TestSizeGuardrails_MapWithinLimit verifies that a map within limits
// deserializes successfully.
func TestSizeGuardrails_MapWithinLimit(t *testing.T) {
f := New(WithXlang(true), WithMaxCollectionSize(100))
m := map[string]string{"a": "1", "b": "2"}
bytes, err := f.Marshal(m)
if err != nil {
t.Fatalf("serialize failed: %v", err)
}
var result any
err = f.Unmarshal(bytes, &result)
if err != nil {
t.Fatalf("deserialize should succeed within limit: %v", err)
}
}

// TestSizeGuardrails_DefaultConfig verifies that default limits are set.
func TestSizeGuardrails_DefaultConfig(t *testing.T) {
f := New()
if f.config.MaxBinarySize != 64*1024*1024 {
t.Fatalf("expected default MaxBinarySize=64MB, got %d", f.config.MaxBinarySize)
}
if f.config.MaxCollectionSize != 1_000_000 {
t.Fatalf("expected default MaxCollectionSize=1000000, got %d", f.config.MaxCollectionSize)
}
}

// TestSizeGuardrails_NoLimitWhenZero verifies that 0 means unlimited.
func TestSizeGuardrails_NoLimitWhenZero(t *testing.T) {
f := New(WithXlang(true), WithMaxCollectionSize(0))
data := []int32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
bytes, err := f.Marshal(data)
if err != nil {
t.Fatalf("serialize failed: %v", err)
}
var result any
err = f.Unmarshal(bytes, &result)
if err != nil {
t.Fatalf("deserialize with no limit should succeed: %v", err)
}
}
1 change: 1 addition & 0 deletions go/fory/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ func (s mapSerializer) ReadData(ctx *ReadContext, value reflect.Value) {
refResolver.Reference(value)

size := int(buf.ReadVarUint32(ctxErr))
ctx.checkCollectionSize(size)
if size == 0 || ctx.HasError() {
return
}
Expand Down
79 changes: 62 additions & 17 deletions go/fory/map_primitive.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,13 @@ func writeMapStringString(buf *ByteBuffer, m map[string]string, hasGenerics bool
}

// readMapStringString reads map[string]string using chunk protocol
func readMapStringString(buf *ByteBuffer, err *Error) map[string]string {
func readMapStringString(buf *ByteBuffer, err *Error, maxCollectionSize int) map[string]string {
size := int(buf.ReadVarUint32(err))
if maxCollectionSize > 0 && size > maxCollectionSize {
err.SetError(DeserializationErrorf(
"fory: map size %d exceeds limit %d", size, maxCollectionSize))
return nil
}
result := make(map[string]string, size)
if size == 0 {
return result
Expand Down Expand Up @@ -172,8 +177,13 @@ func writeMapStringInt64(buf *ByteBuffer, m map[string]int64, hasGenerics bool)
}

// readMapStringInt64 reads map[string]int64 using chunk protocol
func readMapStringInt64(buf *ByteBuffer, err *Error) map[string]int64 {
func readMapStringInt64(buf *ByteBuffer, err *Error, maxCollectionSize int) map[string]int64 {
size := int(buf.ReadVarUint32(err))
if maxCollectionSize > 0 && size > maxCollectionSize {
err.SetError(DeserializationErrorf(
"fory: map size %d exceeds limit %d", size, maxCollectionSize))
return nil
}
result := make(map[string]int64, size)
if size == 0 {
return result
Expand Down Expand Up @@ -246,8 +256,13 @@ func writeMapStringInt32(buf *ByteBuffer, m map[string]int32, hasGenerics bool)
}

// readMapStringInt32 reads map[string]int32 using chunk protocol
func readMapStringInt32(buf *ByteBuffer, err *Error) map[string]int32 {
func readMapStringInt32(buf *ByteBuffer, err *Error, maxCollectionSize int) map[string]int32 {
size := int(buf.ReadVarUint32(err))
if maxCollectionSize > 0 && size > maxCollectionSize {
err.SetError(DeserializationErrorf(
"fory: map size %d exceeds limit %d", size, maxCollectionSize))
return nil
}
result := make(map[string]int32, size)
if size == 0 {
return result
Expand Down Expand Up @@ -320,8 +335,13 @@ func writeMapStringInt(buf *ByteBuffer, m map[string]int, hasGenerics bool) {
}

// readMapStringInt reads map[string]int using chunk protocol
func readMapStringInt(buf *ByteBuffer, err *Error) map[string]int {
func readMapStringInt(buf *ByteBuffer, err *Error, maxCollectionSize int) map[string]int {
size := int(buf.ReadVarUint32(err))
if maxCollectionSize > 0 && size > maxCollectionSize {
err.SetError(DeserializationErrorf(
"fory: map size %d exceeds limit %d", size, maxCollectionSize))
return nil
}
result := make(map[string]int, size)
if size == 0 {
return result
Expand Down Expand Up @@ -394,8 +414,13 @@ func writeMapStringFloat64(buf *ByteBuffer, m map[string]float64, hasGenerics bo
}

// readMapStringFloat64 reads map[string]float64 using chunk protocol
func readMapStringFloat64(buf *ByteBuffer, err *Error) map[string]float64 {
func readMapStringFloat64(buf *ByteBuffer, err *Error, maxCollectionSize int) map[string]float64 {
size := int(buf.ReadVarUint32(err))
if maxCollectionSize > 0 && size > maxCollectionSize {
err.SetError(DeserializationErrorf(
"fory: map size %d exceeds limit %d", size, maxCollectionSize))
return nil
}
result := make(map[string]float64, size)
if size == 0 {
return result
Expand Down Expand Up @@ -468,8 +493,13 @@ func writeMapStringBool(buf *ByteBuffer, m map[string]bool, hasGenerics bool) {
}

// readMapStringBool reads map[string]bool using chunk protocol
func readMapStringBool(buf *ByteBuffer, err *Error) map[string]bool {
func readMapStringBool(buf *ByteBuffer, err *Error, maxCollectionSize int) map[string]bool {
size := int(buf.ReadVarUint32(err))
if maxCollectionSize > 0 && size > maxCollectionSize {
err.SetError(DeserializationErrorf(
"fory: map size %d exceeds limit %d", size, maxCollectionSize))
return nil
}
result := make(map[string]bool, size)
if size == 0 {
return result
Expand Down Expand Up @@ -547,8 +577,13 @@ func writeMapInt32Int32(buf *ByteBuffer, m map[int32]int32, hasGenerics bool) {
}

// readMapInt32Int32 reads map[int32]int32 using chunk protocol
func readMapInt32Int32(buf *ByteBuffer, err *Error) map[int32]int32 {
func readMapInt32Int32(buf *ByteBuffer, err *Error, maxCollectionSize int) map[int32]int32 {
size := int(buf.ReadVarUint32(err))
if maxCollectionSize > 0 && size > maxCollectionSize {
err.SetError(DeserializationErrorf(
"fory: map size %d exceeds limit %d", size, maxCollectionSize))
return nil
}
result := make(map[int32]int32, size)
if size == 0 {
return result
Expand Down Expand Up @@ -621,8 +656,13 @@ func writeMapInt64Int64(buf *ByteBuffer, m map[int64]int64, hasGenerics bool) {
}

// readMapInt64Int64 reads map[int64]int64 using chunk protocol
func readMapInt64Int64(buf *ByteBuffer, err *Error) map[int64]int64 {
func readMapInt64Int64(buf *ByteBuffer, err *Error, maxCollectionSize int) map[int64]int64 {
size := int(buf.ReadVarUint32(err))
if maxCollectionSize > 0 && size > maxCollectionSize {
err.SetError(DeserializationErrorf(
"fory: map size %d exceeds limit %d", size, maxCollectionSize))
return nil
}
result := make(map[int64]int64, size)
if size == 0 {
return result
Expand Down Expand Up @@ -695,8 +735,13 @@ func writeMapIntInt(buf *ByteBuffer, m map[int]int, hasGenerics bool) {
}

// readMapIntInt reads map[int]int using chunk protocol
func readMapIntInt(buf *ByteBuffer, err *Error) map[int]int {
func readMapIntInt(buf *ByteBuffer, err *Error, maxCollectionSize int) map[int]int {
size := int(buf.ReadVarUint32(err))
if maxCollectionSize > 0 && size > maxCollectionSize {
err.SetError(DeserializationErrorf(
"fory: map size %d exceeds limit %d", size, maxCollectionSize))
return nil
}
result := make(map[int]int, size)
if size == 0 {
return result
Expand Down Expand Up @@ -752,7 +797,7 @@ func (s stringStringMapSerializer) ReadData(ctx *ReadContext, value reflect.Valu
value.Set(reflect.MakeMap(value.Type()))
}
ctx.RefResolver().Reference(value)
result := readMapStringString(ctx.buffer, ctx.Err())
result := readMapStringString(ctx.buffer, ctx.Err(), ctx.maxCollectionSize)
value.Set(reflect.ValueOf(result))
}

Expand Down Expand Up @@ -787,7 +832,7 @@ func (s stringInt64MapSerializer) ReadData(ctx *ReadContext, value reflect.Value
value.Set(reflect.MakeMap(value.Type()))
}
ctx.RefResolver().Reference(value)
result := readMapStringInt64(ctx.buffer, ctx.Err())
result := readMapStringInt64(ctx.buffer, ctx.Err(), ctx.maxCollectionSize)
value.Set(reflect.ValueOf(result))
}

Expand Down Expand Up @@ -822,7 +867,7 @@ func (s stringIntMapSerializer) ReadData(ctx *ReadContext, value reflect.Value)
value.Set(reflect.MakeMap(value.Type()))
}
ctx.RefResolver().Reference(value)
result := readMapStringInt(ctx.buffer, ctx.Err())
result := readMapStringInt(ctx.buffer, ctx.Err(), ctx.maxCollectionSize)
value.Set(reflect.ValueOf(result))
}

Expand Down Expand Up @@ -857,7 +902,7 @@ func (s stringFloat64MapSerializer) ReadData(ctx *ReadContext, value reflect.Val
value.Set(reflect.MakeMap(value.Type()))
}
ctx.RefResolver().Reference(value)
result := readMapStringFloat64(ctx.buffer, ctx.Err())
result := readMapStringFloat64(ctx.buffer, ctx.Err(), ctx.maxCollectionSize)
value.Set(reflect.ValueOf(result))
}

Expand Down Expand Up @@ -892,7 +937,7 @@ func (s stringBoolMapSerializer) ReadData(ctx *ReadContext, value reflect.Value)
value.Set(reflect.MakeMap(value.Type()))
}
ctx.RefResolver().Reference(value)
result := readMapStringBool(ctx.buffer, ctx.Err())
result := readMapStringBool(ctx.buffer, ctx.Err(), ctx.maxCollectionSize)
value.Set(reflect.ValueOf(result))
}

Expand Down Expand Up @@ -927,7 +972,7 @@ func (s int32Int32MapSerializer) ReadData(ctx *ReadContext, value reflect.Value)
value.Set(reflect.MakeMap(value.Type()))
}
ctx.RefResolver().Reference(value)
result := readMapInt32Int32(ctx.buffer, ctx.Err())
result := readMapInt32Int32(ctx.buffer, ctx.Err(), ctx.maxCollectionSize)
value.Set(reflect.ValueOf(result))
}

Expand Down Expand Up @@ -962,7 +1007,7 @@ func (s int64Int64MapSerializer) ReadData(ctx *ReadContext, value reflect.Value)
value.Set(reflect.MakeMap(value.Type()))
}
ctx.RefResolver().Reference(value)
result := readMapInt64Int64(ctx.buffer, ctx.Err())
result := readMapInt64Int64(ctx.buffer, ctx.Err(), ctx.maxCollectionSize)
value.Set(reflect.ValueOf(result))
}

Expand Down Expand Up @@ -997,7 +1042,7 @@ func (s intIntMapSerializer) ReadData(ctx *ReadContext, value reflect.Value) {
value.Set(reflect.MakeMap(value.Type()))
}
ctx.RefResolver().Reference(value)
result := readMapIntInt(ctx.buffer, ctx.Err())
result := readMapIntInt(ctx.buffer, ctx.Err(), ctx.maxCollectionSize)
value.Set(reflect.ValueOf(result))
}

Expand Down
Loading
Loading