Skip to content

New inverse stream map to accelerate convergence#1919

Merged
unalmis merged 45 commits intomasterfrom
ku/bounce
Apr 8, 2026
Merged

New inverse stream map to accelerate convergence#1919
unalmis merged 45 commits intomasterfrom
ku/bounce

Conversation

@unalmis
Copy link
Copy Markdown
Collaborator

@unalmis unalmis commented Sep 18, 2025

PR Age

Inverse stream maps

Improvements

  • Check-pointing to increase speed and reduce memory consumption of reverse mode differentiation Checkpointing to reduce reverse mode AD memory usage #1347.
  • Adds low_ram mode which is same speed and less memory for objective.compute, but slower for objective.grad since JAX is poor at iterative algorithms.
  • Fully resolves Memory regression in bounce integrals #1864 by avoiding materialization of a large tensor in memory. Previously, we had closed the issue by adding nuffts as a workaround. The improvement here actually solves the JAX regression.
  • Reuses some computations in identifying bounce points to make more efficient.
  • Increase cache hits, fusing, and reduce floating point error in computing bounce points (very important for accurate integrals).
  • Transforms an improper field line integral to one on a compact domain where the integrand is periodic to achieve faster convergence.
  • Improves performance complexity of interp_to_argmin for Bounce2D from fourth order to spectral as required for Alpert quadrature.
  • Resolves Use OOP for surface integrals with faster methods for tensor product grids #1389.

Usability

Bugs

  • Fixes bug in test_compute_everything.
  • Fixes inverse stream map convergence.

Benchmarks

Just go to #2026 and run effective_ripple_profile.py. You will see the large performance improvement from master. The CI benchmarks do not reveal this because those benchmarks are essentially just noise. Note that, using the same parameter inputs, the resolution of this branch is also higher than master due to the faster convergence.

  • If you set use_bounce1d=True on that script, you will run out of memory as expected since it is an inferior approach (as expected, you get the OOM in the jacobian before you compute a single bounce integral).
  • If you set nufft_eps=0, you need 175 GB to run that script on master (you'll get an OOM and JAX will tell you it needs 175GB), but only 34 GB on this branch.
  • Using nuffts, the script requires only 6.5 GB on this branch.

Examples

HELIOTRON

Master branch

test_theta_chebyshev_HELIOTRON

This branch

test_delta_fourier_chebyshev_HELIOTRON

W7-X

Master branch

test_theta_chebyshev_W7-X

This branch

test_delta_fourier_chebyshev_W7-X

NCSX

Master branch

test_theta_chebyshev_NCSX

This branch

test_delta_fourier_chebyshev_NCSX

Removal of spectral aliasing

Figure_1

Dynamic shapes

  • resolves Patch for differentiable code with dynamic shapes #1303 .
  • Activates newton step to find bounce points.
    • It can be shown there exists O(sqrt(epsilon)) error for bounce integrals with 1/v_ll where epsilon is error of bounce point. For v_|| integrals error is conveniently O(epsilon times sqrt(epsilon)). Hence, the spline method would require thousands of knots per transit for just a couple digits of accuracy, and it would stop convergence at epsilon<=1e-5 (so sqrt(epsilon) <=3 digits) error due to condition number. Of course, the spline=False method has always computed the points with spectral accuracy and has very fast convergence after New inverse stream map to accelerate convergence #1919 ; that method converges to epsilon = machine precision without the newton step. With the newton step, fast convergence is achieved with the spline method as well.
    • I suspect fast ion confinement optimization will be easier now.

Benchmarks

Here is a timing benchmark on my CPU with nufft_eps=1e-6. Prior to this PR, every adjoint call to nufft1 took >= 1 second and the full computation was 34 seconds. Now every adjoint call to nufft1 is 250 milliseconds, and the full computation is 14 seconds. These improvements become larger as the sparsity grows and error tolerance parameter for the nuffts epsilon tends to 0. Likewise, the improvement grows linearly with the problem size. As this is called within an optimization loop where time and memory are tight, the improvement is significant.

Before

Screenshot From 2026-03-29 15-22-56 Screenshot From 2026-03-29 15-22-36

After

Screenshot From 2026-03-29 15-12-35 Screenshot From 2026-03-29 15-15-59

@unalmis unalmis self-assigned this Sep 18, 2025
@unalmis unalmis added performance New feature or request to make the code faster robustness Make the code more robust labels Sep 18, 2025
@review-notebook-app
Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@unalmis unalmis changed the base branch from master to ku/nufft September 18, 2025 07:45
@unalmis unalmis marked this pull request as draft September 18, 2025 07:47
@unalmis unalmis added the theory Requires theory work before coding label Sep 18, 2025
@unalmis unalmis changed the title New inverse stream maps to accelerate convergence New inverse stream map to accelerate convergence Sep 18, 2025
@unalmis unalmis added the bug fix Something was fixed label Sep 18, 2025
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Sep 18, 2025

Memory benchmark result

|               Test Name                |      %Δ      |    Master (MB)     |      PR (MB)       |    Δ (MB)    |    Time PR (s)     |  Time Master (s)   |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
  test_objective_jac_w7x                 |    1.38 %    |     4.202e+03      |     4.260e+03      |    58.14     |       41.77        |       38.76        |
  test_proximal_jac_w7x_with_eq_update   |   -2.29 %    |     6.584e+03      |     6.433e+03      |   -151.00    |       158.32       |       157.41       |
  test_proximal_freeb_jac                |   -0.19 %    |     1.340e+04      |     1.338e+04      |    -24.80    |       86.52        |       87.00        |
  test_proximal_freeb_jac_blocked        |    0.53 %    |     7.736e+03      |     7.777e+03      |    41.04     |       76.38        |       76.48        |
  test_proximal_freeb_jac_batched        |    0.50 %    |     7.669e+03      |     7.707e+03      |    38.25     |       75.87        |       77.10        |
  test_proximal_jac_ripple               |   -3.71 %    |     3.733e+03      |     3.594e+03      |   -138.53    |       61.75        |       65.36        |
  test_proximal_jac_ripple_bounce1d      |    1.89 %    |     3.855e+03      |     3.927e+03      |    72.68     |       74.60        |       76.63        |
  test_eq_solve                          |    0.23 %    |     2.222e+03      |     2.227e+03      |     5.07     |       98.03        |       97.03        |

For the memory plots, go to the summary of Memory Benchmarks workflow and download the artifact.

@unalmis unalmis removed bug fix Something was fixed theory Requires theory work before coding labels Sep 18, 2025
Comment thread desc/compute/_fast_ion.py Outdated
@unalmis unalmis changed the title New inverse stream map to accelerate convergence better inverse stream map to accelerate convergence Sep 20, 2025
@unalmis unalmis changed the title better inverse stream map to accelerate convergence new inverse stream map to accelerate convergence Sep 20, 2025
@unalmis unalmis added the theory Requires theory work before coding label Sep 20, 2025
@unalmis unalmis force-pushed the ku/bounce branch 2 times, most recently from a6d949b to d685405 Compare September 22, 2025 04:33
Comment thread desc/batching.py
@unalmis unalmis removed the theory Requires theory work before coding label Sep 22, 2025
Comment thread desc/compute/_equil.py
Comment thread desc/compute/_equil.py
Comment thread desc/equilibrium/coords.py
@unalmis unalmis added the P3 Highest Priority, someone is/should be actively working on this label Sep 23, 2025
@unalmis unalmis dismissed f0uriest’s stale review September 23, 2025 08:44

addressed request

@unalmis unalmis marked this pull request as ready for review September 23, 2025 08:44
@unalmis unalmis requested review from a team, f0uriest and rahulgaur104 and removed request for a team September 23, 2025 08:44
@unalmis unalmis mentioned this pull request Feb 26, 2026
6 tasks
@unalmis
Copy link
Copy Markdown
Collaborator Author

unalmis commented Feb 27, 2026

Someone else who is still working on desc will need to maintain this PR from now on.

@unalmis
Copy link
Copy Markdown
Collaborator Author

unalmis commented Mar 17, 2026

@dpanici @YigitElma @f0uriest @ddudt @rahulgaur104 . I requested an ETA 3 months ago for when this would be merged, but no one has replied.

@YigitElma
Copy link
Copy Markdown
Collaborator

I don't like having a new dependency. We already have many problems with JAX and related packages. I don't feel comfortable approving the current state of the PR.

@unalmis
Copy link
Copy Markdown
Collaborator Author

unalmis commented Mar 17, 2026

  • The dependency uses the same packages that desc already relies on. It's not going to cause problems. Your issue is with JAX, not my code.
  • I had predicted the recent jax finufft issue and put guards in place to avoid them in future. When those guards were explicitly removed by other developers without my approval, a consequence was that issue. That would have been avoided if my code had not been modified.
  • The other developers requested I move it to an external package. I waited 3 months to make the change (October to December) just in case someone changed their mind. I even
    confirmed it again via a poll one week prior to when I added the dependency. To now state that you are blocking the work because of that dependency is unreasonable , especially given the thorough lengths I went to to avoid this particular issue.

@YigitElma
Copy link
Copy Markdown
Collaborator

  • To now state that you are blocking the work because of that dependency is unreasonable

I am not blocking it. If others approve it, you can merge it; that is fine. I don't have to approve all PRs. I've reviewed it multiple times and spent a couple of days on it.

@dpanici
Copy link
Copy Markdown
Collaborator

dpanici commented Mar 18, 2026

@dpanici @YigitElma @f0uriest @ddudt @rahulgaur104 . I requested an ETA 3 months ago for when this would be merged, but no one has replied.

Sorry for the delay on my part, I will review this in the next week. Hopefully by the end of this week once my thesis is turned in. I need to also see what is exactly going on with the finufft package, maybe we can set an upper bound to avoid the breaking changes they introduce in the latest release

Comment thread desc/integrals/_interp_utils.py
unalmis added 7 commits April 2, 2026 01:35
- [x] resolves #1303 .
- [x] Activates newton step to find bounce points.
- It can be shown there exists O(sqrt(epsilon)) error for bounce
integrals with 1/v_ll where epsilon is error of bounce point. For v_||
integrals error is conveniently O(epsilon times sqrt(epsilon)). Hence,
the `spline` method would require thousands of knots per transit for
just a couple digits of accuracy, and it would stop convergence at
epsilon<=1e-5 (so sqrt(epsilon) <=3 digits) error due to condition
number. Of course, the `spline=False` method has always computed the
points with spectral accuracy and has very fast convergence after #1919
; that method converges to epsilon = machine precision without the
newton step. With the newton step, fast convergence is achieved with the
spline method as well.
  - I suspect fast ion confinement optimization will be easier now.

I did this on a couple lunch breaks, so it would be very weird if
clicking the approve button to merge into `master` took a year.

## Benchmarks
Here is a timing benchmark on my CPU with `nufft_eps=1e-6`. Prior to
this PR, every adjoint call to nufft1 took `>= 1 second` and the full
computation was `34 seconds`. Now every adjoint call to nufft1 is `250
milliseconds`, and the full computation is `14 seconds`. These
`improvements become larger` as the `sparsity` grows and error tolerance
parameter for the nuffts epsilon tends to 0. Likewise, the _improvement_
grows linearly with the `problem size`. As this is called within an
optimization loop where time and memory are tight, the improvement is
significant.

### Before
<img width="510" height="135" alt="Screenshot From 2026-03-29 15-22-56"
src="https://github.com/user-attachments/assets/0baf9a88-b775-4a9c-bb39-db568099aae7"
/>
<img width="510" height="135" alt="Screenshot From 2026-03-29 15-22-36"
src="https://github.com/user-attachments/assets/2bc1f730-56af-43ad-9c5e-45e9991c0bd1"
/>

### After
<img width="510" height="135" alt="Screenshot From 2026-03-29 15-12-35"
src="https://github.com/user-attachments/assets/306a6c34-32a3-46e3-8b8c-c4f259a32169"
/>
<img width="510" height="135" alt="Screenshot From 2026-03-29 15-15-59"
src="https://github.com/user-attachments/assets/db12ecb3-210a-4732-a5fc-666ba47475d8"
/>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug fix Something was fixed P∞ P_infty. Ready to merge > 1 years. Top priority to merge to prevent further delay of research. performance New feature or request to make the code faster robustness Make the code more robust run_benchmarks Run timing benchmarks on this PR against current master branch stable Awaiting merge to master. Only updates will be merging from master.

Projects

None yet

5 participants