Skip to content

Fix abs jvp for complex inputs#3745

Open
obchain wants to merge 1 commit into
ml-explore:mainfrom
obchain:fix/abs-complex-jvp
Open

Fix abs jvp for complex inputs#3745
obchain wants to merge 1 commit into
ml-explore:mainfrom
obchain:fix/abs-complex-jvp

Conversation

@obchain

@obchain obchain commented Jun 22, 2026

Copy link
Copy Markdown
Contributor

Proposed changes

Fixes #3744.

Abs::jvp returned tangents[0] * sign(z) for complex inputs. Since |z| is real-valued, the framework dropped the imaginary part of that complex product, leaving Re(z * t) / |z| instead of the correct directional derivative

d|z| = Re(conj(z) * t) / |z| = Re(conj(sign(z)) * t)

so the imaginary contribution had the wrong sign. The fix multiplies by conj(sign(z)) and takes the real part, giving a correct (and real) tangent for complex inputs; the real path is unchanged.

Abs::vjp previously delegated to jvp. For complex inputs the vjp and jvp differ by a conjugate, so it no longer delegates — it keeps its existing, already-correct cotangent * sign(z) form (whose real/imaginary parts are the gradients w.r.t. Re(z) and Im(z)).

The bug also surfaced through ops like mx.abs(mx.fft.rfft(x)), whose jvp is now correct.

Before:

>>> z = mx.array([1+2j]); t = mx.array([0.5-1j])
>>> mx.jvp(mx.abs, [z], [t])[1][0]      # Re(z*t)/|z|, wrong
>>> mx.real(mx.conj(z)*t)/mx.abs(z)     # correct

Added test_complex_abs_grad to test_autograd.py checking the jvp against the hand-computed value, the vjp against cotangent * sign(z), and that real inputs are unaffected.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

Abs::jvp returned tangents[0] * sign(z) for complex inputs. Since |z| is
real-valued, the framework dropped the imaginary part of that complex
product, leaving Re(z * t) / |z| instead of the correct directional
derivative Re(conj(z) * t) / |z|. Multiply by conj(sign(z)) and take the
real part so the tangent is correct (and real) for complex inputs; the
real path is unchanged. The vjp no longer delegates to the jvp (the two
differ by a conjugate for complex inputs) but keeps its existing, correct
cotangent * sign(z) form.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] jvp of mx.abs is wrong for complex inputs (missing conjugate)

1 participant