From a16fedc759bd14fce69d82cbc345a0e2fe6c7415 Mon Sep 17 00:00:00 2001 From: Shizuo Fujita Date: Sun, 28 Jun 2026 02:41:24 +0900 Subject: [PATCH] Decompress all concatenated frames in Zstd.decompress Co-Authored-By: Claude Opus 4.8 (1M context) --- ext/zstdruby/zstdruby.c | 46 ++++++++++++++++++++++++++++++++--------- spec/zstd-ruby_spec.rb | 13 ++++++++++++ 2 files changed, 49 insertions(+), 10 deletions(-) diff --git a/ext/zstdruby/zstdruby.c b/ext/zstdruby/zstdruby.c index 1649fe5..3e8454b 100644 --- a/ext/zstdruby/zstdruby.c +++ b/ext/zstdruby/zstdruby.c @@ -40,7 +40,7 @@ static VALUE rb_compress(int argc, VALUE *argv, VALUE self) return output; } -static VALUE decode_one_frame(ZSTD_DCtx* dctx, const unsigned char* src, size_t size, VALUE kwargs) { +static VALUE decode_one_frame(ZSTD_DCtx* dctx, const unsigned char* src, size_t size, VALUE kwargs, size_t* consumed) { VALUE out = rb_str_buf_new(0); size_t cap = ZSTD_DStreamOutSize(); char *buf = ALLOC_N(char, cap); @@ -64,11 +64,14 @@ static VALUE decode_one_frame(ZSTD_DCtx* dctx, const unsigned char* src, size_t } } xfree(buf); + if (consumed) { + *consumed = in.pos; + } return out; } static VALUE decompress_buffered(ZSTD_DCtx* dctx, const char* data, size_t len) { - return decode_one_frame(dctx, (const unsigned char*)data, len, Qnil); + return decode_one_frame(dctx, (const unsigned char*)data, len, Qnil, NULL); } static VALUE rb_decompress(int argc, VALUE *argv, VALUE self) @@ -84,6 +87,9 @@ static VALUE rb_decompress(int argc, VALUE *argv, VALUE self) const uint32_t ZSTD_MAGIC = 0xFD2FB528U; const uint32_t SKIP_LO = 0x184D2A50U; /* ...5F */ + VALUE result = Qnil; + ZSTD_DCtx *dctx = NULL; + while (off + 4 <= in_size) { uint32_t magic = (uint32_t)in[off] | ((uint32_t)in[off+1] << 8) @@ -103,23 +109,43 @@ static VALUE rb_decompress(int argc, VALUE *argv, VALUE self) } if (magic == ZSTD_MAGIC) { - ZSTD_DCtx *dctx = ZSTD_createDCtx(); - if (!dctx) { - rb_raise(rb_eRuntimeError, "ZSTD_createDCtx failed"); + if (dctx == NULL) { + dctx = ZSTD_createDCtx(); + if (!dctx) { + rb_raise(rb_eRuntimeError, "ZSTD_createDCtx failed"); + } } - VALUE out = decode_one_frame(dctx, in + off, in_size - off, kwargs); + size_t consumed = 0; + VALUE out = decode_one_frame(dctx, in + off, in_size - off, kwargs, &consumed); + if (result == Qnil) { + /* First frame becomes the accumulator, avoiding a copy of its + (potentially large) output in the common single-frame case. */ + result = out; + } else { + rb_str_cat(result, RSTRING_PTR(out), RSTRING_LEN(out)); + } - ZSTD_freeDCtx(dctx); - RB_GC_GUARD(input_value); - return out; + if (consumed == 0) { + /* Guard against a non-advancing frame to avoid an infinite loop. */ + break; + } + off += consumed; + continue; } off += 1; } + if (dctx != NULL) { + ZSTD_freeDCtx(dctx); + } + RB_GC_GUARD(input_value); - rb_raise(rb_eRuntimeError, "not a zstd frame (magic not found)"); + if (result == Qnil) { + rb_raise(rb_eRuntimeError, "not a zstd frame (magic not found)"); + } + return result; } static void free_cdict(void *dict) diff --git a/spec/zstd-ruby_spec.rb b/spec/zstd-ruby_spec.rb index ddfae47..6ef02d9 100644 --- a/spec/zstd-ruby_spec.rb +++ b/spec/zstd-ruby_spec.rb @@ -99,6 +99,19 @@ def to_str expect(Zstd.decompress(res)).to eq(large_strings * 3) end + it 'should decompress concatenated frames' do + a = Zstd.compress("Hello, ") + b = Zstd.compress("World!") + expect(Zstd.decompress(a + b)).to eq("Hello, World!") + end + + it 'should decompress three or more concatenated frames' do + a = Zstd.compress("Hello, ") + b = Zstd.compress("World!") + c = Zstd.compress("!!!") + expect(Zstd.decompress(a + b + c)).to eq("Hello, World!!!!") + end + it 'should raise exception with unsupported object' do expect { Zstd.decompress(Object.new) }.to raise_error(TypeError) end