diff --git a/docs/storage-select-flow.md b/docs/storage-select-flow.md new file mode 100644 index 0000000..b0ab2fa --- /dev/null +++ b/docs/storage-select-flow.md @@ -0,0 +1,421 @@ +# Storage Select 完整流程 + +本文档详细描述 Tavern 存储层 `Select` 方法的完整调用链路、数据结构与核心算法。 + +--- + +## 1. 调用入口 + +Select 被调用的场景有三个: + +| 入口 | 位置 | 说明 | +|------|------|------| +| 缓存中间件 | `server/middleware/caching/processor.go:103` | 需要读写缓存对象时,通过 `object.ID` 找到对应的 Bucket | +| 全局函数 | `storage/global.go:30` | 持有全局 `defaultStorage` 的锁,代理调用 | +| PURGE 删除 | `storage/storage.go:220` | 单对象删除时,先 Select 再 Discard | + +```go +// 缓存中间件调用 +bucket := store.Select(req.Context(), objectID) + +// 全局函数调用 +bucket := storage.Select(ctx, cacheKey) +``` + +--- + +## 2. 完整流程图 + +```mermaid +flowchart TD + subgraph 调用入口 + A1["缓存中间件
processor.go:103
bucket := store.Select(ctx, objectID)"] + A2["全局函数
storage.Select(ctx, cacheKey)
global.go:30"] + A3["PURGE 单对象删除
storage.go:220
bucket := n.Select(ctx, cacheKey)"] + end + + A1 --> B{defaultStorage} + A2 --> B + A3 --> B + + B --> C{"storage.New() 创建时
config.Migration.Enabled ?"} + + C -->|false
普通模式| D["nativeStorage.Select()
storage.go:136"] + C -->|true
迁移模式| E["migratorStorage.Select()
migrator.go:169"] + + subgraph 普通模式 nativeStorage + D --> D1["n.selector.Select(ctx, id)
直接委托给 warmSelector"] + D1 --> F + end + + subgraph 迁移模式 migratorStorage + E --> E1["chainSelector(ctx, id,
hotSelector, warmSelector, coldSelector)
migrator.go:240"] + E1 --> E2["遍历 selector 列表"] + E2 --> E2A["sel.Select(ctx, id) → bucket"] + E2A --> E2B{"bucket != nil
&&
bucket.Exist(ctx, id.Bytes()) ?"} + E2B -->|"是(对象已存在该层)"| E2C["返回该 bucket"] + E2B -->|"否"| E2D{"还有下一个
selector ?"} + E2D -->|是| E2 + E2D -->|"否(所有层都不存在)"| E2E["兜底: warmSelector.Select(ctx, id)"] + E2E --> F + end + + subgraph SelectLayer 按层选择 + SL1["migratorStorage.SelectLayer(ctx, id, layer)
migrator.go:220"] + SL1 --> SL2{"layer ?"} + SL2 -->|"hot"| SL3["hotSelector.Select(ctx, id)"] + SL2 -->|"warm/normal"| SL4["warmSelector.Select(ctx, id)"] + SL2 -->|"cold"| SL5["coldSelector.Select(ctx, id)"] + SL2 -->|"inmemory"| SL6["直接返回 memoryBucket"] + end + + subgraph HashRing 一致性哈希选择器 + F["selector.New(buckets, 'hashring')
selector.go:12"] + F --> G["Balancer.Select(ctx, id)
hashring.go:43"] + + G --> G1["遍历 i = 1..len(buckets)"] + G1 --> G2["hashring.GetN(id.Bytes(), i)
获取 i 个最近节点"] + G2 --> G3["Consistent.GetN(name, n)
consistent.go:173"] + + G3 --> G4["FNV-32a 哈希 key"] + G4 --> G5["二分搜索排序哈希环
找到起始位置"] + G5 --> G6["从起始位置顺时针遍历
收集 n 个不重复节点"] + + G6 --> G7["取第 i-1 个节点 bucket"] + G7 --> G8{"bucket.UseAllow() ?
允许百分比检查"} + G8 -->|否| G9["continue 下一个"] + G8 -->|是| G10{"bucket.HasBad() ?
健康检查"} + G10 -->|"是(不健康)"| G9 + G10 -->|"否(健康)"| G11["返回该 bucket"] + G9 --> G1 + end + + subgraph DirAware 目录感知包装 + H["diraware.New(nativeStorage, checker)
diraware/storage.go:19"] + H --> H1["wrappedStorage.Select(ctx, id)
diraware/storage.go:34"] + H1 --> H2["base.Select(ctx, id) → bucket"] + H2 --> H3["wrapBucket(bucket, checker)
包装 bucket 添加过期标记能力"] + H3 --> H4["返回 wrappedBucket"] + end + + D -->|"config.DirAware.Enabled"| H + E -->|"config.DirAware.Enabled"| H + + subgraph 哈希环构建 Consistent Hash Ring + R1["Consistent.Set(caches)
consistent.go:109"] + R1 --> R2["遍历每个 Node (Bucket)"] + R2 --> R3["add(cache, cache.Weight())"] + R3 --> R4["对每个节点:
NumberOfReplicas(20) × Weight 个虚拟节点"] + R4 --> R5["key = hash('{idx}|{weight}|{bucketID}')
circles[key] = bucket"] + R5 --> R6["updateSortedHashes()
排序所有哈希值"] + end + + F -.->|初始化时| R1 + + subgraph Bucket 存储类型分层 + BT1["memoryBucket
TypeInMemory"] + BT2["hotBucket[]
TypeHot"] + BT3["warmBucket[]
TypeWarm / TypeNormal"] + BT4["coldBucket[]
TypeCold
(仅 migrator 模式)"] + end + + G -.->|选择范围| BT3 + E2A -.->|hotSelector 范围| BT2 + E2A -.->|warmSelector 范围| BT3 + E2A -.->|coldSelector 范围| BT4 +``` + +--- + +## 3. 两大模式对比 + +| 模式 | 条件 | 实现结构体 | 核心逻辑 | +|------|------|-----------|---------| +| **普通模式** | `Migration.Enabled = false` | `nativeStorage` | 直接委托给 warmSelector(一致性哈希),不检查对象是否存在 | +| **迁移模式** | `Migration.Enabled = true` | `migratorStorage` | Hot → Warm → Cold 链式查找,逐层调用 `bucket.Exist()` 检查对象是否存在 | + +### 3.1 普通模式 (`nativeStorage`) + +```go +// storage/storage.go:136 +func (n *nativeStorage) Select(ctx context.Context, id *object.ID) storage.Bucket { + bucket := n.selector.Select(ctx, id) + return bucket +} +``` + +极其简单:一行委托。`n.selector` 即 `warmSelector`,是一个哈希环选择器。 + +### 3.2 迁移模式 (`migratorStorage`) + +```go +// storage/migrator.go:169 +func (m *migratorStorage) Select(ctx context.Context, id *object.ID) storage.Bucket { + return m.chainSelector(ctx, id, + m.hotSelector, + m.warmSelector, + m.coldSelector, + ) +} +``` + +核心在 `chainSelector`: + +```go +// storage/migrator.go:240 +func (m *migratorStorage) chainSelector(ctx context.Context, id *object.ID, selectors ...storage.Selector) storage.Bucket { + for _, sel := range selectors { + if sel == nil { + continue + } + if bucket := sel.Select(ctx, id); bucket != nil && bucket.Exist(ctx, id.Bytes()) { + return bucket + } + } + // 兜底: 返回 warmSelector 的结果 + return m.warmSelector.Select(ctx, id) +} +``` + +**关键差异**:迁移模式不仅做哈希定位,还会调用 `bucket.Exist()` 检查对象是否真的在该 Bucket 的 IndexDB 中。这是为了支持对象的 **Promote/Demote** 跨层迁移: + +- **Promote**:Cold → Warm → Hot(访问量上升) +- **Demote**:Hot → Warm → Cold(访问量下降) + +### 3.3 按层选择 (`SelectLayer`) + +```go +// storage/migrator.go:220 +func (m *migratorStorage) SelectLayer(ctx context.Context, id *object.ID, layer string) storage.Bucket { + switch layer { + case storage.TypeHot: + if m.hotSelector != nil { + return m.hotSelector.Select(ctx, id) + } + case storage.TypeNormal, storage.TypeWarm: + if m.warmSelector != nil { + return m.warmSelector.Select(ctx, id) + } + case storage.TypeCold: + if m.coldSelector != nil { + return m.coldSelector.Select(ctx, id) + } + case storage.TypeInMemory: + return m.memoryBucket + } + return nil +} +``` + +用于迁移操作(Promote/Demote)时,根据目标层直接定位到对应 Bucket。 + +--- + +## 4. HashRing 一致性哈希(核心算法) + +### 4.1 选择器工厂 + +```go +// storage/selector/selector.go:12 +func New(buckets []storage.Bucket, typ string) storage.Selector { + curr, err := hashring.New(buckets, hashring.WithReplicas(20)) + if err != nil { + panic(err) + } + return curr +} +``` + +目前仅支持 `hashring` 类型,`typ` 参数实际未使用。 + +### 4.2 Balancer.Select 算法 + +```go +// storage/selector/hashring/hashring.go:43 +func (b *Balancer) Select(ctx context.Context, id *object.ID) storage.Bucket { + for i := 1; i <= len(b.buckets); i++ { + groups, err := b.hashring.GetN(string(id.Bytes()), i) + if err != nil { + return nil + } + bucket := groups[i-1].(storage.Bucket) + if bucket.UseAllow() { + if bucket.HasBad() { + continue + } + return bucket + } + } + return nil +} +``` + +**算法步骤**: + +1. `i=1` 开始,每次递增,调用 `GetN(key, i)` 获取 i 个最近节点 +2. 取第 `i-1` 个(最后一个,即第 i 近的)节点 +3. 检查 `UseAllow()`(允许百分比)和 `HasBad()`(健康状态) +4. 如果满足条件则返回,否则 `i++` 找下一个更远的节点 +5. 这是一种**退避策略**:优先返回最近的健康节点 + +### 4.3 Consistent.GetN — 哈希环查找 + +```go +// storage/selector/hashring/consistent.go:173 +func (c *Consistent) GetN(name string, n int) ([]Node, error) { + key := c.hashKey(name) // FNV-32a 哈希 + i := c.search(key) // 二分搜索定位 + // 从起始位置顺时针遍历,收集 n 个不重复节点 + // ... +} +``` + +**关键参数**: + +| 参数 | 值 | 说明 | +|------|-----|------| +| 哈希函数 | FNV-32a | `hash/fnv` 标准库 | +| 虚拟副本数 | 20 | `NumberOfReplicas`,可通过 `WithReplicas` 配置 | +| 虚拟节点 key | `"{idx}|{weight}|{bucketID}"` | 每个副本 × 每个权重单位生成一个虚拟节点 | + +### 4.4 哈希环构建 + +```go +// storage/selector/hashring/consistent.go:109 +func (c *Consistent) Set(caches []Node) { + // 移除不再存在的节点 + // 新增节点: add(cache, cache.Weight()) + // → NumberOfReplicas × Weight 个虚拟节点 + // → 每个虚拟节点: circles[hash("{idx}|{weight}|{id}")] = bucket + // updateSortedHashes() → sort +} +``` + +**总虚拟节点数** = `∑(20 × Bucket.Weight)`,例如 3 个 Bucket 各权重 10 → 600 个虚拟节点。 + +--- + +## 5. DirAware 目录感知包装 + +当 `config.DirAware.Enabled = true` 时,`nativeStorage` 或 `migratorStorage` 被 `wrappedStorage` 包装。 + +```go +// storage/diraware/storage.go:34 +func (w *wrappedStorage) Select(ctx context.Context, id *object.ID) storagev1.Bucket { + return wrapBucket(w.base.Select(ctx, id), w.checker) +} +``` + +### 5.1 Bucket 包装 — Lookup 注入 + +```go +// storage/diraware/bucket.go:26 +func (b *wrappedBucket) Lookup(ctx context.Context, id *object.ID) (*object.Metadata, error) { + md, err := b.base.Lookup(ctx, id) + if err != nil || md == nil { + return md, err + } + marked, err := b.checker.Marked(ctx, id, md) + if marked { + md.ExpiresAt = time.Now().Add(-1 * time.Second).Unix() + } + return md, nil +} +``` + +### 5.2 Checker 标记逻辑 + +```go +// storage/diraware/diraware.go:74 +func (c *checker) Marked(ctx context.Context, id *object.ID, md *object.Metadata) (bool, error) { + unix, found := c.pathtrie.Search(id.Path()) + if found && md.RespUnix <= unix { + return true, nil // 对象在推送目录任务之前保存的 → 标记过期 + } + return false, nil +} +``` + +**逻辑**:前缀树(PathTrie)中存储了被推送的目录路径及推送时间。当 `Lookup` 时,如果对象所在路径被推送标记过,且对象的最后修改时间 ≤ 推送时间,说明这个对象是在推送之前缓存的,应当标记为过期。 + +--- + +## 6. 各 Selector 覆盖的 Bucket 范围 + +| Selector | 管理的 Bucket 列表 | 对应 StoreType | +|----------|-------------------|----------------| +| `warmSelector` | `warmBucket[]` | `TypeWarm` / `TypeNormal`(Normal 自动别名为 Warm) | +| `hotSelector` | `hotBucket[]` | `TypeHot`(仅 migrator 模式) | +| `coldSelector` | `coldBucket[]` | `TypeCold`(仅 migrator 模式) | +| `memoryBucket` | 单例 | `TypeInMemory`(直接引用,不走哈希环) | + +### 6.1 Bucket 初始化时的分类 + +```go +// storage/storage.go:104 (reinit 方法) +switch bucket.StoreType() { +case storage.TypeNormal, storage.TypeWarm: + n.warmlBucket = append(n.warmlBucket, bucket) +case storage.TypeHot: + n.hotBucket = append(n.hotBucket, bucket) +case storage.TypeInMemory: + n.memoryBucket = bucket // 只能有一个 +case storage.TypeCold: // 仅 migrator + m.coldBucket = append(m.coldBucket, bucket) +} +``` + +> **注意**:`TypeNormal` 在 `mergeConfig` 中会被自动转为 `TypeWarm`(`storage/builder.go:61`)。 + +--- + +## 7. 关键接口定义 + +```go +// api/defined/v1/storage/storage.go + +type Selector interface { + Select(ctx context.Context, id *object.ID) Bucket + Rebuild(ctx context.Context, buckets []Bucket) error +} + +type Storage interface { + io.Closer + Selector // 内嵌 Selector + Buckets() []Bucket + SharedKV() SharedKV + PURGE(storeUrl string, typ PurgeControl) error +} + +type Migrator interface { + Storage // 内嵌 Storage + SelectLayer(ctx context.Context, id *object.ID, layer string) Bucket +} +``` + +完整实现矩阵: + +| 接口 | `nativeStorage` | `migratorStorage` | `wrappedStorage` (diraware) | +|------|:---:|:---:|:---:| +| `Storage` | ✓ | ✓ | ✓ | +| `Migrator` | ✗ | ✓ | ✗ | +| `Select()` | 委托 warmSelector | chainSelector(Hot→Warm→Cold) | 包装 base.Select() | +| `SelectLayer()` | N/A | 按层直接选择 | N/A | + +--- + +## 8. 文件索引 + +| 文件 | 内容 | +|------|------| +| `api/defined/v1/storage/storage.go` | `Selector`、`Storage`、`Migrator`、`Bucket` 接口定义 | +| `storage/global.go` | 全局 `Select()`、`SetDefault()`、`Current()` | +| `storage/storage.go` | `nativeStorage` 实现(普通模式) | +| `storage/migrator.go` | `migratorStorage` 实现(迁移模式) | +| `storage/builder.go` | `NewBucket()` 工厂、配置合并 | +| `storage/selector/selector.go` | Selector 工厂(目前仅 hashring) | +| `storage/selector/hashring/hashring.go` | `Balancer.Select()` 算法 | +| `storage/selector/hashring/consistent.go` | 一致性哈希环实现 | +| `storage/diraware/storage.go` | DirAware Storage 包装器 | +| `storage/diraware/bucket.go` | DirAware Bucket 包装器(Lookup 注入) | +| `storage/diraware/diraware.go` | Checker 实现(PathTrie + SharedKV) | diff --git a/pkg/e2e/e2e.go b/pkg/e2e/e2e.go index 9268bd5..7e67337 100644 --- a/pkg/e2e/e2e.go +++ b/pkg/e2e/e2e.go @@ -88,14 +88,16 @@ func (e *E2E) Do(rewrite func(r *http.Request)) (*http.Response, error) { // wait for a while to let the server ready time.Sleep(time.Millisecond * 100) - rewrite(e.req) + nr := e.req.Clone(context.Background()) - method := e.req.Method + rewrite(nr) - e.req.Header.Set(protocol.InternalUpstreamAddr, e.ts.Listener.Addr().String()) + method := nr.Method + + nr.Header.Set(protocol.InternalUpstreamAddr, e.ts.Listener.Addr().String()) if dumpReq.Load() && method != "PURGE" { - DumpReq(e.req) + DumpReq(nr) } if manual.Load() { @@ -103,7 +105,7 @@ func (e *E2E) Do(rewrite func(r *http.Request)) (*http.Response, error) { time.Sleep(time.Second * 20) } - resp, err := e.cs.Do(e.req) + resp, err := e.cs.Do(nr) e.resp = resp e.err = err @@ -172,11 +174,15 @@ func DumpResp(resp *http.Response) { fmt.Println(string(buf)) } -func Purge(t *testing.T, url string) { +func PurgeMethod(t *testing.T, url string, dir bool) { resp, err := New(url, func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadGateway) }).Do(func(r *http.Request) { r.Method = "PURGE" + r.Header.Set("Purge-Type", "file,hard") + if dir { + r.Header.Set("Purge-Type", "dir,hard") + } }) assert.NoError(t, err, "purge should not error") @@ -187,3 +193,7 @@ func Purge(t *testing.T, url string) { t.Logf("Purge %s success", url) } } + +func Purge(t *testing.T, url string) { + PurgeMethod(t, url, false) +} diff --git a/server/middleware/caching/caching.go b/server/middleware/caching/caching.go index 36f9839..0639a51 100644 --- a/server/middleware/caching/caching.go +++ b/server/middleware/caching/caching.go @@ -12,6 +12,7 @@ import ( "time" "github.com/kelindar/bitmap" + "github.com/prometheus/client_golang/prometheus" "github.com/omalloc/tavern/api/defined/v1/event" configv1 "github.com/omalloc/tavern/api/defined/v1/middleware" @@ -25,8 +26,6 @@ import ( "github.com/omalloc/tavern/proxy" "github.com/omalloc/tavern/server/middleware" storagev1 "github.com/omalloc/tavern/storage" - - "github.com/prometheus/client_golang/prometheus" ) const BYPASS = "BYPASS" @@ -114,6 +113,12 @@ func Middleware(c *configv1.Middleware) (middleware.Middleware, func(), error) { proxyClient := proxy.GetProxy() store := storagev1.Current() + // Flight groups for collapsed forwarding at object and chunk level. + // These mirror Squid's collapsed_forwarding: one origin request + // serves many waiting clients. + objectFlight := &ObjectFlightGroup{} + chunkFlight := &ChunkFlightGroup{} + return middleware.RoundTripperFunc(func(req *http.Request) (resp *http.Response, err error) { // only cache GET/HEAD request if req.Method != http.MethodGet && req.Method != http.MethodHead { @@ -132,9 +137,20 @@ func Middleware(c *configv1.Middleware) (middleware.Middleware, func(), error) { // cachingPool.Put(caching) //}() + // Wire up chunk-level collapsed forwarding. + caching.chunkFlight = chunkFlight + + // Increment cacheRequestTotal on every response path (BYPASS, + // HIT, MISS-collapsed, MISS-direct). Placed as a defer so it + // fires after cacheStatus is finalized. + defer func() { + cacheRequestTotal.WithLabelValues(caching.cacheStatus.String(), caching.bucket.StoreType()).Inc() + }() + // err to BYPASS caching if err != nil { - caching.log.Warnf("Precache processor failed: %v, BYPASS", err) + caching.log.Warnf("Precache processor failed: %v BYPASS", err) + caching.cacheStatus = storage.BYPASS resp, err = caching.doProxy(req, false) // do reverse proxy if err != nil { return nil, err @@ -144,57 +160,70 @@ func Middleware(c *configv1.Middleware) (middleware.Middleware, func(), error) { // set cache-staus header BYPASS resp.Header.Set(protocol.ProtocolCacheStatusKey, BYPASS) } - cacheRequestTotal.WithLabelValues(storage.BYPASS.String(), caching.bucket.StoreType()).Inc() return } // cache HIT if caching.hit { - caching.cacheStatus = storage.CacheHit - - rng, err1 := xhttp.SingleRange(req.Header.Get("Range"), caching.md.Size) - if err1 != nil { - // 无效 Range 处理 - headers := make(http.Header) - xhttp.CopyHeader(caching.md.Headers, headers) - headers.Set("Content-Range", fmt.Sprintf("bytes */%d", caching.md.Size)) - cacheRequestTotal.WithLabelValues(caching.cacheStatus.String(), caching.bucket.StoreType()).Inc() - return nil, xhttp.NewBizError(http.StatusRequestedRangeNotSatisfiable, headers) - } - - // mark cache status with Range requests. - caching.markCacheStatus(rng.Start, rng.End) + return caching.respondFromCache(req) + } - // find file seek(start, end) - resp, err = caching.lazilyRespond(req, rng.Start, rng.End) - if err != nil { - // fd leak - closeBody(resp) - cacheRequestTotal.WithLabelValues(caching.cacheStatus.String(), caching.bucket.StoreType()).Inc() - return nil, err + // full MISS — use object-level collapsed forwarding so that + // concurrent requests for the same cache object share one + // origin fetch (Squid-style collapsed_forwarding). + if opts.CollapsedRequest { + flightResp, _, flightErr := objectFlight.Do(caching.id.HashStr(), opts.CollapsedRequestWaitTimeout.AsDuration(), func() (*http.Response, error) { + r, e := caching.doProxy(req, false) + if e != nil { + return nil, e + } + return processor.postCacheProcessor(caching, req, r) + }) + if flightErr != nil { + return nil, flightErr } - - // response now - resp, err = caching.processor.postCacheProcessor(caching, req, resp) - cacheRequestTotal.WithLabelValues(caching.cacheStatus.String(), caching.bucket.StoreType()).Inc() + resp = flightResp return } - // full MISS + // full MISS (collapsed forwarding disabled) resp, err = caching.doProxy(req, false) if err != nil { - cacheRequestTotal.WithLabelValues(caching.cacheStatus.String(), caching.bucket.StoreType()).Inc() return nil, err } resp, err = processor.postCacheProcessor(caching, req, resp) - cacheRequestTotal.WithLabelValues(caching.cacheStatus.String(), caching.bucket.StoreType()).Inc() return }) }, middleware.EmptyCleanup, nil } +// respondFromCache assembles a response from cached chunks for a cache HIT. +// It parses the Range header, builds a multi-part reader from disk, and +// runs post-cache processing (headers, cache status, store). +func (c *Caching) respondFromCache(req *http.Request) (*http.Response, error) { + c.cacheStatus = storage.CacheHit + + rng, err := xhttp.SingleRange(req.Header.Get("Range"), c.md.Size) + if err != nil { + headers := make(http.Header) + xhttp.CopyHeader(c.md.Headers, headers) + headers.Set("Content-Range", fmt.Sprintf("bytes */%d", c.md.Size)) + return nil, xhttp.NewBizError(http.StatusRequestedRangeNotSatisfiable, headers) + } + + c.markCacheStatus(rng.Start, rng.End) + + resp, err := c.lazilyRespond(req, rng.Start, rng.End) + if err != nil { + closeBody(resp) + return nil, err + } + + return c.processor.postCacheProcessor(c, req, resp) +} + func (c *Caching) lazilyRespond(req *http.Request, start, end int64) (*http.Response, error) { // 这里通过缓存的块大小来计算,而不是配置默认的 SliceSize // 这样已缓存的对象可以使用原来的配置块大小,不受配置文件变更影响 @@ -261,27 +290,25 @@ func (c *Caching) lazilyRespond(req *http.Request, start, end int64) (*http.Resp func (c *Caching) getUpstreamReader(fromByte, toByte uint64, async bool) (io.ReadCloser, error) { // get from origin request header rawRange := c.req.Header.Get("Range") - newRange := fmt.Sprintf("bytes=%d-%d", fromByte, toByte) - req := c.req.Clone(context.Background()) - req.Header.Set("Range", newRange) - // add request-id [range] - // req.Header.Set("X-Request-ID", fmt.Sprintf("%s-%d", req.Header.Get(appctx.ProtocolRequestIDKey), fromByte)) // 附加 Request-ID suffix - // remove all internal header - req.Header.Del(protocol.ProtocolCacheStatusKey) + // doSubRequest is parameterized by the union range so the flight group + // can expand the range to cover multiple concurrent callers. + doSubRequest := func(unionFrom, unionTo uint64) (*http.Response, error) { + newRange := fmt.Sprintf("bytes=%d-%d", unionFrom, unionTo) + subReq := c.req.Clone(context.Background()) + subReq.Header.Set("Range", newRange) + subReq.Header.Del(protocol.ProtocolCacheStatusKey) - doSubRequest := func() (*http.Response, error) { now := time.Now() c.log.Debugf("getUpstreamReader doProxy[chunk]: begin: %s, rawRange: %s, newRange: %s", now, rawRange, newRange) - resp, err := c.doProxy(req, true) + resp, err := c.doProxy(subReq, true) c.log.Infof("getUpstreamReader doProxy[chunk]: timeCost: %s, rawRange: %s, newRange: %s", time.Since(now), rawRange, newRange) if err != nil { closeBody(resp) return nil, err } - // 部分命中 - c.cacheStatus = storage.CachePartHit - // 发起的是 206 请求,但是返回的非 206 + // 206 Partial Content is expected for range requests. + // If we get a different status code, it may indicate an issue with the upstream response. if resp.StatusCode != http.StatusPartialContent { c.log.Warnf("getUpstreamReader doProxy[chunk]: status code: %d, bod size: %d", resp.StatusCode, resp.ContentLength) return resp, xhttp.NewBizError(resp.StatusCode, resp.Header) @@ -289,11 +316,37 @@ func (c *Caching) getUpstreamReader(fromByte, toByte uint64, async bool) (io.Rea return resp, nil } + // Chunk-level collapsed forwarding: if another goroutine is already + // fetching (possibly a different) byte range for this object, wait + // and share the union response body (io.MultiWriter fan-out + RangeReader + // trimming). This mirrors Squid's collapsed_forwarding at the chunk + // level, with automatic range union. + if c.chunkFlight != nil && c.opt.CollapsedRequest && c.id != nil { + reader, _, err := c.chunkFlight.Do(c.id.HashStr(), fromByte, toByte, + c.opt.CollapsedRequestWaitTimeout.AsDuration(), doSubRequest) + // Both leader and shared callers are partial hits: at least one + // chunk was missing and is being fetched from origin. + + // c.cacheStatus = storage.CachePartMiss + // if shared { + // c.cacheStatus = storage.CachePartHit + // } + return reader, err + } + + // Any path that reaches getUpstreamReader is a partial hit: at least + // one chunk was missing from cache and is now being fetched from + // origin. Set the status here for the async and sync paths (the + // chunk-flight path above already sets it before returning). + // c.cacheStatus = storage.CachePartMiss + if async { - return iobuf.AsyncReadCloser(doSubRequest), nil + return iobuf.AsyncReadCloser(func() (*http.Response, error) { + return doSubRequest(fromByte, toByte) + }), nil } - resp, err := doSubRequest() + resp, err := doSubRequest(fromByte, toByte) if resp != nil { return resp.Body, err } diff --git a/server/middleware/caching/caching_test.go b/server/middleware/caching/caching_test.go index 25c7ec2..5712b16 100644 --- a/server/middleware/caching/caching_test.go +++ b/server/middleware/caching/caching_test.go @@ -1,6 +1,10 @@ package caching -import "testing" +import ( + "net/http" + "strings" + "testing" +) func BenchmarkWithPooling(b *testing.B) { for i := 0; i < b.N; i++ { @@ -8,3 +12,13 @@ func BenchmarkWithPooling(b *testing.B) { c.reset() } } + +func TestObjectFlight_PanicRecovery(t *testing.T) { + g := &ObjectFlightGroup{} + _, _, err := g.Do("key", 0, func() (*http.Response, error) { + panic("boom") + }) + if err == nil || !strings.Contains(err.Error(), "panic") { + t.Fatalf("expected panic error, got %v", err) + } +} diff --git a/server/middleware/caching/chunk_flight.go b/server/middleware/caching/chunk_flight.go new file mode 100644 index 0000000..7a50a01 --- /dev/null +++ b/server/middleware/caching/chunk_flight.go @@ -0,0 +1,197 @@ +package caching + +import ( + "fmt" + "io" + "net/http" + "sync" + "time" + + "github.com/omalloc/tavern/pkg/iobuf" +) + +// chunkRange holds one caller's desired byte range within the object. +type chunkRange struct { + fromByte, toByte uint64 +} + +// chunkCall is an in-flight chunk upstream request. Multiple callers +// register their desired ranges; the leader computes the union and +// fetches a single range that covers everyone. Each caller receives +// their own sub-range trimmed via iobuf.RangeReader. +type chunkCall struct { + pipes []*io.PipeWriter + ranges []chunkRange + mu sync.Mutex // protects pipes and ranges during registration + wg sync.WaitGroup // signals that unionFrom / unionTo are computed + unionFrom uint64 + unionTo uint64 + err error // set when fn fails +} + +// ChunkFlightGroup collapses concurrent upstream requests for different +// byte ranges of the same object into a single origin fetch covering the +// union of all requested ranges. Response body bytes are fanned out to +// all waiters via io.MultiWriter + io.Pipe, and each caller trims the +// shared stream down to its own sub-range with iobuf.RangeReader. +// +// This mirrors Squid's collapsed_forwarding at the chunk/segment level: +// when two goroutines request different byte ranges of the same cached +// object, only one hits origin and the others wait, even when the +// ranges differ. +type ChunkFlightGroup struct { + mu sync.Mutex + m map[string]*chunkCall +} + +// Do executes fn once per objectKey, passing the union of all registered +// byte ranges to fn. All callers — including the first — receive an +// io.PipeReader trimmed to exactly [fromByte, toByte]. The returned +// bool reports whether this caller shared an in-flight request. +// +// waiter is the duration the leader goroutine pauses *before* calling fn, +// giving late-arriving callers a window to register under the same key. +// In production the network round-trip naturally provides this window; +// waiter ensures correctness even when fn would otherwise complete nearly +// instantly (e.g. in tests, or for tiny ranges on a local origin). +// +// Contract: fn owns resp.Body. On success ChunkFlightGroup reads and +// closes it. On error fn must either return (nil, err) or close the body +// before returning (resp, err). +func (g *ChunkFlightGroup) Do(objectKey string, fromByte, toByte uint64, waiter time.Duration, fn func(unionFrom, unionTo uint64) (*http.Response, error)) (io.ReadCloser, bool, error) { + pr, pw := io.Pipe() + + g.mu.Lock() + if g.m == nil { + g.m = make(map[string]*chunkCall) + } + if c, ok := g.m[objectKey]; ok { + // Waiter: register pipe writer and desired range, then wait for + // the leader to compute the union range. + c.mu.Lock() + c.pipes = append(c.pipes, pw) + c.ranges = append(c.ranges, chunkRange{fromByte: fromByte, toByte: toByte}) + c.mu.Unlock() + g.mu.Unlock() + + c.wg.Wait() + c.mu.Lock() + flightErr := c.err + c.mu.Unlock() + if flightErr != nil { + _ = pw.CloseWithError(flightErr) + return nil, true, flightErr + } + // Return immediately — fn is executed asynchronously by the + // leader goroutine, so the caller can build response headers + // before c.md is mutated by the upstream fetch. + return iobuf.RangeReader(pr, int(c.unionFrom), int(c.unionTo), int(fromByte), int(toByte)), true, nil + } + + // Leader: create the flight and register own range. + c := &chunkCall{ + pipes: []*io.PipeWriter{pw}, + ranges: []chunkRange{{fromByte: fromByte, toByte: toByte}}, + } + c.wg.Add(1) + g.m[objectKey] = c + g.mu.Unlock() + + // Pause before hitting origin so concurrent callers have time + // to register under this key. Without this window an instant + // fn would compute the union and delete the map entry before + // anyone else could join. + if waiter > 0 { + time.Sleep(waiter) + } + + // Compute the union range across all registered callers. + c.mu.Lock() + unionFrom := c.ranges[0].fromByte + unionTo := c.ranges[0].toByte + for _, r := range c.ranges[1:] { + if r.fromByte < unionFrom { + unionFrom = r.fromByte + } + if r.toByte > unionTo { + unionTo = r.toByte + } + } + c.unionFrom = unionFrom + c.unionTo = unionTo + c.mu.Unlock() + + // Release waiters — they now know unionFrom/unionTo and can build + // their RangeReader wrappers around the shared pipe. fn has not + // been called yet, so callers can safely read c.md headers before + // the upstream fetch mutates them. + c.wg.Done() + + // Remove the key from the map so that no late-arriving callers can + // join this flight with a stale union range. The waiter window + // (time.Sleep above) is the intentional batching period — callers + // arriving after it will start a fresh flight. This trades a + // possible duplicate origin request for guaranteed range correctness. + g.mu.Lock() + delete(g.m, objectKey) + g.mu.Unlock() + + // The leader returns immediately with a pipe reader. The upstream + // fetch (fn) and body fan-out run in a background goroutine so that + // response headers are built before c.md is touched. + go func() { + + // check for panic to avoid leaving waiters hanging indefinitely + resp, err := func() (r *http.Response, e error) { + defer func() { + if rec := recover(); rec != nil { + e = fmt.Errorf("chunk flight panic: %v", rec) + } + }() + return fn(unionFrom, unionTo) + }() + + // Snapshot pipes under c.mu. The map entry is already deleted + // (above), so no further callers can register against this + // flight — c.mu only protects against the window where waiters + // that registered before the deletion are still being appended. + c.mu.Lock() + pipes := make([]*io.PipeWriter, len(c.pipes)) + copy(pipes, c.pipes) + if err != nil { + c.err = err + } + c.mu.Unlock() + + if err != nil { + for _, p := range pipes { + _ = p.CloseWithError(err) + } + // fn owns resp.Body on error — it must close it before + // returning. We only guard against a nil body here. + return + } + + // Build MultiWriter from all registered pipe writers. + writers := make([]io.Writer, len(pipes)) + for i, p := range pipes { + writers[i] = p + } + mw := io.MultiWriter(writers...) + + _, copyErr := io.Copy(mw, resp.Body) + _ = resp.Body.Close() + + for _, p := range pipes { + if copyErr != nil && copyErr != io.EOF { + _ = p.CloseWithError(copyErr) + } else { + _ = p.Close() + } + } + }() + + // Leader wraps its reader with RangeReader so that it sees exactly + // [fromByte, toByte] trimmed from the union response. + return iobuf.RangeReader(pr, int(unionFrom), int(unionTo), int(fromByte), int(toByte)), false, nil +} diff --git a/server/middleware/caching/collapsed_forwarding_test.go b/server/middleware/caching/collapsed_forwarding_test.go new file mode 100644 index 0000000..c8f8768 --- /dev/null +++ b/server/middleware/caching/collapsed_forwarding_test.go @@ -0,0 +1,519 @@ +package caching + +import ( + "bytes" + "errors" + "io" + "net/http" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +// --------------------------------------------------------------------------- +// ChunkFlightGroup tests +// --------------------------------------------------------------------------- + +func TestChunkFlight_BasicCollapse(t *testing.T) { + g := &ChunkFlightGroup{} + var callCount atomic.Int32 + + // fn returns a body equal in length to the requested range so callers + // can verify their sub-range trimming works. + fn := func(unionFrom, unionTo uint64) (*http.Response, error) { + callCount.Add(1) + size := int(unionTo - unionFrom + 1) + return &http.Response{ + StatusCode: http.StatusPartialContent, + Body: io.NopCloser(bytes.NewReader(makebuf(size))), + }, nil + } + + type result struct { + length int + shared bool + } + + results := make([]result, 3) + var wg sync.WaitGroup + start := make(chan struct{}) + + // All three callers request the same range — classic collapse. + for i := 0; i < 3; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-start + r, shared, err := g.Do("obj1", 0, 1023, 50*time.Millisecond, fn) + if err != nil { + t.Errorf("caller %d: unexpected error: %v", idx, err) + return + } + data, readErr := io.ReadAll(r) + _ = r.Close() + if readErr != nil { + t.Errorf("caller %d: read error: %v", idx, readErr) + return + } + results[idx] = result{len(data), shared} + }(i) + } + + // Release all callers simultaneously so they race on the map entry. + time.Sleep(10 * time.Millisecond) + close(start) + wg.Wait() + + if callCount.Load() != 1 { + t.Fatalf("expected 1 call, got %d", callCount.Load()) + } + + sharedCount := 0 + for _, r := range results { + if r.shared { + sharedCount++ + } + if r.length != 1024 { + t.Errorf("got %d bytes, want 1024", r.length) + } + } + if sharedCount != 2 { + t.Errorf("expected 2 shared callers, got %d", sharedCount) + } +} + +func TestChunkFlight_RangeUnion(t *testing.T) { + g := &ChunkFlightGroup{} + var callCount atomic.Int32 + + fn := func(unionFrom, unionTo uint64) (*http.Response, error) { + callCount.Add(1) + size := int(unionTo - unionFrom + 1) + return &http.Response{ + StatusCode: http.StatusPartialContent, + Body: io.NopCloser(bytes.NewReader(makebuf(size))), + }, nil + } + + type result struct { + length int + shared bool + } + + type caller struct { + from, to uint64 + wantLen int + } + + callers := []caller{ + {0, 999, 1000}, // bytes 0-999 + {500, 1999, 1500}, // bytes 500-1999, overlaps first + {1500, 2999, 1500}, // bytes 1500-2999, overlaps second + } + results := make([]result, len(callers)) + var wg sync.WaitGroup + start := make(chan struct{}) + + for i, c := range callers { + wg.Add(1) + go func(idx int, from, to uint64) { + defer wg.Done() + <-start + r, shared, err := g.Do("union-obj", from, to, 50*time.Millisecond, fn) + if err != nil { + t.Errorf("caller %d: unexpected error: %v", idx, err) + return + } + data, readErr := io.ReadAll(r) + _ = r.Close() + if readErr != nil { + t.Errorf("caller %d: read error: %v", idx, readErr) + return + } + results[idx] = result{len(data), shared} + }(i, c.from, c.to) + } + + time.Sleep(10 * time.Millisecond) + close(start) + wg.Wait() + + // With range union, all three callers share one origin fetch covering + // the union range 0-2999. + if callCount.Load() != 1 { + t.Fatalf("expected 1 call (union), got %d", callCount.Load()) + } + + for i, r := range results { + if r.length != callers[i].wantLen { + t.Errorf("caller %d: got %d bytes, want %d", i, r.length, callers[i].wantLen) + } + } + + sharedCount := 0 + for _, r := range results { + if r.shared { + sharedCount++ + } + } + if sharedCount != len(callers)-1 { + t.Errorf("expected %d shared callers, got %d", len(callers)-1, sharedCount) + } +} + +func TestChunkFlight_ErrorPropagation(t *testing.T) { + g := &ChunkFlightGroup{} + + fn := func(_, _ uint64) (*http.Response, error) { + return nil, errors.New("upstream timeout") + } + + var wg sync.WaitGroup + start := make(chan struct{}) + errs := make([]error, 3) + + for i := 0; i < 3; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-start + r, _, err := g.Do("obj1", 0, 1023, 50*time.Millisecond, fn) + if err != nil { + errs[idx] = err + return + } + _, readErr := io.ReadAll(r) + _ = r.Close() + if readErr != nil { + errs[idx] = readErr + } + }(i) + } + + time.Sleep(10 * time.Millisecond) + close(start) + wg.Wait() + + for i, err := range errs { + if err == nil { + t.Errorf("caller %d: expected error, got nil", i) + } + } +} + +func TestChunkFlight_KeyIsolation(t *testing.T) { + g := &ChunkFlightGroup{} + var callCount atomic.Int32 + + makeFn := func(data string) func(uint64, uint64) (*http.Response, error) { + return func(unionFrom, unionTo uint64) (*http.Response, error) { + callCount.Add(1) + size := int(unionTo - unionFrom + 1) + return &http.Response{ + StatusCode: http.StatusPartialContent, + Body: io.NopCloser(bytes.NewReader(makebuf(size))), + }, nil + } + } + + var wg sync.WaitGroup + results := make(map[string]int, 4) + var mu sync.Mutex + + // Two objects (obj1, obj2), each with two concurrent callers requesting + // different ranges. Within each object the ranges are unioned, but + // different objects are isolated. + type job struct { + key string + from, to uint64 + wantLen int + } + jobs := []job{ + {"obj1", 0, 1048575, 1048576}, + {"obj1", 1048576, 2097151, 1048576}, + {"obj2", 0, 1048575, 1048576}, + {"obj2", 1048576, 2097151, 1048576}, + } + for _, j := range jobs { + wg.Add(1) + go func(k string, from, to uint64) { + defer wg.Done() + r, _, err := g.Do(k, from, to, 50*time.Millisecond, makeFn(k)) + if err != nil { + t.Errorf("key %s: unexpected error: %v", k, err) + return + } + data, _ := io.ReadAll(r) + _ = r.Close() + mu.Lock() + results[k] = len(data) + mu.Unlock() + }(j.key, j.from, j.to) + } + wg.Wait() + + // With object-level keys, obj1's two callers collapse into one + // (union: 0-2097151), obj2's two callers collapse into another. + if callCount.Load() != 2 { + t.Fatalf("expected 2 calls (one per object), got %d", callCount.Load()) + } + + // Each caller must receive exactly their requested byte count. + for _, j := range jobs { + if results[j.key] != j.wantLen { + t.Errorf("key %s: got %d bytes, want %d", j.key, results[j.key], j.wantLen) + } + } +} + +func TestChunkFlight_ConcurrentSameKey(t *testing.T) { + g := &ChunkFlightGroup{} + var callCount atomic.Int32 + + fn := func(_, _ uint64) (*http.Response, error) { + callCount.Add(1) + return &http.Response{ + StatusCode: http.StatusPartialContent, + Body: io.NopCloser(bytes.NewReader(makebuf(1 << 18))), + }, nil + } + + var wg sync.WaitGroup + start := make(chan struct{}) + const numCallers = 10 + sharedCount := atomic.Int32{} + + for i := 0; i < numCallers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-start + // All callers request the same range 0-262143. + r, shared, err := g.Do("same-key", 0, 262143, 100*time.Millisecond, fn) + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + _, _ = io.ReadAll(r) + _ = r.Close() + if shared { + sharedCount.Add(1) + } + }() + } + + time.Sleep(10 * time.Millisecond) + close(start) + wg.Wait() + + if callCount.Load() != 1 { + t.Fatalf("expected exactly 1 origin call, got %d", callCount.Load()) + } + if sharedCount.Load() != numCallers-1 { + t.Fatalf("expected %d shared callers, got %d", numCallers-1, sharedCount.Load()) + } +} + +func TestChunkFlight_PanicRecovery(t *testing.T) { + g := &ChunkFlightGroup{} + + pr, shared, err := g.Do("panic-key", 0, 1023, 0, func(_, _ uint64) (*http.Response, error) { + panic("boom") + }) + if shared { + t.Fatal("expected leader, not shared") + } + + // fn is now called asynchronously — the leader gets the error through + // the pipe, not from the Do return value. + if err != nil { + t.Fatalf("unexpected error from Do: %v", err) + } + + _, readErr := io.ReadAll(pr) + _ = pr.Close() + if readErr == nil || !strings.Contains(readErr.Error(), "panic") { + t.Fatalf("expected panic error from pipe, got %v", readErr) + } +} + +// --------------------------------------------------------------------------- +// ObjectFlightGroup tests +// --------------------------------------------------------------------------- + +func TestObjectFlight_BasicCollapse(t *testing.T) { + g := &ObjectFlightGroup{} + var callCount atomic.Int32 + + fn := func() (*http.Response, error) { + callCount.Add(1) + time.Sleep(30 * time.Millisecond) // simulate origin latency + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("response-body")), + }, nil + } + + var wg sync.WaitGroup + start := make(chan struct{}) + bodies := make([]string, 5) + shareds := make([]bool, 5) + + for i := 0; i < 5; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-start + resp, shared, err := g.Do("cache-key-1", 50*time.Millisecond, fn) + if err != nil { + t.Errorf("caller %d: unexpected error: %v", idx, err) + return + } + shareds[idx] = shared + body, readErr := io.ReadAll(resp.Body) + _ = resp.Body.Close() + if readErr != nil { + t.Errorf("caller %d: read error: %v", idx, readErr) + return + } + bodies[idx] = string(body) + }(i) + } + + time.Sleep(10 * time.Millisecond) + close(start) + wg.Wait() + + if callCount.Load() != 1 { + t.Fatalf("expected 1 call, got %d", callCount.Load()) + } + nonShared := 0 + shared := 0 + for _, s := range shareds { + if s { + shared++ + } else { + nonShared++ + } + } + if nonShared != 1 { + t.Errorf("expected 1 non-shared caller, got %d", nonShared) + } + if shared != 4 { + t.Errorf("expected 4 shared callers, got %d", shared) + } + for i, b := range bodies { + if b != "response-body" { + t.Errorf("caller %d: body = %q, want %q", i, b, "response-body") + } + } +} + +func TestObjectFlight_ErrorPropagation(t *testing.T) { + g := &ObjectFlightGroup{} + var callCount atomic.Int32 + + testErr := errors.New("origin connection refused") + fn := func() (*http.Response, error) { + callCount.Add(1) + time.Sleep(30 * time.Millisecond) // window for dup callers to register + return nil, testErr + } + + var wg sync.WaitGroup + start := make(chan struct{}) + errs := make([]error, 3) + + for i := 0; i < 3; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-start + _, _, err := g.Do("cache-key-err", 50*time.Millisecond, fn) + errs[idx] = err + }(i) + } + + time.Sleep(10 * time.Millisecond) + close(start) + wg.Wait() + + if callCount.Load() != 1 { + t.Fatalf("expected 1 call, got %d", callCount.Load()) + } + for i, err := range errs { + if !errors.Is(err, testErr) { + t.Errorf("caller %d: got %v, want %v", i, err, testErr) + } + } +} + +func TestObjectFlight_KeyIsolation(t *testing.T) { + g := &ObjectFlightGroup{} + var callCount atomic.Int32 + + fn := func() (*http.Response, error) { + callCount.Add(1) + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("body")), + }, nil + } + + var wg sync.WaitGroup + for _, key := range []string{"key-a", "key-b", "key-c"} { + wg.Add(1) + go func(k string) { + defer wg.Done() + resp, _, err := g.Do(k, 0, fn) + if err != nil { + t.Errorf("key %s: unexpected error: %v", k, err) + return + } + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + }(key) + } + wg.Wait() + + if callCount.Load() != 3 { + t.Fatalf("expected 3 distinct calls, got %d", callCount.Load()) + } +} + +func TestObjectFlight_SequentialReuse(t *testing.T) { + g := &ObjectFlightGroup{} + var callCount atomic.Int32 + + fn := func() (*http.Response, error) { + callCount.Add(1) + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("body")), + }, nil + } + + resp, _, err := g.Do("seq-key", 0, fn) + if err != nil { + t.Fatalf("first call: unexpected error: %v", err) + } + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + if callCount.Load() != 1 { + t.Fatalf("first call: expected 1, got %d", callCount.Load()) + } + + time.Sleep(10 * time.Millisecond) + + resp, _, err = g.Do("seq-key", 0, fn) + if err != nil { + t.Fatalf("second call: unexpected error: %v", err) + } + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + if callCount.Load() != 2 { + t.Fatalf("sequential call: expected 2, got %d", callCount.Load()) + } +} diff --git a/server/middleware/caching/internal.go b/server/middleware/caching/internal.go index c45c1a9..55aff5a 100644 --- a/server/middleware/caching/internal.go +++ b/server/middleware/caching/internal.go @@ -45,6 +45,7 @@ type Caching struct { rootmd *object.Metadata bucket storage.Bucket proxyClient proxy.Proxy + chunkFlight *ChunkFlightGroup cacheStatus storage.CacheStatus cacheable bool hit bool diff --git a/server/middleware/caching/object_flight.go b/server/middleware/caching/object_flight.go new file mode 100644 index 0000000..c6a0e77 --- /dev/null +++ b/server/middleware/caching/object_flight.go @@ -0,0 +1,172 @@ +package caching + +import ( + "fmt" + "io" + "net/http" + "sync" + "time" +) + +// objectFlightCall represents an in-flight full-object origin fetch. +// +// Unlike the previous WaitGroup-only approach, this uses io.Pipe + +// io.MultiWriter to fan out the response body to all concurrent callers. +// This ensures the leader's response body is consumed (which drives the +// SavepartAsyncReader → disk writes) while simultaneously providing data +// to all waiting callers — no cache re-lookup is needed. +type objectFlightCall struct { + resp *http.Response + pipes []*io.PipeWriter + mu sync.Mutex // protects pipes during registration and snapshot + wg sync.WaitGroup // signals that resp headers / err are ready + err error +} + +// ObjectFlightGroup collapses concurrent full-MISS requests for the same +// cache object. Unlike ChunkFlightGroup (which works at the chunk/segment +// level), this operates at the whole-object level — it ensures only one +// goroutine hits origin for a given cache key. +// +// The returned response carries the headers from the leader's fn and a +// body that fans out to all concurrent callers. Callers must close the +// body. +type ObjectFlightGroup struct { + mu sync.Mutex + m map[string]*objectFlightCall +} + +// Do executes fn once per key and fans out the response body to all +// concurrent callers. All callers receive the same response headers +// (cloned) and a shared body stream. +// +// waiter is the duration the leader pauses before calling fn, giving +// late-arriving callers a window to register under the same key. +// +// Returns: +// +// resp — a response carrying the leader's headers and a shared body +// shared — true if this caller joined an existing flight +// err — error from fn or from body copy +func (g *ObjectFlightGroup) Do(key string, waiter time.Duration, fn func() (*http.Response, error)) (*http.Response, bool, error) { + pr, pw := io.Pipe() + + g.mu.Lock() + if g.m == nil { + g.m = make(map[string]*objectFlightCall) + } + if c, ok := g.m[key]; ok { + // Waiter: register a pipe writer and wait for headers. + c.mu.Lock() + c.pipes = append(c.pipes, pw) + c.mu.Unlock() + g.mu.Unlock() + + c.wg.Wait() + if c.err != nil { + _ = pw.CloseWithError(c.err) + return nil, true, c.err + } + + resp := cloneResponse(c.resp) + resp.Body = pr + return resp, true, nil + } + + // Leader: create the flight and execute fn. + c := &objectFlightCall{pipes: []*io.PipeWriter{pw}} + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + if waiter > 0 { + time.Sleep(waiter) + } + + // check for panic to avoid leaving waiters hanging indefinitely + resp, err := func() (r *http.Response, e error) { + defer func() { + if rec := recover(); rec != nil { + e = fmt.Errorf("object flight panic: %v", rec) + } + }() + return fn() + }() + + g.mu.Lock() + delete(g.m, key) + + if err != nil { + c.err = err + g.mu.Unlock() + c.wg.Done() + + // Snapshot pipes under c.mu to avoid racing with waiter registrations. + c.mu.Lock() + for _, p := range c.pipes { + _ = p.CloseWithError(err) + } + c.mu.Unlock() + return nil, false, err + } + + c.resp = resp + c.wg.Done() // release waiters — headers are now available + + // Snapshot pipes under c.mu to avoid racing with waiter registrations. + c.mu.Lock() + pipes := make([]*io.PipeWriter, len(c.pipes)) + copy(pipes, c.pipes) + c.mu.Unlock() + g.mu.Unlock() + + // Fan out the response body to all pipes (including the leader's). + // This also drives the SavepartAsyncReader → disk write chain. + go func() { + writers := make([]io.Writer, len(pipes)) + for i, p := range pipes { + writers[i] = p + } + mw := io.MultiWriter(writers...) + + var copyErr error + if resp.Body != nil { + _, copyErr = io.Copy(mw, resp.Body) + _ = resp.Body.Close() + } + + for _, p := range pipes { + if copyErr != nil && copyErr != io.EOF { + _ = p.CloseWithError(copyErr) + } else { + _ = p.Close() + } + } + }() + + leaderResp := cloneResponse(resp) + leaderResp.Body = pr + return leaderResp, false, nil +} + +// cloneResponse returns a shallow copy of resp with a cloned Header map. +// Body is left nil — the caller sets it to a pipe reader. +func cloneResponse(resp *http.Response) *http.Response { + if resp == nil { + return nil + } + return &http.Response{ + Status: resp.Status, + StatusCode: resp.StatusCode, + Proto: resp.Proto, + ProtoMajor: resp.ProtoMajor, + ProtoMinor: resp.ProtoMinor, + Header: resp.Header.Clone(), + ContentLength: resp.ContentLength, + TransferEncoding: resp.TransferEncoding, + Close: resp.Close, + Uncompressed: resp.Uncompressed, + Request: resp.Request, + TLS: resp.TLS, + } +} diff --git a/server/middleware/caching/processor.go b/server/middleware/caching/processor.go index cc4c79a..6ddd94e 100644 --- a/server/middleware/caching/processor.go +++ b/server/middleware/caching/processor.go @@ -102,6 +102,7 @@ func (pc *ProcessorChain) preCacheProcessor(proxyClient proxy.Proxy, store stora // hashring or diskhash bucket := store.Select(req.Context(), objectID) if bucket == nil { + // fallback EMPTY storage return caching, fmt.Errorf("failed select bucket for objectID: %s", objectID) } caching.bucket = bucket diff --git a/storage/bucket/memory/memory.go b/storage/bucket/memory/memory.go index 0a7f05a..2a42d2a 100644 --- a/storage/bucket/memory/memory.go +++ b/storage/bucket/memory/memory.go @@ -283,7 +283,9 @@ func (m *memoryBucket) WriteChunkFile(ctx context.Context, id *object.ID, index _ = m.fs.MkdirAll(filepath.Dir(wpath), m.fileMode) if log.Enabled(log.LevelDebug) { - log.Context(ctx).Infof("write inmemory chunk file %s", wpath) + defer func() { + log.Context(ctx).Infof("write inmemory chunk file %s", wpath) + }() } f, err := m.fs.OpenReadWrite(wpath, vfs.WriteCategoryUnspecified) diff --git a/tests/all-features/caching/collapsed_forwarding_test.go b/tests/all-features/caching/collapsed_forwarding_test.go new file mode 100644 index 0000000..c2baccf --- /dev/null +++ b/tests/all-features/caching/collapsed_forwarding_test.go @@ -0,0 +1,462 @@ +package caching + +import ( + "io" + "net/http" + "strconv" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/omalloc/tavern/pkg/e2e" +) + +func TestCollapsedForwardingObjectFlight(t *testing.T) { + f := e2e.GenFile(t, 2<<20) + + t.Run("test Collapsed Forwarding ObjectFlight Collapse", func(t *testing.T) { + + var originCallCount atomic.Int32 + + case1 := e2e.New("http://objflight.example.com/of/object/collapse.bin", e2e.RespCallbackFile(f, func(w http.ResponseWriter, r *http.Request) { + originCallCount.Add(1) + time.Sleep(80 * time.Millisecond) // window for concurrent registrations + + t.Logf("X-Request-Idx: %s", r.Header.Get("X-Request-Idx")) + + w.Header().Set("Cache-Control", "max-age=10") + w.Header().Set("ETag", "obj-flight-etag") + })) + defer case1.Close() + + const N = 5 + var wg sync.WaitGroup + start := make(chan struct{}, N) + bodies := make([]string, N) + codes := make([]int, N) + xCaches := make([]string, N) + + for i := 0; i < N; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-start + + t.Logf("started for caller %d", idx) + + resp, err := case1.Do(func(r *http.Request) { + r.Header.Set("X-Request-Idx", strconv.Itoa(idx)) + }) + + require.NoError(t, err, "caller %d: request should not error", idx) + defer resp.Body.Close() + + hash := e2e.HashBody(resp) + + bodies[idx] = hash + codes[idx] = resp.StatusCode + xCaches[idx] = resp.Header.Get("X-Cache") + }(i) + } + + time.Sleep(10 * time.Millisecond) + close(start) + wg.Wait() + + // Verify only one origin call — ObjectFlightGroup collapsed all 5. + assert.Equal(t, int32(1), originCallCount.Load(), + "object flight should collapse concurrent full-MISS requests") + + // All callers must receive identical response bodies. + for i := 0; i < N; i++ { + assert.Equal(t, http.StatusOK, codes[i], "caller %d: status mismatch", i) + assert.Equal(t, f.MD5, bodies[i], "caller %d: body mismatch", i) + } + + // At least one should be MISS (the first), the rest may be HIT + // depending on whether they re-looked up metadata in time. + hasMiss := false + for _, c := range xCaches { + if c != "" { + hasMiss = hasMiss || strings.Contains(c, "MISS") + } + } + assert.True(t, hasMiss, "at least one response should report MISS") + }) + + t.Run("PURGE", func(t *testing.T) { + e2e.Purge(t, "http://objflight.example.com/of/object/collapse.bin") + }) + + t.Run("test Collapsed Forwarding ObjectFlight Sequential", func(t *testing.T) { + var originCallCount atomic.Int32 + + originCallCount.Store(0) + + case1 := e2e.New("http://objflight.example.com/of/object/sequential.bin", e2e.RespCallbackFile(f, func(w http.ResponseWriter, r *http.Request) { + originCallCount.Add(1) + + w.Header().Set("Cache-Control", "max-age=10") + w.Header().Set("ETag", "obj-flight-etag") + })) + defer case1.Close() + + const N = 3 + + bodies := make([]string, N) + + // Sequential requests should not be collapsed. + for i := 0; i < N; i++ { + t.Logf("starting request %d", i) + + resp, err := case1.Do(func(r *http.Request) { + r.Header.Set("X-Request-Idx", strconv.Itoa(i)) + }) + + require.NoError(t, err, "request %d should not error", i) + bodies[i] = e2e.HashBody(resp) + + resp.Body.Close() + } + + assert.Equal(t, int32(1), originCallCount.Load(), + "object flight should not collapse sequential requests") + + for i := 0; i < N; i++ { + assert.Equal(t, f.MD5, bodies[i], "request %d body-hash mismatch", i) + } + + }) + + t.Run("PURGE", func(t *testing.T) { + e2e.Purge(t, "http://objflight.example.com/of/object/sequential.bin") + }) + + t.Run("test Collapsed Forwarding ObjectFlight KeyIsolation", func(t *testing.T) { + var originCallCount atomic.Int32 + + case1 := e2e.New("http://keys.example.com/of/object/", func(w http.ResponseWriter, r *http.Request) { + originCallCount.Add(1) + time.Sleep(80 * time.Millisecond) + + w.Header().Set("Cache-Control", "max-age=10") + w.WriteHeader(http.StatusOK) + + _, _ = w.Write([]byte(r.URL.Path)) + }) + defer case1.Close() + + keys := []string{"key-a", "key-b", "key-c"} + + var wg sync.WaitGroup + start := make(chan struct{}, len(keys)) + + for _, key := range keys { + wg.Add(1) + go func(k string) { + defer wg.Done() + <-start + + resp, err := case1.Do(func(r *http.Request) { + r.URL.Path += k + t.Logf("Requesting key: %s", k) + }) + + require.NoError(t, err) + buf, _ := io.ReadAll(resp.Body) + resp.Body.Close() + + assert.Equal(t, "/of/object/"+k, string(buf), "response body should match requested key") + }(key) + } + + time.Sleep(10 * time.Millisecond) + close(start) + wg.Wait() + + // Three different keys → three independent origin calls. + assert.Equal(t, int32(3), originCallCount.Load(), + "different URLs should have independent object flights") + + }) + + t.Run("PURGE", func(t *testing.T) { + e2e.PurgeMethod(t, "http://keys.example.com/of/object/", true) + }) + +} + +func TestCollapsedForwardingChunkFlight(t *testing.T) { + file := e2e.GenFile(t, 3<<20) // 3MB → 6 chunks at 512KB + + t.Run("test Collapsed Forwarding ChunkFlight", func(t *testing.T) { + var originCallCount atomic.Int32 + + case1 := e2e.New("http://chunkflight.example.com/cf/chunk/collapse.bin", e2e.RespCallbackFile(file, func(w http.ResponseWriter, r *http.Request) { + originCallCount.Add(1) + + w.Header().Set("Cache-Control", "max-age=30") + w.Header().Set("ETag", file.MD5) + })) + defer case1.Close() + + resp, err := case1.Do(func(r *http.Request) { + r.Header.Set("Range", "bytes=0-524287") + }) + + require.NoError(t, err) + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + require.Equal(t, http.StatusPartialContent, resp.StatusCode) + + // Give storage time to finish writing indexdb metadata. + time.Sleep(300 * time.Millisecond) + + // Phase 2 — concurrent requests for a range that needs missing chunks. + originCallCount.Store(0) + + const N = 3 + var wg sync.WaitGroup + start := make(chan struct{}, N) + bodies := make([]string, N) + codes := make([]int, N) + xCaches := make([]string, N) + ranges := make([]string, N) + cls := make([]string, N) + + for i := 0; i < N; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-start + + t.Logf("started for caller %d", idx) + + resp1, err1 := case1.Do(func(r *http.Request) { + r.Header.Set("Range", "bytes=524288-2097151") + }) + + require.NoError(t, err1, "caller %d: request should not error", idx) + defer resp1.Body.Close() + + hashStr := e2e.HashBody(resp1) + + bodies[idx] = string(hashStr) + codes[idx] = resp1.StatusCode + xCaches[idx] = resp1.Header.Get("X-Cache") + ranges[idx] = resp1.Header.Get("Content-Range") + cls[idx] = strconv.Itoa(int(resp1.ContentLength)) + }(i) + } + + time.Sleep(10 * time.Millisecond) + close(start) + wg.Wait() + + t.Logf("origin call count for concurrent phase: %d", originCallCount.Load()) + + // All callers must receive correct 206 responses. + for i := 0; i < N; i++ { + assert.Equal(t, http.StatusPartialContent, codes[i], "caller %d: status mismatch", i) + assert.NotEmpty(t, bodies[i], "caller %d: body should not be empty", i) + + t.Logf("caller %d: hash: %s range: %s X-Cache: %s Content-Length: %s", i, bodies[i], ranges[i], xCaches[i], cls[i]) + } + + // Verify body correctness: compare against the source file. + expected := e2e.HashFile(file.Path, 524288, 2097151-524288+1) + for i := 0; i < N; i++ { + actual := bodies[i] + assert.Equal(t, expected, actual, "caller %d: body hash mismatch", i) + } + + // The concurrent chunk fetch for the missing range must be collapsed. + assert.Equal(t, int32(1), originCallCount.Load(), + "chunk flight should collapse concurrent chunk fetches to 1 origin call") + }) + + t.Run("test Collapsed Forwarding ChunkFlight KeyIsolation", func(t *testing.T) { + var originCallCount atomic.Int32 + + case1 := e2e.New("http://chunkflight.example.com/cf/chunk/keys.bin", e2e.RespCallbackFile(file, func(w http.ResponseWriter, r *http.Request) { + t.Logf("process req %s, range %s", r.Header.Get("X-Request-Id"), r.Header.Get("Range")) + originCallCount.Add(1) + + w.Header().Set("Cache-Control", "max-age=30") + w.Header().Set("ETag", file.MD5) + })) + defer case1.Close() + + // Phase 1 — cache only the middle chunk (chunk 1, bytes 524288-1048575). + resp, err := case1.Do(func(r *http.Request) { + r.Header.Set("X-Request-Id", "0") + r.Header.Set("Range", "bytes=524288-1048575") + }) + + require.NoError(t, err) + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + require.Equal(t, http.StatusPartialContent, resp.StatusCode) + + originCallCount.Store(0) + + time.Sleep(300 * time.Millisecond) + + // Phase 2 — request two different missing ranges concurrently. + // Range A: bytes=0-524287 (needs chunk 0, not cached) + // Range B: bytes=1048576-2097151 (needs chunk 2+, not cached) + // With range union, these two ranges collapse into one origin + // fetch for the union range bytes=0-2097151. + var wg sync.WaitGroup + + ranges := []string{"bytes=0-524287", "bytes=1048576-2097151"} + start := make(chan struct{}, len(ranges)) + errs := make([]error, len(ranges)) + xCaches := make([]string, len(ranges)) + + for i, rng := range ranges { + wg.Add(1) + go func(idx int, rng string) { + defer wg.Done() + <-start + + resp2, e := case1.Do(func(r *http.Request) { + t.Logf("started for caller %d with range %s", idx+1, rng) + + r.Header.Set("X-Request-Id", strconv.Itoa(idx+1)) + r.Header.Set("Range", rng) + }) + if e != nil { + t.Logf("caller %d: request error: %v", idx+1, e) + errs[idx] = e + return + } + + io.Copy(io.Discard, resp2.Body) + resp2.Body.Close() + + xCaches[idx] = resp2.Header.Get("X-Cache") + }(i, rng) + } + + time.Sleep(10 * time.Millisecond) + close(start) + wg.Wait() + + for i, e := range errs { + assert.NoError(t, e, "request %d should not error", i) + t.Logf("caller %d X-Cache: %s", i+1, xCaches[i]) + } + + // With range union, the two different ranges are collapsed into + // one origin fetch covering the union bytes=0-2097151. + assert.Equal(t, int32(1), originCallCount.Load(), + "different byte ranges should be collapsed via range union") + + }) + + t.Run("test Collapsed Forwarding ChunkFlight RangeUnion", func(t *testing.T) { + var originCallCount atomic.Int32 + + case1 := e2e.New("http://chunkflight.example.com/cf/chunk/union.bin", e2e.RespCallbackFile(file, func(w http.ResponseWriter, r *http.Request) { + t.Logf("process req %s, range %s", r.Header.Get("X-Request-Id"), r.Header.Get("Range")) + originCallCount.Add(1) + + w.Header().Set("Cache-Control", "max-age=30") + w.Header().Set("ETag", file.MD5) + })) + defer case1.Close() + + // Phase 1 — cache only the middle chunk (chunk 1, bytes 524288-1048575). + resp, err := case1.Do(func(r *http.Request) { + r.Header.Set("X-Request-Id", "0") + r.Header.Set("Range", "bytes=524288-1048575") + }) + + require.NoError(t, err) + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + require.Equal(t, http.StatusPartialContent, resp.StatusCode) + + originCallCount.Store(0) + + time.Sleep(300 * time.Millisecond) + + // Phase 2 — three concurrent requests for different ranges that + // need missing chunks. With range union, all three share a single + // origin fetch whose range covers the union of all requested ranges. + type reqSpec struct { + id string + rng string + from int + length int + } + specs := []reqSpec{ + {"1", "bytes=0-524287", 0, 524288}, + {"2", "bytes=0-1048575", 0, 1048576}, + {"3", "bytes=1048576-2097151", 1048576, 2097151 - 1048576 + 1}, + } + + var wg sync.WaitGroup + start := make(chan struct{}) + errs := make([]error, len(specs)) + hashes := make([]string, len(specs)) + xCaches := make([]string, len(specs)) + + for i, spec := range specs { + wg.Add(1) + go func(idx int, s reqSpec) { + defer wg.Done() + <-start + + resp2, e := case1.Do(func(r *http.Request) { + r.Header.Set("X-Request-Id", s.id) + r.Header.Set("Range", s.rng) + }) + if e != nil { + t.Logf("caller %s: request error: %v", s.id, e) + errs[idx] = e + return + } + + hashes[idx] = e2e.HashBody(resp2) + resp2.Body.Close() + + xCaches[idx] = resp2.Header.Get("X-Cache") + }(i, spec) + } + + time.Sleep(10 * time.Millisecond) + close(start) + wg.Wait() + + for i, e := range errs { + assert.NoError(t, e, "request %d should not error", i) + t.Logf("caller %s X-Cache: %s hash: %s", specs[i].id, xCaches[i], hashes[i]) + } + + // All three different ranges should collapse into one origin call via + // automatic range union. + assert.Equal(t, int32(1), originCallCount.Load(), + "range union should collapse different ranges into 1 origin call") + + // Each caller must receive the correct bytes for its range. + for i, spec := range specs { + expected := e2e.HashFile(file.Path, spec.from, spec.length) + assert.Equal(t, expected, hashes[i], + "caller %s (range %s): body hash mismatch", spec.id, spec.rng) + } + }) + + t.Run("PURGE", func(t *testing.T) { + e2e.SetDump(true) + e2e.Purge(t, "http://chunkflight.example.com/cf/chunk/collapse.bin") + e2e.Purge(t, "http://chunkflight.example.com/cf/chunk/keys.bin") + e2e.Purge(t, "http://chunkflight.example.com/cf/chunk/union.bin") + }) +} diff --git a/tests/all-features/filechanged/contentlength_test.go b/tests/all-features/filechanged/contentlength_test.go index 77773e1..aeb621d 100644 --- a/tests/all-features/filechanged/contentlength_test.go +++ b/tests/all-features/filechanged/contentlength_test.go @@ -109,7 +109,7 @@ func TestContengLenChanged(t *testing.T) { } func TestContentLenShorter(t *testing.T) { - f := e2e.GenFile(t, 1<<20) + f := e2e.GenFile(t, 1048576) f2 := e2e.GenFile(t, 1048570) t.Run("test content-length shorter old file", func(t *testing.T) { @@ -121,6 +121,7 @@ func TestContentLenShorter(t *testing.T) { rr := xhttp.NewRequestRange(1, 400000) resp, err := case1.Do(func(r *http.Request) { + r.Header.Set("X-Request-Id", t.Name()) r.Header.Set("Range", rr.String()) }) @@ -147,6 +148,7 @@ func TestContentLenShorter(t *testing.T) { rr := xhttp.NewRequestRange(600000, 0) // bytes=600000- resp, err := case1.Do(func(r *http.Request) { + r.Header.Set("X-Request-Id", t.Name()) r.Header.Set("Range", rr.String()) }) @@ -172,6 +174,7 @@ func TestContentLenShorter(t *testing.T) { rr := xhttp.NewRequestRange(0, 0) resp, err := case1.Do(func(r *http.Request) { + r.Header.Set("X-Request-Id", t.Name()) r.Header.Set("Range", rr.String()) }) diff --git a/tests/config.test.yaml b/tests/config.test.yaml index 3970bd5..2fc5f28 100644 --- a/tests/config.test.yaml +++ b/tests/config.test.yaml @@ -74,10 +74,6 @@ plugin: - "localhost" - "@" log_path: ./logs/purge.log - - name: watchdog - options: - check_interval: 1s - timeout_threshold: 3 - name: verifier options: endpoint: https://crc-svc.omalloc.com/receive