Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
421 changes: 421 additions & 0 deletions docs/storage-select-flow.md

Large diffs are not rendered by default.

22 changes: 16 additions & 6 deletions pkg/e2e/e2e.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,22 +88,24 @@ func (e *E2E) Do(rewrite func(r *http.Request)) (*http.Response, error) {
// wait for a while to let the server ready
time.Sleep(time.Millisecond * 100)

rewrite(e.req)
nr := e.req.Clone(context.Background())

method := e.req.Method
rewrite(nr)

e.req.Header.Set(protocol.InternalUpstreamAddr, e.ts.Listener.Addr().String())
method := nr.Method

nr.Header.Set(protocol.InternalUpstreamAddr, e.ts.Listener.Addr().String())

if dumpReq.Load() && method != "PURGE" {
DumpReq(e.req)
DumpReq(nr)
}

if manual.Load() {
fmt.Printf("manual mode wait 20s, src addr %q\n", e.ts.Listener.Addr().String())
time.Sleep(time.Second * 20)
}

resp, err := e.cs.Do(e.req)
resp, err := e.cs.Do(nr)
e.resp = resp
e.err = err

Expand Down Expand Up @@ -172,11 +174,15 @@ func DumpResp(resp *http.Response) {
fmt.Println(string(buf))
}

func Purge(t *testing.T, url string) {
func PurgeMethod(t *testing.T, url string, dir bool) {
resp, err := New(url, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadGateway)
}).Do(func(r *http.Request) {
r.Method = "PURGE"
r.Header.Set("Purge-Type", "file,hard")
if dir {
r.Header.Set("Purge-Type", "dir,hard")
}
})

assert.NoError(t, err, "purge should not error")
Expand All @@ -187,3 +193,7 @@ func Purge(t *testing.T, url string) {
t.Logf("Purge %s success", url)
}
}

func Purge(t *testing.T, url string) {
PurgeMethod(t, url, false)
}
145 changes: 99 additions & 46 deletions server/middleware/caching/caching.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"time"

"github.com/kelindar/bitmap"
"github.com/prometheus/client_golang/prometheus"

"github.com/omalloc/tavern/api/defined/v1/event"
configv1 "github.com/omalloc/tavern/api/defined/v1/middleware"
Expand All @@ -25,8 +26,6 @@ import (
"github.com/omalloc/tavern/proxy"
"github.com/omalloc/tavern/server/middleware"
storagev1 "github.com/omalloc/tavern/storage"

"github.com/prometheus/client_golang/prometheus"
)

const BYPASS = "BYPASS"
Expand Down Expand Up @@ -114,6 +113,12 @@ func Middleware(c *configv1.Middleware) (middleware.Middleware, func(), error) {
proxyClient := proxy.GetProxy()
store := storagev1.Current()

// Flight groups for collapsed forwarding at object and chunk level.
// These mirror Squid's collapsed_forwarding: one origin request
// serves many waiting clients.
objectFlight := &ObjectFlightGroup{}
chunkFlight := &ChunkFlightGroup{}

return middleware.RoundTripperFunc(func(req *http.Request) (resp *http.Response, err error) {
// only cache GET/HEAD request
if req.Method != http.MethodGet && req.Method != http.MethodHead {
Expand All @@ -132,9 +137,20 @@ func Middleware(c *configv1.Middleware) (middleware.Middleware, func(), error) {
// cachingPool.Put(caching)
//}()

// Wire up chunk-level collapsed forwarding.
caching.chunkFlight = chunkFlight

// Increment cacheRequestTotal on every response path (BYPASS,
// HIT, MISS-collapsed, MISS-direct). Placed as a defer so it
// fires after cacheStatus is finalized.
defer func() {
cacheRequestTotal.WithLabelValues(caching.cacheStatus.String(), caching.bucket.StoreType()).Inc()
}()

// err to BYPASS caching
if err != nil {
caching.log.Warnf("Precache processor failed: %v, BYPASS", err)
caching.log.Warnf("Precache processor failed: %v BYPASS", err)
caching.cacheStatus = storage.BYPASS
resp, err = caching.doProxy(req, false) // do reverse proxy
if err != nil {
return nil, err
Expand All @@ -144,57 +160,70 @@ func Middleware(c *configv1.Middleware) (middleware.Middleware, func(), error) {
// set cache-staus header BYPASS
resp.Header.Set(protocol.ProtocolCacheStatusKey, BYPASS)
}
cacheRequestTotal.WithLabelValues(storage.BYPASS.String(), caching.bucket.StoreType()).Inc()
return
}

// cache HIT
if caching.hit {
caching.cacheStatus = storage.CacheHit

rng, err1 := xhttp.SingleRange(req.Header.Get("Range"), caching.md.Size)
if err1 != nil {
// 无效 Range 处理
headers := make(http.Header)
xhttp.CopyHeader(caching.md.Headers, headers)
headers.Set("Content-Range", fmt.Sprintf("bytes */%d", caching.md.Size))
cacheRequestTotal.WithLabelValues(caching.cacheStatus.String(), caching.bucket.StoreType()).Inc()
return nil, xhttp.NewBizError(http.StatusRequestedRangeNotSatisfiable, headers)
}

// mark cache status with Range requests.
caching.markCacheStatus(rng.Start, rng.End)
return caching.respondFromCache(req)
}

// find file seek(start, end)
resp, err = caching.lazilyRespond(req, rng.Start, rng.End)
if err != nil {
// fd leak
closeBody(resp)
cacheRequestTotal.WithLabelValues(caching.cacheStatus.String(), caching.bucket.StoreType()).Inc()
return nil, err
// full MISS — use object-level collapsed forwarding so that
// concurrent requests for the same cache object share one
// origin fetch (Squid-style collapsed_forwarding).
if opts.CollapsedRequest {
flightResp, _, flightErr := objectFlight.Do(caching.id.HashStr(), opts.CollapsedRequestWaitTimeout.AsDuration(), func() (*http.Response, error) {
r, e := caching.doProxy(req, false)
if e != nil {
return nil, e
}
return processor.postCacheProcessor(caching, req, r)
})
if flightErr != nil {
return nil, flightErr
}

// response now
resp, err = caching.processor.postCacheProcessor(caching, req, resp)
cacheRequestTotal.WithLabelValues(caching.cacheStatus.String(), caching.bucket.StoreType()).Inc()
resp = flightResp
return
}

// full MISS
// full MISS (collapsed forwarding disabled)
resp, err = caching.doProxy(req, false)
if err != nil {
cacheRequestTotal.WithLabelValues(caching.cacheStatus.String(), caching.bucket.StoreType()).Inc()
return nil, err
}

resp, err = processor.postCacheProcessor(caching, req, resp)
cacheRequestTotal.WithLabelValues(caching.cacheStatus.String(), caching.bucket.StoreType()).Inc()
return
})

}, middleware.EmptyCleanup, nil
}

// respondFromCache assembles a response from cached chunks for a cache HIT.
// It parses the Range header, builds a multi-part reader from disk, and
// runs post-cache processing (headers, cache status, store).
func (c *Caching) respondFromCache(req *http.Request) (*http.Response, error) {
c.cacheStatus = storage.CacheHit

rng, err := xhttp.SingleRange(req.Header.Get("Range"), c.md.Size)
if err != nil {
headers := make(http.Header)
xhttp.CopyHeader(c.md.Headers, headers)
headers.Set("Content-Range", fmt.Sprintf("bytes */%d", c.md.Size))
return nil, xhttp.NewBizError(http.StatusRequestedRangeNotSatisfiable, headers)
}

c.markCacheStatus(rng.Start, rng.End)

resp, err := c.lazilyRespond(req, rng.Start, rng.End)
if err != nil {
closeBody(resp)
return nil, err
}

return c.processor.postCacheProcessor(c, req, resp)
}

func (c *Caching) lazilyRespond(req *http.Request, start, end int64) (*http.Response, error) {
// 这里通过缓存的块大小来计算,而不是配置默认的 SliceSize
// 这样已缓存的对象可以使用原来的配置块大小,不受配置文件变更影响
Expand Down Expand Up @@ -261,39 +290,63 @@ func (c *Caching) lazilyRespond(req *http.Request, start, end int64) (*http.Resp
func (c *Caching) getUpstreamReader(fromByte, toByte uint64, async bool) (io.ReadCloser, error) {
// get from origin request header
rawRange := c.req.Header.Get("Range")
newRange := fmt.Sprintf("bytes=%d-%d", fromByte, toByte)
req := c.req.Clone(context.Background())
req.Header.Set("Range", newRange)
// add request-id [range]
// req.Header.Set("X-Request-ID", fmt.Sprintf("%s-%d", req.Header.Get(appctx.ProtocolRequestIDKey), fromByte)) // 附加 Request-ID suffix

// remove all internal header
req.Header.Del(protocol.ProtocolCacheStatusKey)
// doSubRequest is parameterized by the union range so the flight group
// can expand the range to cover multiple concurrent callers.
doSubRequest := func(unionFrom, unionTo uint64) (*http.Response, error) {
newRange := fmt.Sprintf("bytes=%d-%d", unionFrom, unionTo)
subReq := c.req.Clone(context.Background())
subReq.Header.Set("Range", newRange)
subReq.Header.Del(protocol.ProtocolCacheStatusKey)

doSubRequest := func() (*http.Response, error) {
now := time.Now()
c.log.Debugf("getUpstreamReader doProxy[chunk]: begin: %s, rawRange: %s, newRange: %s", now, rawRange, newRange)
resp, err := c.doProxy(req, true)
resp, err := c.doProxy(subReq, true)
c.log.Infof("getUpstreamReader doProxy[chunk]: timeCost: %s, rawRange: %s, newRange: %s", time.Since(now), rawRange, newRange)
if err != nil {
closeBody(resp)
return nil, err
}
// 部分命中
c.cacheStatus = storage.CachePartHit
// 发起的是 206 请求,但是返回的非 206
// 206 Partial Content is expected for range requests.
// If we get a different status code, it may indicate an issue with the upstream response.
if resp.StatusCode != http.StatusPartialContent {
c.log.Warnf("getUpstreamReader doProxy[chunk]: status code: %d, bod size: %d", resp.StatusCode, resp.ContentLength)
return resp, xhttp.NewBizError(resp.StatusCode, resp.Header)
}
return resp, nil
}

// Chunk-level collapsed forwarding: if another goroutine is already
// fetching (possibly a different) byte range for this object, wait
// and share the union response body (io.MultiWriter fan-out + RangeReader
// trimming). This mirrors Squid's collapsed_forwarding at the chunk
// level, with automatic range union.
if c.chunkFlight != nil && c.opt.CollapsedRequest && c.id != nil {
reader, _, err := c.chunkFlight.Do(c.id.HashStr(), fromByte, toByte,
c.opt.CollapsedRequestWaitTimeout.AsDuration(), doSubRequest)
// Both leader and shared callers are partial hits: at least one
// chunk was missing and is being fetched from origin.

// c.cacheStatus = storage.CachePartMiss
// if shared {
// c.cacheStatus = storage.CachePartHit
// }
return reader, err
}

// Any path that reaches getUpstreamReader is a partial hit: at least
// one chunk was missing from cache and is now being fetched from
// origin. Set the status here for the async and sync paths (the
// chunk-flight path above already sets it before returning).
// c.cacheStatus = storage.CachePartMiss

if async {
return iobuf.AsyncReadCloser(doSubRequest), nil
return iobuf.AsyncReadCloser(func() (*http.Response, error) {
return doSubRequest(fromByte, toByte)
}), nil
}

resp, err := doSubRequest()
resp, err := doSubRequest(fromByte, toByte)
if resp != nil {
return resp.Body, err
}
Expand Down
16 changes: 15 additions & 1 deletion server/middleware/caching/caching_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
package caching

import "testing"
import (
"net/http"
"strings"
"testing"
)

func BenchmarkWithPooling(b *testing.B) {
for i := 0; i < b.N; i++ {
c := cachingPool.Get().(*Caching)
c.reset()
}
}

func TestObjectFlight_PanicRecovery(t *testing.T) {
g := &ObjectFlightGroup{}
_, _, err := g.Do("key", 0, func() (*http.Response, error) {
panic("boom")
})
if err == nil || !strings.Contains(err.Error(), "panic") {
t.Fatalf("expected panic error, got %v", err)
}
}
Loading
Loading