Summary
(coauthed by Claude)
The PyTorch fused attention bindings support an optional return_max_logit parameter that returns per-head maximum attention scores, but the JAX fused attention API (transformer_engine.jax.attention.fused_attn) has no equivalent mechanism. Additionally, the softmax_aux tensor (log-sum-exp statistics) computed during the forward pass is kept internal to the custom VJP machinery and never exposed to callers.
Primary ask: Expose max_logit (the per-head maximum attention score) through transformer_engine.jax.attention.fused_attn, matching the PyTorch API's return_max_logit parameter.
Secondary ask: Optionally expose softmax_aux (log-sum-exp statistics) for users who need them.
Motivation
The max_logit value is already supported in the PyTorch frontend via return_max_logit: bool in transformer_engine/pytorch/cpp_extensions/fused_attn.py. Use cases include:
- Custom loss functions and reward signals that depend on attention statistics
- Numerical stability diagnostics and debugging
- Implementing custom backward passes outside of JAX's built-in autodiff
- Feature parity with the PyTorch frontend
Current Behavior
PyTorch (supports max_logit):
In transformer_engine/pytorch/cpp_extensions/fused_attn.py, the forward function accepts return_max_logit: bool and returns max_logit as a separate output tensor when requested.
JAX (no max_logit support):
In transformer_engine/jax/attention.py, the public API fused_attn (line ~1394) returns only the output tensor. The internal forward rule _fused_attn_fwd_rule (line ~1272) calls tex.fused_attn_fwd which returns (output, softmax_aux, rng_state), but only output is surfaced to the caller. The softmax_aux and rng_state are stored in the custom VJP context for the backward pass and discarded from the public return value.
# Current JAX public API — no way to get aux stats
def fused_attn(qkv, bias, sequence_descriptor, seed, ...) -> jnp.ndarray:
...
return output # softmax_aux is only available internally for backward pass
Expected Behavior
from transformer_engine.jax.attention import fused_attn
# Primary: max_logit support
output, aux = fused_attn(q, k, v, ..., return_max_logit=True)
max_logit = aux["max_logit"] # per-head max attention scores
# Secondary: softmax_aux (log-sum-exp stats)
output, aux = fused_attn(q, k, v, ..., return_aux=True)
softmax_aux = aux["softmax_aux"] # shape [B, H, Sq, 1], float32 (log-sum-exp)
Proposed Changes
1. Add return_max_logit support to the JAX C++ extension layer
The underlying cuDNN kernels already support computing max_logit (as evidenced by PyTorch support). Wire this through fused_attn_fwd in transformer_engine/jax/cpp_extensions/attention.py.
2. Propagate through the custom VJP wrapper
Update _fused_attn and its forward/backward rules in transformer_engine/jax/attention.py to optionally return max_logit alongside output. Care is needed since changing the return signature affects jax.custom_vjp — one approach is to always compute max_logit when requested and pass it through the VJP context without requiring gradients for it.
3. Expose in the public fused_attn API
Add return_max_logit: bool = False (and optionally return_softmax_aux: bool = False) to the fused_attn signature at line ~1394. When enabled, return a tuple (output, aux_dict) instead of just output.
4. Current fused_attn signature for reference
def fused_attn(
qkv, bias, sequence_descriptor, seed,
attn_bias_type, attn_mask_type, qkv_layout, softmax_type,
scaling_factor, dropout_probability, is_training,
max_segments_per_seq=1, window_size=None,
context_parallel_strategy=CPStrategy.DEFAULT,
context_parallel_causal_load_balanced=False,
context_parallel_axis="", context_checkpoint_name="context",
softmax_offset=None, stripe_size=None,
) -> jnp.ndarray
Softmax Aux Details
The softmax_aux tensor in JAX contains log-sum-exp statistics with shape determined by cuDNN version:
- cuDNN ≥ 9.6:
[B, H, Sq, 1] for BSHD layouts, [B, Sq, H, 1] for THD layouts
- cuDNN < 9.6:
[B, H, Sq, max_segments_per_seq]
- Always:
float32
Note: softmax_aux contains log(Σ exp(x - max(x))), not the raw max logits. max_logit is a separate output tensor.
References
Environment
- TransformerEngine version: main branch (HEAD)
- Framework: JAX
Summary
(coauthed by Claude)
The PyTorch fused attention bindings support an optional
return_max_logitparameter that returns per-head maximum attention scores, but the JAX fused attention API (transformer_engine.jax.attention.fused_attn) has no equivalent mechanism. Additionally, thesoftmax_auxtensor (log-sum-exp statistics) computed during the forward pass is kept internal to the custom VJP machinery and never exposed to callers.Primary ask: Expose
max_logit(the per-head maximum attention score) throughtransformer_engine.jax.attention.fused_attn, matching the PyTorch API'sreturn_max_logitparameter.Secondary ask: Optionally expose
softmax_aux(log-sum-exp statistics) for users who need them.Motivation
The
max_logitvalue is already supported in the PyTorch frontend viareturn_max_logit: boolintransformer_engine/pytorch/cpp_extensions/fused_attn.py. Use cases include:Current Behavior
PyTorch (supports
max_logit):In
transformer_engine/pytorch/cpp_extensions/fused_attn.py, the forward function acceptsreturn_max_logit: booland returnsmax_logitas a separate output tensor when requested.JAX (no
max_logitsupport):In
transformer_engine/jax/attention.py, the public APIfused_attn(line ~1394) returns only the output tensor. The internal forward rule_fused_attn_fwd_rule(line ~1272) callstex.fused_attn_fwdwhich returns(output, softmax_aux, rng_state), but onlyoutputis surfaced to the caller. Thesoftmax_auxandrng_stateare stored in the custom VJP context for the backward pass and discarded from the public return value.Expected Behavior
Proposed Changes
1. Add
return_max_logitsupport to the JAX C++ extension layerThe underlying cuDNN kernels already support computing
max_logit(as evidenced by PyTorch support). Wire this throughfused_attn_fwdintransformer_engine/jax/cpp_extensions/attention.py.2. Propagate through the custom VJP wrapper
Update
_fused_attnand its forward/backward rules intransformer_engine/jax/attention.pyto optionally returnmax_logitalongsideoutput. Care is needed since changing the return signature affectsjax.custom_vjp— one approach is to always computemax_logitwhen requested and pass it through the VJP context without requiring gradients for it.3. Expose in the public
fused_attnAPIAdd
return_max_logit: bool = False(and optionallyreturn_softmax_aux: bool = False) to thefused_attnsignature at line ~1394. When enabled, return a tuple(output, aux_dict)instead of justoutput.4. Current
fused_attnsignature for referenceSoftmax Aux Details
The
softmax_auxtensor in JAX contains log-sum-exp statistics with shape determined by cuDNN version:[B, H, Sq, 1]for BSHD layouts,[B, Sq, H, 1]for THD layouts[B, H, Sq, max_segments_per_seq]float32Note:
softmax_auxcontainslog(Σ exp(x - max(x))), not the raw max logits.max_logitis a separate output tensor.References
return_max_logitsupport:transformer_engine/pytorch/cpp_extensions/fused_attn.pytransformer_engine/jax/attention.pyfused_attn~L1394transformer_engine/jax/attention.py_fused_attn_fwd_rule~L1272transformer_engine/jax/cpp_extensions/attention.pyfused_attn_fwd~L3359Environment