diff --git a/docs/storage-select-flow.md b/docs/storage-select-flow.md
new file mode 100644
index 0000000..b0ab2fa
--- /dev/null
+++ b/docs/storage-select-flow.md
@@ -0,0 +1,421 @@
+# Storage Select 完整流程
+
+本文档详细描述 Tavern 存储层 `Select` 方法的完整调用链路、数据结构与核心算法。
+
+---
+
+## 1. 调用入口
+
+Select 被调用的场景有三个:
+
+| 入口 | 位置 | 说明 |
+|------|------|------|
+| 缓存中间件 | `server/middleware/caching/processor.go:103` | 需要读写缓存对象时,通过 `object.ID` 找到对应的 Bucket |
+| 全局函数 | `storage/global.go:30` | 持有全局 `defaultStorage` 的锁,代理调用 |
+| PURGE 删除 | `storage/storage.go:220` | 单对象删除时,先 Select 再 Discard |
+
+```go
+// 缓存中间件调用
+bucket := store.Select(req.Context(), objectID)
+
+// 全局函数调用
+bucket := storage.Select(ctx, cacheKey)
+```
+
+---
+
+## 2. 完整流程图
+
+```mermaid
+flowchart TD
+ subgraph 调用入口
+ A1["缓存中间件
processor.go:103
bucket := store.Select(ctx, objectID)"]
+ A2["全局函数
storage.Select(ctx, cacheKey)
global.go:30"]
+ A3["PURGE 单对象删除
storage.go:220
bucket := n.Select(ctx, cacheKey)"]
+ end
+
+ A1 --> B{defaultStorage}
+ A2 --> B
+ A3 --> B
+
+ B --> C{"storage.New() 创建时
config.Migration.Enabled ?"}
+
+ C -->|false
普通模式| D["nativeStorage.Select()
storage.go:136"]
+ C -->|true
迁移模式| E["migratorStorage.Select()
migrator.go:169"]
+
+ subgraph 普通模式 nativeStorage
+ D --> D1["n.selector.Select(ctx, id)
直接委托给 warmSelector"]
+ D1 --> F
+ end
+
+ subgraph 迁移模式 migratorStorage
+ E --> E1["chainSelector(ctx, id,
hotSelector, warmSelector, coldSelector)
migrator.go:240"]
+ E1 --> E2["遍历 selector 列表"]
+ E2 --> E2A["sel.Select(ctx, id) → bucket"]
+ E2A --> E2B{"bucket != nil
&&
bucket.Exist(ctx, id.Bytes()) ?"}
+ E2B -->|"是(对象已存在该层)"| E2C["返回该 bucket"]
+ E2B -->|"否"| E2D{"还有下一个
selector ?"}
+ E2D -->|是| E2
+ E2D -->|"否(所有层都不存在)"| E2E["兜底: warmSelector.Select(ctx, id)"]
+ E2E --> F
+ end
+
+ subgraph SelectLayer 按层选择
+ SL1["migratorStorage.SelectLayer(ctx, id, layer)
migrator.go:220"]
+ SL1 --> SL2{"layer ?"}
+ SL2 -->|"hot"| SL3["hotSelector.Select(ctx, id)"]
+ SL2 -->|"warm/normal"| SL4["warmSelector.Select(ctx, id)"]
+ SL2 -->|"cold"| SL5["coldSelector.Select(ctx, id)"]
+ SL2 -->|"inmemory"| SL6["直接返回 memoryBucket"]
+ end
+
+ subgraph HashRing 一致性哈希选择器
+ F["selector.New(buckets, 'hashring')
selector.go:12"]
+ F --> G["Balancer.Select(ctx, id)
hashring.go:43"]
+
+ G --> G1["遍历 i = 1..len(buckets)"]
+ G1 --> G2["hashring.GetN(id.Bytes(), i)
获取 i 个最近节点"]
+ G2 --> G3["Consistent.GetN(name, n)
consistent.go:173"]
+
+ G3 --> G4["FNV-32a 哈希 key"]
+ G4 --> G5["二分搜索排序哈希环
找到起始位置"]
+ G5 --> G6["从起始位置顺时针遍历
收集 n 个不重复节点"]
+
+ G6 --> G7["取第 i-1 个节点 bucket"]
+ G7 --> G8{"bucket.UseAllow() ?
允许百分比检查"}
+ G8 -->|否| G9["continue 下一个"]
+ G8 -->|是| G10{"bucket.HasBad() ?
健康检查"}
+ G10 -->|"是(不健康)"| G9
+ G10 -->|"否(健康)"| G11["返回该 bucket"]
+ G9 --> G1
+ end
+
+ subgraph DirAware 目录感知包装
+ H["diraware.New(nativeStorage, checker)
diraware/storage.go:19"]
+ H --> H1["wrappedStorage.Select(ctx, id)
diraware/storage.go:34"]
+ H1 --> H2["base.Select(ctx, id) → bucket"]
+ H2 --> H3["wrapBucket(bucket, checker)
包装 bucket 添加过期标记能力"]
+ H3 --> H4["返回 wrappedBucket"]
+ end
+
+ D -->|"config.DirAware.Enabled"| H
+ E -->|"config.DirAware.Enabled"| H
+
+ subgraph 哈希环构建 Consistent Hash Ring
+ R1["Consistent.Set(caches)
consistent.go:109"]
+ R1 --> R2["遍历每个 Node (Bucket)"]
+ R2 --> R3["add(cache, cache.Weight())"]
+ R3 --> R4["对每个节点:
NumberOfReplicas(20) × Weight 个虚拟节点"]
+ R4 --> R5["key = hash('{idx}|{weight}|{bucketID}')
circles[key] = bucket"]
+ R5 --> R6["updateSortedHashes()
排序所有哈希值"]
+ end
+
+ F -.->|初始化时| R1
+
+ subgraph Bucket 存储类型分层
+ BT1["memoryBucket
TypeInMemory"]
+ BT2["hotBucket[]
TypeHot"]
+ BT3["warmBucket[]
TypeWarm / TypeNormal"]
+ BT4["coldBucket[]
TypeCold
(仅 migrator 模式)"]
+ end
+
+ G -.->|选择范围| BT3
+ E2A -.->|hotSelector 范围| BT2
+ E2A -.->|warmSelector 范围| BT3
+ E2A -.->|coldSelector 范围| BT4
+```
+
+---
+
+## 3. 两大模式对比
+
+| 模式 | 条件 | 实现结构体 | 核心逻辑 |
+|------|------|-----------|---------|
+| **普通模式** | `Migration.Enabled = false` | `nativeStorage` | 直接委托给 warmSelector(一致性哈希),不检查对象是否存在 |
+| **迁移模式** | `Migration.Enabled = true` | `migratorStorage` | Hot → Warm → Cold 链式查找,逐层调用 `bucket.Exist()` 检查对象是否存在 |
+
+### 3.1 普通模式 (`nativeStorage`)
+
+```go
+// storage/storage.go:136
+func (n *nativeStorage) Select(ctx context.Context, id *object.ID) storage.Bucket {
+ bucket := n.selector.Select(ctx, id)
+ return bucket
+}
+```
+
+极其简单:一行委托。`n.selector` 即 `warmSelector`,是一个哈希环选择器。
+
+### 3.2 迁移模式 (`migratorStorage`)
+
+```go
+// storage/migrator.go:169
+func (m *migratorStorage) Select(ctx context.Context, id *object.ID) storage.Bucket {
+ return m.chainSelector(ctx, id,
+ m.hotSelector,
+ m.warmSelector,
+ m.coldSelector,
+ )
+}
+```
+
+核心在 `chainSelector`:
+
+```go
+// storage/migrator.go:240
+func (m *migratorStorage) chainSelector(ctx context.Context, id *object.ID, selectors ...storage.Selector) storage.Bucket {
+ for _, sel := range selectors {
+ if sel == nil {
+ continue
+ }
+ if bucket := sel.Select(ctx, id); bucket != nil && bucket.Exist(ctx, id.Bytes()) {
+ return bucket
+ }
+ }
+ // 兜底: 返回 warmSelector 的结果
+ return m.warmSelector.Select(ctx, id)
+}
+```
+
+**关键差异**:迁移模式不仅做哈希定位,还会调用 `bucket.Exist()` 检查对象是否真的在该 Bucket 的 IndexDB 中。这是为了支持对象的 **Promote/Demote** 跨层迁移:
+
+- **Promote**:Cold → Warm → Hot(访问量上升)
+- **Demote**:Hot → Warm → Cold(访问量下降)
+
+### 3.3 按层选择 (`SelectLayer`)
+
+```go
+// storage/migrator.go:220
+func (m *migratorStorage) SelectLayer(ctx context.Context, id *object.ID, layer string) storage.Bucket {
+ switch layer {
+ case storage.TypeHot:
+ if m.hotSelector != nil {
+ return m.hotSelector.Select(ctx, id)
+ }
+ case storage.TypeNormal, storage.TypeWarm:
+ if m.warmSelector != nil {
+ return m.warmSelector.Select(ctx, id)
+ }
+ case storage.TypeCold:
+ if m.coldSelector != nil {
+ return m.coldSelector.Select(ctx, id)
+ }
+ case storage.TypeInMemory:
+ return m.memoryBucket
+ }
+ return nil
+}
+```
+
+用于迁移操作(Promote/Demote)时,根据目标层直接定位到对应 Bucket。
+
+---
+
+## 4. HashRing 一致性哈希(核心算法)
+
+### 4.1 选择器工厂
+
+```go
+// storage/selector/selector.go:12
+func New(buckets []storage.Bucket, typ string) storage.Selector {
+ curr, err := hashring.New(buckets, hashring.WithReplicas(20))
+ if err != nil {
+ panic(err)
+ }
+ return curr
+}
+```
+
+目前仅支持 `hashring` 类型,`typ` 参数实际未使用。
+
+### 4.2 Balancer.Select 算法
+
+```go
+// storage/selector/hashring/hashring.go:43
+func (b *Balancer) Select(ctx context.Context, id *object.ID) storage.Bucket {
+ for i := 1; i <= len(b.buckets); i++ {
+ groups, err := b.hashring.GetN(string(id.Bytes()), i)
+ if err != nil {
+ return nil
+ }
+ bucket := groups[i-1].(storage.Bucket)
+ if bucket.UseAllow() {
+ if bucket.HasBad() {
+ continue
+ }
+ return bucket
+ }
+ }
+ return nil
+}
+```
+
+**算法步骤**:
+
+1. `i=1` 开始,每次递增,调用 `GetN(key, i)` 获取 i 个最近节点
+2. 取第 `i-1` 个(最后一个,即第 i 近的)节点
+3. 检查 `UseAllow()`(允许百分比)和 `HasBad()`(健康状态)
+4. 如果满足条件则返回,否则 `i++` 找下一个更远的节点
+5. 这是一种**退避策略**:优先返回最近的健康节点
+
+### 4.3 Consistent.GetN — 哈希环查找
+
+```go
+// storage/selector/hashring/consistent.go:173
+func (c *Consistent) GetN(name string, n int) ([]Node, error) {
+ key := c.hashKey(name) // FNV-32a 哈希
+ i := c.search(key) // 二分搜索定位
+ // 从起始位置顺时针遍历,收集 n 个不重复节点
+ // ...
+}
+```
+
+**关键参数**:
+
+| 参数 | 值 | 说明 |
+|------|-----|------|
+| 哈希函数 | FNV-32a | `hash/fnv` 标准库 |
+| 虚拟副本数 | 20 | `NumberOfReplicas`,可通过 `WithReplicas` 配置 |
+| 虚拟节点 key | `"{idx}|{weight}|{bucketID}"` | 每个副本 × 每个权重单位生成一个虚拟节点 |
+
+### 4.4 哈希环构建
+
+```go
+// storage/selector/hashring/consistent.go:109
+func (c *Consistent) Set(caches []Node) {
+ // 移除不再存在的节点
+ // 新增节点: add(cache, cache.Weight())
+ // → NumberOfReplicas × Weight 个虚拟节点
+ // → 每个虚拟节点: circles[hash("{idx}|{weight}|{id}")] = bucket
+ // updateSortedHashes() → sort
+}
+```
+
+**总虚拟节点数** = `∑(20 × Bucket.Weight)`,例如 3 个 Bucket 各权重 10 → 600 个虚拟节点。
+
+---
+
+## 5. DirAware 目录感知包装
+
+当 `config.DirAware.Enabled = true` 时,`nativeStorage` 或 `migratorStorage` 被 `wrappedStorage` 包装。
+
+```go
+// storage/diraware/storage.go:34
+func (w *wrappedStorage) Select(ctx context.Context, id *object.ID) storagev1.Bucket {
+ return wrapBucket(w.base.Select(ctx, id), w.checker)
+}
+```
+
+### 5.1 Bucket 包装 — Lookup 注入
+
+```go
+// storage/diraware/bucket.go:26
+func (b *wrappedBucket) Lookup(ctx context.Context, id *object.ID) (*object.Metadata, error) {
+ md, err := b.base.Lookup(ctx, id)
+ if err != nil || md == nil {
+ return md, err
+ }
+ marked, err := b.checker.Marked(ctx, id, md)
+ if marked {
+ md.ExpiresAt = time.Now().Add(-1 * time.Second).Unix()
+ }
+ return md, nil
+}
+```
+
+### 5.2 Checker 标记逻辑
+
+```go
+// storage/diraware/diraware.go:74
+func (c *checker) Marked(ctx context.Context, id *object.ID, md *object.Metadata) (bool, error) {
+ unix, found := c.pathtrie.Search(id.Path())
+ if found && md.RespUnix <= unix {
+ return true, nil // 对象在推送目录任务之前保存的 → 标记过期
+ }
+ return false, nil
+}
+```
+
+**逻辑**:前缀树(PathTrie)中存储了被推送的目录路径及推送时间。当 `Lookup` 时,如果对象所在路径被推送标记过,且对象的最后修改时间 ≤ 推送时间,说明这个对象是在推送之前缓存的,应当标记为过期。
+
+---
+
+## 6. 各 Selector 覆盖的 Bucket 范围
+
+| Selector | 管理的 Bucket 列表 | 对应 StoreType |
+|----------|-------------------|----------------|
+| `warmSelector` | `warmBucket[]` | `TypeWarm` / `TypeNormal`(Normal 自动别名为 Warm) |
+| `hotSelector` | `hotBucket[]` | `TypeHot`(仅 migrator 模式) |
+| `coldSelector` | `coldBucket[]` | `TypeCold`(仅 migrator 模式) |
+| `memoryBucket` | 单例 | `TypeInMemory`(直接引用,不走哈希环) |
+
+### 6.1 Bucket 初始化时的分类
+
+```go
+// storage/storage.go:104 (reinit 方法)
+switch bucket.StoreType() {
+case storage.TypeNormal, storage.TypeWarm:
+ n.warmlBucket = append(n.warmlBucket, bucket)
+case storage.TypeHot:
+ n.hotBucket = append(n.hotBucket, bucket)
+case storage.TypeInMemory:
+ n.memoryBucket = bucket // 只能有一个
+case storage.TypeCold: // 仅 migrator
+ m.coldBucket = append(m.coldBucket, bucket)
+}
+```
+
+> **注意**:`TypeNormal` 在 `mergeConfig` 中会被自动转为 `TypeWarm`(`storage/builder.go:61`)。
+
+---
+
+## 7. 关键接口定义
+
+```go
+// api/defined/v1/storage/storage.go
+
+type Selector interface {
+ Select(ctx context.Context, id *object.ID) Bucket
+ Rebuild(ctx context.Context, buckets []Bucket) error
+}
+
+type Storage interface {
+ io.Closer
+ Selector // 内嵌 Selector
+ Buckets() []Bucket
+ SharedKV() SharedKV
+ PURGE(storeUrl string, typ PurgeControl) error
+}
+
+type Migrator interface {
+ Storage // 内嵌 Storage
+ SelectLayer(ctx context.Context, id *object.ID, layer string) Bucket
+}
+```
+
+完整实现矩阵:
+
+| 接口 | `nativeStorage` | `migratorStorage` | `wrappedStorage` (diraware) |
+|------|:---:|:---:|:---:|
+| `Storage` | ✓ | ✓ | ✓ |
+| `Migrator` | ✗ | ✓ | ✗ |
+| `Select()` | 委托 warmSelector | chainSelector(Hot→Warm→Cold) | 包装 base.Select() |
+| `SelectLayer()` | N/A | 按层直接选择 | N/A |
+
+---
+
+## 8. 文件索引
+
+| 文件 | 内容 |
+|------|------|
+| `api/defined/v1/storage/storage.go` | `Selector`、`Storage`、`Migrator`、`Bucket` 接口定义 |
+| `storage/global.go` | 全局 `Select()`、`SetDefault()`、`Current()` |
+| `storage/storage.go` | `nativeStorage` 实现(普通模式) |
+| `storage/migrator.go` | `migratorStorage` 实现(迁移模式) |
+| `storage/builder.go` | `NewBucket()` 工厂、配置合并 |
+| `storage/selector/selector.go` | Selector 工厂(目前仅 hashring) |
+| `storage/selector/hashring/hashring.go` | `Balancer.Select()` 算法 |
+| `storage/selector/hashring/consistent.go` | 一致性哈希环实现 |
+| `storage/diraware/storage.go` | DirAware Storage 包装器 |
+| `storage/diraware/bucket.go` | DirAware Bucket 包装器(Lookup 注入) |
+| `storage/diraware/diraware.go` | Checker 实现(PathTrie + SharedKV) |
diff --git a/pkg/e2e/e2e.go b/pkg/e2e/e2e.go
index 9268bd5..7e67337 100644
--- a/pkg/e2e/e2e.go
+++ b/pkg/e2e/e2e.go
@@ -88,14 +88,16 @@ func (e *E2E) Do(rewrite func(r *http.Request)) (*http.Response, error) {
// wait for a while to let the server ready
time.Sleep(time.Millisecond * 100)
- rewrite(e.req)
+ nr := e.req.Clone(context.Background())
- method := e.req.Method
+ rewrite(nr)
- e.req.Header.Set(protocol.InternalUpstreamAddr, e.ts.Listener.Addr().String())
+ method := nr.Method
+
+ nr.Header.Set(protocol.InternalUpstreamAddr, e.ts.Listener.Addr().String())
if dumpReq.Load() && method != "PURGE" {
- DumpReq(e.req)
+ DumpReq(nr)
}
if manual.Load() {
@@ -103,7 +105,7 @@ func (e *E2E) Do(rewrite func(r *http.Request)) (*http.Response, error) {
time.Sleep(time.Second * 20)
}
- resp, err := e.cs.Do(e.req)
+ resp, err := e.cs.Do(nr)
e.resp = resp
e.err = err
@@ -172,11 +174,15 @@ func DumpResp(resp *http.Response) {
fmt.Println(string(buf))
}
-func Purge(t *testing.T, url string) {
+func PurgeMethod(t *testing.T, url string, dir bool) {
resp, err := New(url, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadGateway)
}).Do(func(r *http.Request) {
r.Method = "PURGE"
+ r.Header.Set("Purge-Type", "file,hard")
+ if dir {
+ r.Header.Set("Purge-Type", "dir,hard")
+ }
})
assert.NoError(t, err, "purge should not error")
@@ -187,3 +193,7 @@ func Purge(t *testing.T, url string) {
t.Logf("Purge %s success", url)
}
}
+
+func Purge(t *testing.T, url string) {
+ PurgeMethod(t, url, false)
+}
diff --git a/server/middleware/caching/caching.go b/server/middleware/caching/caching.go
index 36f9839..0639a51 100644
--- a/server/middleware/caching/caching.go
+++ b/server/middleware/caching/caching.go
@@ -12,6 +12,7 @@ import (
"time"
"github.com/kelindar/bitmap"
+ "github.com/prometheus/client_golang/prometheus"
"github.com/omalloc/tavern/api/defined/v1/event"
configv1 "github.com/omalloc/tavern/api/defined/v1/middleware"
@@ -25,8 +26,6 @@ import (
"github.com/omalloc/tavern/proxy"
"github.com/omalloc/tavern/server/middleware"
storagev1 "github.com/omalloc/tavern/storage"
-
- "github.com/prometheus/client_golang/prometheus"
)
const BYPASS = "BYPASS"
@@ -114,6 +113,12 @@ func Middleware(c *configv1.Middleware) (middleware.Middleware, func(), error) {
proxyClient := proxy.GetProxy()
store := storagev1.Current()
+ // Flight groups for collapsed forwarding at object and chunk level.
+ // These mirror Squid's collapsed_forwarding: one origin request
+ // serves many waiting clients.
+ objectFlight := &ObjectFlightGroup{}
+ chunkFlight := &ChunkFlightGroup{}
+
return middleware.RoundTripperFunc(func(req *http.Request) (resp *http.Response, err error) {
// only cache GET/HEAD request
if req.Method != http.MethodGet && req.Method != http.MethodHead {
@@ -132,9 +137,20 @@ func Middleware(c *configv1.Middleware) (middleware.Middleware, func(), error) {
// cachingPool.Put(caching)
//}()
+ // Wire up chunk-level collapsed forwarding.
+ caching.chunkFlight = chunkFlight
+
+ // Increment cacheRequestTotal on every response path (BYPASS,
+ // HIT, MISS-collapsed, MISS-direct). Placed as a defer so it
+ // fires after cacheStatus is finalized.
+ defer func() {
+ cacheRequestTotal.WithLabelValues(caching.cacheStatus.String(), caching.bucket.StoreType()).Inc()
+ }()
+
// err to BYPASS caching
if err != nil {
- caching.log.Warnf("Precache processor failed: %v, BYPASS", err)
+ caching.log.Warnf("Precache processor failed: %v BYPASS", err)
+ caching.cacheStatus = storage.BYPASS
resp, err = caching.doProxy(req, false) // do reverse proxy
if err != nil {
return nil, err
@@ -144,57 +160,70 @@ func Middleware(c *configv1.Middleware) (middleware.Middleware, func(), error) {
// set cache-staus header BYPASS
resp.Header.Set(protocol.ProtocolCacheStatusKey, BYPASS)
}
- cacheRequestTotal.WithLabelValues(storage.BYPASS.String(), caching.bucket.StoreType()).Inc()
return
}
// cache HIT
if caching.hit {
- caching.cacheStatus = storage.CacheHit
-
- rng, err1 := xhttp.SingleRange(req.Header.Get("Range"), caching.md.Size)
- if err1 != nil {
- // 无效 Range 处理
- headers := make(http.Header)
- xhttp.CopyHeader(caching.md.Headers, headers)
- headers.Set("Content-Range", fmt.Sprintf("bytes */%d", caching.md.Size))
- cacheRequestTotal.WithLabelValues(caching.cacheStatus.String(), caching.bucket.StoreType()).Inc()
- return nil, xhttp.NewBizError(http.StatusRequestedRangeNotSatisfiable, headers)
- }
-
- // mark cache status with Range requests.
- caching.markCacheStatus(rng.Start, rng.End)
+ return caching.respondFromCache(req)
+ }
- // find file seek(start, end)
- resp, err = caching.lazilyRespond(req, rng.Start, rng.End)
- if err != nil {
- // fd leak
- closeBody(resp)
- cacheRequestTotal.WithLabelValues(caching.cacheStatus.String(), caching.bucket.StoreType()).Inc()
- return nil, err
+ // full MISS — use object-level collapsed forwarding so that
+ // concurrent requests for the same cache object share one
+ // origin fetch (Squid-style collapsed_forwarding).
+ if opts.CollapsedRequest {
+ flightResp, _, flightErr := objectFlight.Do(caching.id.HashStr(), opts.CollapsedRequestWaitTimeout.AsDuration(), func() (*http.Response, error) {
+ r, e := caching.doProxy(req, false)
+ if e != nil {
+ return nil, e
+ }
+ return processor.postCacheProcessor(caching, req, r)
+ })
+ if flightErr != nil {
+ return nil, flightErr
}
-
- // response now
- resp, err = caching.processor.postCacheProcessor(caching, req, resp)
- cacheRequestTotal.WithLabelValues(caching.cacheStatus.String(), caching.bucket.StoreType()).Inc()
+ resp = flightResp
return
}
- // full MISS
+ // full MISS (collapsed forwarding disabled)
resp, err = caching.doProxy(req, false)
if err != nil {
- cacheRequestTotal.WithLabelValues(caching.cacheStatus.String(), caching.bucket.StoreType()).Inc()
return nil, err
}
resp, err = processor.postCacheProcessor(caching, req, resp)
- cacheRequestTotal.WithLabelValues(caching.cacheStatus.String(), caching.bucket.StoreType()).Inc()
return
})
}, middleware.EmptyCleanup, nil
}
+// respondFromCache assembles a response from cached chunks for a cache HIT.
+// It parses the Range header, builds a multi-part reader from disk, and
+// runs post-cache processing (headers, cache status, store).
+func (c *Caching) respondFromCache(req *http.Request) (*http.Response, error) {
+ c.cacheStatus = storage.CacheHit
+
+ rng, err := xhttp.SingleRange(req.Header.Get("Range"), c.md.Size)
+ if err != nil {
+ headers := make(http.Header)
+ xhttp.CopyHeader(c.md.Headers, headers)
+ headers.Set("Content-Range", fmt.Sprintf("bytes */%d", c.md.Size))
+ return nil, xhttp.NewBizError(http.StatusRequestedRangeNotSatisfiable, headers)
+ }
+
+ c.markCacheStatus(rng.Start, rng.End)
+
+ resp, err := c.lazilyRespond(req, rng.Start, rng.End)
+ if err != nil {
+ closeBody(resp)
+ return nil, err
+ }
+
+ return c.processor.postCacheProcessor(c, req, resp)
+}
+
func (c *Caching) lazilyRespond(req *http.Request, start, end int64) (*http.Response, error) {
// 这里通过缓存的块大小来计算,而不是配置默认的 SliceSize
// 这样已缓存的对象可以使用原来的配置块大小,不受配置文件变更影响
@@ -261,27 +290,25 @@ func (c *Caching) lazilyRespond(req *http.Request, start, end int64) (*http.Resp
func (c *Caching) getUpstreamReader(fromByte, toByte uint64, async bool) (io.ReadCloser, error) {
// get from origin request header
rawRange := c.req.Header.Get("Range")
- newRange := fmt.Sprintf("bytes=%d-%d", fromByte, toByte)
- req := c.req.Clone(context.Background())
- req.Header.Set("Range", newRange)
- // add request-id [range]
- // req.Header.Set("X-Request-ID", fmt.Sprintf("%s-%d", req.Header.Get(appctx.ProtocolRequestIDKey), fromByte)) // 附加 Request-ID suffix
- // remove all internal header
- req.Header.Del(protocol.ProtocolCacheStatusKey)
+ // doSubRequest is parameterized by the union range so the flight group
+ // can expand the range to cover multiple concurrent callers.
+ doSubRequest := func(unionFrom, unionTo uint64) (*http.Response, error) {
+ newRange := fmt.Sprintf("bytes=%d-%d", unionFrom, unionTo)
+ subReq := c.req.Clone(context.Background())
+ subReq.Header.Set("Range", newRange)
+ subReq.Header.Del(protocol.ProtocolCacheStatusKey)
- doSubRequest := func() (*http.Response, error) {
now := time.Now()
c.log.Debugf("getUpstreamReader doProxy[chunk]: begin: %s, rawRange: %s, newRange: %s", now, rawRange, newRange)
- resp, err := c.doProxy(req, true)
+ resp, err := c.doProxy(subReq, true)
c.log.Infof("getUpstreamReader doProxy[chunk]: timeCost: %s, rawRange: %s, newRange: %s", time.Since(now), rawRange, newRange)
if err != nil {
closeBody(resp)
return nil, err
}
- // 部分命中
- c.cacheStatus = storage.CachePartHit
- // 发起的是 206 请求,但是返回的非 206
+ // 206 Partial Content is expected for range requests.
+ // If we get a different status code, it may indicate an issue with the upstream response.
if resp.StatusCode != http.StatusPartialContent {
c.log.Warnf("getUpstreamReader doProxy[chunk]: status code: %d, bod size: %d", resp.StatusCode, resp.ContentLength)
return resp, xhttp.NewBizError(resp.StatusCode, resp.Header)
@@ -289,11 +316,37 @@ func (c *Caching) getUpstreamReader(fromByte, toByte uint64, async bool) (io.Rea
return resp, nil
}
+ // Chunk-level collapsed forwarding: if another goroutine is already
+ // fetching (possibly a different) byte range for this object, wait
+ // and share the union response body (io.MultiWriter fan-out + RangeReader
+ // trimming). This mirrors Squid's collapsed_forwarding at the chunk
+ // level, with automatic range union.
+ if c.chunkFlight != nil && c.opt.CollapsedRequest && c.id != nil {
+ reader, _, err := c.chunkFlight.Do(c.id.HashStr(), fromByte, toByte,
+ c.opt.CollapsedRequestWaitTimeout.AsDuration(), doSubRequest)
+ // Both leader and shared callers are partial hits: at least one
+ // chunk was missing and is being fetched from origin.
+
+ // c.cacheStatus = storage.CachePartMiss
+ // if shared {
+ // c.cacheStatus = storage.CachePartHit
+ // }
+ return reader, err
+ }
+
+ // Any path that reaches getUpstreamReader is a partial hit: at least
+ // one chunk was missing from cache and is now being fetched from
+ // origin. Set the status here for the async and sync paths (the
+ // chunk-flight path above already sets it before returning).
+ // c.cacheStatus = storage.CachePartMiss
+
if async {
- return iobuf.AsyncReadCloser(doSubRequest), nil
+ return iobuf.AsyncReadCloser(func() (*http.Response, error) {
+ return doSubRequest(fromByte, toByte)
+ }), nil
}
- resp, err := doSubRequest()
+ resp, err := doSubRequest(fromByte, toByte)
if resp != nil {
return resp.Body, err
}
diff --git a/server/middleware/caching/caching_test.go b/server/middleware/caching/caching_test.go
index 25c7ec2..5712b16 100644
--- a/server/middleware/caching/caching_test.go
+++ b/server/middleware/caching/caching_test.go
@@ -1,6 +1,10 @@
package caching
-import "testing"
+import (
+ "net/http"
+ "strings"
+ "testing"
+)
func BenchmarkWithPooling(b *testing.B) {
for i := 0; i < b.N; i++ {
@@ -8,3 +12,13 @@ func BenchmarkWithPooling(b *testing.B) {
c.reset()
}
}
+
+func TestObjectFlight_PanicRecovery(t *testing.T) {
+ g := &ObjectFlightGroup{}
+ _, _, err := g.Do("key", 0, func() (*http.Response, error) {
+ panic("boom")
+ })
+ if err == nil || !strings.Contains(err.Error(), "panic") {
+ t.Fatalf("expected panic error, got %v", err)
+ }
+}
diff --git a/server/middleware/caching/chunk_flight.go b/server/middleware/caching/chunk_flight.go
new file mode 100644
index 0000000..7a50a01
--- /dev/null
+++ b/server/middleware/caching/chunk_flight.go
@@ -0,0 +1,197 @@
+package caching
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+ "sync"
+ "time"
+
+ "github.com/omalloc/tavern/pkg/iobuf"
+)
+
+// chunkRange holds one caller's desired byte range within the object.
+type chunkRange struct {
+ fromByte, toByte uint64
+}
+
+// chunkCall is an in-flight chunk upstream request. Multiple callers
+// register their desired ranges; the leader computes the union and
+// fetches a single range that covers everyone. Each caller receives
+// their own sub-range trimmed via iobuf.RangeReader.
+type chunkCall struct {
+ pipes []*io.PipeWriter
+ ranges []chunkRange
+ mu sync.Mutex // protects pipes and ranges during registration
+ wg sync.WaitGroup // signals that unionFrom / unionTo are computed
+ unionFrom uint64
+ unionTo uint64
+ err error // set when fn fails
+}
+
+// ChunkFlightGroup collapses concurrent upstream requests for different
+// byte ranges of the same object into a single origin fetch covering the
+// union of all requested ranges. Response body bytes are fanned out to
+// all waiters via io.MultiWriter + io.Pipe, and each caller trims the
+// shared stream down to its own sub-range with iobuf.RangeReader.
+//
+// This mirrors Squid's collapsed_forwarding at the chunk/segment level:
+// when two goroutines request different byte ranges of the same cached
+// object, only one hits origin and the others wait, even when the
+// ranges differ.
+type ChunkFlightGroup struct {
+ mu sync.Mutex
+ m map[string]*chunkCall
+}
+
+// Do executes fn once per objectKey, passing the union of all registered
+// byte ranges to fn. All callers — including the first — receive an
+// io.PipeReader trimmed to exactly [fromByte, toByte]. The returned
+// bool reports whether this caller shared an in-flight request.
+//
+// waiter is the duration the leader goroutine pauses *before* calling fn,
+// giving late-arriving callers a window to register under the same key.
+// In production the network round-trip naturally provides this window;
+// waiter ensures correctness even when fn would otherwise complete nearly
+// instantly (e.g. in tests, or for tiny ranges on a local origin).
+//
+// Contract: fn owns resp.Body. On success ChunkFlightGroup reads and
+// closes it. On error fn must either return (nil, err) or close the body
+// before returning (resp, err).
+func (g *ChunkFlightGroup) Do(objectKey string, fromByte, toByte uint64, waiter time.Duration, fn func(unionFrom, unionTo uint64) (*http.Response, error)) (io.ReadCloser, bool, error) {
+ pr, pw := io.Pipe()
+
+ g.mu.Lock()
+ if g.m == nil {
+ g.m = make(map[string]*chunkCall)
+ }
+ if c, ok := g.m[objectKey]; ok {
+ // Waiter: register pipe writer and desired range, then wait for
+ // the leader to compute the union range.
+ c.mu.Lock()
+ c.pipes = append(c.pipes, pw)
+ c.ranges = append(c.ranges, chunkRange{fromByte: fromByte, toByte: toByte})
+ c.mu.Unlock()
+ g.mu.Unlock()
+
+ c.wg.Wait()
+ c.mu.Lock()
+ flightErr := c.err
+ c.mu.Unlock()
+ if flightErr != nil {
+ _ = pw.CloseWithError(flightErr)
+ return nil, true, flightErr
+ }
+ // Return immediately — fn is executed asynchronously by the
+ // leader goroutine, so the caller can build response headers
+ // before c.md is mutated by the upstream fetch.
+ return iobuf.RangeReader(pr, int(c.unionFrom), int(c.unionTo), int(fromByte), int(toByte)), true, nil
+ }
+
+ // Leader: create the flight and register own range.
+ c := &chunkCall{
+ pipes: []*io.PipeWriter{pw},
+ ranges: []chunkRange{{fromByte: fromByte, toByte: toByte}},
+ }
+ c.wg.Add(1)
+ g.m[objectKey] = c
+ g.mu.Unlock()
+
+ // Pause before hitting origin so concurrent callers have time
+ // to register under this key. Without this window an instant
+ // fn would compute the union and delete the map entry before
+ // anyone else could join.
+ if waiter > 0 {
+ time.Sleep(waiter)
+ }
+
+ // Compute the union range across all registered callers.
+ c.mu.Lock()
+ unionFrom := c.ranges[0].fromByte
+ unionTo := c.ranges[0].toByte
+ for _, r := range c.ranges[1:] {
+ if r.fromByte < unionFrom {
+ unionFrom = r.fromByte
+ }
+ if r.toByte > unionTo {
+ unionTo = r.toByte
+ }
+ }
+ c.unionFrom = unionFrom
+ c.unionTo = unionTo
+ c.mu.Unlock()
+
+ // Release waiters — they now know unionFrom/unionTo and can build
+ // their RangeReader wrappers around the shared pipe. fn has not
+ // been called yet, so callers can safely read c.md headers before
+ // the upstream fetch mutates them.
+ c.wg.Done()
+
+ // Remove the key from the map so that no late-arriving callers can
+ // join this flight with a stale union range. The waiter window
+ // (time.Sleep above) is the intentional batching period — callers
+ // arriving after it will start a fresh flight. This trades a
+ // possible duplicate origin request for guaranteed range correctness.
+ g.mu.Lock()
+ delete(g.m, objectKey)
+ g.mu.Unlock()
+
+ // The leader returns immediately with a pipe reader. The upstream
+ // fetch (fn) and body fan-out run in a background goroutine so that
+ // response headers are built before c.md is touched.
+ go func() {
+
+ // check for panic to avoid leaving waiters hanging indefinitely
+ resp, err := func() (r *http.Response, e error) {
+ defer func() {
+ if rec := recover(); rec != nil {
+ e = fmt.Errorf("chunk flight panic: %v", rec)
+ }
+ }()
+ return fn(unionFrom, unionTo)
+ }()
+
+ // Snapshot pipes under c.mu. The map entry is already deleted
+ // (above), so no further callers can register against this
+ // flight — c.mu only protects against the window where waiters
+ // that registered before the deletion are still being appended.
+ c.mu.Lock()
+ pipes := make([]*io.PipeWriter, len(c.pipes))
+ copy(pipes, c.pipes)
+ if err != nil {
+ c.err = err
+ }
+ c.mu.Unlock()
+
+ if err != nil {
+ for _, p := range pipes {
+ _ = p.CloseWithError(err)
+ }
+ // fn owns resp.Body on error — it must close it before
+ // returning. We only guard against a nil body here.
+ return
+ }
+
+ // Build MultiWriter from all registered pipe writers.
+ writers := make([]io.Writer, len(pipes))
+ for i, p := range pipes {
+ writers[i] = p
+ }
+ mw := io.MultiWriter(writers...)
+
+ _, copyErr := io.Copy(mw, resp.Body)
+ _ = resp.Body.Close()
+
+ for _, p := range pipes {
+ if copyErr != nil && copyErr != io.EOF {
+ _ = p.CloseWithError(copyErr)
+ } else {
+ _ = p.Close()
+ }
+ }
+ }()
+
+ // Leader wraps its reader with RangeReader so that it sees exactly
+ // [fromByte, toByte] trimmed from the union response.
+ return iobuf.RangeReader(pr, int(unionFrom), int(unionTo), int(fromByte), int(toByte)), false, nil
+}
diff --git a/server/middleware/caching/collapsed_forwarding_test.go b/server/middleware/caching/collapsed_forwarding_test.go
new file mode 100644
index 0000000..c8f8768
--- /dev/null
+++ b/server/middleware/caching/collapsed_forwarding_test.go
@@ -0,0 +1,519 @@
+package caching
+
+import (
+ "bytes"
+ "errors"
+ "io"
+ "net/http"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+// ---------------------------------------------------------------------------
+// ChunkFlightGroup tests
+// ---------------------------------------------------------------------------
+
+func TestChunkFlight_BasicCollapse(t *testing.T) {
+ g := &ChunkFlightGroup{}
+ var callCount atomic.Int32
+
+ // fn returns a body equal in length to the requested range so callers
+ // can verify their sub-range trimming works.
+ fn := func(unionFrom, unionTo uint64) (*http.Response, error) {
+ callCount.Add(1)
+ size := int(unionTo - unionFrom + 1)
+ return &http.Response{
+ StatusCode: http.StatusPartialContent,
+ Body: io.NopCloser(bytes.NewReader(makebuf(size))),
+ }, nil
+ }
+
+ type result struct {
+ length int
+ shared bool
+ }
+
+ results := make([]result, 3)
+ var wg sync.WaitGroup
+ start := make(chan struct{})
+
+ // All three callers request the same range — classic collapse.
+ for i := 0; i < 3; i++ {
+ wg.Add(1)
+ go func(idx int) {
+ defer wg.Done()
+ <-start
+ r, shared, err := g.Do("obj1", 0, 1023, 50*time.Millisecond, fn)
+ if err != nil {
+ t.Errorf("caller %d: unexpected error: %v", idx, err)
+ return
+ }
+ data, readErr := io.ReadAll(r)
+ _ = r.Close()
+ if readErr != nil {
+ t.Errorf("caller %d: read error: %v", idx, readErr)
+ return
+ }
+ results[idx] = result{len(data), shared}
+ }(i)
+ }
+
+ // Release all callers simultaneously so they race on the map entry.
+ time.Sleep(10 * time.Millisecond)
+ close(start)
+ wg.Wait()
+
+ if callCount.Load() != 1 {
+ t.Fatalf("expected 1 call, got %d", callCount.Load())
+ }
+
+ sharedCount := 0
+ for _, r := range results {
+ if r.shared {
+ sharedCount++
+ }
+ if r.length != 1024 {
+ t.Errorf("got %d bytes, want 1024", r.length)
+ }
+ }
+ if sharedCount != 2 {
+ t.Errorf("expected 2 shared callers, got %d", sharedCount)
+ }
+}
+
+func TestChunkFlight_RangeUnion(t *testing.T) {
+ g := &ChunkFlightGroup{}
+ var callCount atomic.Int32
+
+ fn := func(unionFrom, unionTo uint64) (*http.Response, error) {
+ callCount.Add(1)
+ size := int(unionTo - unionFrom + 1)
+ return &http.Response{
+ StatusCode: http.StatusPartialContent,
+ Body: io.NopCloser(bytes.NewReader(makebuf(size))),
+ }, nil
+ }
+
+ type result struct {
+ length int
+ shared bool
+ }
+
+ type caller struct {
+ from, to uint64
+ wantLen int
+ }
+
+ callers := []caller{
+ {0, 999, 1000}, // bytes 0-999
+ {500, 1999, 1500}, // bytes 500-1999, overlaps first
+ {1500, 2999, 1500}, // bytes 1500-2999, overlaps second
+ }
+ results := make([]result, len(callers))
+ var wg sync.WaitGroup
+ start := make(chan struct{})
+
+ for i, c := range callers {
+ wg.Add(1)
+ go func(idx int, from, to uint64) {
+ defer wg.Done()
+ <-start
+ r, shared, err := g.Do("union-obj", from, to, 50*time.Millisecond, fn)
+ if err != nil {
+ t.Errorf("caller %d: unexpected error: %v", idx, err)
+ return
+ }
+ data, readErr := io.ReadAll(r)
+ _ = r.Close()
+ if readErr != nil {
+ t.Errorf("caller %d: read error: %v", idx, readErr)
+ return
+ }
+ results[idx] = result{len(data), shared}
+ }(i, c.from, c.to)
+ }
+
+ time.Sleep(10 * time.Millisecond)
+ close(start)
+ wg.Wait()
+
+ // With range union, all three callers share one origin fetch covering
+ // the union range 0-2999.
+ if callCount.Load() != 1 {
+ t.Fatalf("expected 1 call (union), got %d", callCount.Load())
+ }
+
+ for i, r := range results {
+ if r.length != callers[i].wantLen {
+ t.Errorf("caller %d: got %d bytes, want %d", i, r.length, callers[i].wantLen)
+ }
+ }
+
+ sharedCount := 0
+ for _, r := range results {
+ if r.shared {
+ sharedCount++
+ }
+ }
+ if sharedCount != len(callers)-1 {
+ t.Errorf("expected %d shared callers, got %d", len(callers)-1, sharedCount)
+ }
+}
+
+func TestChunkFlight_ErrorPropagation(t *testing.T) {
+ g := &ChunkFlightGroup{}
+
+ fn := func(_, _ uint64) (*http.Response, error) {
+ return nil, errors.New("upstream timeout")
+ }
+
+ var wg sync.WaitGroup
+ start := make(chan struct{})
+ errs := make([]error, 3)
+
+ for i := 0; i < 3; i++ {
+ wg.Add(1)
+ go func(idx int) {
+ defer wg.Done()
+ <-start
+ r, _, err := g.Do("obj1", 0, 1023, 50*time.Millisecond, fn)
+ if err != nil {
+ errs[idx] = err
+ return
+ }
+ _, readErr := io.ReadAll(r)
+ _ = r.Close()
+ if readErr != nil {
+ errs[idx] = readErr
+ }
+ }(i)
+ }
+
+ time.Sleep(10 * time.Millisecond)
+ close(start)
+ wg.Wait()
+
+ for i, err := range errs {
+ if err == nil {
+ t.Errorf("caller %d: expected error, got nil", i)
+ }
+ }
+}
+
+func TestChunkFlight_KeyIsolation(t *testing.T) {
+ g := &ChunkFlightGroup{}
+ var callCount atomic.Int32
+
+ makeFn := func(data string) func(uint64, uint64) (*http.Response, error) {
+ return func(unionFrom, unionTo uint64) (*http.Response, error) {
+ callCount.Add(1)
+ size := int(unionTo - unionFrom + 1)
+ return &http.Response{
+ StatusCode: http.StatusPartialContent,
+ Body: io.NopCloser(bytes.NewReader(makebuf(size))),
+ }, nil
+ }
+ }
+
+ var wg sync.WaitGroup
+ results := make(map[string]int, 4)
+ var mu sync.Mutex
+
+ // Two objects (obj1, obj2), each with two concurrent callers requesting
+ // different ranges. Within each object the ranges are unioned, but
+ // different objects are isolated.
+ type job struct {
+ key string
+ from, to uint64
+ wantLen int
+ }
+ jobs := []job{
+ {"obj1", 0, 1048575, 1048576},
+ {"obj1", 1048576, 2097151, 1048576},
+ {"obj2", 0, 1048575, 1048576},
+ {"obj2", 1048576, 2097151, 1048576},
+ }
+ for _, j := range jobs {
+ wg.Add(1)
+ go func(k string, from, to uint64) {
+ defer wg.Done()
+ r, _, err := g.Do(k, from, to, 50*time.Millisecond, makeFn(k))
+ if err != nil {
+ t.Errorf("key %s: unexpected error: %v", k, err)
+ return
+ }
+ data, _ := io.ReadAll(r)
+ _ = r.Close()
+ mu.Lock()
+ results[k] = len(data)
+ mu.Unlock()
+ }(j.key, j.from, j.to)
+ }
+ wg.Wait()
+
+ // With object-level keys, obj1's two callers collapse into one
+ // (union: 0-2097151), obj2's two callers collapse into another.
+ if callCount.Load() != 2 {
+ t.Fatalf("expected 2 calls (one per object), got %d", callCount.Load())
+ }
+
+ // Each caller must receive exactly their requested byte count.
+ for _, j := range jobs {
+ if results[j.key] != j.wantLen {
+ t.Errorf("key %s: got %d bytes, want %d", j.key, results[j.key], j.wantLen)
+ }
+ }
+}
+
+func TestChunkFlight_ConcurrentSameKey(t *testing.T) {
+ g := &ChunkFlightGroup{}
+ var callCount atomic.Int32
+
+ fn := func(_, _ uint64) (*http.Response, error) {
+ callCount.Add(1)
+ return &http.Response{
+ StatusCode: http.StatusPartialContent,
+ Body: io.NopCloser(bytes.NewReader(makebuf(1 << 18))),
+ }, nil
+ }
+
+ var wg sync.WaitGroup
+ start := make(chan struct{})
+ const numCallers = 10
+ sharedCount := atomic.Int32{}
+
+ for i := 0; i < numCallers; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ <-start
+ // All callers request the same range 0-262143.
+ r, shared, err := g.Do("same-key", 0, 262143, 100*time.Millisecond, fn)
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ return
+ }
+ _, _ = io.ReadAll(r)
+ _ = r.Close()
+ if shared {
+ sharedCount.Add(1)
+ }
+ }()
+ }
+
+ time.Sleep(10 * time.Millisecond)
+ close(start)
+ wg.Wait()
+
+ if callCount.Load() != 1 {
+ t.Fatalf("expected exactly 1 origin call, got %d", callCount.Load())
+ }
+ if sharedCount.Load() != numCallers-1 {
+ t.Fatalf("expected %d shared callers, got %d", numCallers-1, sharedCount.Load())
+ }
+}
+
+func TestChunkFlight_PanicRecovery(t *testing.T) {
+ g := &ChunkFlightGroup{}
+
+ pr, shared, err := g.Do("panic-key", 0, 1023, 0, func(_, _ uint64) (*http.Response, error) {
+ panic("boom")
+ })
+ if shared {
+ t.Fatal("expected leader, not shared")
+ }
+
+ // fn is now called asynchronously — the leader gets the error through
+ // the pipe, not from the Do return value.
+ if err != nil {
+ t.Fatalf("unexpected error from Do: %v", err)
+ }
+
+ _, readErr := io.ReadAll(pr)
+ _ = pr.Close()
+ if readErr == nil || !strings.Contains(readErr.Error(), "panic") {
+ t.Fatalf("expected panic error from pipe, got %v", readErr)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// ObjectFlightGroup tests
+// ---------------------------------------------------------------------------
+
+func TestObjectFlight_BasicCollapse(t *testing.T) {
+ g := &ObjectFlightGroup{}
+ var callCount atomic.Int32
+
+ fn := func() (*http.Response, error) {
+ callCount.Add(1)
+ time.Sleep(30 * time.Millisecond) // simulate origin latency
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader("response-body")),
+ }, nil
+ }
+
+ var wg sync.WaitGroup
+ start := make(chan struct{})
+ bodies := make([]string, 5)
+ shareds := make([]bool, 5)
+
+ for i := 0; i < 5; i++ {
+ wg.Add(1)
+ go func(idx int) {
+ defer wg.Done()
+ <-start
+ resp, shared, err := g.Do("cache-key-1", 50*time.Millisecond, fn)
+ if err != nil {
+ t.Errorf("caller %d: unexpected error: %v", idx, err)
+ return
+ }
+ shareds[idx] = shared
+ body, readErr := io.ReadAll(resp.Body)
+ _ = resp.Body.Close()
+ if readErr != nil {
+ t.Errorf("caller %d: read error: %v", idx, readErr)
+ return
+ }
+ bodies[idx] = string(body)
+ }(i)
+ }
+
+ time.Sleep(10 * time.Millisecond)
+ close(start)
+ wg.Wait()
+
+ if callCount.Load() != 1 {
+ t.Fatalf("expected 1 call, got %d", callCount.Load())
+ }
+ nonShared := 0
+ shared := 0
+ for _, s := range shareds {
+ if s {
+ shared++
+ } else {
+ nonShared++
+ }
+ }
+ if nonShared != 1 {
+ t.Errorf("expected 1 non-shared caller, got %d", nonShared)
+ }
+ if shared != 4 {
+ t.Errorf("expected 4 shared callers, got %d", shared)
+ }
+ for i, b := range bodies {
+ if b != "response-body" {
+ t.Errorf("caller %d: body = %q, want %q", i, b, "response-body")
+ }
+ }
+}
+
+func TestObjectFlight_ErrorPropagation(t *testing.T) {
+ g := &ObjectFlightGroup{}
+ var callCount atomic.Int32
+
+ testErr := errors.New("origin connection refused")
+ fn := func() (*http.Response, error) {
+ callCount.Add(1)
+ time.Sleep(30 * time.Millisecond) // window for dup callers to register
+ return nil, testErr
+ }
+
+ var wg sync.WaitGroup
+ start := make(chan struct{})
+ errs := make([]error, 3)
+
+ for i := 0; i < 3; i++ {
+ wg.Add(1)
+ go func(idx int) {
+ defer wg.Done()
+ <-start
+ _, _, err := g.Do("cache-key-err", 50*time.Millisecond, fn)
+ errs[idx] = err
+ }(i)
+ }
+
+ time.Sleep(10 * time.Millisecond)
+ close(start)
+ wg.Wait()
+
+ if callCount.Load() != 1 {
+ t.Fatalf("expected 1 call, got %d", callCount.Load())
+ }
+ for i, err := range errs {
+ if !errors.Is(err, testErr) {
+ t.Errorf("caller %d: got %v, want %v", i, err, testErr)
+ }
+ }
+}
+
+func TestObjectFlight_KeyIsolation(t *testing.T) {
+ g := &ObjectFlightGroup{}
+ var callCount atomic.Int32
+
+ fn := func() (*http.Response, error) {
+ callCount.Add(1)
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader("body")),
+ }, nil
+ }
+
+ var wg sync.WaitGroup
+ for _, key := range []string{"key-a", "key-b", "key-c"} {
+ wg.Add(1)
+ go func(k string) {
+ defer wg.Done()
+ resp, _, err := g.Do(k, 0, fn)
+ if err != nil {
+ t.Errorf("key %s: unexpected error: %v", k, err)
+ return
+ }
+ io.Copy(io.Discard, resp.Body)
+ resp.Body.Close()
+ }(key)
+ }
+ wg.Wait()
+
+ if callCount.Load() != 3 {
+ t.Fatalf("expected 3 distinct calls, got %d", callCount.Load())
+ }
+}
+
+func TestObjectFlight_SequentialReuse(t *testing.T) {
+ g := &ObjectFlightGroup{}
+ var callCount atomic.Int32
+
+ fn := func() (*http.Response, error) {
+ callCount.Add(1)
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader("body")),
+ }, nil
+ }
+
+ resp, _, err := g.Do("seq-key", 0, fn)
+ if err != nil {
+ t.Fatalf("first call: unexpected error: %v", err)
+ }
+ io.Copy(io.Discard, resp.Body)
+ resp.Body.Close()
+ if callCount.Load() != 1 {
+ t.Fatalf("first call: expected 1, got %d", callCount.Load())
+ }
+
+ time.Sleep(10 * time.Millisecond)
+
+ resp, _, err = g.Do("seq-key", 0, fn)
+ if err != nil {
+ t.Fatalf("second call: unexpected error: %v", err)
+ }
+ io.Copy(io.Discard, resp.Body)
+ resp.Body.Close()
+ if callCount.Load() != 2 {
+ t.Fatalf("sequential call: expected 2, got %d", callCount.Load())
+ }
+}
diff --git a/server/middleware/caching/internal.go b/server/middleware/caching/internal.go
index c45c1a9..55aff5a 100644
--- a/server/middleware/caching/internal.go
+++ b/server/middleware/caching/internal.go
@@ -45,6 +45,7 @@ type Caching struct {
rootmd *object.Metadata
bucket storage.Bucket
proxyClient proxy.Proxy
+ chunkFlight *ChunkFlightGroup
cacheStatus storage.CacheStatus
cacheable bool
hit bool
diff --git a/server/middleware/caching/object_flight.go b/server/middleware/caching/object_flight.go
new file mode 100644
index 0000000..c6a0e77
--- /dev/null
+++ b/server/middleware/caching/object_flight.go
@@ -0,0 +1,172 @@
+package caching
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+ "sync"
+ "time"
+)
+
+// objectFlightCall represents an in-flight full-object origin fetch.
+//
+// Unlike the previous WaitGroup-only approach, this uses io.Pipe +
+// io.MultiWriter to fan out the response body to all concurrent callers.
+// This ensures the leader's response body is consumed (which drives the
+// SavepartAsyncReader → disk writes) while simultaneously providing data
+// to all waiting callers — no cache re-lookup is needed.
+type objectFlightCall struct {
+ resp *http.Response
+ pipes []*io.PipeWriter
+ mu sync.Mutex // protects pipes during registration and snapshot
+ wg sync.WaitGroup // signals that resp headers / err are ready
+ err error
+}
+
+// ObjectFlightGroup collapses concurrent full-MISS requests for the same
+// cache object. Unlike ChunkFlightGroup (which works at the chunk/segment
+// level), this operates at the whole-object level — it ensures only one
+// goroutine hits origin for a given cache key.
+//
+// The returned response carries the headers from the leader's fn and a
+// body that fans out to all concurrent callers. Callers must close the
+// body.
+type ObjectFlightGroup struct {
+ mu sync.Mutex
+ m map[string]*objectFlightCall
+}
+
+// Do executes fn once per key and fans out the response body to all
+// concurrent callers. All callers receive the same response headers
+// (cloned) and a shared body stream.
+//
+// waiter is the duration the leader pauses before calling fn, giving
+// late-arriving callers a window to register under the same key.
+//
+// Returns:
+//
+// resp — a response carrying the leader's headers and a shared body
+// shared — true if this caller joined an existing flight
+// err — error from fn or from body copy
+func (g *ObjectFlightGroup) Do(key string, waiter time.Duration, fn func() (*http.Response, error)) (*http.Response, bool, error) {
+ pr, pw := io.Pipe()
+
+ g.mu.Lock()
+ if g.m == nil {
+ g.m = make(map[string]*objectFlightCall)
+ }
+ if c, ok := g.m[key]; ok {
+ // Waiter: register a pipe writer and wait for headers.
+ c.mu.Lock()
+ c.pipes = append(c.pipes, pw)
+ c.mu.Unlock()
+ g.mu.Unlock()
+
+ c.wg.Wait()
+ if c.err != nil {
+ _ = pw.CloseWithError(c.err)
+ return nil, true, c.err
+ }
+
+ resp := cloneResponse(c.resp)
+ resp.Body = pr
+ return resp, true, nil
+ }
+
+ // Leader: create the flight and execute fn.
+ c := &objectFlightCall{pipes: []*io.PipeWriter{pw}}
+ c.wg.Add(1)
+ g.m[key] = c
+ g.mu.Unlock()
+
+ if waiter > 0 {
+ time.Sleep(waiter)
+ }
+
+ // check for panic to avoid leaving waiters hanging indefinitely
+ resp, err := func() (r *http.Response, e error) {
+ defer func() {
+ if rec := recover(); rec != nil {
+ e = fmt.Errorf("object flight panic: %v", rec)
+ }
+ }()
+ return fn()
+ }()
+
+ g.mu.Lock()
+ delete(g.m, key)
+
+ if err != nil {
+ c.err = err
+ g.mu.Unlock()
+ c.wg.Done()
+
+ // Snapshot pipes under c.mu to avoid racing with waiter registrations.
+ c.mu.Lock()
+ for _, p := range c.pipes {
+ _ = p.CloseWithError(err)
+ }
+ c.mu.Unlock()
+ return nil, false, err
+ }
+
+ c.resp = resp
+ c.wg.Done() // release waiters — headers are now available
+
+ // Snapshot pipes under c.mu to avoid racing with waiter registrations.
+ c.mu.Lock()
+ pipes := make([]*io.PipeWriter, len(c.pipes))
+ copy(pipes, c.pipes)
+ c.mu.Unlock()
+ g.mu.Unlock()
+
+ // Fan out the response body to all pipes (including the leader's).
+ // This also drives the SavepartAsyncReader → disk write chain.
+ go func() {
+ writers := make([]io.Writer, len(pipes))
+ for i, p := range pipes {
+ writers[i] = p
+ }
+ mw := io.MultiWriter(writers...)
+
+ var copyErr error
+ if resp.Body != nil {
+ _, copyErr = io.Copy(mw, resp.Body)
+ _ = resp.Body.Close()
+ }
+
+ for _, p := range pipes {
+ if copyErr != nil && copyErr != io.EOF {
+ _ = p.CloseWithError(copyErr)
+ } else {
+ _ = p.Close()
+ }
+ }
+ }()
+
+ leaderResp := cloneResponse(resp)
+ leaderResp.Body = pr
+ return leaderResp, false, nil
+}
+
+// cloneResponse returns a shallow copy of resp with a cloned Header map.
+// Body is left nil — the caller sets it to a pipe reader.
+func cloneResponse(resp *http.Response) *http.Response {
+ if resp == nil {
+ return nil
+ }
+ return &http.Response{
+ Status: resp.Status,
+ StatusCode: resp.StatusCode,
+ Proto: resp.Proto,
+ ProtoMajor: resp.ProtoMajor,
+ ProtoMinor: resp.ProtoMinor,
+ Header: resp.Header.Clone(),
+ ContentLength: resp.ContentLength,
+ TransferEncoding: resp.TransferEncoding,
+ Close: resp.Close,
+ Uncompressed: resp.Uncompressed,
+ Request: resp.Request,
+ TLS: resp.TLS,
+ }
+}
diff --git a/server/middleware/caching/processor.go b/server/middleware/caching/processor.go
index cc4c79a..6ddd94e 100644
--- a/server/middleware/caching/processor.go
+++ b/server/middleware/caching/processor.go
@@ -102,6 +102,7 @@ func (pc *ProcessorChain) preCacheProcessor(proxyClient proxy.Proxy, store stora
// hashring or diskhash
bucket := store.Select(req.Context(), objectID)
if bucket == nil {
+ // fallback EMPTY storage
return caching, fmt.Errorf("failed select bucket for objectID: %s", objectID)
}
caching.bucket = bucket
diff --git a/storage/bucket/memory/memory.go b/storage/bucket/memory/memory.go
index 0a7f05a..2a42d2a 100644
--- a/storage/bucket/memory/memory.go
+++ b/storage/bucket/memory/memory.go
@@ -283,7 +283,9 @@ func (m *memoryBucket) WriteChunkFile(ctx context.Context, id *object.ID, index
_ = m.fs.MkdirAll(filepath.Dir(wpath), m.fileMode)
if log.Enabled(log.LevelDebug) {
- log.Context(ctx).Infof("write inmemory chunk file %s", wpath)
+ defer func() {
+ log.Context(ctx).Infof("write inmemory chunk file %s", wpath)
+ }()
}
f, err := m.fs.OpenReadWrite(wpath, vfs.WriteCategoryUnspecified)
diff --git a/tests/all-features/caching/collapsed_forwarding_test.go b/tests/all-features/caching/collapsed_forwarding_test.go
new file mode 100644
index 0000000..c2baccf
--- /dev/null
+++ b/tests/all-features/caching/collapsed_forwarding_test.go
@@ -0,0 +1,462 @@
+package caching
+
+import (
+ "io"
+ "net/http"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/omalloc/tavern/pkg/e2e"
+)
+
+func TestCollapsedForwardingObjectFlight(t *testing.T) {
+ f := e2e.GenFile(t, 2<<20)
+
+ t.Run("test Collapsed Forwarding ObjectFlight Collapse", func(t *testing.T) {
+
+ var originCallCount atomic.Int32
+
+ case1 := e2e.New("http://objflight.example.com/of/object/collapse.bin", e2e.RespCallbackFile(f, func(w http.ResponseWriter, r *http.Request) {
+ originCallCount.Add(1)
+ time.Sleep(80 * time.Millisecond) // window for concurrent registrations
+
+ t.Logf("X-Request-Idx: %s", r.Header.Get("X-Request-Idx"))
+
+ w.Header().Set("Cache-Control", "max-age=10")
+ w.Header().Set("ETag", "obj-flight-etag")
+ }))
+ defer case1.Close()
+
+ const N = 5
+ var wg sync.WaitGroup
+ start := make(chan struct{}, N)
+ bodies := make([]string, N)
+ codes := make([]int, N)
+ xCaches := make([]string, N)
+
+ for i := 0; i < N; i++ {
+ wg.Add(1)
+ go func(idx int) {
+ defer wg.Done()
+ <-start
+
+ t.Logf("started for caller %d", idx)
+
+ resp, err := case1.Do(func(r *http.Request) {
+ r.Header.Set("X-Request-Idx", strconv.Itoa(idx))
+ })
+
+ require.NoError(t, err, "caller %d: request should not error", idx)
+ defer resp.Body.Close()
+
+ hash := e2e.HashBody(resp)
+
+ bodies[idx] = hash
+ codes[idx] = resp.StatusCode
+ xCaches[idx] = resp.Header.Get("X-Cache")
+ }(i)
+ }
+
+ time.Sleep(10 * time.Millisecond)
+ close(start)
+ wg.Wait()
+
+ // Verify only one origin call — ObjectFlightGroup collapsed all 5.
+ assert.Equal(t, int32(1), originCallCount.Load(),
+ "object flight should collapse concurrent full-MISS requests")
+
+ // All callers must receive identical response bodies.
+ for i := 0; i < N; i++ {
+ assert.Equal(t, http.StatusOK, codes[i], "caller %d: status mismatch", i)
+ assert.Equal(t, f.MD5, bodies[i], "caller %d: body mismatch", i)
+ }
+
+ // At least one should be MISS (the first), the rest may be HIT
+ // depending on whether they re-looked up metadata in time.
+ hasMiss := false
+ for _, c := range xCaches {
+ if c != "" {
+ hasMiss = hasMiss || strings.Contains(c, "MISS")
+ }
+ }
+ assert.True(t, hasMiss, "at least one response should report MISS")
+ })
+
+ t.Run("PURGE", func(t *testing.T) {
+ e2e.Purge(t, "http://objflight.example.com/of/object/collapse.bin")
+ })
+
+ t.Run("test Collapsed Forwarding ObjectFlight Sequential", func(t *testing.T) {
+ var originCallCount atomic.Int32
+
+ originCallCount.Store(0)
+
+ case1 := e2e.New("http://objflight.example.com/of/object/sequential.bin", e2e.RespCallbackFile(f, func(w http.ResponseWriter, r *http.Request) {
+ originCallCount.Add(1)
+
+ w.Header().Set("Cache-Control", "max-age=10")
+ w.Header().Set("ETag", "obj-flight-etag")
+ }))
+ defer case1.Close()
+
+ const N = 3
+
+ bodies := make([]string, N)
+
+ // Sequential requests should not be collapsed.
+ for i := 0; i < N; i++ {
+ t.Logf("starting request %d", i)
+
+ resp, err := case1.Do(func(r *http.Request) {
+ r.Header.Set("X-Request-Idx", strconv.Itoa(i))
+ })
+
+ require.NoError(t, err, "request %d should not error", i)
+ bodies[i] = e2e.HashBody(resp)
+
+ resp.Body.Close()
+ }
+
+ assert.Equal(t, int32(1), originCallCount.Load(),
+ "object flight should not collapse sequential requests")
+
+ for i := 0; i < N; i++ {
+ assert.Equal(t, f.MD5, bodies[i], "request %d body-hash mismatch", i)
+ }
+
+ })
+
+ t.Run("PURGE", func(t *testing.T) {
+ e2e.Purge(t, "http://objflight.example.com/of/object/sequential.bin")
+ })
+
+ t.Run("test Collapsed Forwarding ObjectFlight KeyIsolation", func(t *testing.T) {
+ var originCallCount atomic.Int32
+
+ case1 := e2e.New("http://keys.example.com/of/object/", func(w http.ResponseWriter, r *http.Request) {
+ originCallCount.Add(1)
+ time.Sleep(80 * time.Millisecond)
+
+ w.Header().Set("Cache-Control", "max-age=10")
+ w.WriteHeader(http.StatusOK)
+
+ _, _ = w.Write([]byte(r.URL.Path))
+ })
+ defer case1.Close()
+
+ keys := []string{"key-a", "key-b", "key-c"}
+
+ var wg sync.WaitGroup
+ start := make(chan struct{}, len(keys))
+
+ for _, key := range keys {
+ wg.Add(1)
+ go func(k string) {
+ defer wg.Done()
+ <-start
+
+ resp, err := case1.Do(func(r *http.Request) {
+ r.URL.Path += k
+ t.Logf("Requesting key: %s", k)
+ })
+
+ require.NoError(t, err)
+ buf, _ := io.ReadAll(resp.Body)
+ resp.Body.Close()
+
+ assert.Equal(t, "/of/object/"+k, string(buf), "response body should match requested key")
+ }(key)
+ }
+
+ time.Sleep(10 * time.Millisecond)
+ close(start)
+ wg.Wait()
+
+ // Three different keys → three independent origin calls.
+ assert.Equal(t, int32(3), originCallCount.Load(),
+ "different URLs should have independent object flights")
+
+ })
+
+ t.Run("PURGE", func(t *testing.T) {
+ e2e.PurgeMethod(t, "http://keys.example.com/of/object/", true)
+ })
+
+}
+
+func TestCollapsedForwardingChunkFlight(t *testing.T) {
+ file := e2e.GenFile(t, 3<<20) // 3MB → 6 chunks at 512KB
+
+ t.Run("test Collapsed Forwarding ChunkFlight", func(t *testing.T) {
+ var originCallCount atomic.Int32
+
+ case1 := e2e.New("http://chunkflight.example.com/cf/chunk/collapse.bin", e2e.RespCallbackFile(file, func(w http.ResponseWriter, r *http.Request) {
+ originCallCount.Add(1)
+
+ w.Header().Set("Cache-Control", "max-age=30")
+ w.Header().Set("ETag", file.MD5)
+ }))
+ defer case1.Close()
+
+ resp, err := case1.Do(func(r *http.Request) {
+ r.Header.Set("Range", "bytes=0-524287")
+ })
+
+ require.NoError(t, err)
+ io.Copy(io.Discard, resp.Body)
+ resp.Body.Close()
+ require.Equal(t, http.StatusPartialContent, resp.StatusCode)
+
+ // Give storage time to finish writing indexdb metadata.
+ time.Sleep(300 * time.Millisecond)
+
+ // Phase 2 — concurrent requests for a range that needs missing chunks.
+ originCallCount.Store(0)
+
+ const N = 3
+ var wg sync.WaitGroup
+ start := make(chan struct{}, N)
+ bodies := make([]string, N)
+ codes := make([]int, N)
+ xCaches := make([]string, N)
+ ranges := make([]string, N)
+ cls := make([]string, N)
+
+ for i := 0; i < N; i++ {
+ wg.Add(1)
+ go func(idx int) {
+ defer wg.Done()
+ <-start
+
+ t.Logf("started for caller %d", idx)
+
+ resp1, err1 := case1.Do(func(r *http.Request) {
+ r.Header.Set("Range", "bytes=524288-2097151")
+ })
+
+ require.NoError(t, err1, "caller %d: request should not error", idx)
+ defer resp1.Body.Close()
+
+ hashStr := e2e.HashBody(resp1)
+
+ bodies[idx] = string(hashStr)
+ codes[idx] = resp1.StatusCode
+ xCaches[idx] = resp1.Header.Get("X-Cache")
+ ranges[idx] = resp1.Header.Get("Content-Range")
+ cls[idx] = strconv.Itoa(int(resp1.ContentLength))
+ }(i)
+ }
+
+ time.Sleep(10 * time.Millisecond)
+ close(start)
+ wg.Wait()
+
+ t.Logf("origin call count for concurrent phase: %d", originCallCount.Load())
+
+ // All callers must receive correct 206 responses.
+ for i := 0; i < N; i++ {
+ assert.Equal(t, http.StatusPartialContent, codes[i], "caller %d: status mismatch", i)
+ assert.NotEmpty(t, bodies[i], "caller %d: body should not be empty", i)
+
+ t.Logf("caller %d: hash: %s range: %s X-Cache: %s Content-Length: %s", i, bodies[i], ranges[i], xCaches[i], cls[i])
+ }
+
+ // Verify body correctness: compare against the source file.
+ expected := e2e.HashFile(file.Path, 524288, 2097151-524288+1)
+ for i := 0; i < N; i++ {
+ actual := bodies[i]
+ assert.Equal(t, expected, actual, "caller %d: body hash mismatch", i)
+ }
+
+ // The concurrent chunk fetch for the missing range must be collapsed.
+ assert.Equal(t, int32(1), originCallCount.Load(),
+ "chunk flight should collapse concurrent chunk fetches to 1 origin call")
+ })
+
+ t.Run("test Collapsed Forwarding ChunkFlight KeyIsolation", func(t *testing.T) {
+ var originCallCount atomic.Int32
+
+ case1 := e2e.New("http://chunkflight.example.com/cf/chunk/keys.bin", e2e.RespCallbackFile(file, func(w http.ResponseWriter, r *http.Request) {
+ t.Logf("process req %s, range %s", r.Header.Get("X-Request-Id"), r.Header.Get("Range"))
+ originCallCount.Add(1)
+
+ w.Header().Set("Cache-Control", "max-age=30")
+ w.Header().Set("ETag", file.MD5)
+ }))
+ defer case1.Close()
+
+ // Phase 1 — cache only the middle chunk (chunk 1, bytes 524288-1048575).
+ resp, err := case1.Do(func(r *http.Request) {
+ r.Header.Set("X-Request-Id", "0")
+ r.Header.Set("Range", "bytes=524288-1048575")
+ })
+
+ require.NoError(t, err)
+ io.Copy(io.Discard, resp.Body)
+ resp.Body.Close()
+ require.Equal(t, http.StatusPartialContent, resp.StatusCode)
+
+ originCallCount.Store(0)
+
+ time.Sleep(300 * time.Millisecond)
+
+ // Phase 2 — request two different missing ranges concurrently.
+ // Range A: bytes=0-524287 (needs chunk 0, not cached)
+ // Range B: bytes=1048576-2097151 (needs chunk 2+, not cached)
+ // With range union, these two ranges collapse into one origin
+ // fetch for the union range bytes=0-2097151.
+ var wg sync.WaitGroup
+
+ ranges := []string{"bytes=0-524287", "bytes=1048576-2097151"}
+ start := make(chan struct{}, len(ranges))
+ errs := make([]error, len(ranges))
+ xCaches := make([]string, len(ranges))
+
+ for i, rng := range ranges {
+ wg.Add(1)
+ go func(idx int, rng string) {
+ defer wg.Done()
+ <-start
+
+ resp2, e := case1.Do(func(r *http.Request) {
+ t.Logf("started for caller %d with range %s", idx+1, rng)
+
+ r.Header.Set("X-Request-Id", strconv.Itoa(idx+1))
+ r.Header.Set("Range", rng)
+ })
+ if e != nil {
+ t.Logf("caller %d: request error: %v", idx+1, e)
+ errs[idx] = e
+ return
+ }
+
+ io.Copy(io.Discard, resp2.Body)
+ resp2.Body.Close()
+
+ xCaches[idx] = resp2.Header.Get("X-Cache")
+ }(i, rng)
+ }
+
+ time.Sleep(10 * time.Millisecond)
+ close(start)
+ wg.Wait()
+
+ for i, e := range errs {
+ assert.NoError(t, e, "request %d should not error", i)
+ t.Logf("caller %d X-Cache: %s", i+1, xCaches[i])
+ }
+
+ // With range union, the two different ranges are collapsed into
+ // one origin fetch covering the union bytes=0-2097151.
+ assert.Equal(t, int32(1), originCallCount.Load(),
+ "different byte ranges should be collapsed via range union")
+
+ })
+
+ t.Run("test Collapsed Forwarding ChunkFlight RangeUnion", func(t *testing.T) {
+ var originCallCount atomic.Int32
+
+ case1 := e2e.New("http://chunkflight.example.com/cf/chunk/union.bin", e2e.RespCallbackFile(file, func(w http.ResponseWriter, r *http.Request) {
+ t.Logf("process req %s, range %s", r.Header.Get("X-Request-Id"), r.Header.Get("Range"))
+ originCallCount.Add(1)
+
+ w.Header().Set("Cache-Control", "max-age=30")
+ w.Header().Set("ETag", file.MD5)
+ }))
+ defer case1.Close()
+
+ // Phase 1 — cache only the middle chunk (chunk 1, bytes 524288-1048575).
+ resp, err := case1.Do(func(r *http.Request) {
+ r.Header.Set("X-Request-Id", "0")
+ r.Header.Set("Range", "bytes=524288-1048575")
+ })
+
+ require.NoError(t, err)
+ io.Copy(io.Discard, resp.Body)
+ resp.Body.Close()
+ require.Equal(t, http.StatusPartialContent, resp.StatusCode)
+
+ originCallCount.Store(0)
+
+ time.Sleep(300 * time.Millisecond)
+
+ // Phase 2 — three concurrent requests for different ranges that
+ // need missing chunks. With range union, all three share a single
+ // origin fetch whose range covers the union of all requested ranges.
+ type reqSpec struct {
+ id string
+ rng string
+ from int
+ length int
+ }
+ specs := []reqSpec{
+ {"1", "bytes=0-524287", 0, 524288},
+ {"2", "bytes=0-1048575", 0, 1048576},
+ {"3", "bytes=1048576-2097151", 1048576, 2097151 - 1048576 + 1},
+ }
+
+ var wg sync.WaitGroup
+ start := make(chan struct{})
+ errs := make([]error, len(specs))
+ hashes := make([]string, len(specs))
+ xCaches := make([]string, len(specs))
+
+ for i, spec := range specs {
+ wg.Add(1)
+ go func(idx int, s reqSpec) {
+ defer wg.Done()
+ <-start
+
+ resp2, e := case1.Do(func(r *http.Request) {
+ r.Header.Set("X-Request-Id", s.id)
+ r.Header.Set("Range", s.rng)
+ })
+ if e != nil {
+ t.Logf("caller %s: request error: %v", s.id, e)
+ errs[idx] = e
+ return
+ }
+
+ hashes[idx] = e2e.HashBody(resp2)
+ resp2.Body.Close()
+
+ xCaches[idx] = resp2.Header.Get("X-Cache")
+ }(i, spec)
+ }
+
+ time.Sleep(10 * time.Millisecond)
+ close(start)
+ wg.Wait()
+
+ for i, e := range errs {
+ assert.NoError(t, e, "request %d should not error", i)
+ t.Logf("caller %s X-Cache: %s hash: %s", specs[i].id, xCaches[i], hashes[i])
+ }
+
+ // All three different ranges should collapse into one origin call via
+ // automatic range union.
+ assert.Equal(t, int32(1), originCallCount.Load(),
+ "range union should collapse different ranges into 1 origin call")
+
+ // Each caller must receive the correct bytes for its range.
+ for i, spec := range specs {
+ expected := e2e.HashFile(file.Path, spec.from, spec.length)
+ assert.Equal(t, expected, hashes[i],
+ "caller %s (range %s): body hash mismatch", spec.id, spec.rng)
+ }
+ })
+
+ t.Run("PURGE", func(t *testing.T) {
+ e2e.SetDump(true)
+ e2e.Purge(t, "http://chunkflight.example.com/cf/chunk/collapse.bin")
+ e2e.Purge(t, "http://chunkflight.example.com/cf/chunk/keys.bin")
+ e2e.Purge(t, "http://chunkflight.example.com/cf/chunk/union.bin")
+ })
+}
diff --git a/tests/all-features/filechanged/contentlength_test.go b/tests/all-features/filechanged/contentlength_test.go
index 77773e1..aeb621d 100644
--- a/tests/all-features/filechanged/contentlength_test.go
+++ b/tests/all-features/filechanged/contentlength_test.go
@@ -109,7 +109,7 @@ func TestContengLenChanged(t *testing.T) {
}
func TestContentLenShorter(t *testing.T) {
- f := e2e.GenFile(t, 1<<20)
+ f := e2e.GenFile(t, 1048576)
f2 := e2e.GenFile(t, 1048570)
t.Run("test content-length shorter old file", func(t *testing.T) {
@@ -121,6 +121,7 @@ func TestContentLenShorter(t *testing.T) {
rr := xhttp.NewRequestRange(1, 400000)
resp, err := case1.Do(func(r *http.Request) {
+ r.Header.Set("X-Request-Id", t.Name())
r.Header.Set("Range", rr.String())
})
@@ -147,6 +148,7 @@ func TestContentLenShorter(t *testing.T) {
rr := xhttp.NewRequestRange(600000, 0) // bytes=600000-
resp, err := case1.Do(func(r *http.Request) {
+ r.Header.Set("X-Request-Id", t.Name())
r.Header.Set("Range", rr.String())
})
@@ -172,6 +174,7 @@ func TestContentLenShorter(t *testing.T) {
rr := xhttp.NewRequestRange(0, 0)
resp, err := case1.Do(func(r *http.Request) {
+ r.Header.Set("X-Request-Id", t.Name())
r.Header.Set("Range", rr.String())
})
diff --git a/tests/config.test.yaml b/tests/config.test.yaml
index 3970bd5..2fc5f28 100644
--- a/tests/config.test.yaml
+++ b/tests/config.test.yaml
@@ -74,10 +74,6 @@ plugin:
- "localhost"
- "@"
log_path: ./logs/purge.log
- - name: watchdog
- options:
- check_interval: 1s
- timeout_threshold: 3
- name: verifier
options:
endpoint: https://crc-svc.omalloc.com/receive