diff --git a/fingers/cache_test.go b/fingers/cache_test.go new file mode 100644 index 0000000..fe19055 --- /dev/null +++ b/fingers/cache_test.go @@ -0,0 +1,219 @@ +package fingers + +import ( + "fmt" + "io" + "net/http" + "strings" + "sync" + "testing" +) + +func TestCachingSenderOnlyCachesSuccessfulResponses(t *testing.T) { + calls := 0 + sender := cachingSender(func(_ []byte) ([]byte, bool) { + calls++ + if calls == 1 { + return nil, false + } + return []byte("ok"), true + }) + + if _, ok := sender([]byte("/probe")); ok { + t.Fatal("first failed response should not be reported as success") + } + if calls != 1 { + t.Fatalf("calls after first attempt = %d, want 1", calls) + } + + resp, ok := sender([]byte("/probe")) + if !ok || string(resp) != "ok" { + t.Fatalf("second response = %q, %v; want ok,true", resp, ok) + } + if calls != 2 { + t.Fatalf("failed response was cached; calls = %d, want 2", calls) + } + + resp, ok = sender([]byte("/probe")) + if !ok || string(resp) != "ok" { + t.Fatalf("cached response = %q, %v; want ok,true", resp, ok) + } + if calls != 2 { + t.Fatalf("successful response was not cached; calls = %d, want 2", calls) + } +} + +func TestPathCachedTransportSeparatesRequestVariants(t *testing.T) { + base := &recordingRoundTripper{} + transport := &pathCachedTransport{base: base, cache: make(map[string]*pathCachedEntry)} + + first := mustCachedRequest(t, http.MethodGet, "http://example.test/probe?a=1", "", "one") + body := mustRoundTripBody(t, transport, first) + if body != "GET /probe?a=1 one #1" { + t.Fatalf("first body = %q", body) + } + + firstAgain := mustCachedRequest(t, http.MethodGet, "http://example.test/probe?a=1", "", "one") + body = mustRoundTripBody(t, transport, firstAgain) + if body != "GET /probe?a=1 one #1" { + t.Fatalf("cached body = %q", body) + } + + differentQuery := mustCachedRequest(t, http.MethodGet, "http://example.test/probe?a=2", "", "one") + body = mustRoundTripBody(t, transport, differentQuery) + if body != "GET /probe?a=2 one #2" { + t.Fatalf("query-variant body = %q", body) + } + + differentHeader := mustCachedRequest(t, http.MethodGet, "http://example.test/probe?a=1", "", "two") + body = mustRoundTripBody(t, transport, differentHeader) + if body != "GET /probe?a=1 two #3" { + t.Fatalf("header-variant body = %q", body) + } + + postNoBody := mustCachedRequest(t, http.MethodPost, "http://example.test/probe?a=1", "", "one") + body = mustRoundTripBody(t, transport, postNoBody) + if body != "POST /probe?a=1 one #4" { + t.Fatalf("method-variant body = %q", body) + } + + if base.calls != 4 { + t.Fatalf("base calls = %d, want 4", base.calls) + } +} + +func TestPathCachedTransportDoesNotCacheRequestsWithBody(t *testing.T) { + base := &recordingRoundTripper{} + transport := &pathCachedTransport{base: base, cache: make(map[string]*pathCachedEntry)} + + first := mustCachedRequest(t, http.MethodPost, "http://example.test/probe", "x=1", "") + if body := mustRoundTripBody(t, transport, first); body != "POST /probe #1" { + t.Fatalf("first body request = %q", body) + } + second := mustCachedRequest(t, http.MethodPost, "http://example.test/probe", "x=1", "") + if body := mustRoundTripBody(t, transport, second); body != "POST /probe #2" { + t.Fatalf("second body request = %q", body) + } + if base.calls != 2 { + t.Fatalf("base calls = %d, want 2", base.calls) + } +} + +func TestPathCachedTransportReturnsIndependentHeaders(t *testing.T) { + base := &recordingRoundTripper{} + transport := &pathCachedTransport{base: base, cache: make(map[string]*pathCachedEntry)} + + req := mustCachedRequest(t, http.MethodGet, "http://example.test/probe", "", "") + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + resp.Header.Set("X-Call", "mutated") + _ = resp.Body.Close() + + resp, err = transport.RoundTrip(mustCachedRequest(t, http.MethodGet, "http://example.test/probe", "", "")) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if got := resp.Header.Get("X-Call"); got != "1" { + t.Fatalf("cached response header = %q, want 1", got) + } +} + +func TestPathCachedTransportConcurrentAccess(t *testing.T) { + base := &recordingRoundTripper{} + transport := &pathCachedTransport{base: base} + + var wg sync.WaitGroup + errCh := make(chan error, 20) + for i := 0; i < 20; i++ { + i := i + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 10; j++ { + req, err := http.NewRequest( + http.MethodGet, + fmt.Sprintf("http://example.test/probe?id=%d", i%5), + nil, + ) + if err != nil { + errCh <- err + return + } + req.Header.Set("X-Mode", fmt.Sprintf("mode-%d", j%3)) + resp, err := transport.RoundTrip(req) + if err != nil { + errCh <- err + return + } + _, err = io.ReadAll(resp.Body) + _ = resp.Body.Close() + if err != nil { + errCh <- err + return + } + } + }() + } + wg.Wait() + close(errCh) + for err := range errCh { + if err != nil { + t.Fatal(err) + } + } +} + +func mustCachedRequest(t *testing.T, method, rawURL, body, mode string) *http.Request { + t.Helper() + var reader io.Reader + if body != "" { + reader = strings.NewReader(body) + } + req, err := http.NewRequest(method, rawURL, reader) + if err != nil { + t.Fatal(err) + } + if mode != "" { + req.Header.Set("X-Mode", mode) + } + return req +} + +func mustRoundTripBody(t *testing.T, rt http.RoundTripper, req *http.Request) string { + t.Helper() + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + return string(body) +} + +type recordingRoundTripper struct { + mu sync.Mutex + calls int +} + +func (rt *recordingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + rt.mu.Lock() + rt.calls++ + call := rt.calls + rt.mu.Unlock() + + body := fmt.Sprintf("%s %s %s #%d", req.Method, req.URL.RequestURI(), req.Header.Get("X-Mode"), call) + return &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Proto: "HTTP/1.1", + Header: http.Header{"X-Call": {fmt.Sprint(call)}}, + Body: io.NopCloser(strings.NewReader(body)), + Request: req, + }, nil +} diff --git a/fingers/context_nil_test.go b/fingers/context_nil_test.go new file mode 100644 index 0000000..166ed87 --- /dev/null +++ b/fingers/context_nil_test.go @@ -0,0 +1,52 @@ +package fingers + +import ( + "context" + "testing" + + "github.com/chainreactors/sdk/pkg/types" +) + +func TestContextNormalizesNilContext(t *testing.T) { + if NewContext().WithContext(nil).Context() == nil { + t.Fatal("WithContext(nil) returned nil context") + } + + var ctx *Context + if ctx.Context() == nil { + t.Fatal("nil receiver Context returned nil") + } + if ctx.WithContext(nil).Context() == nil { + t.Fatal("nil receiver WithContext(nil) returned nil context") + } +} + +func TestContextPreservesCancelledContext(t *testing.T) { + base, cancel := context.WithCancel(context.Background()) + cancel() + + if err := NewContext().WithContext(base).Context().Err(); err == nil { + t.Fatal("cancelled context was not preserved") + } +} + +func TestExecuteHandlesTypedNilContext(t *testing.T) { + eng := newDetailTestEngine(t, NewConfig(), &types.Finger{ + Name: "typed-nil-app", + Protocol: "http", + Rules: types.FingerRules{{ + Regexps: &types.FingerRegexps{Body: []string{"TypedNilMarker"}}, + }}, + }) + + var ctx *Context + resultCh, err := eng.Execute(ctx, NewMatchTask(rawHTTP("TypedNilMarker"))) + if err != nil { + t.Fatalf("execute with typed nil context: %v", err) + } + + result := <-resultCh + if result == nil || !result.Success() { + t.Fatalf("expected successful result, got %#v err=%v", result, result.Error()) + } +} diff --git a/fingers/engine.go b/fingers/engine.go index f9db137..abd7f87 100644 --- a/fingers/engine.go +++ b/fingers/engine.go @@ -1,13 +1,18 @@ package fingers import ( + "bytes" "context" "fmt" + "io" "net" "net/http" "net/url" + "sort" "strings" + "sync" "time" + "unicode" "encoding/json" @@ -477,18 +482,21 @@ func (e *Engine) ActiveMatch(baseURL string, level int, transport http.RoundTrip // 1. native fingers 引擎 — 通过 Sender 回调发包 if fEngine := e.engine.Fingers(); fEngine != nil { client := &http.Client{Transport: transport} - sender := fingersEngine.Sender(func(data []byte) ([]byte, bool) { + sender := cachingSender(fingersEngine.Sender(func(data []byte) ([]byte, bool) { sendPath := string(data) if sendPath == "" { sendPath = "/" } + if !strings.HasPrefix(sendPath, "/") { + sendPath = "/" + sendPath + } fullURL := strings.TrimSuffix(baseURL, "/") + sendPath req, err := http.NewRequest(http.MethodGet, fullURL, nil) if err != nil { return nil, false } - req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36") + sdkhttpx.ApplyBrowserProfileHeaders(req.Header) resp, err := client.Do(req) if err != nil { @@ -497,7 +505,7 @@ func (e *Engine) ActiveMatch(baseURL string, level int, transport http.RoundTrip defer resp.Body.Close() return httputils.ReadRaw(resp), true - }) + })) for _, finger := range fEngine.HTTPFingers { frame, vuln, ok := finger.ActiveMatch(level, sender) @@ -559,6 +567,310 @@ func (e *Engine) HTTPMatch(ctx *Context, urls []string) ([]*TargetResult, error) return results, nil } +// HTTPFocusedMatch runs active HTTP template matching with a path-derived +// template subset. It is intended for scanner enrichment paths where the +// caller already has a live URL such as /mas/ or /smartbi/ and must avoid +// executing every template in a large remote corpus. +func (e *Engine) HTTPFocusedMatch(ctx *Context, urls []string) ([]*TargetResult, error) { + if e == nil || e.config == nil || e.config.FullFingers.Len() == 0 { + return nil, nil + } + + var results []*TargetResult + for _, rawURL := range urls { + tokens := uniqueActiveTemplateTokens(append(activeTemplatePathTokens(rawURL), activeTemplateProbeTokens(ctx, rawURL)...)) + if len(tokens) == 0 { + continue + } + focused := e.config.FullFingers.Filter(func(item *FullFinger) bool { + if item == nil || item.Template == nil || item.RawContent == "" { + return false + } + return templateContentMatchesAnyPathToken(item.RawContent, tokens) + }) + if focused.Len() == 0 { + continue + } + focusedEngine, err := NewEngineWithFingers(focused) + if err != nil { + return results, err + } + matches, err := focusedEngine.HTTPMatch(ctx, []string{rawURL}) + if err != nil { + return results, err + } + results = append(results, matches...) + } + return results, nil +} + +func activeTemplateProbeTokens(ctx *Context, rawURL string) []string { + if ctx == nil { + return nil + } + req, err := http.NewRequestWithContext(ctx.Context(), http.MethodGet, rawURL, nil) + if err != nil { + return nil + } + sdkhttpx.ApplyBrowserProfileHeaders(req.Header) + + resp, err := ctx.GetClient().Do(req) + if err != nil { + return nil + } + defer resp.Body.Close() + + var samples []string + for _, key := range []string{"Server", "Content-Type", "X-Powered-By", "WWW-Authenticate", "Location"} { + if value := resp.Header.Get(key); value != "" { + samples = append(samples, value) + } + } + if body, err := io.ReadAll(io.LimitReader(resp.Body, 64*1024)); err == nil && len(body) > 0 { + samples = append(samples, string(body)) + } + return activeTemplateTokensFromText(strings.Join(samples, "\n")) +} + +func activeTemplatePathTokens(rawURL string) []string { + parsed, err := url.Parse(strings.TrimSpace(rawURL)) + if err != nil { + return nil + } + path := parsed.EscapedPath() + if path == "" || path == "/" { + return nil + } + if unescaped, err := url.PathUnescape(path); err == nil { + path = unescaped + } + + seen := make(map[string]struct{}) + var out []string + add := func(token string) { + token = strings.ToLower(strings.Trim(token, " \t\r\n/._-")) + if token == "" || activeTemplatePathTokenTooBroad(token) { + return + } + if _, ok := seen[token]; ok { + return + } + seen[token] = struct{}{} + out = append(out, token) + } + for _, part := range strings.Split(path, "/") { + add(part) + } + return out +} + +func activeTemplateTokensFromText(text string) []string { + seen := make(map[string]struct{}) + var out []string + var token strings.Builder + flush := func() { + value := strings.ToLower(strings.Trim(token.String(), " \t\r\n/._-")) + token.Reset() + if value == "" || activeTemplatePathTokenTooBroad(value) { + return + } + if _, ok := seen[value]; ok { + return + } + seen[value] = struct{}{} + out = append(out, value) + } + for _, r := range text { + if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '-' || r == '_' { + token.WriteRune(r) + continue + } + flush() + if len(out) >= 64 { + return out + } + } + flush() + if len(out) > 64 { + return out[:64] + } + return out +} + +func activeTemplatePathTokenTooBroad(token string) bool { + if len(token) < 3 { + return true + } + switch token { + case "api", "app", "apps", "assets", "console", "css", "dist", "home", "html", "index", "js", "login", "main", "portal", "public", "service", "services", "static", "web": + return true + default: + return false + } +} + +func uniqueActiveTemplateTokens(tokens []string) []string { + seen := make(map[string]struct{}, len(tokens)) + out := make([]string, 0, len(tokens)) + for _, token := range tokens { + token = strings.ToLower(strings.Trim(token, " \t\r\n/._-")) + if token == "" || activeTemplatePathTokenTooBroad(token) { + continue + } + if _, ok := seen[token]; ok { + continue + } + seen[token] = struct{}{} + out = append(out, token) + } + return out +} + +func templateContentMatchesAnyPathToken(rawContent string, tokens []string) bool { + raw := strings.ToLower(rawContent) + for _, token := range tokens { + if token == "" { + continue + } + if strings.Contains(raw, "/"+token+"/") || + strings.Contains(raw, "/"+token) || + strings.Contains(raw, token+"/") { + return true + } + if len(token) >= 5 && strings.Contains(raw, token) { + return true + } + } + return false +} + +// cachingSender wraps a fingersEngine.Sender with a send_data-level cache so +// that multiple fingers probing the same path share a single HTTP round-trip. +func cachingSender(sender fingersEngine.Sender) fingersEngine.Sender { + type cached struct { + resp []byte + } + m := make(map[string]cached) + return func(data []byte) ([]byte, bool) { + key := string(data) + if entry, found := m[key]; found { + return entry.resp, true + } + resp, ok := sender(data) + if ok { + m[key] = cached{resp: resp} + } + return resp, ok + } +} + +// pathCachedTransport wraps an http.RoundTripper with request-level response +// caching so that FingerPrintHub and Xray engines sharing the same instance +// avoid duplicate HTTP requests without conflating distinct probes. +type pathCachedTransport struct { + base http.RoundTripper + mu sync.Mutex + cache map[string]*pathCachedEntry +} + +type pathCachedEntry struct { + resp *http.Response + body []byte +} + +func (t *pathCachedTransport) RoundTrip(req *http.Request) (*http.Response, error) { + key, cacheable := pathCachedTransportKey(req) + if !cacheable { + return t.baseTransport().RoundTrip(req) + } + + t.mu.Lock() + if entry, ok := t.cache[key]; ok { + t.mu.Unlock() + resp := *entry.resp + resp.Header = entry.resp.Header.Clone() + resp.Body = io.NopCloser(bytes.NewReader(entry.body)) + resp.Request = req + return &resp, nil + } + t.mu.Unlock() + + resp, err := t.baseTransport().RoundTrip(req) + if err != nil { + return nil, err + } + + var body []byte + if resp.Body != nil { + body, err = io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return nil, err + } + } + + hdr := *resp + hdr.Body = nil + hdr.Header = resp.Header.Clone() + t.mu.Lock() + if t.cache == nil { + t.cache = make(map[string]*pathCachedEntry) + } + t.cache[key] = &pathCachedEntry{resp: &hdr, body: body} + t.mu.Unlock() + + resp.Body = io.NopCloser(bytes.NewReader(body)) + return resp, nil +} + +func (t *pathCachedTransport) baseTransport() http.RoundTripper { + if t.base != nil { + return t.base + } + return http.DefaultTransport +} + +func pathCachedTransportKey(req *http.Request) (string, bool) { + if req == nil || req.URL == nil { + return "", false + } + if req.Body != nil && req.Body != http.NoBody { + return "", false + } + + method := req.Method + if method == "" { + method = http.MethodGet + } + + var b strings.Builder + b.WriteString(method) + b.WriteByte(' ') + b.WriteString(req.URL.Scheme) + b.WriteString("://") + b.WriteString(req.URL.Host) + b.WriteString(req.URL.RequestURI()) + if req.Host != "" { + b.WriteString("\nhost: ") + b.WriteString(req.Host) + } + + keys := make([]string, 0, len(req.Header)) + for key := range req.Header { + keys = append(keys, key) + } + sort.Strings(keys) + for _, key := range keys { + values := req.Header.Values(key) + b.WriteByte('\n') + b.WriteString(http.CanonicalHeaderKey(key)) + b.WriteString(": ") + b.WriteString(strings.Join(values, "\x00")) + } + + return b.String(), true +} + // scanHTTPTarget 扫描单个 HTTP 目标,自动对所有注册引擎执行主动探测。 func (e *Engine) scanHTTPTarget(ctx *Context, url string, level int) *TargetResult { result := &TargetResult{ @@ -580,6 +892,8 @@ func (e *Engine) scanHTTPTarget(ctx *Context, url string, level int) *TargetResu if transport == nil { transport = http.DefaultTransport } + transport = wrapRedirectResolvingTransport(transport) + transport = &pathCachedTransport{base: transport, cache: make(map[string]*pathCachedEntry)} result.Results = e.ActiveMatch(baseURL, level, transport) return result @@ -778,6 +1092,9 @@ func (e *Engine) Execute(ctx types.Context, task types.Task) (<-chan types.Resul if !ok { return nil, fmt.Errorf("unsupported context type: %T", ctx) } + if runCtx == nil { + runCtx = NewContext() + } } return e.executeMatch(runCtx, matchTask) @@ -898,3 +1215,23 @@ func pathJoin(base, append string) string { return base + append } +type redirectResolvingTransport struct { + base http.RoundTripper +} + +func wrapRedirectResolvingTransport(base http.RoundTripper) http.RoundTripper { + if base == nil { + base = http.DefaultTransport + } + return redirectResolvingTransport{base: base} +} + +func (t redirectResolvingTransport) RoundTrip(req *http.Request) (*http.Response, error) { + client := &http.Client{ + Transport: t.base, + } + clone := req.Clone(req.Context()) + clone.Body = req.Body + clone.GetBody = req.GetBody + return client.Do(clone) +} diff --git a/fingers/sender.go b/fingers/sender.go index a5224a6..29fcb87 100644 --- a/fingers/sender.go +++ b/fingers/sender.go @@ -74,8 +74,7 @@ func (s *DefaultHTTPSender) Send(url string) (*http.Response, error) { return nil, fmt.Errorf("create request failed: %w", err) } - // 设置默认User-Agent - req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36") + sdkhttpx.ApplyBrowserProfileHeaders(req.Header) resp, err := s.client.Do(req) if err != nil { @@ -92,7 +91,7 @@ func (s *DefaultHTTPSender) SendWithMethod(url, method string, body io.Reader) ( return nil, fmt.Errorf("create request failed: %w", err) } - req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36") + sdkhttpx.ApplyBrowserProfileHeaders(req.Header) resp, err := s.client.Do(req) if err != nil { diff --git a/fingers/types.go b/fingers/types.go index 5540db3..0dc1f75 100644 --- a/fingers/types.go +++ b/fingers/types.go @@ -37,10 +37,20 @@ func NewContext() *Context { } } +func normalizeContext(ctx context.Context) context.Context { + if ctx == nil { + return context.Background() + } + return ctx +} + // WithContext 基于给定的 context.Context 复制 Context func (c *Context) WithContext(ctx context.Context) *Context { + if c == nil { + return NewContext().WithContext(ctx) + } return &Context{ - ctx: ctx, + ctx: normalizeContext(ctx), httpSender: c.httpSender, client: c.client, defaultClient: c.defaultClient, @@ -51,6 +61,9 @@ func (c *Context) WithContext(ctx context.Context) *Context { } func (c *Context) Context() context.Context { + if c == nil || c.ctx == nil { + return context.Background() + } return c.ctx } diff --git a/gogo/context_nil_test.go b/gogo/context_nil_test.go new file mode 100644 index 0000000..f67642c --- /dev/null +++ b/gogo/context_nil_test.go @@ -0,0 +1,52 @@ +package gogo + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" +) + +func TestContextNormalizesNilContext(t *testing.T) { + if NewContext().WithContext(nil).Context() == nil { + t.Fatal("WithContext(nil) returned nil context") + } + + var ctx *Context + if ctx.Context() == nil { + t.Fatal("nil receiver Context returned nil") + } + if ctx.WithContext(nil).Context() == nil { + t.Fatal("nil receiver WithContext(nil) returned nil context") + } +} + +func TestContextPreservesCancelledContext(t *testing.T) { + base, cancel := context.WithCancel(context.Background()) + cancel() + + if err := NewContext().WithContext(base).Context().Err(); err == nil { + t.Fatal("cancelled context was not preserved") + } +} + +func TestExecuteHandlesTypedNilContext(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("gogo typed nil")) + })) + defer server.Close() + + host, port := splitTestServerHostPort(t, server.URL) + eng, err := NewEngine(testConfig()) + if err != nil { + t.Fatal(err) + } + + var ctx *Context + resultCh, err := eng.Execute(ctx, NewScanTask(host, port)) + if err != nil { + t.Fatalf("execute with typed nil context: %v", err) + } + for range resultCh { + } +} diff --git a/gogo/gogo.go b/gogo/gogo.go index dfd7d84..57398d0 100644 --- a/gogo/gogo.go +++ b/gogo/gogo.go @@ -11,10 +11,10 @@ import ( "github.com/chainreactors/gogo/v2/engine" "github.com/chainreactors/gogo/v2/pkg" "github.com/chainreactors/logs" - "github.com/chainreactors/utils" sdkfingers "github.com/chainreactors/sdk/fingers" "github.com/chainreactors/sdk/neutron" "github.com/chainreactors/sdk/pkg/types" + "github.com/chainreactors/utils" "github.com/panjf2000/ants/v2" ) @@ -149,6 +149,16 @@ func (e *Engine) applyInjectedFingers() bool { return false } pkg.FingerEngine = fingerImpl + + // 同时注入 FingerprintHubEngine,使 CyberHub 的 fingerprinthub + // 模板在 gogo 的被动和主动指纹匹配阶段都能生效。 + libEngine := e.fingersEngine.Get() + if libEngine != nil { + if fpHub := libEngine.FingerPrintHub(); fpHub != nil { + pkg.FingerprintHubEngine = fpHub + } + } + return true } @@ -206,6 +216,9 @@ func (e *Engine) Execute(ctx types.Context, task types.Task) (<-chan types.Resul if !ok { return nil, fmt.Errorf("unsupported context type: %T", ctx) } + if runCtx == nil { + runCtx = NewContext() + } } switch t := task.(type) { diff --git a/gogo/types.go b/gogo/types.go index c9787fe..f8dd01d 100644 --- a/gogo/types.go +++ b/gogo/types.go @@ -43,10 +43,20 @@ func NewContext() *Context { } } +func normalizeContext(ctx context.Context) context.Context { + if ctx == nil { + return context.Background() + } + return ctx +} + // WithContext 基于给定的 context.Context 复制 Context func (c *Context) WithContext(ctx context.Context) *Context { + if c == nil { + return NewContext().WithContext(ctx) + } return &Context{ - ctx: ctx, + ctx: normalizeContext(ctx), threads: c.threads, mod: c.mod, excludes: c.excludes, @@ -69,6 +79,9 @@ func (c *Context) SetExcludes(excludes ...string) *Context { } func (c *Context) Context() context.Context { + if c == nil || c.ctx == nil { + return context.Background() + } return c.ctx } @@ -100,9 +113,11 @@ func (c *Context) SetStatsHandler(handler func(types.Stats)) *Context { } func (c *Context) emitStats(stats types.Stats) { - if c != nil && c.statsHandler != nil { - c.statsHandler(stats) + // ctx 已取消(consumer 已拆掉它的 channel)时跳过统计回调,避免 send on closed channel panic + if c == nil || c.statsHandler == nil || c.Context().Err() != nil { + return } + c.statsHandler(stats) } // SetVersionLevel 设置指纹识别级别 diff --git a/neutron/context_nil_test.go b/neutron/context_nil_test.go new file mode 100644 index 0000000..a10361e --- /dev/null +++ b/neutron/context_nil_test.go @@ -0,0 +1,65 @@ +package neutron + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/chainreactors/sdk/pkg/types" +) + +func TestContextNormalizesNilContext(t *testing.T) { + if NewContext().WithContext(nil).Context() == nil { + t.Fatal("WithContext(nil) returned nil context") + } + + var ctx *Context + if ctx.Context() == nil { + t.Fatal("nil receiver Context returned nil") + } + if ctx.WithContext(nil).Context() == nil { + t.Fatal("nil receiver WithContext(nil) returned nil context") + } +} + +func TestContextPreservesCancelledContext(t *testing.T) { + base, cancel := context.WithCancel(context.Background()) + cancel() + + if err := NewContext().WithContext(base).Context().Err(); err == nil { + t.Fatal("cancelled context was not preserved") + } +} + +func TestExecuteHandlesTypedNilContext(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("typed nil neutron marker")) + })) + defer server.Close() + + tpl := parseTemplateForTest(t, `id: typed-nil-context +info: + name: typed nil context + severity: info +http: + - method: GET + path: + - "{{BaseURL}}/" + matchers: + - type: word + words: + - "typed nil neutron marker" +`) + eng := &Engine{config: NewConfig()} + eng.templates = eng.compileTemplates([]*types.Template{tpl}) + eng.SetCapacity(1) + + var ctx *Context + resultCh, err := eng.Execute(ctx, NewExecuteTask(server.URL)) + if err != nil { + t.Fatalf("execute with typed nil context: %v", err) + } + for range resultCh { + } +} diff --git a/neutron/engine.go b/neutron/engine.go index 82b171e..ffc2992 100644 --- a/neutron/engine.go +++ b/neutron/engine.go @@ -118,6 +118,9 @@ func (e *Engine) Execute(ctx types.Context, task types.Task) (<-chan types.Resul if !ok { return nil, fmt.Errorf("unsupported context type: %T", ctx) } + if runCtx == nil { + runCtx = NewContext() + } } return e.executeTemplates(runCtx, templates, execTask.Target, execTask.Payload) diff --git a/neutron/types.go b/neutron/types.go index 8ebc749..8a24ef0 100644 --- a/neutron/types.go +++ b/neutron/types.go @@ -40,14 +40,24 @@ func NewContext() *Context { } } +func normalizeContext(ctx context.Context) context.Context { + if ctx == nil { + return context.Background() + } + return ctx +} + // WithContext 基于给定的 context.Context 复制 Context func (c *Context) WithContext(ctx context.Context) *Context { return &Context{ - ctx: ctx, + ctx: normalizeContext(ctx), } } func (c *Context) Context() context.Context { + if c == nil || c.ctx == nil { + return context.Background() + } return c.ctx } diff --git a/pkg/httpx/httpx.go b/pkg/httpx/httpx.go index 311291e..3386e1e 100644 --- a/pkg/httpx/httpx.go +++ b/pkg/httpx/httpx.go @@ -67,5 +67,7 @@ func NewClient(cfg Config) (*http.Client, error) { uc.DialContext = dialer.DialContext } } - return utilshttpx.NewHTTPClient(uc), nil + client := utilshttpx.NewHTTPClient(uc) + wrapProfileTransport(client) + return client, nil } diff --git a/pkg/httpx/profile.go b/pkg/httpx/profile.go new file mode 100644 index 0000000..3a06b25 --- /dev/null +++ b/pkg/httpx/profile.go @@ -0,0 +1,57 @@ +package httpx + +import ( + "net/http" + "strings" +) + +const ( + ProfileBrowser = "browser" + + browserUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36" +) + +var browserProfileHeaders = map[string]string{ + "User-Agent": browserUserAgent, + "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8", + "Accept-Language": "zh-CN,zh;q=0.9,en-US;q=0.8,en;q=0.7", + "Connection": "keep-alive", +} + +func ApplyBrowserProfileHeaders(header http.Header) { + if header == nil { + return + } + for key, value := range browserProfileHeaders { + if strings.TrimSpace(header.Get(key)) == "" { + header.Set(key, value) + } + } +} + +type profileTransport struct { + base http.RoundTripper +} + +func (t *profileTransport) RoundTrip(req *http.Request) (*http.Response, error) { + base := t.base + if base == nil { + base = http.DefaultTransport + } + ApplyBrowserProfileHeaders(req.Header) + return base.RoundTrip(req) +} + +func wrapProfileTransport(client *http.Client) { + if client == nil { + return + } + base := client.Transport + if base == nil { + base = http.DefaultTransport + } + if _, ok := base.(*profileTransport); ok { + return + } + client.Transport = &profileTransport{base: base} +} diff --git a/pkg/httpx/profile_test.go b/pkg/httpx/profile_test.go new file mode 100644 index 0000000..557516a --- /dev/null +++ b/pkg/httpx/profile_test.go @@ -0,0 +1,64 @@ +package httpx + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestProfileTransportDoesNotRetryBlockedStatus(t *testing.T) { + var requests int + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requests++ + if r.Header.Get("Sec-Fetch-Mode") != "" { + t.Errorf("unexpected secondary profile header: %#v", r.Header) + } + w.WriteHeader(http.StatusForbidden) + fmt.Fprint(w, "blocked") + })) + defer server.Close() + + client, err := NewClient(Config{Timeout: 3 * time.Second, FollowRedirects: false}) + if err != nil { + t.Fatal(err) + } + resp, err := client.Get(server.URL) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusForbidden { + t.Fatalf("status = %d, want 403", resp.StatusCode) + } + if requests != 1 { + t.Fatalf("requests = %d, want 1", requests) + } +} + +func TestProfileTransportAddsBrowserHeaders(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("User-Agent") == "" || r.Header.Get("Accept") == "" || r.Header.Get("Accept-Language") == "" { + w.WriteHeader(http.StatusBadRequest) + return + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client, err := NewClient(Config{Timeout: 3 * time.Second}) + if err != nil { + t.Fatal(err) + } + resp, err := client.Get(server.URL) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } +} diff --git a/pkg/types/spray.go b/pkg/types/spray.go index f6828f7..d9db75b 100644 --- a/pkg/types/spray.go +++ b/pkg/types/spray.go @@ -15,6 +15,12 @@ func NewDefaultSprayOption() *SprayOption { opt.PortRange = "80,443" opt.MaxBodyLength = 100 opt.RandomUserAgent = false + opt.Headers = []string{ + "User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/125.0.0.0 Safari/537.36", + "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", + "Accept-Language: en-US,en;q=0.9", + "Connection: close", + } // Status defaults. opt.BlackStatus = "400,410" diff --git a/proton/context_nil_test.go b/proton/context_nil_test.go new file mode 100644 index 0000000..b3a2a1d --- /dev/null +++ b/proton/context_nil_test.go @@ -0,0 +1,43 @@ +package proton + +import ( + "context" + "testing" +) + +func TestContextNormalizesNilContext(t *testing.T) { + if NewContext().WithContext(nil).Context() == nil { + t.Fatal("WithContext(nil) returned nil context") + } + + var ctx *Context + if ctx.Context() == nil { + t.Fatal("nil receiver Context returned nil") + } + if ctx.WithContext(nil).Context() == nil { + t.Fatal("nil receiver WithContext(nil) returned nil context") + } +} + +func TestContextPreservesCancelledContext(t *testing.T) { + base, cancel := context.WithCancel(context.Background()) + cancel() + + if err := NewContext().WithContext(base).Context().Err(); err == nil { + t.Fatal("cancelled context was not preserved") + } +} + +func TestExecuteHandlesTypedNilContext(t *testing.T) { + eng := mustEngine(t, NewConfig(). + WithCapacity(1). + WithTemplatePaths(writeTempTemplate(t, tmplPrivateKey))) + + var ctx *Context + resultCh, err := eng.Execute(ctx, NewScanDataTask([]byte("PRIVATE KEY\n"), "test.txt")) + if err != nil { + t.Fatalf("execute with typed nil context: %v", err) + } + for range resultCh { + } +} diff --git a/proton/engine.go b/proton/engine.go index 58c06ac..43a6164 100644 --- a/proton/engine.go +++ b/proton/engine.go @@ -88,6 +88,9 @@ func (e *Engine) Execute(ctx types.Context, task types.Task) (<-chan types.Resul if !ok { return nil, fmt.Errorf("unsupported context type: %T", ctx) } + if runCtx == nil { + runCtx = NewContext() + } } switch t := task.(type) { @@ -178,10 +181,10 @@ func (e *Engine) executeScanData(ctx *Context, task *ScanDataTask) (<-chan types } ctx.emitStats(types.Stats{ - Engine: e.Name(), - Task: task.Type(), - Targets: 1, - Results: findingCount, + Engine: e.Name(), + Task: task.Type(), + Targets: 1, + Results: findingCount, Duration: time.Since(started), }) }() diff --git a/proton/types.go b/proton/types.go index b4668ed..ae38028 100644 --- a/proton/types.go +++ b/proton/types.go @@ -35,9 +35,19 @@ func NewContext() *Context { } } +func normalizeContext(ctx context.Context) context.Context { + if ctx == nil { + return context.Background() + } + return ctx +} + func (c *Context) WithContext(ctx context.Context) *Context { + if c == nil { + return NewContext().WithContext(ctx) + } return &Context{ - ctx: ctx, + ctx: normalizeContext(ctx), statsHandler: c.statsHandler, } } @@ -48,13 +58,17 @@ func (c *Context) SetStatsHandler(handler func(types.Stats)) *Context { } func (c *Context) Context() context.Context { + if c == nil || c.ctx == nil { + return context.Background() + } return c.ctx } func (c *Context) emitStats(stats types.Stats) { - if c != nil && c.statsHandler != nil { - c.statsHandler(stats) + if c == nil || c.statsHandler == nil || c.Context().Err() != nil { + return } + c.statsHandler(stats) } // ======================================== diff --git a/spray/context_nil_test.go b/spray/context_nil_test.go new file mode 100644 index 0000000..e1cf931 --- /dev/null +++ b/spray/context_nil_test.go @@ -0,0 +1,51 @@ +package spray + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" +) + +func TestContextNormalizesNilContext(t *testing.T) { + if NewContext().WithContext(nil).Context() == nil { + t.Fatal("WithContext(nil) returned nil context") + } + + var ctx *Context + if ctx.Context() == nil { + t.Fatal("nil receiver Context returned nil") + } + if ctx.WithContext(nil).Context() == nil { + t.Fatal("nil receiver WithContext(nil) returned nil context") + } +} + +func TestContextPreservesCancelledContext(t *testing.T) { + base, cancel := context.WithCancel(context.Background()) + cancel() + + if err := NewContext().WithContext(base).Context().Err(); err == nil { + t.Fatal("cancelled context was not preserved") + } +} + +func TestExecuteHandlesTypedNilContext(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("spray typed nil")) + })) + defer server.Close() + + eng, err := NewEngine(nil) + if err != nil { + t.Fatal(err) + } + + var ctx *Context + resultCh, err := eng.Execute(ctx, NewCheckTask([]string{server.URL})) + if err != nil { + t.Fatalf("execute with typed nil context: %v", err) + } + for range resultCh { + } +} diff --git a/spray/spray.go b/spray/spray.go index e3a8fa4..c5c4d4c 100644 --- a/spray/spray.go +++ b/spray/spray.go @@ -1,7 +1,6 @@ package spray import ( - "context" "fmt" "sync" "time" @@ -138,6 +137,9 @@ func (e *Engine) applyInjectedFingers() bool { } pkg.FingerEngine = libEngine pkg.ActivePath = pkg.ActivePath[:0] + // 当注入了自定义 fingers 引擎(含 fingerprinthub 等子引擎)时, + // 强制开启多引擎指纹识别,确保 CyberHub 的 fingerprinthub 模板被动匹配生效。 + pkg.EnableAllFingerEngine = true logs.Log.Infof("resources type=fingers source=custom %s", libEngine.String()) e.refreshActivePath() return true @@ -248,14 +250,19 @@ func newResult(success bool, err error, data *types.SprayResult) types.Result { return types.NewResult(success, err, data) } -func (e *Engine) handler(ctx context.Context, runner *core.Runner, ch chan types.Result) { +func (e *Engine) handler(ctx *Context, runner *core.Runner, ch chan types.Result) { + var done <-chan struct{} + if ctx != nil && ctx.Context() != nil { + done = ctx.Context().Done() + } // 启动结果处理 goroutine - 处理 OutputCh go func() { for bl := range runner.OutputCh { + result := bl.SprayResult select { - case ch <- newResult(bl.IsValid, nil, bl.SprayResult): + case ch <- newResult(bl.IsValid, nil, result): runner.OutWg.Done() - case <-ctx.Done(): + case <-done: runner.OutWg.Done() continue } @@ -265,10 +272,11 @@ func (e *Engine) handler(ctx context.Context, runner *core.Runner, ch chan types // 启动结果处理 goroutine - 处理 FuzzyCh go func() { for bl := range runner.FuzzyCh { + result := bl.SprayResult select { - case ch <- newResult(bl.IsValid, nil, bl.SprayResult): + case ch <- newResult(bl.IsValid, nil, result): runner.OutWg.Done() - case <-ctx.Done(): + case <-done: runner.OutWg.Done() continue } @@ -276,6 +284,45 @@ func (e *Engine) handler(ctx context.Context, runner *core.Runner, ch chan types }() } +func (e *Engine) mergeActiveFingers(ctx *Context, result *types.SprayResult) *types.SprayResult { + if e == nil || e.fingersEngine == nil || result == nil || result.UrlString == "" { + return result + } + if !result.IsValid && result.Status == 0 { + return result + } + fctx := sdkfingers.NewContext().WithLevel(1) + if ctx != nil { + fctx = fctx.WithContext(ctx.Context()) + if ctx.opt != nil { + if ctx.opt.Timeout > 0 { + fctx = fctx.WithTimeout(ctx.opt.Timeout) + } + if len(ctx.opt.Proxies) > 0 { + fctx = fctx.WithProxy(ctx.opt.Proxies[0]) + } + } + } + matches, err := e.fingersEngine.HTTPMatch(fctx, []string{result.UrlString}) + if err != nil { + return result + } + if result.Frameworks == nil { + result.Frameworks = make(types.Frameworks) + } + for _, target := range matches { + if target == nil { + continue + } + for _, match := range target.Results { + if match != nil && match.Framework != nil { + result.Frameworks.Add(match.Framework) + } + } + } + return result +} + func (e *Engine) executeCheck(ctx *Context, task *CheckTask) (<-chan types.Result, error) { return e.execute(ctx, task.Type(), task.URLs, nil) } @@ -314,6 +361,12 @@ func (e *Engine) execute(ctx *Context, taskType string, urls []string, wordlist return nil, fmt.Errorf("create runner failed: %v", err) } + // NewRunner 内部的 Prepare → LoadDynamicFingers 会重建 pkg.FingerEngine + // 为只含原生 fingers 的引擎,丢失 fingerprinthub 等子引擎。 + // 在此处重新注入 SDK 构建的完整多引擎,确保 CyberHub fingerprinthub + // 模板在被动匹配阶段生效。 + e.applyInjectedFingers() + if wordlist != nil { runner.Wordlist = wordlist runner.Total = len(wordlist) @@ -341,7 +394,7 @@ func (e *Engine) execute(ctx *Context, taskType string, urls []string, wordlist Duration: time.Since(started), }) }() - e.handler(ctx.Context(), runner, ch) + e.handler(ctx, runner, ch) if runner.IsCheck { runner.RunWithCheck(ctx.Context()) diff --git a/spray/types.go b/spray/types.go index 50dc117..471f5ef 100644 --- a/spray/types.go +++ b/spray/types.go @@ -29,16 +29,29 @@ func NewContext() *Context { } } +func normalizeContext(ctx context.Context) context.Context { + if ctx == nil { + return context.Background() + } + return ctx +} + // WithContext 基于给定的 context.Context 复制 Context func (c *Context) WithContext(ctx context.Context) *Context { + if c == nil { + return NewContext().WithContext(ctx) + } return &Context{ - ctx: ctx, + ctx: normalizeContext(ctx), opt: cloneOption(c.opt), statsHandler: c.statsHandler, } } func (c *Context) Context() context.Context { + if c == nil || c.ctx == nil { + return context.Background() + } return c.ctx } @@ -110,7 +123,8 @@ func (c *Context) SetStatsHandler(handler func(types.Stats)) *Context { } func (c *Context) emitStats(stats types.Stats) { - if c == nil || c.statsHandler == nil || c.ctx.Err() != nil { + // ctx 已取消(consumer 已拆掉它的 channel)时跳过统计回调,避免 send on closed channel panic + if c == nil || c.statsHandler == nil || c.Context().Err() != nil { return } c.statsHandler(stats) diff --git a/zombie/context_nil_test.go b/zombie/context_nil_test.go new file mode 100644 index 0000000..7eae3cb --- /dev/null +++ b/zombie/context_nil_test.go @@ -0,0 +1,47 @@ +package zombie + +import ( + "context" + "testing" +) + +func TestContextNormalizesNilContext(t *testing.T) { + if NewContext().WithContext(nil).Context() == nil { + t.Fatal("WithContext(nil) returned nil context") + } + + var ctx *Context + if ctx.Context() == nil { + t.Fatal("nil receiver Context returned nil") + } + if ctx.WithContext(nil).Context() == nil { + t.Fatal("nil receiver WithContext(nil) returned nil context") + } +} + +func TestContextPreservesCancelledContext(t *testing.T) { + base, cancel := context.WithCancel(context.Background()) + cancel() + + if err := NewContext().WithContext(base).Context().Err(); err == nil { + t.Fatal("cancelled context was not preserved") + } +} + +func TestExecuteHandlesTypedNilContext(t *testing.T) { + eng, err := NewEngine(nil) + if err != nil { + t.Fatal(err) + } + + task := NewBruteTask([]Target{{IP: "127.0.0.1", Port: "1", Service: "redis"}}) + task.Passwords = []string{"x"} + + var ctx *Context + resultCh, err := eng.Execute(ctx, task) + if err != nil { + t.Fatalf("execute with typed nil context: %v", err) + } + for range resultCh { + } +} diff --git a/zombie/engine.go b/zombie/engine.go index 2e14cb6..15f1237 100644 --- a/zombie/engine.go +++ b/zombie/engine.go @@ -79,6 +79,9 @@ func (e *Engine) Execute(ctx types.Context, task types.Task) (<-chan types.Resul if !ok { return nil, fmt.Errorf("unsupported context type: %T", ctx) } + if runCtx == nil { + runCtx = NewContext() + } } switch t := task.(type) { diff --git a/zombie/types.go b/zombie/types.go index ede5694..ec6da2c 100644 --- a/zombie/types.go +++ b/zombie/types.go @@ -23,9 +23,19 @@ func NewContext() *Context { } } +func normalizeContext(ctx context.Context) context.Context { + if ctx == nil { + return context.Background() + } + return ctx +} + func (c *Context) WithContext(ctx context.Context) *Context { + if c == nil { + return NewContext().WithContext(ctx) + } return &Context{ - ctx: ctx, + ctx: normalizeContext(ctx), opt: types.CloneZombieOption(c.opt), statsHandler: c.statsHandler, proxy: c.proxy, @@ -41,6 +51,9 @@ func (c *Context) SetProxy(proxies ...string) *Context { } func (c *Context) Context() context.Context { + if c == nil || c.ctx == nil { + return context.Background() + } return c.ctx } @@ -86,9 +99,11 @@ func (c *Context) SetStatsHandler(handler func(types.Stats)) *Context { } func (c *Context) emitStats(stats types.Stats) { - if c != nil && c.statsHandler != nil { - c.statsHandler(stats) + // ctx 已取消(consumer 已拆掉它的 channel)时跳过统计回调,避免 send on closed channel panic + if c == nil || c.statsHandler == nil || c.Context().Err() != nil { + return } + c.statsHandler(stats) } type Config struct {