diff --git a/testing/gradcheck/ops.go b/testing/gradcheck/ops.go index 4e2d2c9..fc42159 100644 --- a/testing/gradcheck/ops.go +++ b/testing/gradcheck/ops.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "math" "github.com/zerfoo/ztensor/compute" "github.com/zerfoo/ztensor/graph" @@ -869,3 +870,396 @@ func (n *groupNormNode[T]) Backward(ctx context.Context, _ types.BackwardMode, g } return []tn[T]{dx}, nil } + +// --- cross-attention (scaled dot-product attention) --------------------------- + +// newCrossAttentionNode builds a single-head scaled-dot-product-attention op +// over three 2D inputs Q[Lq,d], K[Lk,d], V[Lk,d] -> [Lq,d]: +// +// scores = (Q @ K^T) / sqrt(d); A = softmax(scores, axis=1); out = A @ V +// +// This is the packaged cross-attention primitive (separate Q vs K/V +// projections) that E127 composes for text<->stream and audio<->video +// coupling. Landing it extends the ADR-091 oracle to the cross-attention op +// class (T127.1.0a). torch oracle: +// torch.nn.functional.scaled_dot_product_attention(x0, x1, x2) (default scale +// 1/sqrt(E), E = query last dim = d -- matches). No trainable parameters; the +// three inputs are Q, K, V and the gradient flows to all three. +func newCrossAttentionNode[T tensor.Float](e compute.Engine[T]) *opNode[T] { + // Intermediates captured across Forward/Backward. gradcheck builds a fresh + // node per evaluation and never resets an arena between the passes, so + // closure capture is sufficient (no Saver needed -- see the file header). + var ( + attn tn[T] // softmax weights A [Lq, Lk] + qIn tn[T] + kIn tn[T] + vIn tn[T] + scale float64 + ) + return &opNode[T]{ + opType: "CrossAttention", + fwd: func(ctx context.Context, in []tn[T]) (tn[T], error) { + if len(in) != 3 { + return nil, fmt.Errorf("CrossAttention: want 3 inputs (Q,K,V), got %d", len(in)) + } + q, k, v := in[0], in[1], in[2] + qs := q.Shape() + scale = 1.0 / math.Sqrt(float64(qs[len(qs)-1])) + kt, err := e.Transpose(ctx, k, []int{1, 0}) + if err != nil { + return nil, err + } + scores, err := e.MatMul(ctx, q, kt) + if err != nil { + return nil, err + } + scaled, err := e.MulScalar(ctx, scores, T(scale)) + if err != nil { + return nil, err + } + a, err := e.Softmax(ctx, scaled, 1) + if err != nil { + return nil, err + } + attn, qIn, kIn, vIn = a, q, k, v + return e.MatMul(ctx, a, v) + }, + bwd: func(ctx context.Context, g tn[T], _ []tn[T], _ tn[T]) ([]tn[T], error) { + if attn == nil { + return nil, errors.New("CrossAttention: Backward called before Forward") + } + // dV = A^T @ g. + at, err := e.Transpose(ctx, attn, []int{1, 0}) + if err != nil { + return nil, err + } + dV, err := e.MatMul(ctx, at, g) + if err != nil { + return nil, err + } + // dA = g @ V^T. + vt, err := e.Transpose(ctx, vIn, []int{1, 0}) + if err != nil { + return nil, err + } + dA, err := e.MatMul(ctx, g, vt) + if err != nil { + return nil, err + } + // Softmax backward over rows: dScaled = A * (dA - sum(dA*A, axis=1)). + dAA, err := e.Mul(ctx, dA, attn) + if err != nil { + return nil, err + } + s, err := e.ReduceSum(ctx, dAA, 1, true) + if err != nil { + return nil, err + } + dAm, err := e.Sub(ctx, dA, s) + if err != nil { + return nil, err + } + dScaled, err := e.Mul(ctx, attn, dAm) + if err != nil { + return nil, err + } + dScores, err := e.MulScalar(ctx, dScaled, T(scale)) + if err != nil { + return nil, err + } + // scores = Q @ K^T => dQ = dScores @ K ; dK = dScores^T @ Q. + dQ, err := e.MatMul(ctx, dScores, kIn) + if err != nil { + return nil, err + } + dsT, err := e.Transpose(ctx, dScores, []int{1, 0}) + if err != nil { + return nil, err + } + dK, err := e.MatMul(ctx, dsT, qIn) + if err != nil { + return nil, err + } + return []tn[T]{dQ, dK, dV}, nil + }, + } +} + +// --- adaLN (adaptive layer-norm modulation) ----------------------------------- + +// adaLNNode applies the AdaLN affine modulation used by DiT-family diffusion +// models: from a conditioning vector c it projects per-channel scale and shift +// and applies out = x * (1 + scale) + shift, where scale = c @ Ws and +// shift = c @ Wsh. Inputs are x0 = x [N, C] (the pre-normalized activations) +// and x1 = c [N, cond]; Ws, Wsh are [cond, C]. This is the modulation core of +// AdaLN-Zero (the zero-init projection is an initialization detail mapped in +// the arch builder, ADR-092; the op math is identical). Landing it extends the +// ADR-091 oracle to the AdaLN op class (E127/T127.1.0a) and unlocks every +// AdaLN-DiT model. torch oracle: x0 * (1 + x1 @ Ws) + (x1 @ Wsh). +type adaLNNode[T tensor.Float] struct { + engine compute.Engine[T] + ws *graph.Parameter[T] + wsh *graph.Parameter[T] + + onePlusScale tn[T] + xIn tn[T] + cIn tn[T] + saver graph.Saver[T] +} + +func newAdaLNNode[T tensor.Float](e compute.Engine[T], dim, cond int) (*adaLNNode[T], error) { + mk := func(name string, scale float64) (*graph.Parameter[T], error) { + data := make([]T, cond*dim) + for i := range data { + // Deterministic, non-uniform, small values so parameter gradients + // are structurally informative (mirrors newLayerNormNode). + data[i] = T(scale * (0.05 + 0.03*float64(i%7))) + } + v, err := newTensorOf([]int{cond, dim}, data) + if err != nil { + return nil, err + } + return graph.NewParameter[T](name, v, newTensorOf[T]) + } + ws, err := mk("Ws", 1.0) + if err != nil { + return nil, err + } + wsh, err := mk("Wsh", -1.0) + if err != nil { + return nil, err + } + return &adaLNNode[T]{engine: e, ws: ws, wsh: wsh}, nil +} + +func (n *adaLNNode[T]) OpType() string { return "AdaLN" } +func (n *adaLNNode[T]) Attributes() map[string]interface{} { return nil } +func (n *adaLNNode[T]) Parameters() []*graph.Parameter[T] { + return []*graph.Parameter[T]{n.ws, n.wsh} +} +func (n *adaLNNode[T]) SetSaver(s graph.Saver[T]) { n.saver = s } + +func (n *adaLNNode[T]) OutputShape() []int { + if n.onePlusScale == nil { + return nil + } + return n.onePlusScale.Shape() +} + +func (n *adaLNNode[T]) Forward(ctx context.Context, inputs ...tn[T]) (tn[T], error) { + if len(inputs) != 2 { + return nil, fmt.Errorf("AdaLN: want 2 inputs (x, c), got %d", len(inputs)) + } + x, c := inputs[0], inputs[1] + e := n.engine + scale, err := e.MatMul(ctx, c, n.ws.Value) + if err != nil { + return nil, err + } + shift, err := e.MatMul(ctx, c, n.wsh.Value) + if err != nil { + return nil, err + } + onePlusScale, err := e.AddScalar(ctx, scale, T(1)) + if err != nil { + return nil, err + } + n.onePlusScale = onePlusScale + n.xIn = x + n.cIn = c + if n.saver != nil { + n.saver.SaveForBackward(onePlusScale, x, c) + } + xs, err := e.Mul(ctx, x, onePlusScale) + if err != nil { + return nil, err + } + return e.Add(ctx, xs, shift) +} + +func (n *adaLNNode[T]) Backward(ctx context.Context, _ types.BackwardMode, g tn[T], _ ...tn[T]) ([]tn[T], error) { + if n.onePlusScale == nil { + return nil, errors.New("AdaLN: Backward called before Forward") + } + e := n.engine + // out = x*(1+scale) + shift. + // dx = g * (1 + scale). + dx, err := e.Mul(ctx, g, n.onePlusScale) + if err != nil { + return nil, err + } + // dscale = g * x ; dWs = c^T @ dscale. + dscale, err := e.Mul(ctx, g, n.xIn) + if err != nil { + return nil, err + } + cT, err := e.Transpose(ctx, n.cIn, []int{1, 0}) + if err != nil { + return nil, err + } + dWs, err := e.MatMul(ctx, cT, dscale) + if err != nil { + return nil, err + } + if err := n.ws.AddGradient(dWs); err != nil { + return nil, err + } + // dshift = g ; dWsh = c^T @ g. + dWsh, err := e.MatMul(ctx, cT, g) + if err != nil { + return nil, err + } + if err := n.wsh.AddGradient(dWsh); err != nil { + return nil, err + } + // dc = dscale @ Ws^T + g @ Wsh^T. + wsT, err := e.Transpose(ctx, n.ws.Value, []int{1, 0}) + if err != nil { + return nil, err + } + dcScale, err := e.MatMul(ctx, dscale, wsT) + if err != nil { + return nil, err + } + wshT, err := e.Transpose(ctx, n.wsh.Value, []int{1, 0}) + if err != nil { + return nil, err + } + dcShift, err := e.MatMul(ctx, g, wshT) + if err != nil { + return nil, err + } + dc, err := e.Add(ctx, dcScale, dcShift) + if err != nil { + return nil, err + } + return []tn[T]{dx, dc}, nil +} + +// --- timestep sinusoidal embedding -------------------------------------------- + +// timestepEmbedNode is the sinusoidal frequency embedding at the head of every +// diffusion-DiT timestep embedder: from scalar timesteps t [N, 1] it produces +// [N, 2H] = concat(sin(t @ freqs), cos(t @ freqs)) over a learned (here: leaf) +// frequency row freqs [1, H]. The downstream MLP is a plain Linear (already +// covered by MatMul), so this op isolates the sinusoidal piece. Landing it +// extends the ADR-091 oracle to the timestep-embedding op class +// (E127/T127.1.0a). torch oracle: +// torch.cat([torch.sin(x0 @ freqs), torch.cos(x0 @ freqs)], dim=1). +type timestepEmbedNode[T tensor.Float] struct { + engine compute.Engine[T] + freqs *graph.Parameter[T] // [1, H] + half int // H + + sinv tn[T] // sin(arg) [N, H] + cosv tn[T] // cos(arg) [N, H] + tIn tn[T] + saver graph.Saver[T] +} + +func newTimestepEmbedNode[T tensor.Float](e compute.Engine[T], half int) (*timestepEmbedNode[T], error) { + data := make([]T, half) + for j := 0; j < half; j++ { + // Deterministic, moderate frequencies so sin/cos curvature is exercised + // across the input domain (well-conditioned central differences). + data[j] = T(0.5 + 0.4*float64(j)) + } + v, err := newTensorOf([]int{1, half}, data) + if err != nil { + return nil, err + } + freqs, err := graph.NewParameter[T]("freqs", v, newTensorOf[T]) + if err != nil { + return nil, err + } + return ×tepEmbedNode[T]{engine: e, freqs: freqs, half: half}, nil +} + +func (n *timestepEmbedNode[T]) OpType() string { return "TimestepEmbed" } +func (n *timestepEmbedNode[T]) Attributes() map[string]interface{} { return nil } +func (n *timestepEmbedNode[T]) Parameters() []*graph.Parameter[T] { + return []*graph.Parameter[T]{n.freqs} +} +func (n *timestepEmbedNode[T]) SetSaver(s graph.Saver[T]) { n.saver = s } + +func (n *timestepEmbedNode[T]) OutputShape() []int { + if n.sinv == nil { + return nil + } + s := n.sinv.Shape() + return []int{s[0], 2 * s[1]} +} + +func (n *timestepEmbedNode[T]) Forward(ctx context.Context, inputs ...tn[T]) (tn[T], error) { + if len(inputs) != 1 { + return nil, fmt.Errorf("TimestepEmbed: want 1 input (t), got %d", len(inputs)) + } + t := inputs[0] + e := n.engine + arg, err := e.MatMul(ctx, t, n.freqs.Value) // [N,1] @ [1,H] -> [N,H] + if err != nil { + return nil, err + } + sinv, err := e.Sin(ctx, arg) + if err != nil { + return nil, err + } + cosv, err := e.Cos(ctx, arg) + if err != nil { + return nil, err + } + n.sinv, n.cosv, n.tIn = sinv, cosv, t + if n.saver != nil { + n.saver.SaveForBackward(sinv, cosv, t) + } + return e.Concat(ctx, []tn[T]{sinv, cosv}, 1) +} + +func (n *timestepEmbedNode[T]) Backward(ctx context.Context, _ types.BackwardMode, g tn[T], _ ...tn[T]) ([]tn[T], error) { + if n.sinv == nil { + return nil, errors.New("TimestepEmbed: Backward called before Forward") + } + e := n.engine + // Split the [N,2H] upstream gradient into the sin and cos halves. + parts, err := e.Split(ctx, g, 2, 1) + if err != nil { + return nil, err + } + gs, gc := parts[0], parts[1] + // arg = t @ freqs; d/darg[sin] = cos(arg), d/darg[cos] = -sin(arg). + // darg = gs*cos(arg) - gc*sin(arg). + a1, err := e.Mul(ctx, gs, n.cosv) + if err != nil { + return nil, err + } + a2, err := e.Mul(ctx, gc, n.sinv) + if err != nil { + return nil, err + } + darg, err := e.Sub(ctx, a1, a2) + if err != nil { + return nil, err + } + // dt = darg @ freqs^T ([N,H] @ [H,1] -> [N,1]). + fT, err := e.Transpose(ctx, n.freqs.Value, []int{1, 0}) + if err != nil { + return nil, err + } + dt, err := e.MatMul(ctx, darg, fT) + if err != nil { + return nil, err + } + // dfreqs = t^T @ darg ([1,N] @ [N,H] -> [1,H]). + tT, err := e.Transpose(ctx, n.tIn, []int{1, 0}) + if err != nil { + return nil, err + } + dfreqs, err := e.MatMul(ctx, tT, darg) + if err != nil { + return nil, err + } + if err := n.freqs.AddGradient(dfreqs); err != nil { + return nil, err + } + return []tn[T]{dt}, nil +} diff --git a/testing/gradcheck/registry.go b/testing/gradcheck/registry.go index 049be55..6a0e2dd 100644 --- a/testing/gradcheck/registry.go +++ b/testing/gradcheck/registry.go @@ -69,6 +69,12 @@ func NewRegistryNode[T tensor.Float](name string, e compute.Engine[T]) (graph.No return newLayerNormNode(e, 4) case "GroupNorm": return newGroupNormNode(e, 4, 2) + case "CrossAttention": + return newCrossAttentionNode(e), nil + case "AdaLN": + return newAdaLNNode(e, 4, 3) + case "TimestepEmbed": + return newTimestepEmbedNode(e, 4) default: return nil, fmt.Errorf("gradcheck: no registry op named %q", name) } @@ -249,5 +255,26 @@ func Registry() []OpInfo { Make: registryMake("GroupNorm"), InputShapes: [][]int{{3, 4}}, }, + // Cross-attention (scaled dot-product attention): Q[2,4], K[3,4], V[3,4] + // -> [2,4]. Three inputs, no params (E127/T127.1.0a cross-attention). + { + Name: "CrossAttention", Seed: 28, + Make: registryMake("CrossAttention"), + InputShapes: [][]int{{2, 4}, {3, 4}, {3, 4}}, + }, + // AdaLN modulation: x[2,4], c[2,3]; params Ws,Wsh [3,4]. Two inputs + + // two params (E127/T127.1.0a AdaLN-Zero modulation core). + { + Name: "AdaLN", Seed: 29, + Make: registryMake("AdaLN"), + InputShapes: [][]int{{2, 4}, {2, 3}}, + }, + // Timestep sinusoidal embedding: t[3,1] -> [3,8]; freqs [1,4] leaf + // (E127/T127.1.0a timestep-embedding op class). + { + Name: "TimestepEmbed", Seed: 30, + Make: registryMake("TimestepEmbed"), + InputShapes: [][]int{{3, 1}}, + }, } } diff --git a/testing/oracle/torchmap.go b/testing/oracle/torchmap.go index edd8c72..3a9a0f9 100644 --- a/testing/oracle/torchmap.go +++ b/testing/oracle/torchmap.go @@ -67,6 +67,19 @@ var torchMap = map[string]torchOp{ // per-channel affine. Matches gradcheck.newGroupNormNode(e, 4, 2); gamma/beta // leaves stay (1, 4) and reshape to (4,) inside the graph, like LayerNorm. "GroupNorm": {Expr: "torch.nn.functional.group_norm(x0, 2, weight=gamma.reshape(4), bias=beta.reshape(4), eps=1e-05)"}, + + // CrossAttention: single-head scaled dot-product attention over Q=x0, K=x1, + // V=x2. torch's default scale is 1/sqrt(E) with E the query last dim, which + // matches gradcheck.newCrossAttentionNode's 1/sqrt(d). + "CrossAttention": {Expr: "torch.nn.functional.scaled_dot_product_attention(x0, x1, x2)"}, + + // AdaLN: out = x*(1+scale)+shift with scale = c@Ws, shift = c@Wsh. x0=x, + // x1=c; Ws,Wsh are the named [cond,dim] projection leaves. + "AdaLN": {Expr: "x0 * (1 + x1 @ Ws) + (x1 @ Wsh)"}, + + // TimestepEmbed: sinusoidal embedding concat(sin(t@freqs), cos(t@freqs)). + // x0 = t [N,1]; freqs is the named [1,H] leaf. + "TimestepEmbed": {Expr: "torch.cat([torch.sin(x0 @ freqs), torch.cos(x0 @ freqs)], dim=1)"}, } // defaultTolerance is the first-cut f32 comparison bar: ztensor CPU/GPU f32 @@ -85,8 +98,9 @@ var defaultTolerance = Tolerance{ var toleranceOverrides = map[string]Tolerance{ "Softmax": {FwdAtol: 1e-6, FwdRtol: 1e-4, GradAtol: 1e-6, GradRtol: 1e-3}, "LayerNorm": {FwdAtol: 1e-5, FwdRtol: 1e-4, GradAtol: 1e-5, GradRtol: 1e-3}, - "GroupNorm": {FwdAtol: 1e-5, FwdRtol: 1e-4, GradAtol: 1e-5, GradRtol: 1e-3}, - "MatMul": {FwdAtol: 1e-6, FwdRtol: 1e-4, GradAtol: 1e-6, GradRtol: 1e-3}, + "GroupNorm": {FwdAtol: 1e-5, FwdRtol: 1e-4, GradAtol: 1e-5, GradRtol: 1e-3}, + "MatMul": {FwdAtol: 1e-6, FwdRtol: 1e-4, GradAtol: 1e-6, GradRtol: 1e-3}, + "CrossAttention": {FwdAtol: 1e-5, FwdRtol: 1e-4, GradAtol: 1e-5, GradRtol: 1e-3}, } // toleranceFor returns the per-op tolerance, falling back to the default.