Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions phaser/engines/common/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def stream_patterns(

while len(buf) > 0:
(group, group_patterns) = buf.popleft()
yield group, block_until_ready(group_patterns)

# attempt to feed queue
try:
Expand All @@ -84,6 +83,8 @@ def stream_patterns(
except StopIteration:
continue

yield group, block_until_ready(group_patterns)


@tree_dataclass(init=False, static_fields=('xp', 'dtype', 'noise_model', 'group_constraints', 'iter_constraints'), drop_fields=('ky', 'kx'))
class SimulationState:
Expand Down Expand Up @@ -265,7 +266,8 @@ def cutout_group(
def slice_forwards(
props: t.Optional[NDArray[numpy.complexfloating]],
state: StateT,
f: t.Callable[[int, t.Optional[NDArray[numpy.complexfloating]], StateT], StateT]
f: t.Callable[[int, t.Optional[NDArray[numpy.complexfloating]], StateT], StateT], *,
jit_unroll_slices: t.Union[int, bool] = False,
) -> StateT:
if props is None:
return f(0, None, state)
Expand All @@ -278,7 +280,7 @@ def step_fn(carry, slice_i):
new_state = f(slice_i, props[slice_i], carry)
return new_state, None

state, _ = jax.lax.scan(step_fn, state, jax.numpy.arange(n_slices - 1))
state, _ = jax.lax.scan(step_fn, state, jax.numpy.arange(n_slices - 1), unroll=jit_unroll_slices)
return f(n_slices - 1, None, state)

# fallback numpy mode
Expand All @@ -290,7 +292,8 @@ def step_fn(carry, slice_i):
def slice_backwards(
props: t.Optional[NDArray[numpy.complexfloating]],
state: StateT,
f: t.Callable[[int, t.Optional[NDArray[numpy.complexfloating]], StateT], StateT]
f: t.Callable[[int, t.Optional[NDArray[numpy.complexfloating]], StateT], StateT],
jit_unroll_slices: t.Union[int, bool] = False,
) -> StateT:
if props is None:
return f(0, None, state)
Expand All @@ -299,7 +302,10 @@ def slice_backwards(

if is_jax(props):
import jax
state = jax.lax.fori_loop(1, n_slices, lambda i, state: f(n_slices - i, props[n_slices - i - 1], state), state, unroll=False)
state = jax.lax.fori_loop(
1, n_slices, lambda i, state: f(n_slices - i, props[n_slices - i - 1], state), state,
unroll=jit_unroll_slices
)
return f(0, None, state)

for slice_i in range(n_slices - 1, 0, -1):
Expand Down
110 changes: 70 additions & 40 deletions phaser/engines/gradient/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from phaser.hooks.solver import NoiseModel
from phaser.utils.num import (
assert_dtype, get_array_module, cast_array_module, jit, to_numpy,
fft2, ifft2, abs2, check_finite, at, Float, to_complex_dtype, to_real_dtype
fft2, ifft2, fft2shift, ifft2shift, abs2, at, Float,
to_complex_dtype, to_real_dtype, block_until_ready,
)
import phaser.utils.tree as tree
from phaser.utils.optics import fourier_shift_filter
Expand Down Expand Up @@ -157,10 +158,6 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState:
observer: Observer = args.get('observer', Observer())
state = args['state']
seed = args['seed']
patterns = args['data'].patterns
pattern_mask = xp.array(args['data'].pattern_mask)
assert_dtype(patterns, dtype)
assert_dtype(pattern_mask, dtype)

noise_model = props.noise_model(None)

Expand All @@ -186,13 +183,36 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState:
)
start_i = int(state.iter.total_iter)

# check patterns dtype
assert_dtype(args['data'].patterns, dtype)
assert_dtype(args['data'].pattern_mask, dtype)
# load pattern mask
pattern_mask = xp.asarray(args['data'].pattern_mask)

# and load/stream patterns
if props.buffer_n_groups is None:
logging.info("Loading raw data to GPU ('buffer_n_groups' is disabled)...")
patterns = xp.asarray(args['data'].patterns)
else:
logging.info(f"Streaming raw data to GPU (buffering {props.buffer_n_groups} groups)")
patterns = args['data'].patterns

def iter_patterns(groups: t.Iterable[NDArray[numpy.int_]]) -> t.Iterable[t.Tuple[NDArray[numpy.int_], NDArray[numpy.floating]]]:
if props.buffer_n_groups is None:
return ((group, patterns[tuple(xp.asarray(group))]) for group in groups)
return stream_patterns(
groups, patterns, xp=xp, buf_n=props.buffer_n_groups
)

propagators = make_propagators(state, props.bwlim_frac)

# runs rescaling
rescale_factors = []
for (group_i, (group, group_patterns)) in enumerate(stream_patterns(groups.iter(state.scan),
patterns, xp=xp, buf_n=props.buffer_n_groups)):
group_rescale_factors = dry_run(state, group, propagators, group_patterns, xp=xp, dtype=dtype)
for (group_i, (group, group_patterns)) in enumerate(iter_patterns(groups.iter(state.scan))):
group_rescale_factors = dry_run(
state, group, propagators, group_patterns,
xp=xp, dtype=dtype,
)
rescale_factors.append(group_rescale_factors)

rescale_factors = xp.concatenate(rescale_factors, axis=0)
Expand Down Expand Up @@ -239,7 +259,7 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState:
iter_shuffle_groups = shuffle_groups({'state': state, 'niter': props.niter})

# accumulated losses across groups
losses: t.Dict[str, float] = {k: 0.0 for k in loss_keys}
losses_gpu = {k: t.cast(numpy.floating, xp.array(0.0)) for k in loss_keys}

# update schedules for this iteration
# this needs to be done outside the JIT context, which makes this kinda hacky
Expand All @@ -252,30 +272,36 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState:
for (solver, solver_state) in zip(iter_solvers, iter_solver_states)
]

for (group_i, (group, group_patterns)) in enumerate(stream_patterns(groups.iter(state.scan, i, iter_shuffle_groups),
patterns, xp=xp, buf_n=props.buffer_n_groups)):
(state, group_losses, iter_grads, solver_states) = run_group(
for (group_i, (group, group_patterns)) in enumerate(iter_patterns(groups.iter(state.scan, i, iter_shuffle_groups))):
# prevent the loop running ahead of the GPU stream
block_until_ready(losses_gpu['total_loss'])

(state, losses_gpu, iter_grads, solver_states) = run_group(
state, group=group, vars=iter_vars,
noise_model=noise_model,
group_solvers=group_solvers,
group_constraints=group_constraints,
regularizers=regularizers,
losses=losses_gpu,
iter_grads=iter_grads,
solver_states=solver_states,
props=propagators,
group_patterns=group_patterns, #load_group(group),
pattern_mask=pattern_mask,
probe_int=probe_int,
xp=xp, dtype=dtype
xp=xp, dtype=dtype,
jit_unroll_slices=props.jit_unroll_slices,
)

losses = tree.map(xp.add, losses, group_losses)

check_finite(state.object.data, state.probe.data, context=f"object or probe, group {group_i}")
if props.check_every_group and not numpy.isfinite(float(losses_gpu['total_loss'])):
raise ValueError(f"NaN or inf encountered, group {group_i}")
observer.update_group(state, props.send_every_group)

if not numpy.isfinite(float(losses_gpu['total_loss'])):
raise ValueError(f"NaN or inf encountered, iteration {i}")

# report losses normalized by # of probe positions
losses = tree.map(lambda v: float(v / groups.n_pos), losses)
# this also moves losses to CPU
losses: t.Dict[str, float] = tree.map(lambda v: float(v) / groups.n_pos, losses_gpu)
for (k, v) in losses.items():
progress[k].iters.append(i + start_i)
progress[k].values.append(v)
Expand Down Expand Up @@ -303,7 +329,6 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState:
progress['tilt_update_rms'].values.append(tilt_update_rms)
logger.info(f"Tilt update: mean {mean_tilt_update} mrad") # average tilt update, [y, x]


for (reg_i, reg) in enumerate(iter_constraints):
(state, iter_constraint_states[reg_i]) = reg.apply_iter(
state, iter_constraint_states[reg_i]
Expand All @@ -326,7 +351,7 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState:

@partial(
jit,
static_argnames=('vars', 'xp', 'dtype', 'noise_model', 'group_solvers', 'group_constraints', 'regularizers'),
static_argnames=('vars', 'xp', 'dtype', 'noise_model', 'group_solvers', 'group_constraints', 'regularizers', 'jit_unroll_slices'),
donate_argnames=('state', 'iter_grads', 'solver_states'),
)
def run_group(
Expand All @@ -337,6 +362,7 @@ def run_group(
group_solvers: t.Sequence[GradientSolver[t.Any]],
group_constraints: t.Sequence[GroupConstraint[t.Any]],
regularizers: t.Sequence[CostRegularizer[t.Any]],
losses: t.Dict[str, numpy.floating],
iter_grads: t.Dict[ReconsVar, t.Any],
solver_states: SolverStates,
props: t.Optional[NDArray[numpy.complexfloating]],
Expand All @@ -345,14 +371,15 @@ def run_group(
probe_int: t.Union[float, numpy.floating],
xp: t.Any,
dtype: t.Type[numpy.floating],
) -> t.Tuple[ReconsState, t.Dict[str, Float], t.Dict[ReconsVar, t.Any], SolverStates]:
jit_unroll_slices: t.Union[int, bool],
) -> t.Tuple[ReconsState, t.Dict[str, numpy.floating], t.Dict[ReconsVar, t.Any], SolverStates]:
xp = cast_array_module(xp)

(grad, (solver_states, losses)) = tree.grad(run_model, has_aux=True, xp=xp, sign=-1)(
(grad, (solver_states, group_losses)) = tree.grad(run_model, has_aux=True, xp=xp, sign=-1)(
*extract_vars(state, vars, group),
group=group, props=props, group_patterns=group_patterns, pattern_mask=pattern_mask,
noise_model=noise_model, regularizers=regularizers, solver_states=solver_states,
xp=xp, dtype=dtype
xp=xp, dtype=dtype, jit_unroll_slices=jit_unroll_slices
)
for k in grad.keys():
# scale gradients appropriately
Expand All @@ -371,7 +398,7 @@ def run_group(
if len(solver_grads) == 0:
continue
(update, solver_states.group_solver_states[sol_i]) = solver.update(
state, solver_states.group_solver_states[sol_i], solver_grads, losses['total_loss']
state, solver_states.group_solver_states[sol_i], solver_grads, group_losses['total_loss']
)
state = apply_update(state, update)

Expand All @@ -380,12 +407,13 @@ def run_group(
group, state, solver_states.group_constraint_states[reg_i]
)

losses = tree.map(xp.add, losses, group_losses)
return (state, losses, iter_grads, solver_states)


@partial(
jit,
static_argnames=('xp', 'dtype', 'noise_model', 'regularizers'),
static_argnames=('xp', 'dtype', 'noise_model', 'regularizers', 'jit_unroll_slices'),
donate_argnames=('solver_states',),
)
def run_model(
Expand All @@ -400,6 +428,7 @@ def run_model(
solver_states: SolverStates,
xp: t.Any,
dtype: t.Type[numpy.floating],
jit_unroll_slices: t.Union[int, bool],
) -> t.Tuple[Float, t.Tuple[SolverStates, t.Dict[str, Float]]]:
# apply vars to simulation
sim = insert_vars(vars, sim, group)
Expand All @@ -409,23 +438,23 @@ def run_model(
(ky, kx) = sim.probe.sampling.recip_grid(dtype=dtype, xp=xp)
xp = get_array_module(sim.probe.data)
dtype = to_real_dtype(sim.probe.data.dtype)
#complex_dtype = to_complex_dtype(dtype)

probes = sim.probe.data
group_obj = sim.object.sampling.get_view_at_pos(sim.object.data, group_scan, probes.shape[-2:])
group_subpx_filters = fourier_shift_filter(ky, kx, sim.object.sampling.get_subpx_shifts(group_scan, probes.shape[-2:]))[:, None, ...]
probes = ifft2(fft2(probes) * group_subpx_filters)
# preshift probe and object
probes = ifft2shift(sim.probe.data)
group_obj = ifft2shift(sim.object.sampling.get_view_at_pos(sim.object.data, group_scan, probes.shape[-2:]))
group_subpx_filters = fourier_shift_filter(ky, kx, sim.object.sampling.get_subpx_shifts(group_scan, probes.shape[-2:]))
# (group, mode, y, x)
probes = ifft2(fft2(probes, shift=False) * group_subpx_filters[:, None], shift=False)

def sim_slice(slice_i: int, prop: t.Optional[NDArray[numpy.complexfloating]], psi):
# psi: (batch, n_probe, Ny, Nx)
if prop is not None:
return ifft2(fft2(psi * group_obj[:, slice_i, None]) * prop[:, None])
return ifft2(fft2(psi * group_obj[:, slice_i, None], shift=False) * prop[:, None], shift=False)
return psi * group_obj[:, slice_i, None]

t_props = tilt_propagators(ky, kx, sim, props, group_tilts)
model_wave = fft2(slice_forwards(t_props, probes, sim_slice))
model_intensity = xp.sum(abs2(model_wave), axis=1)
model_wave = fft2(slice_forwards(t_props, probes, sim_slice, jit_unroll_slices=jit_unroll_slices), shift=False)

model_intensity = xp.sum(abs2(model_wave), axis=1)
(loss, solver_states.noise_model_state) = noise_model.calc_loss(
model_wave, model_intensity, group_patterns, pattern_mask, solver_states.noise_model_state
)
Expand Down Expand Up @@ -458,19 +487,20 @@ def dry_run(
dtype: t.Type[numpy.floating],
) -> NDArray[numpy.floating]:
(ky, kx) = sim.probe.sampling.recip_grid(dtype=dtype, xp=xp)
group_scan = sim.scan[tuple(group)]

probes = sim.probe.data
group_obj = sim.object.sampling.get_view_at_pos(sim.object.data, sim.scan[tuple(group)], probes.shape[-2:])
group_subpx_filters = fourier_shift_filter(ky, kx, sim.object.sampling.get_subpx_shifts(sim.scan[tuple(group)], probes.shape[-2:]))[:, None, ...]
probes = ifft2(fft2(probes) * group_subpx_filters)
probes = ifft2shift(sim.probe.data)
group_obj = ifft2shift(sim.object.sampling.get_view_at_pos(sim.object.data, group_scan, probes.shape[-2:]))
group_subpx_filters = fourier_shift_filter(ky, kx, sim.object.sampling.get_subpx_shifts(group_scan, probes.shape[-2:]))
probes = ifft2(fft2(probes, shift=False) * group_subpx_filters[:, None], shift=False)

def sim_slice(slice_i: int, prop: t.Optional[NDArray[numpy.complexfloating]], psi):
if prop is not None:
return ifft2(fft2(psi * group_obj[:, slice_i, None]) * prop[:, None])
return ifft2(fft2(psi * group_obj[:, slice_i, None], shift=False) * prop[:, None], shift=False)
return psi * group_obj[:, slice_i, None]

t_props = tilt_propagators(ky, kx, sim, props, sim.tilt[tuple(group)] if sim.tilt is not None else None)
model_wave = fft2(slice_forwards(t_props, probes, sim_slice))
model_wave = fft2(slice_forwards(t_props, probes, sim_slice), shift=False)
model_intensity = xp.sum(abs2(model_wave), axis=(1, -2, -1))
exp_intensity = xp.sum(group_patterns, axis=(-2, -1))

Expand Down
2 changes: 1 addition & 1 deletion phaser/hooks/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def update(


class GradientSolverArgs(t.TypedDict):
plan: 'GradientEnginePlan'
plan: t.Optional['GradientEnginePlan']
params: t.Iterable[ReconsVar]


Expand Down
17 changes: 16 additions & 1 deletion phaser/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,21 @@ class EnginePlan(Dataclass, kw_only=True):
grouping: t.Optional[int] = None
compact: bool = False
shuffle_groups: t.Optional[FlagLike] = None
buffer_n_groups: int = 2
buffer_n_groups: t.Optional[int] = 2
"""
How many groups of patterns to buffer onto the device simultaneously.
Set to 0 to disable buffering, or `None` (`~` in YAML) to preload the
entire dataset to the device.
"""

jit_unroll_slices: t.Union[bool, int] = 10
"""
Slices to unroll during JIT compilation (JAX backend only).
Larger unrolling may be faster, at the expense of increased compilation time.

`True` or `0` unrolls all slices, `False` or `1` disables unrolling.
`10` should be a good default value.
"""

update_probe: FlagLike = True
update_object: FlagLike = True
Expand All @@ -80,6 +94,7 @@ class EnginePlan(Dataclass, kw_only=True):
(smooths over ~1/smoothing iterations)
"""

check_every_group: bool = False
send_every_group: bool = False


Expand Down
7 changes: 7 additions & 0 deletions phaser/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,13 @@ class PreparedRecons:
name: str
observer: 'ObserverSet'

def to_xp(self, xp: t.Any) -> Self:
return self.__class__(
self.patterns,
self.state.to_xp(xp),
self.name, self.observer,
)

def to_numpy(self) -> Self:
return self.__class__(
self.patterns.to_numpy(), self.state.to_numpy(), self.name, self.observer
Expand Down
11 changes: 7 additions & 4 deletions phaser/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,12 @@ def collect_errors(self, val: t.Any) -> t.Optional[ErrorNode]:


__all__ = [
'BackendName', 'Dataclass', 'Slices', 'Flag',
'cast_length', 'BackendName', 'ReconsVar', 'ReconsVars',
'EmptyDict', 'EarlyTermination', 'Dataclass',
'SliceList', 'SliceStep', 'SliceTotal', 'Slices',
'ComplexCartesian', 'ComplexPolar',
'Krivanek', 'KrivanekComplex', 'KrivanekCartesian',
'KrivanekPolar', 'KnownAberration', 'Aberration',
'process_aberrations', 'process_flag', 'flag_any_true',
'Krivanek', 'KrivanekComplex', 'KrivanekCartesian', 'KrivanekPolar',
'KnownAberration', 'Aberration', 'process_aberrations',
'SimpleFlag', 'process_flag', 'process_schedule', 'flag_any_true',
'IsVersion', 'Version',
]
7 changes: 5 additions & 2 deletions phaser/utils/_jax_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,11 @@ def inner(obj: jax.Array, v: t.Tuple[jax.Array, jax.Array]) -> t.Tuple[jax.Array

@partial(jax.jit, static_argnums=2)
def get_cutouts(obj: jax.Array, start_idxs: jax.Array, cutout_shape: t.Tuple[int, int]) -> jax.Array:
return jax.vmap(jax.vmap(lambda start_idx, obj: jax.lax.dynamic_slice(obj, start_idx, cutout_shape), (None, 0)), (0, None))(
to_2d(start_idxs), to_3d(obj)
idxs = to_2d(start_idxs)
ys = idxs[:, 0:1] + jax.numpy.arange(cutout_shape[0])
xs = idxs[:, 1:2] + jax.numpy.arange(cutout_shape[1])
return jax.numpy.swapaxes(
to_3d(obj)[..., ys[:, :, None], xs[:, None, :]], 0, 1
).reshape((*start_idxs.shape[:-1], *obj.shape[:-2], *cutout_shape))


Expand Down
Loading
Loading