From 39924dfc3529c6f7d457c4cf07363bada1a469a3 Mon Sep 17 00:00:00 2001 From: Colin Gilgenbach Date: Fri, 27 Mar 2026 10:45:19 -0400 Subject: [PATCH 1/4] Gradient engine speed improvements --- phaser/engines/common/simulation.py | 13 +++-- phaser/engines/gradient/run.py | 87 +++++++++++++++++++---------- phaser/hooks/solver.py | 2 +- phaser/plan.py | 16 +++++- phaser/state.py | 7 +++ phaser/utils/_jax_kernels.py | 7 ++- phaser/utils/num.py | 84 +++++++++++++++++++++------- 7 files changed, 156 insertions(+), 60 deletions(-) diff --git a/phaser/engines/common/simulation.py b/phaser/engines/common/simulation.py index a6e15b6..ec50bbc 100644 --- a/phaser/engines/common/simulation.py +++ b/phaser/engines/common/simulation.py @@ -265,7 +265,8 @@ def cutout_group( def slice_forwards( props: t.Optional[NDArray[numpy.complexfloating]], state: StateT, - f: t.Callable[[int, t.Optional[NDArray[numpy.complexfloating]], StateT], StateT] + f: t.Callable[[int, t.Optional[NDArray[numpy.complexfloating]], StateT], StateT], *, + jit_unroll_slices: t.Union[int, bool] = False, ) -> StateT: if props is None: return f(0, None, state) @@ -278,7 +279,7 @@ def step_fn(carry, slice_i): new_state = f(slice_i, props[slice_i], carry) return new_state, None - state, _ = jax.lax.scan(step_fn, state, jax.numpy.arange(n_slices - 1)) + state, _ = jax.lax.scan(step_fn, state, jax.numpy.arange(n_slices - 1), unroll=jit_unroll_slices) return f(n_slices - 1, None, state) # fallback numpy mode @@ -290,7 +291,8 @@ def step_fn(carry, slice_i): def slice_backwards( props: t.Optional[NDArray[numpy.complexfloating]], state: StateT, - f: t.Callable[[int, t.Optional[NDArray[numpy.complexfloating]], StateT], StateT] + f: t.Callable[[int, t.Optional[NDArray[numpy.complexfloating]], StateT], StateT], + jit_unroll_slices: t.Union[int, bool] = False, ) -> StateT: if props is None: return f(0, None, state) @@ -299,7 +301,10 @@ def slice_backwards( if is_jax(props): import jax - state = jax.lax.fori_loop(1, n_slices, lambda i, state: f(n_slices - i, props[n_slices - i - 1], state), state, unroll=False) + state = jax.lax.fori_loop( + 1, n_slices, lambda i, state: f(n_slices - i, props[n_slices - i - 1], state), state, + unroll=jit_unroll_slices + ) return f(0, None, state) for slice_i in range(n_slices - 1, 0, -1): diff --git a/phaser/engines/gradient/run.py b/phaser/engines/gradient/run.py index 7482f1d..d29c68c 100644 --- a/phaser/engines/gradient/run.py +++ b/phaser/engines/gradient/run.py @@ -9,7 +9,8 @@ from phaser.hooks.solver import NoiseModel from phaser.utils.num import ( assert_dtype, get_array_module, cast_array_module, jit, to_numpy, - fft2, ifft2, abs2, check_finite, at, Float, to_complex_dtype, to_real_dtype + fft2, ifft2, fft2shift, ifft2shift, abs2, at, Float, + to_complex_dtype, to_real_dtype, ) import phaser.utils.tree as tree from phaser.utils.optics import fourier_shift_filter @@ -157,10 +158,6 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState: observer: Observer = args.get('observer', Observer()) state = args['state'] seed = args['seed'] - patterns = args['data'].patterns - pattern_mask = xp.array(args['data'].pattern_mask) - assert_dtype(patterns, dtype) - assert_dtype(pattern_mask, dtype) noise_model = props.noise_model(None) @@ -186,13 +183,36 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState: ) start_i = int(state.iter.total_iter) + # check patterns dtype + assert_dtype(args['data'].patterns, dtype) + assert_dtype(args['data'].pattern_mask, dtype) + # load pattern mask + pattern_mask = xp.array(args['data'].pattern_mask) + + # and load/stream patterns + if props.buffer_n_groups is None: + logging.info("Loading raw data to GPU ('buffer_n_groups' is disabled)...") + patterns = xp.array(args['data'].patterns) + else: + logging.info(f"Streaming raw data to GPU (buffering {props.buffer_n_groups} groups)") + patterns = args['data'].patterns + + def iter_patterns(groups: t.Iterable[NDArray[numpy.int_]]) -> t.Iterable[t.Tuple[NDArray[numpy.int_], NDArray[numpy.floating]]]: + if props.buffer_n_groups is None: + return ((group, patterns[tuple(xp.asarray(group))]) for group in groups) + return stream_patterns( + groups, patterns, xp=xp, buf_n=props.buffer_n_groups + ) + propagators = make_propagators(state, props.bwlim_frac) # runs rescaling rescale_factors = [] - for (group_i, (group, group_patterns)) in enumerate(stream_patterns(groups.iter(state.scan), - patterns, xp=xp, buf_n=props.buffer_n_groups)): - group_rescale_factors = dry_run(state, group, propagators, group_patterns, xp=xp, dtype=dtype) + for (group_i, (group, group_patterns)) in enumerate(iter_patterns(groups.iter(state.scan))): + group_rescale_factors = dry_run( + state, group, propagators, group_patterns, + xp=xp, dtype=dtype, + ) rescale_factors.append(group_rescale_factors) rescale_factors = xp.concatenate(rescale_factors, axis=0) @@ -252,8 +272,7 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState: for (solver, solver_state) in zip(iter_solvers, iter_solver_states) ] - for (group_i, (group, group_patterns)) in enumerate(stream_patterns(groups.iter(state.scan, i, iter_shuffle_groups), - patterns, xp=xp, buf_n=props.buffer_n_groups)): + for (group_i, (group, group_patterns)) in enumerate(iter_patterns(groups.iter(state.scan, i, iter_shuffle_groups))): (state, group_losses, iter_grads, solver_states) = run_group( state, group=group, vars=iter_vars, noise_model=noise_model, @@ -266,12 +285,15 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState: group_patterns=group_patterns, #load_group(group), pattern_mask=pattern_mask, probe_int=probe_int, - xp=xp, dtype=dtype + xp=xp, dtype=dtype, + jit_unroll_slices=props.jit_unroll_slices, ) - losses = tree.map(xp.add, losses, group_losses) - check_finite(state.object.data, state.probe.data, context=f"object or probe, group {group_i}") + # we only need to check for nan in total_loss + if not xp.isfinite(losses['total_loss']): + raise ValueError(f"NaN or inf encountered, group {group_i}") + observer.update_group(state, props.send_every_group) # report losses normalized by # of probe positions @@ -326,7 +348,7 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState: @partial( jit, - static_argnames=('vars', 'xp', 'dtype', 'noise_model', 'group_solvers', 'group_constraints', 'regularizers'), + static_argnames=('vars', 'xp', 'dtype', 'noise_model', 'group_solvers', 'group_constraints', 'regularizers', 'jit_unroll_slices'), donate_argnames=('state', 'iter_grads', 'solver_states'), ) def run_group( @@ -345,6 +367,7 @@ def run_group( probe_int: t.Union[float, numpy.floating], xp: t.Any, dtype: t.Type[numpy.floating], + jit_unroll_slices: t.Union[int, bool], ) -> t.Tuple[ReconsState, t.Dict[str, Float], t.Dict[ReconsVar, t.Any], SolverStates]: xp = cast_array_module(xp) @@ -352,7 +375,7 @@ def run_group( *extract_vars(state, vars, group), group=group, props=props, group_patterns=group_patterns, pattern_mask=pattern_mask, noise_model=noise_model, regularizers=regularizers, solver_states=solver_states, - xp=xp, dtype=dtype + xp=xp, dtype=dtype, jit_unroll_slices=jit_unroll_slices ) for k in grad.keys(): # scale gradients appropriately @@ -385,7 +408,7 @@ def run_group( @partial( jit, - static_argnames=('xp', 'dtype', 'noise_model', 'regularizers'), + static_argnames=('xp', 'dtype', 'noise_model', 'regularizers', 'jit_unroll_slices'), donate_argnames=('solver_states',), ) def run_model( @@ -400,6 +423,7 @@ def run_model( solver_states: SolverStates, xp: t.Any, dtype: t.Type[numpy.floating], + jit_unroll_slices: t.Union[int, bool], ) -> t.Tuple[Float, t.Tuple[SolverStates, t.Dict[str, Float]]]: # apply vars to simulation sim = insert_vars(vars, sim, group) @@ -409,23 +433,23 @@ def run_model( (ky, kx) = sim.probe.sampling.recip_grid(dtype=dtype, xp=xp) xp = get_array_module(sim.probe.data) dtype = to_real_dtype(sim.probe.data.dtype) - #complex_dtype = to_complex_dtype(dtype) - probes = sim.probe.data - group_obj = sim.object.sampling.get_view_at_pos(sim.object.data, group_scan, probes.shape[-2:]) - group_subpx_filters = fourier_shift_filter(ky, kx, sim.object.sampling.get_subpx_shifts(group_scan, probes.shape[-2:]))[:, None, ...] - probes = ifft2(fft2(probes) * group_subpx_filters) + # preshift probe and object + probes = ifft2shift(sim.probe.data) + group_obj = ifft2shift(sim.object.sampling.get_view_at_pos(sim.object.data, group_scan, probes.shape[-2:])) + group_subpx_filters = fourier_shift_filter(ky, kx, sim.object.sampling.get_subpx_shifts(group_scan, probes.shape[-2:])) + # (group, mode, y, x) + probes = ifft2(fft2(probes, shift=False) * group_subpx_filters[:, None], shift=False) def sim_slice(slice_i: int, prop: t.Optional[NDArray[numpy.complexfloating]], psi): - # psi: (batch, n_probe, Ny, Nx) if prop is not None: - return ifft2(fft2(psi * group_obj[:, slice_i, None]) * prop[:, None]) + return ifft2(fft2(psi * group_obj[:, slice_i, None], shift=False) * prop[:, None], shift=False) return psi * group_obj[:, slice_i, None] t_props = tilt_propagators(ky, kx, sim, props, group_tilts) - model_wave = fft2(slice_forwards(t_props, probes, sim_slice)) - model_intensity = xp.sum(abs2(model_wave), axis=1) + model_wave = fft2(slice_forwards(t_props, probes, sim_slice, jit_unroll_slices=jit_unroll_slices), shift=False) + model_intensity = xp.sum(abs2(model_wave), axis=1) (loss, solver_states.noise_model_state) = noise_model.calc_loss( model_wave, model_intensity, group_patterns, pattern_mask, solver_states.noise_model_state ) @@ -458,19 +482,20 @@ def dry_run( dtype: t.Type[numpy.floating], ) -> NDArray[numpy.floating]: (ky, kx) = sim.probe.sampling.recip_grid(dtype=dtype, xp=xp) + group_scan = sim.scan[tuple(group)] - probes = sim.probe.data - group_obj = sim.object.sampling.get_view_at_pos(sim.object.data, sim.scan[tuple(group)], probes.shape[-2:]) - group_subpx_filters = fourier_shift_filter(ky, kx, sim.object.sampling.get_subpx_shifts(sim.scan[tuple(group)], probes.shape[-2:]))[:, None, ...] - probes = ifft2(fft2(probes) * group_subpx_filters) + probes = ifft2shift(sim.probe.data) + group_obj = ifft2shift(sim.object.sampling.get_view_at_pos(sim.object.data, group_scan, probes.shape[-2:])) + group_subpx_filters = fourier_shift_filter(ky, kx, sim.object.sampling.get_subpx_shifts(group_scan, probes.shape[-2:])) + probes = ifft2(fft2(probes, shift=False) * group_subpx_filters[:, None], shift=False) def sim_slice(slice_i: int, prop: t.Optional[NDArray[numpy.complexfloating]], psi): if prop is not None: - return ifft2(fft2(psi * group_obj[:, slice_i, None]) * prop[:, None]) + return ifft2(fft2(psi * group_obj[:, slice_i, None], shift=False) * prop[:, None], shift=False) return psi * group_obj[:, slice_i, None] t_props = tilt_propagators(ky, kx, sim, props, sim.tilt[tuple(group)] if sim.tilt is not None else None) - model_wave = fft2(slice_forwards(t_props, probes, sim_slice)) + model_wave = fft2(slice_forwards(t_props, probes, sim_slice), shift=False) model_intensity = xp.sum(abs2(model_wave), axis=(1, -2, -1)) exp_intensity = xp.sum(group_patterns, axis=(-2, -1)) diff --git a/phaser/hooks/solver.py b/phaser/hooks/solver.py index 4b402d0..0b87b4e 100644 --- a/phaser/hooks/solver.py +++ b/phaser/hooks/solver.py @@ -162,7 +162,7 @@ def update( class GradientSolverArgs(t.TypedDict): - plan: 'GradientEnginePlan' + plan: t.Optional['GradientEnginePlan'] params: t.Iterable[ReconsVar] diff --git a/phaser/plan.py b/phaser/plan.py index 8690eae..920f5e6 100644 --- a/phaser/plan.py +++ b/phaser/plan.py @@ -57,7 +57,21 @@ class EnginePlan(Dataclass, kw_only=True): grouping: t.Optional[int] = None compact: bool = False shuffle_groups: t.Optional[FlagLike] = None - buffer_n_groups: int = 2 + buffer_n_groups: t.Optional[int] = 2 + """ + How many groups of patterns to buffer onto the device simultaneously. + Set to 0 to disable buffering, or `None` (`~` in YAML) to preload the + entire dataset to the device. + """ + + jit_unroll_slices: t.Union[int, bool] = 10 + """ + Slices to unroll during JIT compilation (JAX backend only). + Larger unrolling may be faster, at the expense of increased compilation time. + + `True` or `0` unrolls all slices, `False` or `1` disables unrolling. + `10` should be a good default value. + """ update_probe: FlagLike = True update_object: FlagLike = True diff --git a/phaser/state.py b/phaser/state.py index 5502a74..09bdd9e 100644 --- a/phaser/state.py +++ b/phaser/state.py @@ -243,6 +243,13 @@ class PreparedRecons: name: str observer: 'ObserverSet' + def to_xp(self, xp: t.Any) -> Self: + return self.__class__( + self.patterns, + self.state.to_xp(xp), + self.name, self.observer, + ) + def to_numpy(self) -> Self: return self.__class__( self.patterns.to_numpy(), self.state.to_numpy(), self.name, self.observer diff --git a/phaser/utils/_jax_kernels.py b/phaser/utils/_jax_kernels.py index 621f1db..3c19f91 100644 --- a/phaser/utils/_jax_kernels.py +++ b/phaser/utils/_jax_kernels.py @@ -54,8 +54,11 @@ def inner(obj: jax.Array, v: t.Tuple[jax.Array, jax.Array]) -> t.Tuple[jax.Array @partial(jax.jit, static_argnums=2) def get_cutouts(obj: jax.Array, start_idxs: jax.Array, cutout_shape: t.Tuple[int, int]) -> jax.Array: - return jax.vmap(jax.vmap(lambda start_idx, obj: jax.lax.dynamic_slice(obj, start_idx, cutout_shape), (None, 0)), (0, None))( - to_2d(start_idxs), to_3d(obj) + idxs = to_2d(start_idxs) + ys = idxs[:, 0:1] + jax.numpy.arange(cutout_shape[0]) + xs = idxs[:, 1:2] + jax.numpy.arange(cutout_shape[1]) + return jax.numpy.swapaxes( + to_3d(obj)[..., ys[:, :, None], xs[:, None, :]], 0, 1 ).reshape((*start_idxs.shape[:-1], *obj.shape[:-2], *cutout_shape)) diff --git a/phaser/utils/num.py b/phaser/utils/num.py index 3a82158..809b3a9 100644 --- a/phaser/utils/num.py +++ b/phaser/utils/num.py @@ -587,60 +587,101 @@ def to_real_dtype(dtype: DTypeLike) -> t.Type[numpy.floating]: @t.overload -def ifft2(a: t.Union[NDArray[numpy.float64], NDArray[numpy.complex128]]) -> NDArray[numpy.complex128]: +def ifft2(a: t.Union[NDArray[numpy.float64], NDArray[numpy.complex128]], *, shift: bool = True) -> NDArray[numpy.complex128]: ... @t.overload -def ifft2(a: t.Union[NDArray[numpy.float32], NDArray[numpy.complex64]]) -> NDArray[numpy.complex64]: +def ifft2(a: t.Union[NDArray[numpy.float32], NDArray[numpy.complex64]], *, shift: bool = True) -> NDArray[numpy.complex64]: ... @t.overload -def ifft2(a: NDArray[NumT]) -> NDArray[numpy.complexfloating]: +def ifft2(a: NDArray[numpy.number], *, shift: bool = True) -> NDArray[numpy.complexfloating]: ... @t.overload -def ifft2(a: ArrayLike) -> NDArray[numpy.complexfloating]: +def ifft2(a: ArrayLike, *, shift: bool = True) -> NDArray[numpy.complexfloating]: ... -def ifft2(a: ArrayLike) -> NDArray[numpy.complexfloating]: +def ifft2(a: ArrayLike, *, shift: bool = True) -> NDArray[numpy.complexfloating]: """ Perform an inverse FFT on the last two axes of `a`. - - Follows our convention of centering real space and normalizing intensities. - """ + Follows our convention of centering real space and normalizing intensities (when `shift` is `True`). + """ xp = get_array_module(a) - if xp_is_torch(xp): - return xp.fft.fftshift(xp.fft.ifft2(a, norm='ortho'), dim=(-2, -1)) # type: ignore - return xp.fft.fftshift(xp.fft.ifft2(a, norm='ortho'), axes=(-2, -1)) + if shift: + if xp_is_torch(xp): + return xp.fft.fftshift(xp.fft.ifft2(a, norm='ortho'), dim=(-2, -1)) # type: ignore + return xp.fft.fftshift(xp.fft.ifft2(a, norm='ortho'), axes=(-2, -1)) + return xp.fft.ifft2(a, norm='ortho') + @t.overload -def fft2(a: t.Union[NDArray[numpy.float64], NDArray[numpy.complex128]]) -> NDArray[numpy.complex128]: +def fft2(a: t.Union[NDArray[numpy.float64], NDArray[numpy.complex128]], *, shift: bool = True) -> NDArray[numpy.complex128]: ... @t.overload -def fft2(a: t.Union[NDArray[numpy.float32], NDArray[numpy.complex64]]) -> NDArray[numpy.complex64]: +def fft2(a: t.Union[NDArray[numpy.float32], NDArray[numpy.complex64]], *, shift: bool = True) -> NDArray[numpy.complex64]: ... @t.overload -def fft2(a: NDArray[NumT]) -> NDArray[numpy.complexfloating]: +def fft2(a: NDArray[numpy.number], *, shift: bool = True) -> NDArray[numpy.complexfloating]: ... @t.overload -def fft2(a: ArrayLike) -> NDArray[numpy.complexfloating]: +def fft2(a: ArrayLike, *, shift: bool = True) -> NDArray[numpy.complexfloating]: ... -def fft2(a: ArrayLike) -> NDArray[numpy.complexfloating]: +def fft2(a: ArrayLike, *, shift: bool = True) -> NDArray[numpy.complexfloating]: """ Perform a forward FFT on the last two axes of `a`. - Follows our convention of centering real space and normalizing intensities. + Follows our convention of centering real space and normalizing intensities (when `shift` is `True`).. """ + xp = get_array_module(a) + if shift: + if xp_is_torch(xp): + return xp.fft.fft2(xp.fft.ifftshift(a, dim=(-2, -1)), norm='ortho') # type: ignore + return xp.fft.fft2(xp.fft.ifftshift(a, axes=(-2, -1)), norm='ortho') + return xp.fft.fft2(a, norm='ortho') + +@t.overload +def fft2shift(a: NDArray[NumT]) -> NDArray[NumT]: + ... + +@t.overload +def fft2shift(a: ArrayLike) -> NDArray[t.Any]: + ... + +def fft2shift(a: ArrayLike) -> NDArray[t.Any]: + """ + FFT-shift the last two axes of `a`. + + Shifts the zero-frequency component to the center of the image + (i.e. this is needed to center realspace after an inverse transform). + """ xp = get_array_module(a) - if xp_is_torch(xp): - return xp.fft.fft2(xp.fft.ifftshift(a, dim=(-2, -1)), norm='ortho') # type: ignore - return xp.fft.fft2(xp.fft.ifftshift(a, axes=(-2, -1)), norm='ortho') + return xp.fft.fftshift(a, axes=(-2, -1)) + + +@t.overload +def ifft2shift(a: NDArray[NumT]) -> NDArray[NumT]: + ... + +@t.overload +def ifft2shift(a: ArrayLike) -> NDArray[t.Any]: + ... + +def ifft2shift(a: ArrayLike) -> NDArray[t.Any]: + """ + Inverse FFT-shift the last two axes of `a`. + + Shifts the zero-frequency component to the corner of the image + (i.e. this is needed to corner realspace before a forward transform). + """ + xp = get_array_module(a) + return xp.fft.ifftshift(a, axes=(-2, -1)) def split_array(arr: NDArray[DTypeT], axis: int = 0, *, keepdims: bool = False) -> t.Tuple[NDArray[DTypeT], ...]: @@ -1037,7 +1078,8 @@ def at(arr: NDArray[DTypeT], idx: IndexLike) -> _AtImpl[DTypeT]: 'is_cupy', 'is_jax', 'xp_is_cupy', 'xp_is_jax', 'jit', 'fuse', 'debug_callback', 'to_complex_dtype', 'to_real_dtype', - 'fft2', 'ifft2', 'abs2', 'split_array', 'unstack', + 'fft2', 'ifft2', 'fft2shift', 'ifft2shift', + 'abs2', 'split_array', 'unstack', 'at', 'ufunc_outer', 'check_finite', 'Sampling', 'IndexLike', ] \ No newline at end of file From 7db6ad1a51af8df7458f0d086343adc8d03eb00b Mon Sep 17 00:00:00 2001 From: Colin Gilgenbach Date: Fri, 27 Mar 2026 14:41:50 -0400 Subject: [PATCH 2/4] Some more speedup --- phaser/engines/common/simulation.py | 3 ++- phaser/engines/gradient/run.py | 35 ++++++++++++++++------------- phaser/plan.py | 3 ++- phaser/utils/num.py | 8 +++++++ 4 files changed, 32 insertions(+), 17 deletions(-) diff --git a/phaser/engines/common/simulation.py b/phaser/engines/common/simulation.py index ec50bbc..02fe937 100644 --- a/phaser/engines/common/simulation.py +++ b/phaser/engines/common/simulation.py @@ -75,7 +75,6 @@ def stream_patterns( while len(buf) > 0: (group, group_patterns) = buf.popleft() - yield group, block_until_ready(group_patterns) # attempt to feed queue try: @@ -84,6 +83,8 @@ def stream_patterns( except StopIteration: continue + yield group, block_until_ready(group_patterns) + @tree_dataclass(init=False, static_fields=('xp', 'dtype', 'noise_model', 'group_constraints', 'iter_constraints'), drop_fields=('ky', 'kx')) class SimulationState: diff --git a/phaser/engines/gradient/run.py b/phaser/engines/gradient/run.py index d29c68c..489ac37 100644 --- a/phaser/engines/gradient/run.py +++ b/phaser/engines/gradient/run.py @@ -10,7 +10,7 @@ from phaser.utils.num import ( assert_dtype, get_array_module, cast_array_module, jit, to_numpy, fft2, ifft2, fft2shift, ifft2shift, abs2, at, Float, - to_complex_dtype, to_real_dtype, + to_complex_dtype, to_real_dtype, block_until_ready, ) import phaser.utils.tree as tree from phaser.utils.optics import fourier_shift_filter @@ -187,12 +187,12 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState: assert_dtype(args['data'].patterns, dtype) assert_dtype(args['data'].pattern_mask, dtype) # load pattern mask - pattern_mask = xp.array(args['data'].pattern_mask) + pattern_mask = xp.asarray(args['data'].pattern_mask) # and load/stream patterns if props.buffer_n_groups is None: logging.info("Loading raw data to GPU ('buffer_n_groups' is disabled)...") - patterns = xp.array(args['data'].patterns) + patterns = xp.asarray(args['data'].patterns) else: logging.info(f"Streaming raw data to GPU (buffering {props.buffer_n_groups} groups)") patterns = args['data'].patterns @@ -259,7 +259,7 @@ def iter_patterns(groups: t.Iterable[NDArray[numpy.int_]]) -> t.Iterable[t.Tuple iter_shuffle_groups = shuffle_groups({'state': state, 'niter': props.niter}) # accumulated losses across groups - losses: t.Dict[str, float] = {k: 0.0 for k in loss_keys} + losses_gpu = {k: t.cast(numpy.floating, xp.array(0.0)) for k in loss_keys} # update schedules for this iteration # this needs to be done outside the JIT context, which makes this kinda hacky @@ -273,12 +273,16 @@ def iter_patterns(groups: t.Iterable[NDArray[numpy.int_]]) -> t.Iterable[t.Tuple ] for (group_i, (group, group_patterns)) in enumerate(iter_patterns(groups.iter(state.scan, i, iter_shuffle_groups))): - (state, group_losses, iter_grads, solver_states) = run_group( + # prevent the loop running ahead of the GPU stream + block_until_ready(losses_gpu['total_loss']) + + (state, losses_gpu, iter_grads, solver_states) = run_group( state, group=group, vars=iter_vars, noise_model=noise_model, group_solvers=group_solvers, group_constraints=group_constraints, regularizers=regularizers, + losses=losses_gpu, iter_grads=iter_grads, solver_states=solver_states, props=propagators, @@ -288,16 +292,16 @@ def iter_patterns(groups: t.Iterable[NDArray[numpy.int_]]) -> t.Iterable[t.Tuple xp=xp, dtype=dtype, jit_unroll_slices=props.jit_unroll_slices, ) - losses = tree.map(xp.add, losses, group_losses) - - # we only need to check for nan in total_loss - if not xp.isfinite(losses['total_loss']): + if props.check_every_group and not numpy.isfinite(float(losses_gpu['total_loss'])): raise ValueError(f"NaN or inf encountered, group {group_i}") - observer.update_group(state, props.send_every_group) + if not numpy.isfinite(float(losses_gpu['total_loss'])): + raise ValueError(f"NaN or inf encountered, iteration {i}") + # report losses normalized by # of probe positions - losses = tree.map(lambda v: float(v / groups.n_pos), losses) + # this also moves losses to CPU + losses: t.Dict[str, float] = tree.map(lambda v: float(v) / groups.n_pos, losses_gpu) for (k, v) in losses.items(): progress[k].iters.append(i + start_i) progress[k].values.append(v) @@ -325,7 +329,6 @@ def iter_patterns(groups: t.Iterable[NDArray[numpy.int_]]) -> t.Iterable[t.Tuple progress['tilt_update_rms'].values.append(tilt_update_rms) logger.info(f"Tilt update: mean {mean_tilt_update} mrad") # average tilt update, [y, x] - for (reg_i, reg) in enumerate(iter_constraints): (state, iter_constraint_states[reg_i]) = reg.apply_iter( state, iter_constraint_states[reg_i] @@ -359,6 +362,7 @@ def run_group( group_solvers: t.Sequence[GradientSolver[t.Any]], group_constraints: t.Sequence[GroupConstraint[t.Any]], regularizers: t.Sequence[CostRegularizer[t.Any]], + losses: t.Dict[str, numpy.floating], iter_grads: t.Dict[ReconsVar, t.Any], solver_states: SolverStates, props: t.Optional[NDArray[numpy.complexfloating]], @@ -368,10 +372,10 @@ def run_group( xp: t.Any, dtype: t.Type[numpy.floating], jit_unroll_slices: t.Union[int, bool], -) -> t.Tuple[ReconsState, t.Dict[str, Float], t.Dict[ReconsVar, t.Any], SolverStates]: +) -> t.Tuple[ReconsState, t.Dict[str, numpy.floating], t.Dict[ReconsVar, t.Any], SolverStates]: xp = cast_array_module(xp) - (grad, (solver_states, losses)) = tree.grad(run_model, has_aux=True, xp=xp, sign=-1)( + (grad, (solver_states, group_losses)) = tree.grad(run_model, has_aux=True, xp=xp, sign=-1)( *extract_vars(state, vars, group), group=group, props=props, group_patterns=group_patterns, pattern_mask=pattern_mask, noise_model=noise_model, regularizers=regularizers, solver_states=solver_states, @@ -394,7 +398,7 @@ def run_group( if len(solver_grads) == 0: continue (update, solver_states.group_solver_states[sol_i]) = solver.update( - state, solver_states.group_solver_states[sol_i], solver_grads, losses['total_loss'] + state, solver_states.group_solver_states[sol_i], solver_grads, group_losses['total_loss'] ) state = apply_update(state, update) @@ -403,6 +407,7 @@ def run_group( group, state, solver_states.group_constraint_states[reg_i] ) + losses = tree.map(xp.add, losses, group_losses) return (state, losses, iter_grads, solver_states) diff --git a/phaser/plan.py b/phaser/plan.py index 920f5e6..0e68ad6 100644 --- a/phaser/plan.py +++ b/phaser/plan.py @@ -64,7 +64,7 @@ class EnginePlan(Dataclass, kw_only=True): entire dataset to the device. """ - jit_unroll_slices: t.Union[int, bool] = 10 + jit_unroll_slices: t.Union[bool, int] = 10 """ Slices to unroll during JIT compilation (JAX backend only). Larger unrolling may be faster, at the expense of increased compilation time. @@ -94,6 +94,7 @@ class EnginePlan(Dataclass, kw_only=True): (smooths over ~1/smoothing iterations) """ + check_every_group: bool = False send_every_group: bool = False diff --git a/phaser/utils/num.py b/phaser/utils/num.py index 809b3a9..7cb6b6e 100644 --- a/phaser/utils/num.py +++ b/phaser/utils/num.py @@ -388,7 +388,15 @@ def xp_is_torch(xp: t.Any) -> bool: return xp is torch +@t.overload +def block_until_ready(arr: DTypeT) -> DTypeT: + ... + +@t.overload def block_until_ready(arr: NDArray[DTypeT]) -> NDArray[DTypeT]: + ... + +def block_until_ready(arr: t.Any) -> t.Any: if hasattr(arr, 'block_until_ready'): # jax return arr.block_until_ready() # type: ignore From 42949e9267930005883d2e64c2913d50bb430a38 Mon Sep 17 00:00:00 2001 From: Colin Gilgenbach Date: Fri, 27 Mar 2026 10:45:59 -0400 Subject: [PATCH 3/4] Bump pane version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 12e5477..3d08640 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ "rich>=12.0.0,<15", "tifffile>=2023.8.25", "optree>=0.13.0", - "py-pane==0.11.3", + "py-pane==0.11.4", "typing_extensions~=4.7", ] From 8676f2e14402d755a040c7d160623d0070467c2e Mon Sep 17 00:00:00 2001 From: Colin Gilgenbach Date: Fri, 27 Mar 2026 10:53:33 -0400 Subject: [PATCH 4/4] Fix star import for phaser.types --- phaser/types.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/phaser/types.py b/phaser/types.py index c6a83b7..c28f7f3 100644 --- a/phaser/types.py +++ b/phaser/types.py @@ -387,9 +387,12 @@ def collect_errors(self, val: t.Any) -> t.Optional[ErrorNode]: __all__ = [ - 'BackendName', 'Dataclass', 'Slices', 'Flag', + 'cast_length', 'BackendName', 'ReconsVar', 'ReconsVars', + 'EmptyDict', 'EarlyTermination', 'Dataclass', + 'SliceList', 'SliceStep', 'SliceTotal', 'Slices', 'ComplexCartesian', 'ComplexPolar', - 'Krivanek', 'KrivanekComplex', 'KrivanekCartesian', - 'KrivanekPolar', 'KnownAberration', 'Aberration', - 'process_aberrations', 'process_flag', 'flag_any_true', + 'Krivanek', 'KrivanekComplex', 'KrivanekCartesian', 'KrivanekPolar', + 'KnownAberration', 'Aberration', 'process_aberrations', + 'SimpleFlag', 'process_flag', 'process_schedule', 'flag_any_true', + 'IsVersion', 'Version', ] \ No newline at end of file