Conversation
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Memory benchmark result| Test Name | %Δ | Master (MB) | PR (MB) | Δ (MB) | Time PR (s) | Time Master (s) |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
test_objective_jac_w7x | 1.38 % | 4.202e+03 | 4.260e+03 | 58.14 | 41.77 | 38.76 |
test_proximal_jac_w7x_with_eq_update | -2.29 % | 6.584e+03 | 6.433e+03 | -151.00 | 158.32 | 157.41 |
test_proximal_freeb_jac | -0.19 % | 1.340e+04 | 1.338e+04 | -24.80 | 86.52 | 87.00 |
test_proximal_freeb_jac_blocked | 0.53 % | 7.736e+03 | 7.777e+03 | 41.04 | 76.38 | 76.48 |
test_proximal_freeb_jac_batched | 0.50 % | 7.669e+03 | 7.707e+03 | 38.25 | 75.87 | 77.10 |
test_proximal_jac_ripple | -3.71 % | 3.733e+03 | 3.594e+03 | -138.53 | 61.75 | 65.36 |
test_proximal_jac_ripple_bounce1d | 1.89 % | 3.855e+03 | 3.927e+03 | 72.68 | 74.60 | 76.63 |
test_eq_solve | 0.23 % | 2.222e+03 | 2.227e+03 | 5.07 | 98.03 | 97.03 |For the memory plots, go to the summary of |
a6d949b to
d685405
Compare
|
Someone else who is still working on desc will need to maintain this PR from now on. |
|
@dpanici @YigitElma @f0uriest @ddudt @rahulgaur104 . I requested an ETA 3 months ago for when this would be merged, but no one has replied. |
|
I don't like having a new dependency. We already have many problems with JAX and related packages. I don't feel comfortable approving the current state of the PR. |
|
I am not blocking it. If others approve it, you can merge it; that is fine. I don't have to approve all PRs. I've reviewed it multiple times and spent a couple of days on it. |
Sorry for the delay on my part, I will review this in the next week. Hopefully by the end of this week once my thesis is turned in. I need to also see what is exactly going on with the finufft package, maybe we can set an upper bound to avoid the breaking changes they introduce in the latest release |
- [x] resolves #1303 . - [x] Activates newton step to find bounce points. - It can be shown there exists O(sqrt(epsilon)) error for bounce integrals with 1/v_ll where epsilon is error of bounce point. For v_|| integrals error is conveniently O(epsilon times sqrt(epsilon)). Hence, the `spline` method would require thousands of knots per transit for just a couple digits of accuracy, and it would stop convergence at epsilon<=1e-5 (so sqrt(epsilon) <=3 digits) error due to condition number. Of course, the `spline=False` method has always computed the points with spectral accuracy and has very fast convergence after #1919 ; that method converges to epsilon = machine precision without the newton step. With the newton step, fast convergence is achieved with the spline method as well. - I suspect fast ion confinement optimization will be easier now. I did this on a couple lunch breaks, so it would be very weird if clicking the approve button to merge into `master` took a year. ## Benchmarks Here is a timing benchmark on my CPU with `nufft_eps=1e-6`. Prior to this PR, every adjoint call to nufft1 took `>= 1 second` and the full computation was `34 seconds`. Now every adjoint call to nufft1 is `250 milliseconds`, and the full computation is `14 seconds`. These `improvements become larger` as the `sparsity` grows and error tolerance parameter for the nuffts epsilon tends to 0. Likewise, the _improvement_ grows linearly with the `problem size`. As this is called within an optimization loop where time and memory are tight, the improvement is significant. ### Before <img width="510" height="135" alt="Screenshot From 2026-03-29 15-22-56" src="https://github.com/user-attachments/assets/0baf9a88-b775-4a9c-bb39-db568099aae7" /> <img width="510" height="135" alt="Screenshot From 2026-03-29 15-22-36" src="https://github.com/user-attachments/assets/2bc1f730-56af-43ad-9c5e-45e9991c0bd1" /> ### After <img width="510" height="135" alt="Screenshot From 2026-03-29 15-12-35" src="https://github.com/user-attachments/assets/306a6c34-32a3-46e3-8b8c-c4f259a32169" /> <img width="510" height="135" alt="Screenshot From 2026-03-29 15-15-59" src="https://github.com/user-attachments/assets/db12ecb3-210a-4732-a5fc-666ba47475d8" />
Inverse stream maps
Yparameter for Heliotron (NFP=19) in effective ripple tutorial #1928 and increases default objective resolution forY_B.Improvements
low_rammode which is same speed and less memory forobjective.compute, but slower forobjective.gradsince JAX is poor at iterative algorithms.interp_to_argminforBounce2Dfrom fourth order to spectral as required for Alpert quadrature.Usability
kwargsas was needed for my article.interpax#1388.Bugs
test_compute_everything.Benchmarks
Just go to #2026 and run
effective_ripple_profile.py. You will see the large performance improvement frommaster. The CI benchmarks do not reveal this because those benchmarks are essentially just noise. Note that, using the same parameter inputs, the resolution of this branch is also higher thanmasterdue to the faster convergence.use_bounce1d=Trueon that script, you will run out of memory as expected since it is an inferior approach (as expected, you get the OOM in the jacobian before you compute a single bounce integral).nufft_eps=0, you need 175 GB to run that script onmaster(you'll get an OOM and JAX will tell you it needs 175GB), but only 34 GB on this branch.Examples
HELIOTRON
MasterbranchThis branch
W7-X
MasterbranchThis branch
NCSX
MasterbranchThis branch
Removal of spectral aliasing
Dynamic shapes
splinemethod would require thousands of knots per transit for just a couple digits of accuracy, and it would stop convergence at epsilon<=1e-5 (so sqrt(epsilon) <=3 digits) error due to condition number. Of course, thespline=Falsemethod has always computed the points with spectral accuracy and has very fast convergence after New inverse stream map to accelerate convergence #1919 ; that method converges to epsilon = machine precision without the newton step. With the newton step, fast convergence is achieved with the spline method as well.Benchmarks
Here is a timing benchmark on my CPU with
nufft_eps=1e-6. Prior to this PR, every adjoint call to nufft1 took>= 1 secondand the full computation was34 seconds. Now every adjoint call to nufft1 is250 milliseconds, and the full computation is14 seconds. Theseimprovements become largeras thesparsitygrows and error tolerance parameter for the nuffts epsilon tends to 0. Likewise, the improvement grows linearly with theproblem size. As this is called within an optimization loop where time and memory are tight, the improvement is significant.Before
After