diff --git a/src/github.go b/src/github.go index d00f8f1..699ff4e 100644 --- a/src/github.go +++ b/src/github.go @@ -7,7 +7,6 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" - "io" "net/http" "strings" "time" @@ -127,81 +126,7 @@ func validateHMAC(body []byte, signature string, secret []byte) bool { var deliveryDeduperCache = newDeliveryDeduper(defaultDeliveryRetention, defaultDeliveryCacheEntries) func githubEventsHandler(w http.ResponseWriter, r *http.Request) { - body, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, "Unable to read request body", http.StatusInternalServerError) - logger.Error("Unable to read request body", zap.Error(err)) - return - } - - signature := r.Header.Get("X-Hub-Signature-256") - if !validateHMAC(body, signature, githubWebhookSecret) { - http.Error(w, "Invalid signature", http.StatusUnauthorized) - logger.Error("Invalid signature") - return - } - - ctx := r.Context() - eventType := r.Header.Get("X-GitHub-Event") - deliveryID := strings.TrimSpace(r.Header.Get("X-GitHub-Delivery")) - duplicate, duplicateErr := markDuplicateDelivery(ctx, eventType, deliveryID) - if duplicateErr != nil { - http.Error(w, "Unable to record webhook delivery", http.StatusInternalServerError) - logger.Error("Unable to record webhook delivery", zap.String("deliveryID", deliveryID), zap.Error(duplicateErr)) - return - } - if duplicate { - w.WriteHeader(http.StatusOK) - return - } - - if eventProcessor != nil { - if err := eventProcessor.Enqueue(ctx, eventType, body); err != nil { - http.Error(w, "Webhook queue is full", http.StatusServiceUnavailable) - logger.Warn("Dropping webhook event because queue is full", zap.String("eventType", eventType), zap.Error(err)) - return - } - w.WriteHeader(http.StatusAccepted) - return - } - - switch eventType { - case "workflow_run": - updateWorkflowMetrics(ctx, body) - case "workflow_job": - updateJobMetrics(ctx, body) - case "push": - updateCommitMetrics(body) - case "pull_request": - updatePullRequestMetrics(body) - default: - logger.Warn("Invalid GitHub event type", zap.String("eventType", eventType)) - } - - w.WriteHeader(http.StatusOK) -} - -func markDuplicateDelivery(ctx context.Context, eventType, deliveryID string) (bool, error) { - if deliveryID == "" { - return false, nil - } - - if stateStore != nil { - processed, err := stateStore.MarkDeliveryProcessed(ctx, deliveryID) - if err != nil { - return false, err - } - if processed { - return false, nil - } - } else if !deliveryDeduperCache.SeenBefore(deliveryID, time.Now()) { - return false, nil - } - - duplicateDeliveriesSeenCounter.WithLabelValues(eventType).Inc() - duplicateDeliveriesDroppedCounter.WithLabelValues(eventType).Inc() - logger.Info("Skipping duplicate GitHub delivery", zap.String("deliveryID", deliveryID), zap.String("eventType", eventType)) - return true, nil + webhookHTTPHandler(newDefaultWebhookIngestion(), logger).ServeHTTP(w, r) } func normalizeRunState(details runMetricDetails) RunState { diff --git a/src/webhook_ingestion.go b/src/webhook_ingestion.go new file mode 100644 index 0000000..a3679b5 --- /dev/null +++ b/src/webhook_ingestion.go @@ -0,0 +1,225 @@ +package main + +import ( + "context" + "io" + "net/http" + "strings" + "time" + + "go.uber.org/zap" +) + +type webhookIngestionRequest struct { + Body []byte + Signature string + EventType string + DeliveryID string +} + +type webhookIngestionResult struct { + StatusCode int + ErrorMessage string +} + +type webhookAcceptor interface { + Accept(context.Context, webhookIngestionRequest) webhookIngestionResult +} + +type webhookDeliveryStore interface { + MarkDeliveryProcessed(ctx context.Context, deliveryID string) (bool, error) +} + +type webhookEventQueue interface { + Enqueue(ctx context.Context, eventType string, body []byte) error +} + +type webhookEventDispatcher interface { + Dispatch(ctx context.Context, eventType string, body []byte) bool +} + +type webhookIngestionMetrics interface { + RecordDuplicateDelivery(eventType string) +} + +const ( + githubEventWorkflowRun = "workflow_run" + githubEventWorkflowJob = "workflow_job" + githubEventPush = "push" + githubEventPullRequest = "pull_request" +) + +type webhookIngestion struct { + secret []byte + logger *zap.Logger + deliveryStore webhookDeliveryStore + localDeduper *deliveryDeduper + queue webhookEventQueue + dispatcher webhookEventDispatcher + metrics webhookIngestionMetrics + now func() time.Time +} + +func newDefaultWebhookIngestion() *webhookIngestion { + ingestion := &webhookIngestion{ + secret: githubWebhookSecret, + logger: logger, + deliveryStore: stateStore, + localDeduper: deliveryDeduperCache, + dispatcher: defaultWebhookEventDispatcher{}, + metrics: prometheusWebhookIngestionMetrics{}, + now: time.Now, + } + if eventProcessor != nil { + ingestion.queue = eventProcessor + } + + return ingestion +} + +func (i *webhookIngestion) Accept(ctx context.Context, request webhookIngestionRequest) webhookIngestionResult { + if !validateHMAC(request.Body, request.Signature, i.secret) { + i.logError("Invalid signature") + return webhookIngestionResult{ + StatusCode: http.StatusUnauthorized, + ErrorMessage: "Invalid signature", + } + } + + eventType := request.EventType + deliveryID := strings.TrimSpace(request.DeliveryID) + duplicate, err := i.markDuplicateDelivery(ctx, eventType, deliveryID) + if err != nil { + i.logError("Unable to record webhook delivery", zap.String("deliveryID", deliveryID), zap.Error(err)) + return webhookIngestionResult{ + StatusCode: http.StatusInternalServerError, + ErrorMessage: "Unable to record webhook delivery", + } + } + if duplicate { + return webhookIngestionResult{StatusCode: http.StatusOK} + } + + if i.queue != nil { + if err := i.queue.Enqueue(ctx, eventType, request.Body); err != nil { + i.logWarn("Dropping webhook event because queue is full", zap.String("eventType", eventType), zap.Error(err)) + return webhookIngestionResult{ + StatusCode: http.StatusServiceUnavailable, + ErrorMessage: "Webhook queue is full", + } + } + return webhookIngestionResult{StatusCode: http.StatusAccepted} + } + + if i.dispatcher != nil { + if ok := i.dispatcher.Dispatch(ctx, eventType, request.Body); !ok { + i.logWarn("Invalid GitHub event type", zap.String("eventType", eventType)) + } + } + + return webhookIngestionResult{StatusCode: http.StatusOK} +} + +func (i *webhookIngestion) markDuplicateDelivery(ctx context.Context, eventType, deliveryID string) (bool, error) { + if deliveryID == "" { + return false, nil + } + + switch { + case i.deliveryStore != nil: + processed, err := i.deliveryStore.MarkDeliveryProcessed(ctx, deliveryID) + if err != nil { + return false, err + } + if processed { + return false, nil + } + case i.localDeduper != nil: + now := time.Now + if i.now != nil { + now = i.now + } + if !i.localDeduper.SeenBefore(deliveryID, now()) { + return false, nil + } + default: + return false, nil + } + + if i.metrics != nil { + i.metrics.RecordDuplicateDelivery(eventType) + } + i.logInfo("Skipping duplicate GitHub delivery", zap.String("deliveryID", deliveryID), zap.String("eventType", eventType)) + return true, nil +} + +func (i *webhookIngestion) logInfo(message string, fields ...zap.Field) { + if i.logger != nil { + i.logger.Info(message, fields...) + } +} + +func (i *webhookIngestion) logWarn(message string, fields ...zap.Field) { + if i.logger != nil { + i.logger.Warn(message, fields...) + } +} + +func (i *webhookIngestion) logError(message string, fields ...zap.Field) { + if i.logger != nil { + i.logger.Error(message, fields...) + } +} + +type defaultWebhookEventDispatcher struct{} + +func (defaultWebhookEventDispatcher) Dispatch(ctx context.Context, eventType string, body []byte) bool { + switch eventType { + case githubEventWorkflowRun: + updateWorkflowMetrics(ctx, body) + case githubEventWorkflowJob: + updateJobMetrics(ctx, body) + case githubEventPush: + updateCommitMetrics(body) + case githubEventPullRequest: + updatePullRequestMetrics(body) + default: + return false + } + + return true +} + +type prometheusWebhookIngestionMetrics struct{} + +func (prometheusWebhookIngestionMetrics) RecordDuplicateDelivery(eventType string) { + duplicateDeliveriesSeenCounter.WithLabelValues(eventType).Inc() + duplicateDeliveriesDroppedCounter.WithLabelValues(eventType).Inc() +} + +func webhookHTTPHandler(acceptor webhookAcceptor, logger *zap.Logger) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Unable to read request body", http.StatusInternalServerError) + if logger != nil { + logger.Error("Unable to read request body", zap.Error(err)) + } + return + } + + result := acceptor.Accept(r.Context(), webhookIngestionRequest{ + Body: body, + Signature: r.Header.Get("X-Hub-Signature-256"), + EventType: r.Header.Get("X-GitHub-Event"), + DeliveryID: r.Header.Get("X-GitHub-Delivery"), + }) + + if result.ErrorMessage != "" { + http.Error(w, result.ErrorMessage, result.StatusCode) + return + } + + w.WriteHeader(result.StatusCode) + } +} diff --git a/src/webhook_ingestion_test.go b/src/webhook_ingestion_test.go new file mode 100644 index 0000000..e3ef66b --- /dev/null +++ b/src/webhook_ingestion_test.go @@ -0,0 +1,227 @@ +//go:build !integration + +package main + +import ( + "bytes" + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "go.uber.org/zap" +) + +type fakeDeliveryMarker struct { + created bool + err error + calls []string +} + +func (f *fakeDeliveryMarker) MarkDeliveryProcessed(_ context.Context, deliveryID string) (bool, error) { + f.calls = append(f.calls, deliveryID) + return f.created, f.err +} + +type fakeWebhookQueue struct { + err error + events []webhookIngestionRequest + lastCtx context.Context +} + +func (f *fakeWebhookQueue) Enqueue(ctx context.Context, eventType string, body []byte) error { + f.lastCtx = ctx + f.events = append(f.events, webhookIngestionRequest{ + EventType: eventType, + Body: append([]byte(nil), body...), + }) + return f.err +} + +type fakeWebhookDispatcher struct { + events []webhookIngestionRequest +} + +func (f *fakeWebhookDispatcher) Dispatch(_ context.Context, eventType string, body []byte) bool { + f.events = append(f.events, webhookIngestionRequest{ + EventType: eventType, + Body: append([]byte(nil), body...), + }) + return eventType == githubEventWorkflowRun +} + +type fakeWebhookMetrics struct { + duplicates []string +} + +func (f *fakeWebhookMetrics) RecordDuplicateDelivery(eventType string) { + f.duplicates = append(f.duplicates, eventType) +} + +func signedWorkflowRunWebhookRequest(body []byte) webhookIngestionRequest { + return webhookIngestionRequest{ + Body: body, + Signature: computeHMAC(body, []byte("test-secret")), + EventType: githubEventWorkflowRun, + DeliveryID: "delivery-1", + } +} + +func newTestWebhookIngestion() (*webhookIngestion, *fakeDeliveryMarker, *fakeWebhookDispatcher, *fakeWebhookMetrics) { + delivery := &fakeDeliveryMarker{created: true} + dispatcher := &fakeWebhookDispatcher{} + metrics := &fakeWebhookMetrics{} + + return &webhookIngestion{ + secret: []byte("test-secret"), + logger: zap.NewNop(), + deliveryStore: delivery, + dispatcher: dispatcher, + metrics: metrics, + }, delivery, dispatcher, metrics +} + +func TestWebhookIngestionRejectsInvalidSignature(t *testing.T) { + ingestion, delivery, dispatcher, _ := newTestWebhookIngestion() + + result := ingestion.Accept(context.Background(), webhookIngestionRequest{ + Body: []byte(`{"ok":true}`), + Signature: "sha256=bad", + EventType: githubEventWorkflowRun, + DeliveryID: "delivery-1", + }) + + if result.StatusCode != http.StatusUnauthorized { + t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, result.StatusCode) + } + if len(delivery.calls) != 0 { + t.Fatalf("expected no delivery calls, got %d", len(delivery.calls)) + } + if len(dispatcher.events) != 0 { + t.Fatalf("expected no dispatch calls, got %d", len(dispatcher.events)) + } +} + +func TestWebhookIngestionDropsDuplicateBeforeDispatch(t *testing.T) { + ingestion, delivery, dispatcher, metrics := newTestWebhookIngestion() + delivery.created = false + + result := ingestion.Accept(context.Background(), signedWorkflowRunWebhookRequest([]byte(`{"ok":true}`))) + + if result.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, result.StatusCode) + } + if len(dispatcher.events) != 0 { + t.Fatalf("expected no dispatch calls, got %d", len(dispatcher.events)) + } + if got := metrics.duplicates; len(got) != 1 || got[0] != githubEventWorkflowRun { + t.Fatalf("expected duplicate metric for workflow_run, got %#v", got) + } +} + +func TestWebhookIngestionReturnsServerErrorWhenDeliveryRecordingFails(t *testing.T) { + ingestion, delivery, dispatcher, _ := newTestWebhookIngestion() + delivery.err = errors.New("redis unavailable") + + result := ingestion.Accept(context.Background(), signedWorkflowRunWebhookRequest([]byte(`{"ok":true}`))) + + if result.StatusCode != http.StatusInternalServerError { + t.Fatalf("expected status %d, got %d", http.StatusInternalServerError, result.StatusCode) + } + if len(dispatcher.events) != 0 { + t.Fatalf("expected no dispatch calls, got %d", len(dispatcher.events)) + } +} + +func TestWebhookIngestionDispatchesSynchronouslyWithoutQueue(t *testing.T) { + ingestion, _, dispatcher, _ := newTestWebhookIngestion() + body := []byte(`{"ok":true}`) + + result := ingestion.Accept(context.Background(), signedWorkflowRunWebhookRequest(body)) + + if result.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, result.StatusCode) + } + if len(dispatcher.events) != 1 { + t.Fatalf("expected one dispatch call, got %d", len(dispatcher.events)) + } + if dispatcher.events[0].EventType != githubEventWorkflowRun || !bytes.Equal(dispatcher.events[0].Body, body) { + t.Fatalf("unexpected dispatched event: %#v", dispatcher.events[0]) + } +} + +func TestWebhookIngestionEnqueuesAcceptedEvent(t *testing.T) { + ingestion, _, dispatcher, _ := newTestWebhookIngestion() + queue := &fakeWebhookQueue{} + ingestion.queue = queue + body := []byte(`{"ok":true}`) + + result := ingestion.Accept(context.Background(), signedWorkflowRunWebhookRequest(body)) + + if result.StatusCode != http.StatusAccepted { + t.Fatalf("expected status %d, got %d", http.StatusAccepted, result.StatusCode) + } + if len(queue.events) != 1 { + t.Fatalf("expected one queued event, got %d", len(queue.events)) + } + if len(dispatcher.events) != 0 { + t.Fatalf("expected no synchronous dispatch calls, got %d", len(dispatcher.events)) + } +} + +func TestWebhookIngestionReturnsUnavailableWhenQueueIsFull(t *testing.T) { + ingestion, _, dispatcher, _ := newTestWebhookIngestion() + ingestion.queue = &fakeWebhookQueue{err: errors.New("queue full")} + + result := ingestion.Accept(context.Background(), signedWorkflowRunWebhookRequest([]byte(`{"ok":true}`))) + + if result.StatusCode != http.StatusServiceUnavailable { + t.Fatalf("expected status %d, got %d", http.StatusServiceUnavailable, result.StatusCode) + } + if len(dispatcher.events) != 0 { + t.Fatalf("expected no synchronous dispatch calls, got %d", len(dispatcher.events)) + } +} + +type fakeWebhookAcceptor struct { + request webhookIngestionRequest + result webhookIngestionResult +} + +func (f *fakeWebhookAcceptor) Accept(_ context.Context, request webhookIngestionRequest) webhookIngestionResult { + f.request = request + return f.result +} + +func TestWebhookHTTPHandlerAdaptsHeadersAndBody(t *testing.T) { + acceptor := &fakeWebhookAcceptor{ + result: webhookIngestionResult{StatusCode: http.StatusAccepted}, + } + handler := webhookHTTPHandler(acceptor, zap.NewNop()) + body := []byte(`{"ok":true}`) + + req := httptest.NewRequest(http.MethodPost, "/webhook", bytes.NewReader(body)) + req.Header.Set("X-Hub-Signature-256", "sha256=test") + req.Header.Set("X-GitHub-Event", githubEventWorkflowRun) + req.Header.Set("X-GitHub-Delivery", "delivery-1") + recorder := httptest.NewRecorder() + + handler.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusAccepted { + t.Fatalf("expected status %d, got %d", http.StatusAccepted, recorder.Code) + } + if acceptor.request.Signature != "sha256=test" { + t.Fatalf("unexpected signature %q", acceptor.request.Signature) + } + if acceptor.request.EventType != githubEventWorkflowRun { + t.Fatalf("unexpected event type %q", acceptor.request.EventType) + } + if acceptor.request.DeliveryID != "delivery-1" { + t.Fatalf("unexpected delivery id %q", acceptor.request.DeliveryID) + } + if !bytes.Equal(acceptor.request.Body, body) { + t.Fatalf("unexpected body %q", string(acceptor.request.Body)) + } +}