From 2fa4bd08f58f9da3bba9a59dcf6e6691bfd3eeff Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 25 Feb 2026 19:35:54 -0500 Subject: [PATCH 01/50] add timing printouts within cov2d --- src/aspire/covariance/covar2d.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/aspire/covariance/covar2d.py b/src/aspire/covariance/covar2d.py index d0bb73b57e..1c852fe81f 100644 --- a/src/aspire/covariance/covar2d.py +++ b/src/aspire/covariance/covar2d.py @@ -554,6 +554,7 @@ def _build(self): self.basis.filter_to_basis_mat(f, pixel_size=self.src.pixel_size) for f in unique_filters ] + logger.info("Represent CTF filters in basis complete") def _calc_rhs(self): src = self.src @@ -683,6 +684,7 @@ def _solve_covar(self, A_covar, b_covar, M, covar_est_opt): return method(A_covar, b_covar, M, covar_est_opt) def _solve_covar_direct(self, A_covar, b_covar, M, covar_est_opt): + t0 = perf_counter() # A_covar is a list of DiagMatrix, representing each ctf in self.basis. # b_covar is a BlkDiagMatrix # M is sum of weighted A squared. @@ -700,9 +702,13 @@ def _solve_covar_direct(self, A_covar, b_covar, M, covar_est_opt): # in Yunpeng's code. res = Minv @ b_covar @ Minv + t1 = perf_counter() + logger.info(f"_solve_covar_direct elapsed: {t1-t0}") return res def _solve_covar_cg(self, A_covar, b_covar, M, covar_est_opt): + t0 = perf_counter() + def precond_fun(S, x): p = np.size(S, 0) assert np.size(x) == p * p, "The sizes of S and x are not consistent." @@ -736,6 +742,8 @@ def apply(A, x): ) covar_coef[ell] = covar_coef_ell.reshape(p, p) + t1 = perf_counter() + logger.info(f"_solve_covar_cgelapsed: {t1-t0}") return covar_coef def get_mean(self): From ad8c109a82eb0397e2dbb2178850ef10f664f7bc Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 25 Feb 2026 19:40:41 -0500 Subject: [PATCH 02/50] add fetch and eval_t times --- src/aspire/covariance/covar2d.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/aspire/covariance/covar2d.py b/src/aspire/covariance/covar2d.py index 1c852fe81f..1dfc1823ac 100644 --- a/src/aspire/covariance/covar2d.py +++ b/src/aspire/covariance/covar2d.py @@ -1,4 +1,5 @@ import logging +from time import perf_counter import numpy as np from numpy.linalg import eig, inv @@ -569,11 +570,20 @@ def _calc_rhs(self): b_covar = BlkDiagMatrix.zeros(self.basis.blk_diag_cov_shape, dtype=self.dtype) + cumulative_image_fetch_time = 0 + cumulative_eval_t_time = 0 for start in range(0, src.n, self.batch_size): batch = np.arange(start, min(start + self.batch_size, src.n)) + t0 = perf_counter() im = src.images[batch[0] : batch[0] + len(batch)] + t1 = perf_counter() + cumulative_image_fetch_time += t1 - t0 + + t0 = perf_counter() coef = basis.evaluate_t(im).asnumpy() + t1 = perf_counter() + cumulative_eval_t_time += t1 - t0 for k in np.unique(ctf_idx[batch]): coef_k = coef[ctf_idx[batch] == k] @@ -603,6 +613,8 @@ def _calc_rhs(self): self.b_mean = b_mean self.b_covar = b_covar + logger.info(f"cumulative_image_fetch_time {cumulative_image_fetch_time}") + logger.info(f"cumulative_eval_t_time {cumulative_eval_t_time}") def _calc_op(self): src = self.src From 9e4b88e54b38f429042d239543be0864aaf46ce4 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 19 Mar 2026 14:12:43 -0400 Subject: [PATCH 03/50] stashing initial rad ctf port, trying compare to our filter --- src/aspire/basis/fb_2d.py | 4 +- src/aspire/basis/ffb_2d.py | 2 +- src/aspire/basis/fle_2d.py | 96 ++++++++++++++++++++++++++++++-- src/aspire/basis/fpswf_2d.py | 4 +- src/aspire/basis/fspca.py | 2 +- src/aspire/basis/pswf_2d.py | 4 +- src/aspire/basis/steerable.py | 67 +++++++++++++++++++++- src/aspire/covariance/covar2d.py | 4 +- 8 files changed, 165 insertions(+), 18 deletions(-) diff --git a/src/aspire/basis/fb_2d.py b/src/aspire/basis/fb_2d.py index 6477d23e5c..3c02414535 100644 --- a/src/aspire/basis/fb_2d.py +++ b/src/aspire/basis/fb_2d.py @@ -289,8 +289,8 @@ def calculate_bispectrum( freq_cutoff=freq_cutoff, ) - def filter_to_basis_mat(self, *args, **kwargs): + def _filter_to_basis_mat(self, *args, **kwargs): """ See `SteerableBasis2D.filter_to_basis_mat`. """ - return super().filter_to_basis_mat(*args, **kwargs) + return super()._filter_to_basis_mat(*args, **kwargs) diff --git a/src/aspire/basis/ffb_2d.py b/src/aspire/basis/ffb_2d.py index 2f871af9f5..151cef4066 100644 --- a/src/aspire/basis/ffb_2d.py +++ b/src/aspire/basis/ffb_2d.py @@ -236,7 +236,7 @@ def _evaluate_t(self, x): return xp.asnumpy(v) - def filter_to_basis_mat(self, f, **kwargs): + def _filter_to_basis_mat(self, f, **kwargs): """ See `SteerableBasis2D.filter_to_basis_mat`. """ diff --git a/src/aspire/basis/fle_2d.py b/src/aspire/basis/fle_2d.py index cc6ea04672..a860c69369 100644 --- a/src/aspire/basis/fle_2d.py +++ b/src/aspire/basis/fle_2d.py @@ -298,7 +298,7 @@ def _compute_nufft_points(self): nodes = ( self.greatest_lambda - self.smallest_lambda ) * nodes + self.smallest_lambda - nodes = nodes.reshape(self.num_radial_nodes, 1) + self.nodes = nodes.reshape(self.num_radial_nodes, 1) radius = self.nres / 2 h = 1 / radius @@ -314,7 +314,7 @@ def _compute_nufft_points(self): ) grid_xy[0] = xp.cos(phi) # x grid_xy[1] = xp.sin(phi) # y - grid_xy[:] = grid_xy * nodes * h + grid_xy[:] = grid_xy * self.nodes * h self.grid_xy = grid_xy.reshape(2, -1) def _build_interpolation_matrix(self): @@ -738,6 +738,8 @@ def radial_convolve(self, coefs, radial_img): _coefs = coefs[k, :] z = self._step1_t(radial_img) b = self._step2_t(z) + # squeeze previously in _radial_convolve_weights + b = b.squeeze() weights = self._radial_convolve_weights(b) b = weights / (self.h**2) b = b.reshape(self.count) @@ -753,14 +755,15 @@ def _radial_convolve_weights(self, b): """ Helper function for step 3 of convolving with a radial function. """ - b = xp.squeeze(b) + # Developer note, this is equivalent `fle2d.expand_radial_vec` up to shapes. b = xp.array(b) # implies copy if self.num_interp > self.num_radial_nodes: b = fft.dct(b, axis=0, type=2) / (2 * self.num_radial_nodes) - bz = xp.zeros(b.shape) + bz = xp.zeros(b.shape, dtype=self.dtype) b = xp.concatenate((b, bz), axis=0) b = fft.idct(b, axis=0, type=2) * 2 * b.shape[0] - a = xp.zeros(self.count, dtype=np.float64) + a = xp.zeros(self.count, dtype=self.dtype) + ## xx note these can be collapsed into one loop later y = [None] * (self.ell_p_max + 1) for i in range(self.ell_p_max + 1): y[i] = (self.A3[i] @ b[:, 0]).flatten() @@ -769,7 +772,7 @@ def _radial_convolve_weights(self, b): return a.flatten() - def filter_to_basis_mat(self, f, **kwargs): + def _filter_to_basis_mat(self, f, **kwargs): """ See `SteerableBasis2D.filter_to_basis_mat`. """ @@ -818,3 +821,84 @@ def filter_to_basis_mat(self, f, **kwargs): h_basis = h_basis[self._fle_to_fb_indices] return DiagMatrix(xp.asnumpy(h_basis)) + + # def _fle_expand_radial_vec(self, radial_vec): + + # radial_vec = radial_vec.T + # #if self.n_interp > self.n_radial: + # if self.num_interp > self.num_radial_nodes: + # radial_vec = fft.dct(radial_vec, axis=0, type=2, workers=-1) / (2 * self.num_radial_nodes) + # radial_vec_z = xp.zeros(radial_vec.shape) + # radial_vec = xp.concatenate((radial_vec, radial_vec_z), axis=0) + # radial_vec = ( + # fft.idct(radial_vec, axis=0, type=2, workers=-1) * 2 * radial_vec.shape[0] + # ) + + # radial_fb = xp.zeros((self.count, radial_vec.shape[1]), dtype=self.dtype) + + # for i in range(self.ell_p_max + 1): + # radial_fb[self.idx_list[i], :] = self.A3[i] @ radial_vec + + # return radial_fb.T + + def expand_radial_vec(self, radial_vec, **kwargs): + coefs = self._radial_convolve_weights(radial_vec) + # _coefs = self._fle_expand_radial_vec(radial_vec.T) + # assert coefs.dtype == _coefs.dtype + # assert np.allclose(coefs,_coefs) + # #breakpoint() + + # check... + # Convert to internal FLE indices ordering + coefs = coefs[..., self._fb_to_fle_indices] + # squeeze should probably be addressed in consuming code, + # for now match old `filter_to_basis_mat` + coefs = xp.asnumpy(coefs).squeeze() + + return DiagMatrix(coefs) + + # def expand_radial_vec(self, radial_vec, **kwargs): + + # radial_vec = xp.asarray(radial_vec) + + # ## XXX looks like we do in fact need the padding/size-correction here... + # if self.num_interp > self.num_radial_nodes: + # radial_vec = fff.dct(radial_vec, axis=1, type=2) / (2 * self.num_radial_nodes) + # radial_vec_z = xp.zeros(radial_vec.shape) + # radial_vec = xp.concatenate((radial_vec, radial_vec_z), axis=1) + # radial_vec = fff.idct(radial_vec, axis=1, type=2) * 2 * radial_vec.shape[1] + + # # appears equiv to angular ordering code + # h_basis = xp.zeros(self.count, dtype=self.dtype) + # # For now we just need to handle 1D (stack of one ctf) + # breakpoint() + # for j in range(self.ell_p_max + 1): + # h_basis[self.idx_list[j]] = self.A3[j] @ radial_vec + + # # Convert from internal FLE ordering to FB convention + # h_basis = h_basis[self._fle_to_fb_indices] + + # return DiagMatrix(xp.asnumpy(h_basis)) + + def _radial_ctf_filter_to_filter_vals(self, f, **kwargs): + """ + Unpack filter attributes and pass to Yunpeng code. + """ + + pts = xp.asnumpy(self.nodes) + + # _filter_pts = np.pad(pts.reshape(1,-1), ((0,1),(0,0))) + # h_vals = f.evaluate(_filter_pts, **kwargs) + + pixel_size = kwargs.get("pixel_size") + h_vals = self._radial_ctf( + f.voltage, + f.Cs, + f.alpha, + (f.defocus_u + f.defocus_v) / 2, + pixel_size, + self.h, + pts, + ) + # breakpoint() + return h_vals diff --git a/src/aspire/basis/fpswf_2d.py b/src/aspire/basis/fpswf_2d.py index 6eac1a91db..4c498757e7 100644 --- a/src/aspire/basis/fpswf_2d.py +++ b/src/aspire/basis/fpswf_2d.py @@ -367,8 +367,8 @@ def _pswf_integration(self, images_nufft): return coef_vec_quad - def filter_to_basis_mat(self, *args, **kwargs): + def _filter_to_basis_mat(self, *args, **kwargs): """ See `SteerableBasis2D.filter_to_basis_mat`. """ - return super().filter_to_basis_mat(*args, **kwargs) + return super()._filter_to_basis_mat(*args, **kwargs) diff --git a/src/aspire/basis/fspca.py b/src/aspire/basis/fspca.py index d2f78b2ff7..053780b27a 100644 --- a/src/aspire/basis/fspca.py +++ b/src/aspire/basis/fspca.py @@ -617,7 +617,7 @@ def shift(self, coef, shifts): self.evaluate_to_image_basis(coef).shift(shifts) ) - def filter_to_basis_mat(self, f, **kwargs): + def _filter_to_basis_mat(self, f, **kwargs): """ Convert a filter into a basis representation. diff --git a/src/aspire/basis/pswf_2d.py b/src/aspire/basis/pswf_2d.py index c9795ec1bc..632aab14a3 100644 --- a/src/aspire/basis/pswf_2d.py +++ b/src/aspire/basis/pswf_2d.py @@ -398,8 +398,8 @@ def _pswf_2d_minor_computations(self, big_n, n, bandlimit, phi_approximate_error range_array = np.arange(approx_length, dtype=self.dtype) return d_vec, approx_length, range_array - def filter_to_basis_mat(self, *args, **kwargs): + def _filter_to_basis_mat(self, *args, **kwargs): """ See `SteerableBasis2D.filter_to_basis_mat`. """ - return super().filter_to_basis_mat(*args, **kwargs) + return super()._filter_to_basis_mat(*args, **kwargs) diff --git a/src/aspire/basis/steerable.py b/src/aspire/basis/steerable.py index b5051065da..04a65bbf99 100644 --- a/src/aspire/basis/steerable.py +++ b/src/aspire/basis/steerable.py @@ -1,12 +1,14 @@ import abc import logging +import warnings from collections.abc import Iterable import numpy as np from aspire.basis import Basis, Coef, ComplexCoef -from aspire.operators import BlkDiagMatrix -from aspire.utils import LogFilterByCount, complex_type, real_type, trange +from aspire.numeric import xp +from aspire.operators import BlkDiagMatrix, CTFFilter, DiagMatrix +from aspire.utils import LogFilterByCount, complex_type, grid_1d, real_type, trange logger = logging.getLogger(__name__) @@ -476,12 +478,57 @@ def to_complex(self, coef): return ComplexCoef(self, complex_coef) + # @abc.abstractmethod + # def expand_radial_vec(self, h_vals, **kwargs): + # """ + # Expand a radial vector given by `h_vals` into a basis mat. + + # :param h_vals: Radial vector(s) + # :return: Basis representation (may be `BlkDiagMatrix`, or `DiagMatrix`) depending on basis. + # """ + # # By default code can point here for a slow implementation. + # # A basis with a specialized solution should implementat that in the respective subclass. + # return basis_mat + + def filter_to_basis_mat(self, f, radial=None, **kwargs): + """ + Convert a filter into a basis operator representation. + + See `_filter_to_basis_mat` here and in subclasses for available **kwargs. + + :param f: `Filter` object, usually a `CTFFilter`. + :param radial: Optionally attempt radial approximation if available. + + :return: Representation of filter as `basis` operator. + Return type will be based on the class's `matrix_type`. + """ + + if (radial == True) and callable( + getattr(self.__class__, "expand_radial_vec", None) + ): + # previous code + # _res = self._filter_to_basis_mat(f, **kwargs) + + # kwargs supports passing through pixel_size + h_vals = self._radial_ctf_filter_to_filter_vals(f, **kwargs).reshape(-1, 1) + + warnings.warn("Using `expand_radial_vec`", UserWarning, stacklevel=1) + res = self.expand_radial_vec(h_vals) + # breakpoint() + return res + else: + warnings.warn( + "Using generic `_filter_to_basis_mat'", UserWarning, stacklevel=1 + ) + # use generic (legacy) filter path/code (may return DiagMatrix) + return self._filter_to_basis_mat(f, **kwargs) + # `abstractmethod` enforces when a new subclass of # `SteerableBasis2D` is created that this method is explicitly # implemented. This is intended to encourage future basis authors # to consider this method for their application. @abc.abstractmethod - def filter_to_basis_mat(self, f, method="evaluate_t", truncate=True, **kwargs): + def _filter_to_basis_mat(self, f, method="evaluate_t", truncate=True, **kwargs): """ Convert a filter into a basis operator representation. @@ -534,3 +581,17 @@ def filter_to_basis_mat(self, f, method="evaluate_t", truncate=True, **kwargs): ) return filt + + #### xxx + + def _radial_ctf(self, voltage, cs, alpha, defocus, pixel_size, h, pts): + wavelength = 12.2643247 / np.sqrt(voltage * 1e3 + 0.978466 * voltage**2) + c2_vec = (-np.pi * wavelength * defocus).reshape(-1, 1) + c4_vec = (0.5 * np.pi * (cs * 1e7) * wavelength**3).reshape(-1, 1) + r2 = (pts * h / (pixel_size * 2 * np.pi)) ** 2 + r4 = r2**2 + + gamma = r2 @ c2_vec.T + r4 @ c4_vec.T + ctf_radial = np.sqrt(1 - alpha**2) * np.sin(gamma) - alpha * np.cos(gamma) + # assert ctf_radial.shape == self.num_radial_nodes, f"ctf_radial_shape {ctf_radial.shape} != num_radial_nodes {self.num_radial_nodes}" + return ctf_radial diff --git a/src/aspire/covariance/covar2d.py b/src/aspire/covariance/covar2d.py index 1dfc1823ac..d49b9d0b9a 100644 --- a/src/aspire/covariance/covar2d.py +++ b/src/aspire/covariance/covar2d.py @@ -552,7 +552,9 @@ def _build(self): unique_filters = src.unique_filters self.ctf_idx = src.filter_indices self.ctf_basis = [ - self.basis.filter_to_basis_mat(f, pixel_size=self.src.pixel_size) + self.basis.filter_to_basis_mat( + f, pixel_size=self.src.pixel_size, radial=True + ) for f in unique_filters ] logger.info("Represent CTF filters in basis complete") From 2bffae378c3a4bd83e7ca1cc048078f2e6084718 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 31 Mar 2026 12:17:50 -0400 Subject: [PATCH 04/50] stashing to_radial and freq pt scaling patches --- src/aspire/basis/fle_2d.py | 14 +++++++++----- src/aspire/operators/filters.py | 15 +++++++++++++++ tests/test_covar2d_denoiser.py | 6 ++++-- 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/src/aspire/basis/fle_2d.py b/src/aspire/basis/fle_2d.py index a860c69369..157ebd2f53 100644 --- a/src/aspire/basis/fle_2d.py +++ b/src/aspire/basis/fle_2d.py @@ -13,7 +13,7 @@ from aspire.nufft import anufft, nufft from aspire.numeric import fft, sparse, xp from aspire.operators import DiagMatrix -from aspire.utils import complex_type, grid_2d +from aspire.utils import complex_type, grid_1d, grid_2d logger = logging.getLogger(__name__) @@ -887,11 +887,12 @@ def _radial_ctf_filter_to_filter_vals(self, f, **kwargs): pts = xp.asnumpy(self.nodes) - # _filter_pts = np.pad(pts.reshape(1,-1), ((0,1),(0,0))) - # h_vals = f.evaluate(_filter_pts, **kwargs) + _filter_pts = np.pad(pts.reshape(1, -1), ((0, 1), (0, 0))) * self.h + _astig_h_vals = f.evaluate(_filter_pts, **kwargs) + h_vals = f.to_radial().evaluate(_filter_pts, **kwargs) pixel_size = kwargs.get("pixel_size") - h_vals = self._radial_ctf( + _h_vals = self._radial_ctf( f.voltage, f.Cs, f.alpha, @@ -900,5 +901,8 @@ def _radial_ctf_filter_to_filter_vals(self, f, **kwargs): self.h, pts, ) - # breakpoint() + print("sum _astig_h_vals", np.sum(_astig_h_vals)) + print("sum h_vals", np.sum(h_vals)) + print("sum _h_vals", np.sum(_h_vals)) + breakpoint() return h_vals diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index 3c3f362fc3..9a02c6b5eb 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -505,6 +505,21 @@ def _evaluate(self, omega, **kwargs): return h + def to_radial(self): + """ + Return a new `RadialCTFFilter` with the mean of astigmatic defocus values. + + :return: `RadialCTFFilter` + """ + mean_defocus = (self.defocus_u + self.defocus_v) / 2 + return RadialCTFFilter( + voltage=self.voltage, + defocus=mean_defocus, + Cs=self.Cs, + alpha=self.alpha, + B=self.B, + ) + class RadialCTFFilter(CTFFilter): def __init__(self, voltage=200, defocus=15000, Cs=2.26, alpha=0.07, B=0): diff --git a/tests/test_covar2d_denoiser.py b/tests/test_covar2d_denoiser.py index 8a4660c5c6..d0bb8cf6da 100644 --- a/tests/test_covar2d_denoiser.py +++ b/tests/test_covar2d_denoiser.py @@ -4,7 +4,7 @@ from aspire.basis import FBBasis2D, FFBBasis2D, FLEBasis2D, FPSWFBasis2D, PSWFBasis2D from aspire.denoising import DenoisedSource, DenoiserCov2D from aspire.noise import WhiteNoiseAdder -from aspire.operators import IdentityFilter, RadialCTFFilter +from aspire.operators import CTFFilter, IdentityFilter, RadialCTFFilter from aspire.source import Simulation from aspire.utils import utest_tolerance @@ -16,7 +16,9 @@ noise_adder = WhiteNoiseAdder(var=noise_var) pixel_size = 5 filters = [ - RadialCTFFilter(200, defocus=d, Cs=2.0, alpha=0.1) + CTFFilter( + 200, defocus_ang=np.pi / 3, defocus_u=d, defocus_v=d + 345, Cs=2.0, alpha=0.1 + ) for d in np.linspace(1.5e4, 2.5e4, 7) ] From 8f2a5535cb59c74b2e84e337e68edde590bddb04 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 31 Mar 2026 13:57:04 -0400 Subject: [PATCH 05/50] cleanup debugging logic a bit --- src/aspire/basis/fle_2d.py | 44 +++------------------------------ src/aspire/basis/steerable.py | 27 ++++++++++++++------ src/aspire/operators/filters.py | 8 ++++++ 3 files changed, 30 insertions(+), 49 deletions(-) diff --git a/src/aspire/basis/fle_2d.py b/src/aspire/basis/fle_2d.py index 157ebd2f53..0bbad31a75 100644 --- a/src/aspire/basis/fle_2d.py +++ b/src/aspire/basis/fle_2d.py @@ -857,30 +857,7 @@ def expand_radial_vec(self, radial_vec, **kwargs): return DiagMatrix(coefs) - # def expand_radial_vec(self, radial_vec, **kwargs): - - # radial_vec = xp.asarray(radial_vec) - - # ## XXX looks like we do in fact need the padding/size-correction here... - # if self.num_interp > self.num_radial_nodes: - # radial_vec = fff.dct(radial_vec, axis=1, type=2) / (2 * self.num_radial_nodes) - # radial_vec_z = xp.zeros(radial_vec.shape) - # radial_vec = xp.concatenate((radial_vec, radial_vec_z), axis=1) - # radial_vec = fff.idct(radial_vec, axis=1, type=2) * 2 * radial_vec.shape[1] - - # # appears equiv to angular ordering code - # h_basis = xp.zeros(self.count, dtype=self.dtype) - # # For now we just need to handle 1D (stack of one ctf) - # breakpoint() - # for j in range(self.ell_p_max + 1): - # h_basis[self.idx_list[j]] = self.A3[j] @ radial_vec - - # # Convert from internal FLE ordering to FB convention - # h_basis = h_basis[self._fle_to_fb_indices] - - # return DiagMatrix(xp.asnumpy(h_basis)) - - def _radial_ctf_filter_to_filter_vals(self, f, **kwargs): + def _radial_filter_to_vals(self, f, **kwargs): """ Unpack filter attributes and pass to Yunpeng code. """ @@ -888,21 +865,6 @@ def _radial_ctf_filter_to_filter_vals(self, f, **kwargs): pts = xp.asnumpy(self.nodes) _filter_pts = np.pad(pts.reshape(1, -1), ((0, 1), (0, 0))) * self.h - _astig_h_vals = f.evaluate(_filter_pts, **kwargs) - h_vals = f.to_radial().evaluate(_filter_pts, **kwargs) - - pixel_size = kwargs.get("pixel_size") - _h_vals = self._radial_ctf( - f.voltage, - f.Cs, - f.alpha, - (f.defocus_u + f.defocus_v) / 2, - pixel_size, - self.h, - pts, - ) - print("sum _astig_h_vals", np.sum(_astig_h_vals)) - print("sum h_vals", np.sum(h_vals)) - print("sum _h_vals", np.sum(_h_vals)) - breakpoint() + h_vals = f.evaluate(_filter_pts, **kwargs) + return h_vals diff --git a/src/aspire/basis/steerable.py b/src/aspire/basis/steerable.py index 04a65bbf99..0ccf9a5d82 100644 --- a/src/aspire/basis/steerable.py +++ b/src/aspire/basis/steerable.py @@ -490,31 +490,42 @@ def to_complex(self, coef): # # A basis with a specialized solution should implementat that in the respective subclass. # return basis_mat - def filter_to_basis_mat(self, f, radial=None, **kwargs): + def filter_to_basis_mat(self, f, radial_optimization=None, **kwargs): """ Convert a filter into a basis operator representation. See `_filter_to_basis_mat` here and in subclasses for available **kwargs. :param f: `Filter` object, usually a `CTFFilter`. - :param radial: Optionally attempt radial approximation if available. + :param radial_optimization: Optionally attempt radial approximation if available. :return: Representation of filter as `basis` operator. Return type will be based on the class's `matrix_type`. """ - if (radial == True) and callable( - getattr(self.__class__, "expand_radial_vec", None) + # does the basis have optimized expand for radial vectors? + optimized_expand = callable(getattr(self.__class__, "expand_radial_vec", None)) + + filter_is_radial = f.radial == True + # does the filter have `to_radial`? + filter_has_to_radial = callable(getattr(f, "to_radial", None)) + + if ( + (radial_optimization == True) + and optimized_expand + and (filter_is_radial or filter_has_to_radial) ): - # previous code - # _res = self._filter_to_basis_mat(f, **kwargs) + + # Make `f` radial as needed + if not filter_is_radial: + f = f.to_radial() # kwargs supports passing through pixel_size - h_vals = self._radial_ctf_filter_to_filter_vals(f, **kwargs).reshape(-1, 1) + h_vals = self._radial_filter_to_vals(f, **kwargs).reshape(-1, 1) warnings.warn("Using `expand_radial_vec`", UserWarning, stacklevel=1) res = self.expand_radial_vec(h_vals) - # breakpoint() + return res else: warnings.warn( diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index 9a02c6b5eb..002e13c47c 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -533,6 +533,14 @@ def __init__(self, voltage=200, defocus=15000, Cs=2.26, alpha=0.07, B=0): B=B, ) + def to_radial(self): + """ + `RadialCTFFilter` is already radial, Returns self. Supports code interop. + + :return: `RadialCTFFilter` + """ + return self + class BlueFilter(Filter): """ From b8ab9d970d373326251fe64b2bc2fded2ff5f82a Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 31 Mar 2026 14:13:32 -0400 Subject: [PATCH 06/50] continue cleanup [skip ci] --- src/aspire/basis/fle_2d.py | 24 +----------------------- 1 file changed, 1 insertion(+), 23 deletions(-) diff --git a/src/aspire/basis/fle_2d.py b/src/aspire/basis/fle_2d.py index 0bbad31a75..0752c99bfe 100644 --- a/src/aspire/basis/fle_2d.py +++ b/src/aspire/basis/fle_2d.py @@ -822,35 +822,13 @@ def _filter_to_basis_mat(self, f, **kwargs): return DiagMatrix(xp.asnumpy(h_basis)) - # def _fle_expand_radial_vec(self, radial_vec): - - # radial_vec = radial_vec.T - # #if self.n_interp > self.n_radial: - # if self.num_interp > self.num_radial_nodes: - # radial_vec = fft.dct(radial_vec, axis=0, type=2, workers=-1) / (2 * self.num_radial_nodes) - # radial_vec_z = xp.zeros(radial_vec.shape) - # radial_vec = xp.concatenate((radial_vec, radial_vec_z), axis=0) - # radial_vec = ( - # fft.idct(radial_vec, axis=0, type=2, workers=-1) * 2 * radial_vec.shape[0] - # ) - - # radial_fb = xp.zeros((self.count, radial_vec.shape[1]), dtype=self.dtype) - - # for i in range(self.ell_p_max + 1): - # radial_fb[self.idx_list[i], :] = self.A3[i] @ radial_vec - - # return radial_fb.T - def expand_radial_vec(self, radial_vec, **kwargs): coefs = self._radial_convolve_weights(radial_vec) - # _coefs = self._fle_expand_radial_vec(radial_vec.T) - # assert coefs.dtype == _coefs.dtype - # assert np.allclose(coefs,_coefs) - # #breakpoint() # check... # Convert to internal FLE indices ordering coefs = coefs[..., self._fb_to_fle_indices] + # squeeze should probably be addressed in consuming code, # for now match old `filter_to_basis_mat` coefs = xp.asnumpy(coefs).squeeze() From 88089dccc2adf9ab8667f1092ee7b9406c52a64f Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 9 Apr 2026 10:18:21 -0400 Subject: [PATCH 07/50] use existing expand_method, add radial, instead of new flag --- src/aspire/basis/steerable.py | 26 ++++++++------------------ src/aspire/covariance/covar2d.py | 5 +++-- 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/src/aspire/basis/steerable.py b/src/aspire/basis/steerable.py index 0ccf9a5d82..59ac5e2d27 100644 --- a/src/aspire/basis/steerable.py +++ b/src/aspire/basis/steerable.py @@ -490,7 +490,7 @@ def to_complex(self, coef): # # A basis with a specialized solution should implementat that in the respective subclass. # return basis_mat - def filter_to_basis_mat(self, f, radial_optimization=None, **kwargs): + def filter_to_basis_mat(self, f, **kwargs): """ Convert a filter into a basis operator representation. @@ -505,21 +505,11 @@ def filter_to_basis_mat(self, f, radial_optimization=None, **kwargs): # does the basis have optimized expand for radial vectors? optimized_expand = callable(getattr(self.__class__, "expand_radial_vec", None)) - + # is the filter radial? filter_is_radial = f.radial == True - # does the filter have `to_radial`? - filter_has_to_radial = callable(getattr(f, "to_radial", None)) - - if ( - (radial_optimization == True) - and optimized_expand - and (filter_is_radial or filter_has_to_radial) - ): - - # Make `f` radial as needed - if not filter_is_radial: - f = f.to_radial() + radial_method = kwargs.get("expand_method", None) == "radial" + if optimized_expand and filter_is_radial and radial_method: # kwargs supports passing through pixel_size h_vals = self._radial_filter_to_vals(f, **kwargs).reshape(-1, 1) @@ -539,12 +529,12 @@ def filter_to_basis_mat(self, f, radial_optimization=None, **kwargs): # implemented. This is intended to encourage future basis authors # to consider this method for their application. @abc.abstractmethod - def _filter_to_basis_mat(self, f, method="evaluate_t", truncate=True, **kwargs): + def _filter_to_basis_mat(self, f, method=None, truncate=True, **kwargs): """ Convert a filter into a basis operator representation. :param f: `Filter` object, usually a `CTFFilter`. - :param method: `evaluate_t` or `expand`. + :param method: `evaluate_t` or `expand`. Default `None` uses `evaluate_t`. :param truncate: Optionally, truncate dense matrix to BlkDiagMatrix. Defaults to True. @@ -552,13 +542,13 @@ def _filter_to_basis_mat(self, f, method="evaluate_t", truncate=True, **kwargs): Return type will be based on the class's `matrix_type`. """ # evaluate_t is not as accurate, but much much faster... - if method == "evaluate_t": + if method == "evaluate_t" or method == None: expand_method = self.evaluate_t elif method == "expand": expand_method = self.expand else: raise NotImplementedError( - "`filter_to_basis_mat` method {method} not supported." + f"`filter_to_basis_mat` method {method} not supported." " Try `evaluate_t` or `expand`." ) diff --git a/src/aspire/covariance/covar2d.py b/src/aspire/covariance/covar2d.py index d49b9d0b9a..a8f0449109 100644 --- a/src/aspire/covariance/covar2d.py +++ b/src/aspire/covariance/covar2d.py @@ -521,7 +521,7 @@ class BatchedRotCov2D(RotCov2D): scaling up to a larger value such as 8192 may yield better performance. """ - def __init__(self, src, basis=None, batch_size=512): + def __init__(self, src, basis=None, expand_method=None, batch_size=512): self.src = src self.basis = basis self.batch_size = batch_size @@ -532,6 +532,7 @@ def __init__(self, src, basis=None, batch_size=512): self.A_mean = None self.A_covar = None self.M_covar = None + self.expand_method = expand_method self._build() @@ -553,7 +554,7 @@ def _build(self): self.ctf_idx = src.filter_indices self.ctf_basis = [ self.basis.filter_to_basis_mat( - f, pixel_size=self.src.pixel_size, radial=True + f, pixel_size=self.src.pixel_size, method=self.expand_method ) for f in unique_filters ] From 710bcf50fbe47ec397157c4a272741bf32235906 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 9 Apr 2026 11:12:03 -0400 Subject: [PATCH 08/50] use existing expand_method, add radial, instead of new flag --- src/aspire/basis/ffb_2d.py | 7 ++++--- src/aspire/basis/fle_2d.py | 19 +++++++++++++++---- src/aspire/basis/steerable.py | 16 ++++++++-------- src/aspire/covariance/covar2d.py | 2 +- tests/test_covar2d_denoiser.py | 10 ++++++---- 5 files changed, 34 insertions(+), 20 deletions(-) diff --git a/src/aspire/basis/ffb_2d.py b/src/aspire/basis/ffb_2d.py index 151cef4066..64de6a0c5a 100644 --- a/src/aspire/basis/ffb_2d.py +++ b/src/aspire/basis/ffb_2d.py @@ -241,10 +241,11 @@ def _filter_to_basis_mat(self, f, **kwargs): See `SteerableBasis2D.filter_to_basis_mat`. """ # Note 'method' and 'truncate' not relevant for this optimized FFB code. - if kwargs.get("method", None) is not None: + expand_method = kwargs.get("expand_method", None) + if expand_method is not None: raise NotImplementedError( - "`FFBBasis2D.filter_to_basis_mat` method {method} not supported." - " Use `method=None`." + f"`FFBBasis2D.filter_to_basis_mat` expand_method '{expand_method}' not supported." + " Use `expand_method=None`." ) pixel_size = kwargs.get("pixel_size", None) diff --git a/src/aspire/basis/fle_2d.py b/src/aspire/basis/fle_2d.py index 0752c99bfe..3ef6313510 100644 --- a/src/aspire/basis/fle_2d.py +++ b/src/aspire/basis/fle_2d.py @@ -775,13 +775,24 @@ def _radial_convolve_weights(self, b): def _filter_to_basis_mat(self, f, **kwargs): """ See `SteerableBasis2D.filter_to_basis_mat`. + + This code implements a radially averaged filter, returning a `DiagMatrix`. + It is meant to be used with non radial filter inputs. + For radial filters, an optimized code path may be available via `expand_method='radial'`. """ - # Note 'method' and 'truncate' not relevant for this optimized FLE code. - if kwargs.get("method", None) is not None: + # Note 'expand_method' and 'truncate' not relevant for this optimized FLE code. + expand_method = kwargs.get("expand_method", None) + if expand_method == "radial" and not f.radial: + raise NotImplementedError( + f"`FLEBasis2D.filter_to_basis_mat` expand_method '{expand_method}' not supported for non radial {f}." + " Convert filter to radial for optimized radial expansion, or use `expand_method=None` for a radially averaged approximation." + ) + elif expand_method is not None: raise NotImplementedError( - "`FLEBasis2D.filter_to_basis_mat` method {method} not supported." - " Use `method=None`." + f"`FLEBasis2D.filter_to_basis_mat` expand_method '{expand_method}' not supported." + " Try `expand_method=None` or 'radial'." ) + pixel_size = kwargs.get("pixel_size", None) # Get the filter's evaluate function. diff --git a/src/aspire/basis/steerable.py b/src/aspire/basis/steerable.py index 59ac5e2d27..8ed4efe9eb 100644 --- a/src/aspire/basis/steerable.py +++ b/src/aspire/basis/steerable.py @@ -529,12 +529,12 @@ def filter_to_basis_mat(self, f, **kwargs): # implemented. This is intended to encourage future basis authors # to consider this method for their application. @abc.abstractmethod - def _filter_to_basis_mat(self, f, method=None, truncate=True, **kwargs): + def _filter_to_basis_mat(self, f, expand_method=None, truncate=True, **kwargs): """ Convert a filter into a basis operator representation. :param f: `Filter` object, usually a `CTFFilter`. - :param method: `evaluate_t` or `expand`. Default `None` uses `evaluate_t`. + :param expand_method: `evaluate_t` or `expand`. Default `None` uses `evaluate_t`. :param truncate: Optionally, truncate dense matrix to BlkDiagMatrix. Defaults to True. @@ -542,13 +542,13 @@ def _filter_to_basis_mat(self, f, method=None, truncate=True, **kwargs): Return type will be based on the class's `matrix_type`. """ # evaluate_t is not as accurate, but much much faster... - if method == "evaluate_t" or method == None: - expand_method = self.evaluate_t - elif method == "expand": - expand_method = self.expand + if expand_method == "evaluate_t" or expand_method == None: + expand_fun = self.evaluate_t + elif expand_method == "expand": + expand_fun = self.expand else: raise NotImplementedError( - f"`filter_to_basis_mat` method {method} not supported." + f"`filter_to_basis_mat` expand_method '{expand_method}' not supported." " Try `evaluate_t` or `expand`." ) @@ -568,7 +568,7 @@ def _filter_to_basis_mat(self, f, method=None, truncate=True, **kwargs): with LogFilterByCount(logger, 1): for i in trange(self.count): try: - filt[i] = expand_method(img[i].filter(f)).asnumpy()[0] + filt[i] = expand_fun(img[i].filter(f)).asnumpy()[0] except Exception: logger.warning( f"Failed to expand basis vector {i} after filter {f}." diff --git a/src/aspire/covariance/covar2d.py b/src/aspire/covariance/covar2d.py index a8f0449109..81821c2e6c 100644 --- a/src/aspire/covariance/covar2d.py +++ b/src/aspire/covariance/covar2d.py @@ -554,7 +554,7 @@ def _build(self): self.ctf_idx = src.filter_indices self.ctf_basis = [ self.basis.filter_to_basis_mat( - f, pixel_size=self.src.pixel_size, method=self.expand_method + f, pixel_size=self.src.pixel_size, expand_method=self.expand_method ) for f in unique_filters ] diff --git a/tests/test_covar2d_denoiser.py b/tests/test_covar2d_denoiser.py index d0bb8cf6da..984acdd342 100644 --- a/tests/test_covar2d_denoiser.py +++ b/tests/test_covar2d_denoiser.py @@ -224,15 +224,17 @@ def test_filter_to_basis_mat_id_expand(coef, basis): # IdentityFilter should produce id filt = IdentityFilter() - # Some basis do not provide alternative `method`s + # Some basis do not provide alternative `expand_method`s if isinstance(basis, FFBBasis2D) or isinstance(basis, FLEBasis2D): with pytest.raises(NotImplementedError, match=r".*not supported.*"): - _ = basis.filter_to_basis_mat(filt, method="expand") + _ = basis.filter_to_basis_mat(filt, expand_method="expand") return # Apply the basis filter operator. # Note transpose because `apply` expects and returns column vectors. - coef_ftbm = (basis.filter_to_basis_mat(filt, method="expand") @ coef.asnumpy().T).T + coef_ftbm = ( + basis.filter_to_basis_mat(filt, expand_method="expand") @ coef.asnumpy().T + ).T # Apply evaluate->filter->expand manually imgs = coef.evaluate() @@ -252,4 +254,4 @@ def test_filter_to_basis_mat_id_expand(coef, basis): def test_filter_to_basis_mat_bad(coef, basis): filt = IdentityFilter() with pytest.raises(NotImplementedError, match=r".*not supported.*"): - _ = basis.filter_to_basis_mat(filt, method="bad_method") + _ = basis.filter_to_basis_mat(filt, expand_method="bad_method") From ef830a46a73226526056bec9cc3b700fbe945c74 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 9 Apr 2026 11:15:04 -0400 Subject: [PATCH 09/50] cleanup --- src/aspire/basis/steerable.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/src/aspire/basis/steerable.py b/src/aspire/basis/steerable.py index 8ed4efe9eb..9d40d0d3a9 100644 --- a/src/aspire/basis/steerable.py +++ b/src/aspire/basis/steerable.py @@ -478,18 +478,6 @@ def to_complex(self, coef): return ComplexCoef(self, complex_coef) - # @abc.abstractmethod - # def expand_radial_vec(self, h_vals, **kwargs): - # """ - # Expand a radial vector given by `h_vals` into a basis mat. - - # :param h_vals: Radial vector(s) - # :return: Basis representation (may be `BlkDiagMatrix`, or `DiagMatrix`) depending on basis. - # """ - # # By default code can point here for a slow implementation. - # # A basis with a specialized solution should implementat that in the respective subclass. - # return basis_mat - def filter_to_basis_mat(self, f, **kwargs): """ Convert a filter into a basis operator representation. @@ -582,17 +570,3 @@ def _filter_to_basis_mat(self, f, expand_method=None, truncate=True, **kwargs): ) return filt - - #### xxx - - def _radial_ctf(self, voltage, cs, alpha, defocus, pixel_size, h, pts): - wavelength = 12.2643247 / np.sqrt(voltage * 1e3 + 0.978466 * voltage**2) - c2_vec = (-np.pi * wavelength * defocus).reshape(-1, 1) - c4_vec = (0.5 * np.pi * (cs * 1e7) * wavelength**3).reshape(-1, 1) - r2 = (pts * h / (pixel_size * 2 * np.pi)) ** 2 - r4 = r2**2 - - gamma = r2 @ c2_vec.T + r4 @ c4_vec.T - ctf_radial = np.sqrt(1 - alpha**2) * np.sin(gamma) - alpha * np.cos(gamma) - # assert ctf_radial.shape == self.num_radial_nodes, f"ctf_radial_shape {ctf_radial.shape} != num_radial_nodes {self.num_radial_nodes}" - return ctf_radial From 052acac3f6cf84f41d5411e14685fc0036e2cff8 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 9 Apr 2026 11:18:32 -0400 Subject: [PATCH 10/50] cleanup --- src/aspire/basis/fle_2d.py | 4 ++-- src/aspire/basis/steerable.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/aspire/basis/fle_2d.py b/src/aspire/basis/fle_2d.py index 3ef6313510..e67722f5af 100644 --- a/src/aspire/basis/fle_2d.py +++ b/src/aspire/basis/fle_2d.py @@ -13,7 +13,7 @@ from aspire.nufft import anufft, nufft from aspire.numeric import fft, sparse, xp from aspire.operators import DiagMatrix -from aspire.utils import complex_type, grid_1d, grid_2d +from aspire.utils import complex_type, grid_2d logger = logging.getLogger(__name__) @@ -763,7 +763,7 @@ def _radial_convolve_weights(self, b): b = xp.concatenate((b, bz), axis=0) b = fft.idct(b, axis=0, type=2) * 2 * b.shape[0] a = xp.zeros(self.count, dtype=self.dtype) - ## xx note these can be collapsed into one loop later + # xx note these can be collapsed into one loop later y = [None] * (self.ell_p_max + 1) for i in range(self.ell_p_max + 1): y[i] = (self.A3[i] @ b[:, 0]).flatten() diff --git a/src/aspire/basis/steerable.py b/src/aspire/basis/steerable.py index 9d40d0d3a9..dcc1d5edcd 100644 --- a/src/aspire/basis/steerable.py +++ b/src/aspire/basis/steerable.py @@ -6,9 +6,8 @@ import numpy as np from aspire.basis import Basis, Coef, ComplexCoef -from aspire.numeric import xp -from aspire.operators import BlkDiagMatrix, CTFFilter, DiagMatrix -from aspire.utils import LogFilterByCount, complex_type, grid_1d, real_type, trange +from aspire.operators import BlkDiagMatrix +from aspire.utils import LogFilterByCount, complex_type, real_type, trange logger = logging.getLogger(__name__) @@ -494,7 +493,8 @@ def filter_to_basis_mat(self, f, **kwargs): # does the basis have optimized expand for radial vectors? optimized_expand = callable(getattr(self.__class__, "expand_radial_vec", None)) # is the filter radial? - filter_is_radial = f.radial == True + filter_is_radial = f.radial is True + # did user request the special radial expansion method? radial_method = kwargs.get("expand_method", None) == "radial" if optimized_expand and filter_is_radial and radial_method: @@ -530,7 +530,7 @@ def _filter_to_basis_mat(self, f, expand_method=None, truncate=True, **kwargs): Return type will be based on the class's `matrix_type`. """ # evaluate_t is not as accurate, but much much faster... - if expand_method == "evaluate_t" or expand_method == None: + if expand_method == "evaluate_t" or expand_method is None: expand_fun = self.evaluate_t elif expand_method == "expand": expand_fun = self.expand From 85c96fae7e31fe47bf54c70cf047638ce47eff27 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 9 Apr 2026 14:00:18 -0400 Subject: [PATCH 11/50] remove warnings warning sysetm needs improving --- src/aspire/basis/steerable.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/aspire/basis/steerable.py b/src/aspire/basis/steerable.py index dcc1d5edcd..44a4e3de3f 100644 --- a/src/aspire/basis/steerable.py +++ b/src/aspire/basis/steerable.py @@ -1,6 +1,5 @@ import abc import logging -import warnings from collections.abc import Iterable import numpy as np @@ -500,15 +499,9 @@ def filter_to_basis_mat(self, f, **kwargs): if optimized_expand and filter_is_radial and radial_method: # kwargs supports passing through pixel_size h_vals = self._radial_filter_to_vals(f, **kwargs).reshape(-1, 1) - - warnings.warn("Using `expand_radial_vec`", UserWarning, stacklevel=1) res = self.expand_radial_vec(h_vals) - return res else: - warnings.warn( - "Using generic `_filter_to_basis_mat'", UserWarning, stacklevel=1 - ) # use generic (legacy) filter path/code (may return DiagMatrix) return self._filter_to_basis_mat(f, **kwargs) From 8d2c1dabf435ddcb97f1fc99496fec83ef01a50f Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 10 Apr 2026 09:10:57 -0400 Subject: [PATCH 12/50] fix logic error --- src/aspire/basis/fle_2d.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/aspire/basis/fle_2d.py b/src/aspire/basis/fle_2d.py index e67722f5af..3a6bf6121b 100644 --- a/src/aspire/basis/fle_2d.py +++ b/src/aspire/basis/fle_2d.py @@ -782,16 +782,16 @@ def _filter_to_basis_mat(self, f, **kwargs): """ # Note 'expand_method' and 'truncate' not relevant for this optimized FLE code. expand_method = kwargs.get("expand_method", None) + if expand_method not in [None, "radial"]: + raise NotImplementedError( + f"`FLEBasis2D.filter_to_basis_mat` expand_method '{expand_method}' not supported." + " Try `expand_method=None` or 'radial'." + ) if expand_method == "radial" and not f.radial: raise NotImplementedError( f"`FLEBasis2D.filter_to_basis_mat` expand_method '{expand_method}' not supported for non radial {f}." " Convert filter to radial for optimized radial expansion, or use `expand_method=None` for a radially averaged approximation." ) - elif expand_method is not None: - raise NotImplementedError( - f"`FLEBasis2D.filter_to_basis_mat` expand_method '{expand_method}' not supported." - " Try `expand_method=None` or 'radial'." - ) pixel_size = kwargs.get("pixel_size", None) From 979def8612979956da62b9fb6d6d9ff4d4b8b1e0 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 10 Apr 2026 09:33:52 -0400 Subject: [PATCH 13/50] refix logic error --- src/aspire/basis/fle_2d.py | 15 ++++++++++----- src/aspire/basis/steerable.py | 2 +- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/aspire/basis/fle_2d.py b/src/aspire/basis/fle_2d.py index 3a6bf6121b..dd0cdeaf3d 100644 --- a/src/aspire/basis/fle_2d.py +++ b/src/aspire/basis/fle_2d.py @@ -782,16 +782,21 @@ def _filter_to_basis_mat(self, f, **kwargs): """ # Note 'expand_method' and 'truncate' not relevant for this optimized FLE code. expand_method = kwargs.get("expand_method", None) - if expand_method not in [None, "radial"]: - raise NotImplementedError( - f"`FLEBasis2D.filter_to_basis_mat` expand_method '{expand_method}' not supported." - " Try `expand_method=None` or 'radial'." - ) + # This first check is to guide users to either + # the optimized purely radial (1d) calc, + # or the radially averaged (2d) calc. + # Errors requesting a (1d) calc on a non radial filter. + # Errors on unknown `expand_method`. if expand_method == "radial" and not f.radial: raise NotImplementedError( f"`FLEBasis2D.filter_to_basis_mat` expand_method '{expand_method}' not supported for non radial {f}." " Convert filter to radial for optimized radial expansion, or use `expand_method=None` for a radially averaged approximation." ) + elif expand_method is not None: + raise NotImplementedError( + f"`FLEBasis2D.filter_to_basis_mat` expand_method '{expand_method}' not supported." + " Try `expand_method=None` or 'radial'." + ) pixel_size = kwargs.get("pixel_size", None) diff --git a/src/aspire/basis/steerable.py b/src/aspire/basis/steerable.py index 44a4e3de3f..94cea0a96a 100644 --- a/src/aspire/basis/steerable.py +++ b/src/aspire/basis/steerable.py @@ -492,7 +492,7 @@ def filter_to_basis_mat(self, f, **kwargs): # does the basis have optimized expand for radial vectors? optimized_expand = callable(getattr(self.__class__, "expand_radial_vec", None)) # is the filter radial? - filter_is_radial = f.radial is True + filter_is_radial = f.radial # did user request the special radial expansion method? radial_method = kwargs.get("expand_method", None) == "radial" From c2232fb4ee29d80d2e707ab65217dd3d6cd58038 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 10 Apr 2026 13:29:26 -0400 Subject: [PATCH 14/50] add tqdm to cov2d filter to basis mat --- src/aspire/covariance/covar2d.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/aspire/covariance/covar2d.py b/src/aspire/covariance/covar2d.py index 81821c2e6c..f25a8c4beb 100644 --- a/src/aspire/covariance/covar2d.py +++ b/src/aspire/covariance/covar2d.py @@ -8,7 +8,7 @@ from aspire.basis import Coef, FFBBasis2D from aspire.operators import BlkDiagMatrix, DiagMatrix from aspire.optimization import conj_grad, fill_struct -from aspire.utils import make_symmat +from aspire.utils import make_symmat, tqdm logger = logging.getLogger(__name__) @@ -556,7 +556,7 @@ def _build(self): self.basis.filter_to_basis_mat( f, pixel_size=self.src.pixel_size, expand_method=self.expand_method ) - for f in unique_filters + for f in tqdm(unique_filters,desc='Converting filters to basis') ] logger.info("Represent CTF filters in basis complete") From 8106ecf92f881471f6b5ea93ae9eafaaca927bb6 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 14 Apr 2026 09:08:47 -0400 Subject: [PATCH 15/50] optimal fle basis comp --- src/aspire/basis/fle_2d.py | 10 +++--- src/aspire/covariance/covar2d.py | 60 ++++++++++++++++++++++++++------ src/aspire/operators/filters.py | 59 ++++++++++++++++++++++++------- 3 files changed, 102 insertions(+), 27 deletions(-) diff --git a/src/aspire/basis/fle_2d.py b/src/aspire/basis/fle_2d.py index dd0cdeaf3d..cd40e4441a 100644 --- a/src/aspire/basis/fle_2d.py +++ b/src/aspire/basis/fle_2d.py @@ -756,21 +756,21 @@ def _radial_convolve_weights(self, b): Helper function for step 3 of convolving with a radial function. """ # Developer note, this is equivalent `fle2d.expand_radial_vec` up to shapes. - b = xp.array(b) # implies copy + b = xp.asarray(b) # implies copy if self.num_interp > self.num_radial_nodes: b = fft.dct(b, axis=0, type=2) / (2 * self.num_radial_nodes) bz = xp.zeros(b.shape, dtype=self.dtype) b = xp.concatenate((b, bz), axis=0) b = fft.idct(b, axis=0, type=2) * 2 * b.shape[0] - a = xp.zeros(self.count, dtype=self.dtype) + a = xp.zeros((b.shape[-1], self.count), dtype=self.dtype) # xx note these can be collapsed into one loop later y = [None] * (self.ell_p_max + 1) for i in range(self.ell_p_max + 1): - y[i] = (self.A3[i] @ b[:, 0]).flatten() + y[i] = self.A3[i] @ b[:, :] # .flatten() for i in range(self.ell_p_max + 1): - a[self.idx_list[i]] = y[i] + a[:, self.idx_list[i]] = y[i].T - return a.flatten() + return a # .flatten() def _filter_to_basis_mat(self, f, **kwargs): """ diff --git a/src/aspire/covariance/covar2d.py b/src/aspire/covariance/covar2d.py index f25a8c4beb..505154aa7b 100644 --- a/src/aspire/covariance/covar2d.py +++ b/src/aspire/covariance/covar2d.py @@ -6,7 +6,8 @@ from scipy.linalg import solve, sqrtm from aspire.basis import Coef, FFBBasis2D -from aspire.operators import BlkDiagMatrix, DiagMatrix +from aspire.numeric import xp +from aspire.operators import BlkDiagMatrix, CTFFilter, DiagMatrix from aspire.optimization import conj_grad, fill_struct from aspire.utils import make_symmat, tqdm @@ -549,16 +550,55 @@ def _build(self): self.ctf_basis = [self._identity_mat()] else: - logger.info("Represent CTF filters in basis") - unique_filters = src.unique_filters + logger.info("Representing filters in basis") self.ctf_idx = src.filter_indices - self.ctf_basis = [ - self.basis.filter_to_basis_mat( - f, pixel_size=self.src.pixel_size, expand_method=self.expand_method - ) - for f in tqdm(unique_filters,desc='Converting filters to basis') - ] - logger.info("Represent CTF filters in basis complete") + self.ctf_basis = self.filter_to_basis_mats() + logger.info("Representing filters in basis complete") + + def filter_to_basis_mats(self): + if all(isinstance(f, CTFFilter) for f in self.src.unique_filters): + logger.info("Found all filters are CTF, using bulk basis mat eval") + return self._ctf_filter_to_basis_mats() + logger.info("Mixed filters, using sequential basis mat eval") + return self._filter_to_basis_mats() + + def _filter_to_basis_mats(self): + """ + old code, should work with all basis and filters. slow. + """ + unique_filters = self.src.unique_filters + basis_mats = [ + self.basis.filter_to_basis_mat( + f, pixel_size=self.src.pixel_size, expand_method=self.expand_method + ) + for f in tqdm(unique_filters, desc="Converting filters to basis mats") + ] + return basis_mats + + def _ctf_filter_to_basis_mats(self): + unique_filters = self.src.unique_filters + + # lol + logger.info("Extracting CTF filter parameters and generating eval points") + params = np.empty((len(unique_filters), 7), dtype=self.dtype) + for i, f in enumerate(unique_filters): + params[i] = f._ctf_params() + + _pts = xp.asnumpy(self.basis.nodes) + _filter_pts = np.pad(_pts.reshape(1, -1), ((0, 1), (0, 0))) * self.basis.h + + logger.info("Computing CTF filters at eval points") + # if we have many filters, might be worth trip to GPU + if len(unique_filters) >= 2048: + params = xp.asarray(params) + _filter_pts = xp.asarray(_filter_pts) + + _filter_vals = CTFFilter.ctf_formula( + _filter_pts, self.src.pixel_size, *(params.T) + ) + + logger.info("Computing basis radial expansion") + return [DiagMatrix(f) for f in self.basis.expand_radial_vec(_filter_vals.T)] def _calc_rhs(self): src = self.src diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index 002e13c47c..3366f8db29 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -448,9 +448,16 @@ def __init__( self.alpha = alpha self.B = B - # Convert angstrom to nm and divide by 2 - self._defocus_mean_nm = 0.05 * (self.defocus_u + self.defocus_v) - self._defocus_diff_nm = 0.05 * (self.defocus_u - self.defocus_v) + def _ctf_params(self): + return ( + self.voltage, + self.defocus_u, + self.defocus_v, + self.defocus_ang, + self.Cs, + self.alpha, + self.B, + ) def _evaluate(self, omega, **kwargs): # Ensure we have a pixel size, @@ -462,6 +469,22 @@ def _evaluate(self, omega, **kwargs): # and that it is a floating point value. pixel_size = float(pixel_size) + return self.ctf_formula( + omega, + pixel_size, + self.voltage, + self.defocus_u, + self.defocus_v, + self.defocus_ang, + self.Cs, + self.alpha, + self.B, + ) + + @staticmethod + def ctf_formula( + omega, pixel_size, voltage, defocus_u_a, defocus_v_a, defocus_ang, Cs, alpha, B + ): # Reference MATLAB code, includes reference to paper # Mindell, J. A.; Grigorieff, N. (2003). # https://github.com/PrincetonUniversity/aspire/blob/760a43b35453e55ff2d9354339e9ffa109a25371/projections/cryo_CTF_Relion.m#L34 @@ -472,6 +495,16 @@ def _evaluate(self, omega, **kwargs): # and further rescale the radii `s` by half below. # # Additionally we upcast so downstream computations remain in doubles. + + # First prepare arrays for broadcasting. + voltage = voltage[:, None] + defocus_u_a = defocus_u_a[:, None] + defocus_v_a = defocus_v_a[:, None] + defocus_ang = defocus_ang[:, None] + Cs = Cs[:, None] + alpha = alpha[:, None] + B = B[:, None] + x, y = omega.astype(np.float64, copy=False) / np.pi # Returns radii such that when multiplied by the @@ -479,29 +512,31 @@ def _evaluate(self, omega, **kwargs): # corresponding to each pixel in our nxn grid. theta, s = cart2pol(x, y) s = s / 2 + theta = theta[None, :] + s = s[None, :] # Wavelength in nm. - lamb = 1.22639 / np.sqrt(self.voltage * 1000 + 0.97845 * self.voltage**2) + lamb = 1.22639 / np.sqrt(voltage * 1000 + 0.97845 * voltage**2) # Divide by 10 to make pixel size in nm. BW is the # bandwidth of the signal corresponding to the given pixel size. BW = 1 / (pixel_size / 10) s = s * BW - DFavg = self._defocus_mean_nm # (DefocusU+DefocusV)/2 - DFdiff = self._defocus_diff_nm # (DefocusU-DefocusV) + DFavg = 0.05 * (defocus_u_a + defocus_v_a) # (u+v)/2 * 1nm/10A + DFdiff = 0.05 * (defocus_u_a - defocus_v_a) # (u-v)/2 * 1nm/10A # Note division by 2 is pre-computed in _defocus_diff_nm - df = DFavg + DFdiff * np.cos(2 * (theta - self.defocus_ang)) - + df = DFavg + DFdiff * np.cos(2 * (theta - defocus_ang)) k2 = np.pi * lamb * df # 10*6 converts Cs from mm to nm. - k4 = np.pi / 2 * 10**6 * self.Cs * lamb**3 + k4 = np.pi / 2 * 10**6 * Cs * lamb**3 chi = k4 * s**4 - k2 * s**2 - h = np.sqrt(1 - self.alpha**2) * np.sin(chi) - self.alpha * np.cos(chi) + alpha = alpha + h = np.sqrt(1 - alpha**2) * np.sin(chi) - alpha * np.cos(chi) - if self.B: - h *= np.exp(-self.B * s**2) + if np.any(B): + h *= np.exp(-B * s**2) return h From 3cc687ac87b8cbd9a60e2ffb3e85f62aca8b8c8a Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 16 Apr 2026 12:20:17 -0400 Subject: [PATCH 16/50] stub in bulk ctf code still needs cleanup --- src/aspire/basis/ffb_2d.py | 52 ++++++++++++++++++++++++++++++++ src/aspire/basis/fle_2d.py | 7 ++++- src/aspire/covariance/covar2d.py | 30 +++++++++--------- 3 files changed, 74 insertions(+), 15 deletions(-) diff --git a/src/aspire/basis/ffb_2d.py b/src/aspire/basis/ffb_2d.py index 64de6a0c5a..279f12c178 100644 --- a/src/aspire/basis/ffb_2d.py +++ b/src/aspire/basis/ffb_2d.py @@ -67,6 +67,11 @@ def _build(self): self._precomp["gl_nodes"] ) + # Generate radial filter point set for radial optimized eval + self._filter_pts = np.pad( + self._precomp["gl_nodes"].reshape(1, -1), ((0, 1), (0, 0)) + ) # should use self.gl_weighted_nodes ?? + def _precomp(self): """ Precomute the basis functions on a polar Fourier grid @@ -297,3 +302,50 @@ def _filter_to_basis_mat(self, f, **kwargs): ind_ell += 1 return h_basis + + def expand_radial_vec(self, h_vals): + """ """ + h_vals = h_vals.T # why fle transposed... + + # Set same dimensions as basis object + n_k = self.n_r + radial = self._precomp["radial"] + + # hrrmm, can we always use the basis precomp, or do we need to use lgwt as in the old filter_to_basis_mat? + k_vals = xp.asarray(self._precomp["gl_nodes"]) + wts = xp.asarray(self._precomp["gl_weights"]) + + # Represent 1D function values in basis + h_basis = [ + BlkDiagMatrix.empty(2 * self.ell_max + 1, dtype=self.dtype) for _ in h_vals + ] + + ind_ell = 0 + for ell in range(0, self.ell_max + 1): + k_max = self.k_max[ell] + rmat = ( + 2 * xp.asnumpy(k_vals.reshape(n_k, 1)) * self.r0[ell][0:k_max].T + ) # WHAT IN THE WORLD IS GOING ON HERE + basis_vals = xp.zeros_like(rmat) + ind_radial = np.sum(self.k_max[0:ell]) + basis_vals[:, 0:k_max] = xp.asarray( + radial[ind_radial : ind_radial + k_max] + ).T + h_basis_vals = basis_vals * h_vals.reshape(len(h_basis), n_k, 1) + h_basis_ell = basis_vals.T @ ( + h_basis_vals * k_vals.reshape(1, n_k, 1) * wts.reshape(1, n_k, 1) + ) + h_basis_ell = xp.asnumpy(h_basis_ell) + for _filter in range(len(h_vals)): + _tmp = h_basis[_filter][ind_ell] = h_basis_ell[_filter] + if ell > 0: + h_basis[_filter][ind_ell + 1] = _tmp + if _filter == len(h_vals) - 1: + ind_ell += 1 + if ell > 0: + ind_ell += 1 + + # might as well just take the diagonal elements + h_basis = [h.diag() for h in h_basis] + + return h_basis diff --git a/src/aspire/basis/fle_2d.py b/src/aspire/basis/fle_2d.py index cd40e4441a..55f9334a6b 100644 --- a/src/aspire/basis/fle_2d.py +++ b/src/aspire/basis/fle_2d.py @@ -317,6 +317,11 @@ def _compute_nufft_points(self): grid_xy[:] = grid_xy * self.nodes * h self.grid_xy = grid_xy.reshape(2, -1) + # Generate radial filter point set for radial optimized eval + self._filter_pts = ( + np.pad(xp.asnumpy(self.nodes).reshape(1, -1), ((0, 1), (0, 0))) * self.h + ) + def _build_interpolation_matrix(self): """ Create the matrix used in the third step of evaluate_t() and the first step of evaluate() @@ -849,7 +854,7 @@ def expand_radial_vec(self, radial_vec, **kwargs): # for now match old `filter_to_basis_mat` coefs = xp.asnumpy(coefs).squeeze() - return DiagMatrix(coefs) + return [DiagMatrix(c) for c in coefs] def _radial_filter_to_vals(self, f, **kwargs): """ diff --git a/src/aspire/covariance/covar2d.py b/src/aspire/covariance/covar2d.py index 505154aa7b..210fdc181a 100644 --- a/src/aspire/covariance/covar2d.py +++ b/src/aspire/covariance/covar2d.py @@ -552,17 +552,17 @@ def _build(self): else: logger.info("Representing filters in basis") self.ctf_idx = src.filter_indices - self.ctf_basis = self.filter_to_basis_mats() + self.ctf_basis = self.filters_to_basis_mats() logger.info("Representing filters in basis complete") - def filter_to_basis_mats(self): + def filters_to_basis_mats(self): if all(isinstance(f, CTFFilter) for f in self.src.unique_filters): logger.info("Found all filters are CTF, using bulk basis mat eval") - return self._ctf_filter_to_basis_mats() + return self._ctf_filters_to_basis_mats() logger.info("Mixed filters, using sequential basis mat eval") - return self._filter_to_basis_mats() + return self._filters_to_basis_mats() - def _filter_to_basis_mats(self): + def _filters_to_basis_mats(self): """ old code, should work with all basis and filters. slow. """ @@ -575,7 +575,7 @@ def _filter_to_basis_mats(self): ] return basis_mats - def _ctf_filter_to_basis_mats(self): + def _ctf_filters_to_basis_mats(self): unique_filters = self.src.unique_filters # lol @@ -584,21 +584,18 @@ def _ctf_filter_to_basis_mats(self): for i, f in enumerate(unique_filters): params[i] = f._ctf_params() - _pts = xp.asnumpy(self.basis.nodes) - _filter_pts = np.pad(_pts.reshape(1, -1), ((0, 1), (0, 0))) * self.basis.h - logger.info("Computing CTF filters at eval points") # if we have many filters, might be worth trip to GPU if len(unique_filters) >= 2048: params = xp.asarray(params) - _filter_pts = xp.asarray(_filter_pts) + _filter_pts = xp.asarray(self.basis._filter_pts) _filter_vals = CTFFilter.ctf_formula( _filter_pts, self.src.pixel_size, *(params.T) ) logger.info("Computing basis radial expansion") - return [DiagMatrix(f) for f in self.basis.expand_radial_vec(_filter_vals.T)] + return self.basis.expand_radial_vec(_filter_vals.T) def _calc_rhs(self): src = self.src @@ -667,7 +664,12 @@ def _calc_op(self): A_mean = BlkDiagMatrix.zeros(self.basis.blk_diag_cov_shape, self.dtype) A_covar = [None for _ in ctf_basis] - M_covar = self._zeros_mat() + # If we're given all diag filters, A_covar and M can be diag + if all(isinstance(c, DiagMatrix) for c in ctf_basis): + M_covar = DiagMatrix.zeros(self.basis.count, dtype=self.dtype) + # otherwise, take the default `matrix_type` for the basis + else: + M_covar = self._zeros_mat() for k in np.unique(ctf_idx): weight = float(np.count_nonzero(ctf_idx == k) / src.n) @@ -733,7 +735,7 @@ def _noise_correct_covar_rhs(self, b_covar, b_noise, noise_var, shrinker): def _solve_covar(self, A_covar, b_covar, M, covar_est_opt): method = self._solve_covar_cg - if self.basis.matrix_type == DiagMatrix: + if all(isinstance(a, DiagMatrix) for a in A_covar): method = self._solve_covar_direct return method(A_covar, b_covar, M, covar_est_opt) @@ -798,7 +800,7 @@ def apply(A, x): covar_coef[ell] = covar_coef_ell.reshape(p, p) t1 = perf_counter() - logger.info(f"_solve_covar_cgelapsed: {t1-t0}") + logger.info(f"_solve_covar_cg elapsed: {t1-t0}") return covar_coef def get_mean(self): From b40e4f6a8e1952e8d6bf35013aab5f692e4cf3ac Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 21 Apr 2026 11:16:08 -0400 Subject: [PATCH 17/50] tests passing except FFB opts --- src/aspire/basis/ffb_2d.py | 7 ++++--- src/aspire/basis/fle_2d.py | 24 +++++++++++++++--------- src/aspire/covariance/covar2d.py | 21 +++++++++++++++------ src/aspire/operators/filters.py | 14 +++++++------- 4 files changed, 41 insertions(+), 25 deletions(-) diff --git a/src/aspire/basis/ffb_2d.py b/src/aspire/basis/ffb_2d.py index 279f12c178..e08a325588 100644 --- a/src/aspire/basis/ffb_2d.py +++ b/src/aspire/basis/ffb_2d.py @@ -71,6 +71,9 @@ def _build(self): self._filter_pts = np.pad( self._precomp["gl_nodes"].reshape(1, -1), ((0, 1), (0, 0)) ) # should use self.gl_weighted_nodes ?? + # self._filter_pts = np.pad( + # xp.asnumpy(self.gl_weighted_nodes).reshape(1, -1), ((0, 1), (0, 0)) + # ) def _precomp(self): """ @@ -305,8 +308,6 @@ def _filter_to_basis_mat(self, f, **kwargs): def expand_radial_vec(self, h_vals): """ """ - h_vals = h_vals.T # why fle transposed... - # Set same dimensions as basis object n_k = self.n_r radial = self._precomp["radial"] @@ -346,6 +347,6 @@ def expand_radial_vec(self, h_vals): ind_ell += 1 # might as well just take the diagonal elements - h_basis = [h.diag() for h in h_basis] + # h_basis = [h.diag() for h in h_basis] return h_basis diff --git a/src/aspire/basis/fle_2d.py b/src/aspire/basis/fle_2d.py index 55f9334a6b..58ba43b1ae 100644 --- a/src/aspire/basis/fle_2d.py +++ b/src/aspire/basis/fle_2d.py @@ -743,9 +743,7 @@ def radial_convolve(self, coefs, radial_img): _coefs = coefs[k, :] z = self._step1_t(radial_img) b = self._step2_t(z) - # squeeze previously in _radial_convolve_weights - b = b.squeeze() - weights = self._radial_convolve_weights(b) + weights = self._radial_convolve_weights(b[..., 0]) b = weights / (self.h**2) b = b.reshape(self.count) coefs_conv[k, :] = (self.c2r @ (b * (self.r2c @ _coefs).flatten())).real @@ -759,23 +757,31 @@ def radial_convolve(self, coefs, radial_img): def _radial_convolve_weights(self, b): """ Helper function for step 3 of convolving with a radial function. + + :param b: Radial vector or stack of radial vectors """ # Developer note, this is equivalent `fle2d.expand_radial_vec` up to shapes. + # Convert vector to (1,...) + if b.ndim == 1: + b = b.reshape(1, *b.shape) b = xp.asarray(b) # implies copy if self.num_interp > self.num_radial_nodes: - b = fft.dct(b, axis=0, type=2) / (2 * self.num_radial_nodes) + b = fft.dct(b, axis=1, type=2) / (2 * self.num_radial_nodes) bz = xp.zeros(b.shape, dtype=self.dtype) - b = xp.concatenate((b, bz), axis=0) - b = fft.idct(b, axis=0, type=2) * 2 * b.shape[0] - a = xp.zeros((b.shape[-1], self.count), dtype=self.dtype) + b = xp.concatenate((b, bz), axis=1) + b = fft.idct(b, axis=1, type=2) * 2 * b.shape[1] + a = xp.zeros((b.shape[0], self.count), dtype=self.dtype) # xx note these can be collapsed into one loop later y = [None] * (self.ell_p_max + 1) for i in range(self.ell_p_max + 1): - y[i] = self.A3[i] @ b[:, :] # .flatten() + # Wierd mul transpose forced by A3 being CSR. + # Can't reshape A3, but can broadcast over last dim of b. + # T here (c, num_img) and y[i].T below back to (num_img, c) + y[i] = self.A3[i] @ b[:, :].T for i in range(self.ell_p_max + 1): a[:, self.idx_list[i]] = y[i].T - return a # .flatten() + return a def _filter_to_basis_mat(self, f, **kwargs): """ diff --git a/src/aspire/covariance/covar2d.py b/src/aspire/covariance/covar2d.py index 210fdc181a..6312d4fdc1 100644 --- a/src/aspire/covariance/covar2d.py +++ b/src/aspire/covariance/covar2d.py @@ -556,11 +556,19 @@ def _build(self): logger.info("Representing filters in basis complete") def filters_to_basis_mats(self): - if all(isinstance(f, CTFFilter) for f in self.src.unique_filters): - logger.info("Found all filters are CTF, using bulk basis mat eval") + optimized_expand = callable( + getattr(self.basis.__class__, "expand_radial_vec", None) + ) + if optimized_expand and all( + isinstance(f, CTFFilter) for f in self.src.unique_filters + ): + logger.info( + "Found all filters are CTF, and `basis.expand_radial_vec` available using, bulk basis mat eval" + ) return self._ctf_filters_to_basis_mats() - logger.info("Mixed filters, using sequential basis mat eval") - return self._filters_to_basis_mats() + else: + logger.info("Using sequential basis mat eval") + return self._filters_to_basis_mats() def _filters_to_basis_mats(self): """ @@ -585,17 +593,18 @@ def _ctf_filters_to_basis_mats(self): params[i] = f._ctf_params() logger.info("Computing CTF filters at eval points") + _filter_pts = self.basis._filter_pts # if we have many filters, might be worth trip to GPU if len(unique_filters) >= 2048: params = xp.asarray(params) - _filter_pts = xp.asarray(self.basis._filter_pts) + _filter_pts = xp.asarray(_filter_pts) _filter_vals = CTFFilter.ctf_formula( _filter_pts, self.src.pixel_size, *(params.T) ) logger.info("Computing basis radial expansion") - return self.basis.expand_radial_vec(_filter_vals.T) + return self.basis.expand_radial_vec(_filter_vals) def _calc_rhs(self): src = self.src diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index 3366f8db29..b89f90487c 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -497,13 +497,13 @@ def ctf_formula( # Additionally we upcast so downstream computations remain in doubles. # First prepare arrays for broadcasting. - voltage = voltage[:, None] - defocus_u_a = defocus_u_a[:, None] - defocus_v_a = defocus_v_a[:, None] - defocus_ang = defocus_ang[:, None] - Cs = Cs[:, None] - alpha = alpha[:, None] - B = B[:, None] + voltage = np.atleast_1d(voltage)[:, None] + defocus_u_a = np.atleast_1d(defocus_u_a)[:, None] + defocus_v_a = np.atleast_1d(defocus_v_a)[:, None] + defocus_ang = np.atleast_1d(defocus_ang)[:, None] + Cs = np.atleast_1d(Cs)[:, None] + alpha = np.atleast_1d(alpha)[:, None] + B = np.atleast_1d(B)[:, None] x, y = omega.astype(np.float64, copy=False) / np.pi From 5253d29a5b675595de6272b97aaf7a96a385a13f Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 21 Apr 2026 15:05:02 -0400 Subject: [PATCH 18/50] ffb equiv checkpoint --- src/aspire/basis/ffb_2d.py | 26 +++++++++++++++++----- tests/test_FFBbasis2D.py | 44 +++++++++++++++++++++++++++++++++++++- 2 files changed, 64 insertions(+), 6 deletions(-) diff --git a/src/aspire/basis/ffb_2d.py b/src/aspire/basis/ffb_2d.py index e08a325588..3e495f9330 100644 --- a/src/aspire/basis/ffb_2d.py +++ b/src/aspire/basis/ffb_2d.py @@ -68,13 +68,24 @@ def _build(self): ) # Generate radial filter point set for radial optimized eval - self._filter_pts = np.pad( - self._precomp["gl_nodes"].reshape(1, -1), ((0, 1), (0, 0)) - ) # should use self.gl_weighted_nodes ?? # self._filter_pts = np.pad( - # xp.asnumpy(self.gl_weighted_nodes).reshape(1, -1), ((0, 1), (0, 0)) + # self._precomp["gl_nodes"].reshape(1, -1), ((0, 1), (0, 0)) # ) + # Set same dimensions as basis object + n_k = self.n_r + n_theta = self.n_theta + radial = self._precomp["radial"] + + # get 2D grid in polar coordinate + k_vals, wts = lgwt(n_k, 0, 0.5, dtype=self.dtype) + k, theta = np.meshgrid(k_vals, np.array(0), indexing="ij") + + # Get function values in polar 2D grid and average out angle contribution + omegax = k * np.cos(theta) + self._filter_pts = np.pad(2 * np.pi * omegax.reshape(1, -1), ((0, 1), (0, 0))) + breakpoint() + def _precomp(self): """ Precomute the basis functions on a polar Fourier grid @@ -307,7 +318,10 @@ def _filter_to_basis_mat(self, f, **kwargs): return h_basis def expand_radial_vec(self, h_vals): - """ """ + # Convert vector to (1,...) + if h_vals.ndim == 1: + h_vals = h_vals.reshape(1, *h_vals.shape) + # Set same dimensions as basis object n_k = self.n_r radial = self._precomp["radial"] @@ -320,6 +334,8 @@ def expand_radial_vec(self, h_vals): h_basis = [ BlkDiagMatrix.empty(2 * self.ell_max + 1, dtype=self.dtype) for _ in h_vals ] + print(h_vals) + breakpoint() ind_ell = 0 for ell in range(0, self.ell_max + 1): diff --git a/tests/test_FFBbasis2D.py b/tests/test_FFBbasis2D.py index e0132971cf..f4c3cf8ce5 100644 --- a/tests/test_FFBbasis2D.py +++ b/tests/test_FFBbasis2D.py @@ -7,8 +7,9 @@ from aspire.basis import Coef, FFBBasis2D from aspire.nufft import all_backends +from aspire.operators import RadialCTFFilter from aspire.source import Simulation -from aspire.utils.misc import grid_2d +from aspire.utils import grid_2d, utest_tolerance from aspire.volume import Volume from ._basis_util import Steerable2DMixin, UniversalBasisMixin, basis_params_2d @@ -161,3 +162,44 @@ def testHighResFFBBasis2D(L, dtype): np.testing.assert_allclose( im_ffb.asnumpy()[0][mask], im.asnumpy()[0][mask], rtol=1e-05, atol=1e-4 ) + + +def test_bulk_expand_radial_vec(): + """ + For a given stack of radial vectors (such as from + RadialCTFFilters) `expand_radial_vec` should return equivalent + result as calling filter_to_basis_mat on each filter. + """ + + L = 32 + dtype = np.float32 + basis = FFBBasis2D(L, dtype=dtype) + pixel_size = 1.23 + + filters = [RadialCTFFilter(defocus=d) for d in np.linspace(10000, 15000, 3)] + + references = [basis.filter_to_basis_mat(f, pixel_size=pixel_size) for f in filters] + + # from cov code + params = np.empty((len(filters), 7), dtype=dtype) + for i, f in enumerate(filters): + params[i] = f._ctf_params() + + _filter_vals = RadialCTFFilter.ctf_formula( + basis._filter_pts, pixel_size, *(params.T) + ) + + results = basis.expand_radial_vec(_filter_vals) + results2 = [basis.expand_radial_vec(f)[0] for f in _filter_vals] + + # expand_radial_vec should be same as itself called sequentially + assert len(results2) == len(results) + for res, ref in zip(results2, results): + np.testing.assert_allclose(res.dense(), ref.dense()) + + # and should be equivalent to calling filter_to_basis_mat + assert len(results) == len(references) + for res, ref in zip(results, references): + np.testing.assert_allclose( + res.dense(), ref.dense(), atol=utest_tolerance(dtype) + ) From 53026c298fd66bf9068f898cbe71df63936336de Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 21 Apr 2026 15:10:45 -0400 Subject: [PATCH 19/50] cleanup --- src/aspire/basis/ffb_2d.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/aspire/basis/ffb_2d.py b/src/aspire/basis/ffb_2d.py index 3e495f9330..dde4a2cd3c 100644 --- a/src/aspire/basis/ffb_2d.py +++ b/src/aspire/basis/ffb_2d.py @@ -75,7 +75,6 @@ def _build(self): # Set same dimensions as basis object n_k = self.n_r n_theta = self.n_theta - radial = self._precomp["radial"] # get 2D grid in polar coordinate k_vals, wts = lgwt(n_k, 0, 0.5, dtype=self.dtype) @@ -84,7 +83,6 @@ def _build(self): # Get function values in polar 2D grid and average out angle contribution omegax = k * np.cos(theta) self._filter_pts = np.pad(2 * np.pi * omegax.reshape(1, -1), ((0, 1), (0, 0))) - breakpoint() def _precomp(self): """ @@ -334,8 +332,6 @@ def expand_radial_vec(self, h_vals): h_basis = [ BlkDiagMatrix.empty(2 * self.ell_max + 1, dtype=self.dtype) for _ in h_vals ] - print(h_vals) - breakpoint() ind_ell = 0 for ell in range(0, self.ell_max + 1): From 922536dc7aebe73cff3650e3e67e7f1b480b2c19 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 22 Apr 2026 07:27:06 -0400 Subject: [PATCH 20/50] cleanup --- src/aspire/basis/ffb_2d.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/aspire/basis/ffb_2d.py b/src/aspire/basis/ffb_2d.py index dde4a2cd3c..eb15ddf974 100644 --- a/src/aspire/basis/ffb_2d.py +++ b/src/aspire/basis/ffb_2d.py @@ -358,7 +358,4 @@ def expand_radial_vec(self, h_vals): if ell > 0: ind_ell += 1 - # might as well just take the diagonal elements - # h_basis = [h.diag() for h in h_basis] - return h_basis From 2e34d5f6966a1443284b506c725f8ed18b5a879e Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 22 Apr 2026 08:16:49 -0400 Subject: [PATCH 21/50] cleanup FFB doc strings --- src/aspire/basis/ffb_2d.py | 52 ++++++++++++++++++++++---------------- src/aspire/basis/fle_2d.py | 8 ++++++ 2 files changed, 38 insertions(+), 22 deletions(-) diff --git a/src/aspire/basis/ffb_2d.py b/src/aspire/basis/ffb_2d.py index eb15ddf974..824f0f8971 100644 --- a/src/aspire/basis/ffb_2d.py +++ b/src/aspire/basis/ffb_2d.py @@ -68,22 +68,16 @@ def _build(self): ) # Generate radial filter point set for radial optimized eval + k_vals, _ = lgwt(self.n_r, 0, 0.5, dtype=self.dtype) + self._filter_pts = np.pad(2 * np.pi * k_vals.reshape(1, -1), ((0, 1), (0, 0))) + + # Ask Joakim about this... + # Why does filter_to_basis_mat hard code lgwt instead of following basis self.kcut + # they are the same by default. # self._filter_pts = np.pad( - # self._precomp["gl_nodes"].reshape(1, -1), ((0, 1), (0, 0)) + # 2 * np.pi * self._precomp["gl_nodes"].reshape(1, -1), ((0, 1), (0, 0)) # ) - # Set same dimensions as basis object - n_k = self.n_r - n_theta = self.n_theta - - # get 2D grid in polar coordinate - k_vals, wts = lgwt(n_k, 0, 0.5, dtype=self.dtype) - k, theta = np.meshgrid(k_vals, np.array(0), indexing="ij") - - # Get function values in polar 2D grid and average out angle contribution - omegax = k * np.cos(theta) - self._filter_pts = np.pad(2 * np.pi * omegax.reshape(1, -1), ((0, 1), (0, 0))) - def _precomp(self): """ Precomute the basis functions on a polar Fourier grid @@ -257,7 +251,8 @@ def _filter_to_basis_mat(self, f, **kwargs): """ See `SteerableBasis2D.filter_to_basis_mat`. """ - # Note 'method' and 'truncate' not relevant for this optimized FFB code. + # Note 'method' and 'truncate' not relevant for this specific FFB code. + # Method `radial` should have already been diverted. expand_method = kwargs.get("expand_method", None) if expand_method is not None: raise NotImplementedError( @@ -279,6 +274,7 @@ def _filter_to_basis_mat(self, f, **kwargs): radial = self._precomp["radial"] # get 2D grid in polar coordinate + # Confirm this lgwt call with Joakim (should it follow basis config self.kcut? same by default) k_vals, wts = lgwt(n_k, 0, 0.5, dtype=self.dtype) k, theta = np.meshgrid( k_vals, np.arange(n_theta) * 2 * np.pi / (2 * n_theta), indexing="ij" @@ -315,22 +311,32 @@ def _filter_to_basis_mat(self, f, **kwargs): return h_basis - def expand_radial_vec(self, h_vals): + def expand_radial_vec(self, radial_vec, force_diag=False): + """ + Expands radial vector or stack of vetors `radial_vec` to basis matrix. + + :param radial_vec: Array holding radial vector, + shaped (n_radial_pts) or (n_vectors, n_radial_pts) + :force_diag: Optionally flush off-diagonal elements to zero and return `DiagMatrix` + :return: List of `BlkDiagMatrix`, or list of `DiagMatrix` + """ # Convert vector to (1,...) - if h_vals.ndim == 1: - h_vals = h_vals.reshape(1, *h_vals.shape) + if radial_vec.ndim == 1: + radial_vec = radial_vec.reshape(1, *radial_vec.shape) # Set same dimensions as basis object n_k = self.n_r radial = self._precomp["radial"] - # hrrmm, can we always use the basis precomp, or do we need to use lgwt as in the old filter_to_basis_mat? + # hrrmm, ask Joakim can we always use the basis precomp, or do we need to use lgwt as in the old filter_to_basis_mat? + # This is doing opposite logic (same result) by default. Joy. k_vals = xp.asarray(self._precomp["gl_nodes"]) wts = xp.asarray(self._precomp["gl_weights"]) # Represent 1D function values in basis h_basis = [ - BlkDiagMatrix.empty(2 * self.ell_max + 1, dtype=self.dtype) for _ in h_vals + BlkDiagMatrix.empty(2 * self.ell_max + 1, dtype=self.dtype) + for _ in radial_vec ] ind_ell = 0 @@ -344,18 +350,20 @@ def expand_radial_vec(self, h_vals): basis_vals[:, 0:k_max] = xp.asarray( radial[ind_radial : ind_radial + k_max] ).T - h_basis_vals = basis_vals * h_vals.reshape(len(h_basis), n_k, 1) + h_basis_vals = basis_vals * radial_vec.reshape(len(h_basis), n_k, 1) h_basis_ell = basis_vals.T @ ( h_basis_vals * k_vals.reshape(1, n_k, 1) * wts.reshape(1, n_k, 1) ) h_basis_ell = xp.asnumpy(h_basis_ell) - for _filter in range(len(h_vals)): + for _filter in range(len(radial_vec)): _tmp = h_basis[_filter][ind_ell] = h_basis_ell[_filter] if ell > 0: h_basis[_filter][ind_ell + 1] = _tmp - if _filter == len(h_vals) - 1: + if _filter == len(radial_vec) - 1: ind_ell += 1 if ell > 0: ind_ell += 1 + if force_diag: + h_basis = [h.diag() for h in h_basis] return h_basis diff --git a/src/aspire/basis/fle_2d.py b/src/aspire/basis/fle_2d.py index 58ba43b1ae..4477d96f13 100644 --- a/src/aspire/basis/fle_2d.py +++ b/src/aspire/basis/fle_2d.py @@ -850,6 +850,14 @@ def _filter_to_basis_mat(self, f, **kwargs): return DiagMatrix(xp.asnumpy(h_basis)) def expand_radial_vec(self, radial_vec, **kwargs): + """ + Expands radial vector or stack of vectors `radial_vec` to basis matrix. + + :param radial_vec: Array holding radial vector, + shaped (n_radial_pts) or (n_vectors, n_radial_pts) + :return: List of `DiagMatrix` + """ + coefs = self._radial_convolve_weights(radial_vec) # check... From e9c61df46f1e44c80d375b38411fe12aec850f5c Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 22 Apr 2026 08:30:55 -0400 Subject: [PATCH 22/50] dtype sensitivity --- src/aspire/basis/ffb_2d.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/aspire/basis/ffb_2d.py b/src/aspire/basis/ffb_2d.py index 824f0f8971..7ceaf35924 100644 --- a/src/aspire/basis/ffb_2d.py +++ b/src/aspire/basis/ffb_2d.py @@ -68,8 +68,11 @@ def _build(self): ) # Generate radial filter point set for radial optimized eval - k_vals, _ = lgwt(self.n_r, 0, 0.5, dtype=self.dtype) - self._filter_pts = np.pad(2 * np.pi * k_vals.reshape(1, -1), ((0, 1), (0, 0))) + # Weights appear a little sensitive to dtype ... + k_vals, _ = lgwt(self.n_r, 0, 0.5, dtype=np.float64) + self._filter_pts = np.pad( + 2 * np.pi * k_vals.reshape(1, -1), ((0, 1), (0, 0)) + ).astype(self.dtype) # Ask Joakim about this... # Why does filter_to_basis_mat hard code lgwt instead of following basis self.kcut From 33a7f3c2d69e4a6c87f337114960aa5dbee2d9e3 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 22 Apr 2026 11:37:10 -0400 Subject: [PATCH 23/50] more cleanup, maybe CI passing --- src/aspire/basis/ffb_2d.py | 2 ++ src/aspire/numeric/cupy.py | 13 +++++++++++++ src/aspire/operators/filters.py | 25 ++++++++++++++++++------- 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/src/aspire/basis/ffb_2d.py b/src/aspire/basis/ffb_2d.py index 7ceaf35924..239e945ec5 100644 --- a/src/aspire/basis/ffb_2d.py +++ b/src/aspire/basis/ffb_2d.py @@ -326,6 +326,8 @@ def expand_radial_vec(self, radial_vec, force_diag=False): # Convert vector to (1,...) if radial_vec.ndim == 1: radial_vec = radial_vec.reshape(1, *radial_vec.shape) + # Optionally transfer to GPU + radial_vec = xp.asarray(radial_vec) # Set same dimensions as basis object n_k = self.n_r diff --git a/src/aspire/numeric/cupy.py b/src/aspire/numeric/cupy.py index f02aa393c2..2853cd7b7a 100644 --- a/src/aspire/numeric/cupy.py +++ b/src/aspire/numeric/cupy.py @@ -1,4 +1,5 @@ import cupy as cp +import numpy as np class Cupy: @@ -7,3 +8,15 @@ def __getattr__(self, item): Catch-all method to to allow a straight pass-through of any attribute that is not supported above. """ return getattr(cp, item) + + @staticmethod + def atleast_1d(x): + """ + Provide an agnostic `atleast_1d`. + + Returns same type as input. + """ + _fn = np.atleast_1d + if cp and isinstance(x, cp.ndarray): + _fn = cp.atleast_1d + return _fn(x) diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index b89f90487c..f1c8ae4419 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -6,6 +6,7 @@ from scipy.interpolate import RegularGridInterpolator from aspire import config +from aspire.numeric import xp from aspire.utils import cart2pol, grid_2d, voltage_to_wavelength logger = logging.getLogger(__name__) @@ -497,13 +498,23 @@ def ctf_formula( # Additionally we upcast so downstream computations remain in doubles. # First prepare arrays for broadcasting. - voltage = np.atleast_1d(voltage)[:, None] - defocus_u_a = np.atleast_1d(defocus_u_a)[:, None] - defocus_v_a = np.atleast_1d(defocus_v_a)[:, None] - defocus_ang = np.atleast_1d(defocus_ang)[:, None] - Cs = np.atleast_1d(Cs)[:, None] - alpha = np.atleast_1d(alpha)[:, None] - B = np.atleast_1d(B)[:, None] + voltage = xp.atleast_1d(voltage)[:, None] + defocus_u_a = xp.atleast_1d(defocus_u_a)[:, None] + defocus_v_a = xp.atleast_1d(defocus_v_a)[:, None] + defocus_ang = xp.atleast_1d(defocus_ang)[:, None] + Cs = xp.atleast_1d(Cs)[:, None] + alpha = xp.atleast_1d(alpha)[:, None] + B = xp.atleast_1d(B)[:, None] + + # If omega is on GPU, move the params over + if isinstance(omega, xp.ndarray) and not isinstance(omega, np.ndarray): + voltage = xp.asarray(voltage) + defocus_u_a = xp.array(defocus_u_a) + defocus_v_a = xp.array(defocus_v_a) + defocus_ang = xp.array(defocus_ang) + Cs = xp.array(Cs) + alpha = xp.array(alpha) + B = xp.array(B) x, y = omega.astype(np.float64, copy=False) / np.pi From 39d821164ec110f64bbf65cf85253b7bdb4dc1a7 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 22 Apr 2026 15:13:01 -0400 Subject: [PATCH 24/50] more cleanup, maybe docs runs --- src/aspire/operators/filters.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index f1c8ae4419..3ad2c45ed2 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -37,8 +37,9 @@ def evaluate_src_filters_on_grid(src, indices=None): idx_k = np.where(src.filter_indices[indices] == i)[0] if len(idx_k) > 0: filter_values = filt.evaluate(omega, pixel_size=src.pixel_size) - h[:, idx_k] = np.column_stack((filter_values,) * len(idx_k)) - + # convert filter_values row vector to column vector and tile broadcast + filter_values = filter_values.reshape(-1, 1) + h[:, idx_k] = np.tile(filter_values, len(idx_k)) h = np.reshape(h, grid2d["x"].shape + (len(indices),)) return h From 83ac18867f4053194b7c70b7ddcad7d678ff4c33 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 23 Apr 2026 08:18:59 -0400 Subject: [PATCH 25/50] cleanup extra loop --- src/aspire/basis/fle_2d.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/aspire/basis/fle_2d.py b/src/aspire/basis/fle_2d.py index 4477d96f13..e39e26b3a2 100644 --- a/src/aspire/basis/fle_2d.py +++ b/src/aspire/basis/fle_2d.py @@ -771,15 +771,11 @@ def _radial_convolve_weights(self, b): b = xp.concatenate((b, bz), axis=1) b = fft.idct(b, axis=1, type=2) * 2 * b.shape[1] a = xp.zeros((b.shape[0], self.count), dtype=self.dtype) - # xx note these can be collapsed into one loop later - y = [None] * (self.ell_p_max + 1) for i in range(self.ell_p_max + 1): # Wierd mul transpose forced by A3 being CSR. - # Can't reshape A3, but can broadcast over last dim of b. - # T here (c, num_img) and y[i].T below back to (num_img, c) - y[i] = self.A3[i] @ b[:, :].T - for i in range(self.ell_p_max + 1): - a[:, self.idx_list[i]] = y[i].T + # Can't reshape A3, but can broadcast over dims of b. + # T b first to yield (cnt, num_img) and T result back to (num_img, cnt) + a[:, self.idx_list[i]] = (self.A3[i] @ b[:, :].T).T return a From 751eb5e3ebd7f737f85030dbc6be1cdb40037bb9 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 29 Apr 2026 08:58:57 -0400 Subject: [PATCH 26/50] add force diag option to cov2d --- src/aspire/covariance/covar2d.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/aspire/covariance/covar2d.py b/src/aspire/covariance/covar2d.py index 6312d4fdc1..8530d8f4bf 100644 --- a/src/aspire/covariance/covar2d.py +++ b/src/aspire/covariance/covar2d.py @@ -522,7 +522,7 @@ class BatchedRotCov2D(RotCov2D): scaling up to a larger value such as 8192 may yield better performance. """ - def __init__(self, src, basis=None, expand_method=None, batch_size=512): + def __init__(self, src, basis=None, expand_method=None, force_diag=False, batch_size=512): self.src = src self.basis = basis self.batch_size = batch_size @@ -534,6 +534,7 @@ def __init__(self, src, basis=None, expand_method=None, batch_size=512): self.A_covar = None self.M_covar = None self.expand_method = expand_method + self.force_diag = force_diag self._build() @@ -604,7 +605,7 @@ def _ctf_filters_to_basis_mats(self): ) logger.info("Computing basis radial expansion") - return self.basis.expand_radial_vec(_filter_vals) + return self.basis.expand_radial_vec(_filter_vals, force_diag=self.force_diag) def _calc_rhs(self): src = self.src From c998bf020f7ee13b13c4332fc8f824c2f2bfeaf3 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 6 May 2026 14:59:00 -0400 Subject: [PATCH 27/50] ctf stack unit test patches --- src/aspire/covariance/covar2d.py | 7 ++- src/aspire/operators/filters.py | 90 ++++++++++++++++++++++++++------ src/aspire/source/simulation.py | 7 +-- tests/test_FFBbasis2D.py | 3 +- tests/test_coordinate_source.py | 52 +++++++++--------- 5 files changed, 109 insertions(+), 50 deletions(-) diff --git a/src/aspire/covariance/covar2d.py b/src/aspire/covariance/covar2d.py index 8530d8f4bf..7c8c45685f 100644 --- a/src/aspire/covariance/covar2d.py +++ b/src/aspire/covariance/covar2d.py @@ -522,7 +522,9 @@ class BatchedRotCov2D(RotCov2D): scaling up to a larger value such as 8192 may yield better performance. """ - def __init__(self, src, basis=None, expand_method=None, force_diag=False, batch_size=512): + def __init__( + self, src, basis=None, expand_method=None, force_diag=False, batch_size=512 + ): self.src = src self.basis = basis self.batch_size = batch_size @@ -591,7 +593,8 @@ def _ctf_filters_to_basis_mats(self): logger.info("Extracting CTF filter parameters and generating eval points") params = np.empty((len(unique_filters), 7), dtype=self.dtype) for i, f in enumerate(unique_filters): - params[i] = f._ctf_params() + ### TODO xxx fix up param dump, same as in source/sim + params[i] = np.array(f._ctf_params()).flatten() logger.info("Computing CTF filters at eval points") _filter_pts = self.basis._filter_pts diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index 3ad2c45ed2..16562b8389 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -114,7 +114,8 @@ def scale(self, c=1): """ return ScaledFilter(self, c) - @lru_cache(maxsize=config["cache"]["filter_cache_size"].get()) # noqa: B019 + # this cache no longer appears to work? now unhashable? (fine but why not before?) + # @lru_cache(maxsize=config["cache"]["filter_cache_size"].get()) # noqa: B019 def evaluate_grid(self, L, *args, dtype=np.float32, **kwargs): """ Generates a two dimensional grid with prescribed dtype, @@ -441,14 +442,62 @@ def __init__( :param B: Envelope decay in inverse square angstrom (default 0) """ super().__init__(dim=2, radial=defocus_u == defocus_v) - self.voltage = voltage - self.wavelength = voltage_to_wavelength(self.voltage) - self.defocus_u = defocus_u - self.defocus_v = defocus_v - self.defocus_ang = defocus_ang - self.Cs = Cs - self.alpha = alpha - self.B = B + voltage = np.atleast_1d(voltage) # maybe allow singleton here for V + defocus_u = np.atleast_1d(defocus_u) + defocus_v = np.atleast_1d(defocus_v) + defocus_ang = np.atleast_1d(defocus_ang) + Cs = np.atleast_1d(Cs) + alpha = np.atleast_1d(alpha) + B = np.atleast_1d(B) + + # TODO check all sizes match + self.n = max( + len(voltage), + len(defocus_u), + len(defocus_v), + len(defocus_ang), + len(Cs), + len(alpha), + len(B), + ) + + self.voltage = self._to_full(voltage) + self.defocus_u = self._to_full(defocus_u) + self.defocus_v = self._to_full(defocus_v) + self.defocus_ang = self._to_full(defocus_ang) + self.Cs = self._to_full(Cs) + self.alpha = self._to_full(alpha) + self.B = self._to_full(B) + + # derived value + # todo, check/fix broadcast in voltage_to_wavelength + self.wavelength = np.array([voltage_to_wavelength(v) for v in self.voltage]) + + def _to_full(self, vals): + if len(vals) == self.n: + return vals + elif len(vals) == 1: + return np.full(self.n, fill_value=vals[0]) + else: + raise RuntimeError("dont do that") + + def __getitem__(self, items): + return CTFFilter( + self.voltage[items], + # self.wavelength[items], + self.defocus_u[items], + self.defocus_v[items], + self.defocus_ang[items], + self.Cs[items], + self.alpha[items], + self.B[items], + ) + + def __len__(self): + """ + Return stack length + """ + return self.n def _ctf_params(self): return ( @@ -462,6 +511,10 @@ def _ctf_params(self): ) def _evaluate(self, omega, **kwargs): + indices = kwargs.get("indices", None) + if indices is None: + indices = np.arange(self.n) + # Ensure we have a pixel size, pixel_size = kwargs.get("pixel_size", None) if pixel_size is None: @@ -474,13 +527,13 @@ def _evaluate(self, omega, **kwargs): return self.ctf_formula( omega, pixel_size, - self.voltage, - self.defocus_u, - self.defocus_v, - self.defocus_ang, - self.Cs, - self.alpha, - self.B, + self.voltage[indices], + self.defocus_u[indices], + self.defocus_v[indices], + self.defocus_ang[indices], + self.Cs[indices], + self.alpha[indices], + self.B[indices], ) @staticmethod @@ -567,6 +620,11 @@ def to_radial(self): B=self.B, ) + def __eq__(self, other): + if len(self) != len(other): + return False + return self._ctf_params() == other._ctf_params() + class RadialCTFFilter(CTFFilter): def __init__(self, voltage=200, defocus=15000, Cs=2.26, alpha=0.07, B=0): diff --git a/src/aspire/source/simulation.py b/src/aspire/source/simulation.py index b5de869a67..98ff6d9e39 100644 --- a/src/aspire/source/simulation.py +++ b/src/aspire/source/simulation.py @@ -233,9 +233,10 @@ def _populate_ctf_metadata(self, filter_indices): # and for each image (indexed by filter_indices) filter_values = np.zeros((len(filter_indices), len(CTFFilter_attributes))) for i, filt in enumerate(self.unique_filters): - filter_values[filter_indices == i] = [ - getattr(filt, att, np.nan) for att in CTFFilter_attributes - ] + # TODO xxx change to param dump later + filter_values[filter_indices == i] = np.array( + [getattr(filt, att, np.nan) for att in CTFFilter_attributes] + ).flatten() # set the corresponding Relion metadata values that we would expect # from a STAR file self.set_metadata( diff --git a/tests/test_FFBbasis2D.py b/tests/test_FFBbasis2D.py index f4c3cf8ce5..cd07b2ff84 100644 --- a/tests/test_FFBbasis2D.py +++ b/tests/test_FFBbasis2D.py @@ -183,7 +183,8 @@ def test_bulk_expand_radial_vec(): # from cov code params = np.empty((len(filters), 7), dtype=dtype) for i, f in enumerate(filters): - params[i] = f._ctf_params() + # TODO xxx fix param dump, same as sim/source + params[i] = np.array(f._ctf_params()).flatten() _filter_vals = RadialCTFFilter.ctf_formula( basis._filter_pts, pixel_size, *(params.T) diff --git a/tests/test_coordinate_source.py b/tests/test_coordinate_source.py index 3b0f29f1ee..3d521832e4 100644 --- a/tests/test_coordinate_source.py +++ b/tests/test_coordinate_source.py @@ -626,37 +626,34 @@ def _testCtfFilters(self, src, uniform_pixel_sizes=True): # based on the arbitrary values we added to the CTF files # note these values are not realistic filter0 = src.unique_filters[0] - self.assertTrue( - np.allclose( - np.array( - [ - 1000.0, - 900.0, - 800.0 * np.pi / 180.0, - 700.0, - 600.0, - 500.0, - ], - dtype=src.dtype, - ), - np.array( - [ - filter0.defocus_u, - filter0.defocus_v, - filter0.defocus_ang, - filter0.Cs, - filter0.alpha, - filter0.voltage, - ] - ), - ) + np.testing.assert_allclose( + np.array( + [ + 1000.0, + 900.0, + 800.0 * np.pi / 180.0, + 700.0, + 600.0, + 500.0, + ], + dtype=src.dtype, + ), + np.array( + [ + filter0.defocus_u, + filter0.defocus_v, + filter0.defocus_ang, + filter0.Cs, + filter0.alpha, + filter0.voltage, + ] + ).flatten(), ) filter1 = src.unique_filters[1] pixel_size1 = self.pixel_size if not uniform_pixel_sizes: pixel_size1 += 0.01 - self.assertTrue( - np.allclose( + np.testing.assert_allclose( np.array( [ 1001.0, @@ -677,9 +674,8 @@ def _testCtfFilters(self, src, uniform_pixel_sizes=True): filter1.alpha, filter1.voltage, ] - ), + ).flatten(), ) - ) # the first 200 particles should correspond to the first filter # since they came from the first micrograph self.assertTrue( From 91802245414fd2c8b6846e7a60ded720d283e59f Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 12 May 2026 09:38:20 -0400 Subject: [PATCH 28/50] minimal xform/pipeline patches [skip ci] --- src/aspire/covariance/covar2d.py | 20 +++----- src/aspire/image/xform.py | 17 +++++-- src/aspire/operators/filters.py | 52 +++++++++++++------- src/aspire/source/coordinates.py | 18 ++----- src/aspire/source/image.py | 48 ++++++++++--------- src/aspire/source/micrograph.py | 9 ++-- src/aspire/source/relion.py | 25 +++------- src/aspire/source/simulation.py | 68 ++++++++++++--------------- tests/test_anisotropic_noise.py | 4 +- tests/test_array_image_source.py | 2 - tests/test_batched_covar2d.py | 11 ++--- tests/test_coordinate_source.py | 6 +-- tests/test_covar2d.py | 16 +++---- tests/test_covar2d_denoiser.py | 12 ++--- tests/test_covar3d.py | 4 +- tests/test_downsample.py | 9 ++-- tests/test_indexed_source.py | 24 ++++------ tests/test_mean_estimator.py | 4 +- tests/test_micrograph_simulation.py | 9 ++-- tests/test_preprocess_pipeline.py | 4 +- tests/test_simulation.py | 28 +++++------ tests/test_simulation_metadata.py | 4 +- tests/test_weighted_mean_estimator.py | 4 +- 23 files changed, 181 insertions(+), 217 deletions(-) diff --git a/src/aspire/covariance/covar2d.py b/src/aspire/covariance/covar2d.py index 7c8c45685f..6e833b1b08 100644 --- a/src/aspire/covariance/covar2d.py +++ b/src/aspire/covariance/covar2d.py @@ -546,7 +546,7 @@ def _build(self): if self.basis is None: self.basis = FFBBasis2D((src.L, src.L), dtype=self.dtype) - if not src.unique_filters: + if src.filter_stack is None: logger.info("CTF filters are not included in Cov2D denoising") # set all CTF filters to an identity filter self.ctf_idx = np.zeros(src.n, dtype=int) @@ -562,9 +562,7 @@ def filters_to_basis_mats(self): optimized_expand = callable( getattr(self.basis.__class__, "expand_radial_vec", None) ) - if optimized_expand and all( - isinstance(f, CTFFilter) for f in self.src.unique_filters - ): + if optimized_expand and isinstance(self.src.filter_stack, CTFFilter): logger.info( "Found all filters are CTF, and `basis.expand_radial_vec` available using, bulk basis mat eval" ) @@ -577,29 +575,25 @@ def _filters_to_basis_mats(self): """ old code, should work with all basis and filters. slow. """ - unique_filters = self.src.unique_filters basis_mats = [ self.basis.filter_to_basis_mat( f, pixel_size=self.src.pixel_size, expand_method=self.expand_method ) - for f in tqdm(unique_filters, desc="Converting filters to basis mats") + for f in tqdm( + self.src.filter_stack, desc="Converting filters to basis mats" + ) ] return basis_mats def _ctf_filters_to_basis_mats(self): - unique_filters = self.src.unique_filters - # lol logger.info("Extracting CTF filter parameters and generating eval points") - params = np.empty((len(unique_filters), 7), dtype=self.dtype) - for i, f in enumerate(unique_filters): - ### TODO xxx fix up param dump, same as in source/sim - params[i] = np.array(f._ctf_params()).flatten() + params = self.src.filter_stack._ctf_params() logger.info("Computing CTF filters at eval points") _filter_pts = self.basis._filter_pts # if we have many filters, might be worth trip to GPU - if len(unique_filters) >= 2048: + if len(self.src.filter_stack) >= 2048: params = xp.asarray(params) _filter_pts = xp.asarray(_filter_pts) diff --git a/src/aspire/image/xform.py b/src/aspire/image/xform.py index 20a2e62c5c..7582688b16 100644 --- a/src/aspire/image/xform.py +++ b/src/aspire/image/xform.py @@ -314,6 +314,15 @@ def _forward(self, im, indices): def __str__(self): return f"FilterXform ({self.filter})" + def __len__(self): + """ + Return the len of the underlying filter stack. + """ + return len(self.filter) + + def __getitem__(self, item): + return FilterXform(self.filter[item]) + class Add(Xform): """ @@ -400,7 +409,9 @@ def __init__(self, unique_xforms, indices=None): # A list of references to individual Xform objects, with possibly multiple references pointing to # the same Xform object. - self.xforms = [unique_xforms[i] for i in indices] + # Crap, im stuck + # self.xforms = [unique_xforms[i] for i in indices] + self.xforms = unique_xforms def _indexed_operation(self, im, indices, which): """ @@ -420,7 +431,7 @@ def _indexed_operation(self, im, indices, which): im_data = np.empty_like(im.asnumpy()) # For each individual transformation - for i, xform in enumerate(self.unique_xforms): + for i in range(len(self.unique_xforms)): # Get the indices corresponding to that transformation idx = np.flatnonzero(self.indices == i) # For the incoming Image object, find out which transformation indices are applicable @@ -429,7 +440,7 @@ def _indexed_operation(self, im, indices, which): im_data_indices = np.flatnonzero(np.isin(indices, idx)) # Apply the transformation to the selected indices in the Image object if len(im_data_indices) > 0: - fn_handle = getattr(xform, which) + fn_handle = getattr(self.unique_xforms[i], which) im_data[im_data_indices] = fn_handle(im[im_data_indices]).asnumpy() return Image(im_data, pixel_size=im.pixel_size) diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index 16562b8389..5eaf97f195 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -31,15 +31,17 @@ def evaluate_src_filters_on_grid(src, indices=None): grid2d = grid_2d(src.L, indexing="yx", dtype=src.dtype) omega = np.pi * np.vstack((grid2d["x"].flatten(), grid2d["y"].flatten())) - # Initialize h as ones to mimic an IdentityFilter when src.unique_filters is None. + # xxx filter opt (eval in bulk instead of loop here), remove branch + # Initialize h as ones to mimic an IdentityFilter when src.filter_stack is None. h = np.ones((omega.shape[-1], len(indices)), dtype=src.dtype) - for i, filt in enumerate(src.unique_filters): - idx_k = np.where(src.filter_indices[indices] == i)[0] - if len(idx_k) > 0: - filter_values = filt.evaluate(omega, pixel_size=src.pixel_size) - # convert filter_values row vector to column vector and tile broadcast - filter_values = filter_values.reshape(-1, 1) - h[:, idx_k] = np.tile(filter_values, len(idx_k)) + if src.filter_stack is not None: + for i, filt in enumerate(src.filter_stack): + idx_k = np.where(src.filter_indices[indices] == i)[0] + if len(idx_k) > 0: + filter_values = filt.evaluate(omega, pixel_size=src.pixel_size) + # convert filter_values row vector to column vector and tile broadcast + filter_values = filter_values.reshape(-1, 1) + h[:, idx_k] = np.tile(filter_values, len(idx_k)) h = np.reshape(h, grid2d["x"].shape + (len(indices),)) return h @@ -252,6 +254,15 @@ def __init__(self, filter, f): def _evaluate(self, omega, **kwargs): return self._f(self._filter.evaluate(omega, **kwargs)) + def __len__(self): + """ + Return length of underlying filter stack + """ + return len(self._filter) + + def __getitem__(self, item): + return LambdaFilter(self._filter[item], self._f) + class MultiplicativeFilter(Filter): """ @@ -441,7 +452,7 @@ def __init__( :param alpha: Amplitude contrast phase in radians :param B: Envelope decay in inverse square angstrom (default 0) """ - super().__init__(dim=2, radial=defocus_u == defocus_v) + super().__init__(dim=2, radial=np.all(defocus_u == defocus_v)) voltage = np.atleast_1d(voltage) # maybe allow singleton here for V defocus_u = np.atleast_1d(defocus_u) defocus_v = np.atleast_1d(defocus_v) @@ -500,15 +511,20 @@ def __len__(self): return self.n def _ctf_params(self): - return ( - self.voltage, - self.defocus_u, - self.defocus_v, - self.defocus_ang, - self.Cs, - self.alpha, - self.B, - ) + """ + Return n_filters-by-n_param array. + """ + return np.array( + [ + self.voltage, + self.defocus_u, + self.defocus_v, + self.defocus_ang, + self.Cs, + self.alpha, + self.B, + ] + ).T def _evaluate(self, omega, **kwargs): indices = kwargs.get("indices", None) diff --git a/src/aspire/source/coordinates.py b/src/aspire/source/coordinates.py index 83d69b72b8..1eacfdc96c 100644 --- a/src/aspire/source/coordinates.py +++ b/src/aspire/source/coordinates.py @@ -176,7 +176,7 @@ def __init__( # set CTF metadata to defaults # this can be updated with import_ctf() self.filter_indices = np.zeros(self.n, dtype=int) - self.unique_filters = [IdentityFilter()] + self.filter_stack = IdentityFilter() self.set_metadata("__filter_indices", np.zeros(self.n, dtype=int)) # populate __mrc_filename and __mrc_index @@ -406,25 +406,15 @@ def _extract_ctf(self, data_block): ) # convert defocus_ang from degrees to radians - filter_params[:, 3] *= np.pi / 180.0 + filter_params[:, 3] = np.deg2rad(filter_params[:, 3]) # Warn if CTF pixel_sizes do match self.pixel_size ctf_pixel_sizes = np.unique(filter_params[:, 6]) check_pixel_size(ctf_pixel_sizes, self.pixel_size) # construct filters - self.unique_filters = [ - CTFFilter( - voltage=filter_params[i, 0], - defocus_u=filter_params[i, 1], - defocus_v=filter_params[i, 2], - defocus_ang=filter_params[i, 3], - Cs=filter_params[i, 4], - alpha=filter_params[i, 5], - B=self.B, - ) - for i in range(len(filter_params)) - ] + # drop pixel size column + self.filter_stack = CTFFilter(*(filter_params[:, :6]).T, B=self.B) # set metadata for mrc_idx, filter_index in enumerate(indices): diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index 0660f402c5..2483358578 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -22,13 +22,7 @@ Pipeline, ) from aspire.noise import LegacyNoiseEstimator, NoiseEstimator, WhiteNoiseEstimator -from aspire.operators import ( - CTFFilter, - Filter, - IdentityFilter, - MultiplicativeFilter, - PowerFilter, -) +from aspire.operators import CTFFilter, Filter, MultiplicativeFilter, PowerFilter from aspire.storage import MrcStats, StarFile from aspire.utils import ( Rotation, @@ -214,7 +208,7 @@ def __init__( self._populate_pixel_size(pixel_size) self._populate_symmetry_group(symmetry_group) - self.unique_filters = [] + self.filter_stack = None self.generation_pipeline = Pipeline(xforms=None, memory=memory) logger.info(f"Creating {self.__class__.__name__} with {len(self)} images.") @@ -443,7 +437,10 @@ def n_ctf_filters(self): """ Return the number of CTFFilters found in this Source. """ - return len([f for f in self.unique_filters if isinstance(f, CTFFilter)]) + n = 0 + if isinstance(self.filter_stack, CTFFilter): + n = len(self.filter_stack) + return n @property def states(self): @@ -785,6 +782,11 @@ def _apply_filters( im = im_orig.copy() + if filters is None: + return im + + # else evaluate filters + # XXXX broadcast filter eval for i, filt in enumerate(filters): idx_k = np.where(indices == i)[0] if len(idx_k) > 0: @@ -795,7 +797,7 @@ def _apply_filters( def _apply_source_filters(self, im_orig, indices): return self._apply_filters( im_orig, - self.unique_filters, + self.filter_stack, self.filter_indices[indices], ) @@ -868,8 +870,10 @@ def downsample(self, L, zero_nyquist=True, centered_fft=True): ) ) + # XXXX sigh ds_factor = self.L / L - self.unique_filters = [f.scale(ds_factor) for f in self.unique_filters] + if self.filter_stack is not None: + self.filter_stack = self.filter_stack.scale(ds_factor) if self.pixel_size is not None: self.pixel_size *= ds_factor @@ -956,9 +960,8 @@ def whiten(self, noise_estimate=None, epsilon=None): whiten_filter = PowerFilter(noise_filter, power=-0.5, epsilon=epsilon) logger.info("Transforming all CTF Filters into Multiplicative Filters") - self.unique_filters = [ - MultiplicativeFilter(f, whiten_filter) for f in self.unique_filters - ] + # XXXX + self.filter_stack = MultiplicativeFilter(self.filter_stack, whiten_filter) logger.info("Adding Whitening Filter Xform to end of generation pipeline") self.generation_pipeline.add_xform(FilterXform(whiten_filter)) @@ -1005,8 +1008,9 @@ def phase_flip(self): logger.info("Perform phase flip on source object") - if len(self.unique_filters) >= 1: - unique_xforms = [FilterXform(f.sign) for f in self.unique_filters] + if self.filter_stack is not None: + # XXXX + unique_xforms = FilterXform(self.filter_stack.sign) logger.info("Adding Phase Flip Xform to end of generation pipeline") self.generation_pipeline.add_xform( @@ -1781,18 +1785,20 @@ def __init__(self, src, indices, memory=None): pixel_size=src.pixel_size, ) - if src.unique_filters: + if src.filter_stack is not None: # Remap the filter indices to be unique. # Removes duplicates and filters that are unused in new source. _filter_indices = src.filter_indices[self.index_map] # _unq[_inv] reconstructs _filter_indices _unq, _inv = np.unique(_filter_indices, return_inverse=True) - # Repack unique_filters + # Repack filter_stack self.filter_indices = _inv - self.unique_filters = [copy.copy(src.unique_filters[i]) for i in _unq] + self.filter_stack = copy.copy( + src.filter_stack[_unq] + ) # xxx, this might just work by slicing... else: # Pass through the None case - self.unique_filters = src.unique_filters + self.filter_stack = src.filter_stack self.filter_indices = np.zeros(self.n, dtype=int) # Any further operations should not mutate this instance. @@ -2029,7 +2035,7 @@ def __init__( # Create filter indices, these are required to pass unharmed through filter eval code # that is potentially called by other methods later. self.filter_indices = np.zeros(self.n, dtype=int) - self.unique_filters = [IdentityFilter()] + self.filter_stack = None # Optionally populate angles/rotations. if angles is not None: diff --git a/src/aspire/source/micrograph.py b/src/aspire/source/micrograph.py index a0e11e61ae..6a1b386f90 100644 --- a/src/aspire/source/micrograph.py +++ b/src/aspire/source/micrograph.py @@ -393,10 +393,7 @@ def __init__( self.filter_indices = None if ctf_filters is not None: acceptable_lens = [1, self.micrograph_count, self.total_particle_count] - if ( - not isinstance(ctf_filters, list) - or len(ctf_filters) not in acceptable_lens - ): + if len(ctf_filters) not in acceptable_lens: raise TypeError( f"`ctf_filters` expects a list of len {acceptable_lens[0]}," f" {acceptable_lens[1]}, or {acceptable_lens[2]}." @@ -424,7 +421,7 @@ def __init__( offsets=0, amplitudes=self.particle_amplitudes, angles=self.projection_angles, - unique_filters=ctf_filters, + filter_stack=ctf_filters, filter_indices=self.filter_indices, pixel_size=self.pixel_size, dtype=self.dtype, @@ -699,7 +696,7 @@ def save(self, path, name_prefix="micrograph", overwrite=True): # CTF ctf_metadata = dict() - if self.simulation.unique_filters: + if self.simulation.filter_stack: ctf_metadata = self.simulation.get_metadata( metadata_fields=_meta_fields["ctf"], indices=self.get_particle_indices(m), diff --git a/src/aspire/source/relion.py b/src/aspire/source/relion.py index 963e7ae427..fd37e7d513 100644 --- a/src/aspire/source/relion.py +++ b/src/aspire/source/relion.py @@ -146,23 +146,12 @@ def __init__( return_inverse=True, axis=0, ) - filters = [] - # for each unique CTF configuration, create a CTFFilter object - for row in filter_params: - filters.append( - CTFFilter( - voltage=row[0], - defocus_u=row[1], - defocus_v=row[2], - defocus_ang=row[3] * np.pi / 180, # degrees to radians - Cs=row[4], - alpha=row[5], - B=B, - ) - ) - self.unique_filters = filters + # Convert `defocus_ang` from degrees to radians + filter_params[:, 3] = np.deg2rad(filter_params[:, 3]) + # Create a CTFFilter stack + self.filter_stack = CTFFilter(*filter_params.T, B=B) # filter_indices stores, for each particle index, the index in - # self.unique_filters of the filter that should be applied + # self.filter_stack of the filter that should be applied self.filter_indices = filter_indices # If we detect ASPIRE added dummy variables, log and initialize identity filter @@ -170,7 +159,7 @@ def __init__( logger.info( "Detected ASPIRE-generated dummy optics; initializing identity filters." ) - self.unique_filters = [IdentityFilter()] + self.filter_stack = IdentityFilter() self.filter_indices = np.zeros(self.n, dtype=int) # We have provided some, but not all the required params @@ -182,7 +171,7 @@ def __init__( # If no CTF info in STAR, we initialize the filter values of metadata with default values else: - self.unique_filters = [IdentityFilter()] + self.filter_stack = IdentityFilter() self.filter_indices = np.zeros(self.n, dtype=int) logger.info(f"Populated {self.n_ctf_filters} CTFFilters from '{filepath}'") diff --git a/src/aspire/source/simulation.py b/src/aspire/source/simulation.py index 98ff6d9e39..997f7f8890 100644 --- a/src/aspire/source/simulation.py +++ b/src/aspire/source/simulation.py @@ -22,7 +22,7 @@ class Simulation(ImageSource): `metadata`. The images are generated via projections of a supplied `Volume` object, `vols`, over orientations define by the Euler angles, `angles`. Various types of corruption, such as noise and CTF effects, can be added to the images by supplying a `Filter` object to the `noise_filter` or - `unique_filters` arguments. + `filter_stack` arguments. """ def __init__( @@ -31,7 +31,7 @@ def __init__( n=1024, vols=None, states=None, - unique_filters=None, + filter_stack=None, filter_indices=None, offsets=None, amplitudes=None, @@ -54,9 +54,9 @@ def __init__( Default is generated with `volume.volume_synthesis.AsymmetricVolume`. :param states: A 1d array of n integers in the interval [0, C). The i'th integer indicates the volume stack index used to produce the i'th projection image. Default is a random set. - :param unique_filters: A list of Filter objects to be applied to projection images. + :param filter_stack: A Filter object to be applied to projection images. :param filter_indices: A 1d array of n integers indicating the `unique_filter` indices associated - with each image. Default is a random set of filter indices, .ie the filters from `unique_filters` + with each image. Default is a random set of filter indices, .ie the filters from `filter_stack` are randomly assigned to the stack of images. :param offsets: A n-by-2 array of coordinates to offset the images. Default is a normally distributed set of offsets. Set `offsets = 0` to disable offsets. @@ -158,17 +158,15 @@ def __init__( self.angles = self._init_angles(angles) - if unique_filters is None: - unique_filters = [] - self.unique_filters = unique_filters + self.filter_stack = filter_stack # sim_filters must be a deep copy so that it is not changed - # when unique_filters is changed - self.sim_filters = copy.deepcopy(unique_filters) + # when filter_stack is changed + self.sim_filters = copy.deepcopy(filter_stack) # Create filter indices and fill the metadata based on unique filters - if unique_filters: + if filter_stack is not None: if filter_indices is None: - filter_indices = randi(len(unique_filters), n, seed=seed) - 1 + filter_indices = randi(len(filter_stack), n, seed=seed) - 1 self._populate_ctf_metadata(filter_indices) self.filter_indices = filter_indices else: @@ -220,35 +218,31 @@ def _populate_ctf_metadata(self, filter_indices): # for these columns # # class attributes of CTFFilter: - CTFFilter_attributes = ( - "voltage", - "defocus_u", - "defocus_v", - "defocus_ang", - "Cs", - "alpha", - ) - - # get the CTF parameters, if they exist, for each filter - # and for each image (indexed by filter_indices) - filter_values = np.zeros((len(filter_indices), len(CTFFilter_attributes))) - for i, filt in enumerate(self.unique_filters): - # TODO xxx change to param dump later - filter_values[filter_indices == i] = np.array( - [getattr(filt, att, np.nan) for att in CTFFilter_attributes] - ).flatten() + CTFFilter_attributes = [ + "_rlnVoltage", + "_rlnDefocusU", + "_rlnDefocusV", + "_rlnDefocusAngle", + "_rlnSphericalAberration", + "_rlnAmplitudeContrast", + ] + + # Unpack the `filter_stack` params across images using `filter_indices` mapping + # Note this does not include the B factor term (unique to ASPIRE?,xxx should we add to star if used?) + filter_stack_params = self.filter_stack._ctf_params()[ + :, :6 + ] # params per filter + image_filter_values = np.zeros( + (len(filter_indices), len(CTFFilter_attributes)) + ) # params per image + for i, params in enumerate(filter_stack_params): + # assign `params` to all matching images + image_filter_values[filter_indices == i] = params # set the corresponding Relion metadata values that we would expect # from a STAR file self.set_metadata( - [ - "_rlnVoltage", - "_rlnDefocusU", - "_rlnDefocusV", - "_rlnDefocusAngle", - "_rlnSphericalAberration", - "_rlnAmplitudeContrast", - ], - filter_values, + CTFFilter_attributes, + image_filter_values, ) @property diff --git a/tests/test_anisotropic_noise.py b/tests/test_anisotropic_noise.py index ee699cb20f..ccfc824e8d 100644 --- a/tests/test_anisotropic_noise.py +++ b/tests/test_anisotropic_noise.py @@ -19,9 +19,7 @@ def setUp(self): self.sim = _LegacySimulation( n=1024, vols=self.vol, - unique_filters=[ - RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7) - ], + filter_stack=RadialCTFFilter(defocus=np.linspace(1.5e4, 2.5e4, 7)), dtype=self.dtype, ) diff --git a/tests/test_array_image_source.py b/tests/test_array_image_source.py index a2c3ee4f0c..fc2735d15c 100644 --- a/tests/test_array_image_source.py +++ b/tests/test_array_image_source.py @@ -10,7 +10,6 @@ from aspire.basis import FBBasis3D from aspire.image import Image -from aspire.operators import IdentityFilter from aspire.reconstruction import MeanEstimator from aspire.source import ArrayImageSource, RelionSource, Simulation from aspire.utils import Rotation, utest_tolerance @@ -31,7 +30,6 @@ def setUp(self): self.sim = sim = Simulation( n=self.n, L=self.resolution, - unique_filters=[IdentityFilter()], seed=0, dtype=self.dtype, # We'll use random angles diff --git a/tests/test_batched_covar2d.py b/tests/test_batched_covar2d.py index f239354700..70906425d7 100644 --- a/tests/test_batched_covar2d.py +++ b/tests/test_batched_covar2d.py @@ -34,7 +34,7 @@ def setUp(self): self.src = Simulation( L, n, - unique_filters=self.filters, + filter_stack=self.filters, pixel_size=5, dtype=self.dtype, noise_adder=noise_adder, @@ -253,10 +253,9 @@ class BatchedRotCov2DTestCaseCTF(BatchedRotCov2DTestCase): @property def filters(self): - return [ - RadialCTFFilter(200, defocus=d, Cs=2.0, alpha=0.1) - for d in np.linspace(1.5e4, 2.5e4, 7) - ] + return RadialCTFFilter( + 200, defocus=np.linspace(1.5e4, 2.5e4, 7), Cs=2.0, alpha=0.1 + ) @property def ctf_idx(self): @@ -266,5 +265,5 @@ def ctf_idx(self): def ctf_basis(self): return [ self.basis.filter_to_basis_mat(f, pixel_size=self.src.pixel_size) - for f in self.src.unique_filters + for f in self.src.filter_stack ] diff --git a/tests/test_coordinate_source.py b/tests/test_coordinate_source.py index 3d521832e4..4838f75a9d 100644 --- a/tests/test_coordinate_source.py +++ b/tests/test_coordinate_source.py @@ -621,11 +621,11 @@ def testImportCtfFromRelionLegacy(self): def _testCtfFilters(self, src, uniform_pixel_sizes=True): # there are two micrographs and two CTF files, so there should be two # unique CTF filters - self.assertEqual(len(src.unique_filters), 2) + self.assertEqual(len(src.filter_stack), 2) # test the properties of the CTF filters # based on the arbitrary values we added to the CTF files # note these values are not realistic - filter0 = src.unique_filters[0] + filter0 = src.filter_stack[0] np.testing.assert_allclose( np.array( [ @@ -649,7 +649,7 @@ def _testCtfFilters(self, src, uniform_pixel_sizes=True): ] ).flatten(), ) - filter1 = src.unique_filters[1] + filter1 = src.filter_stack[1] pixel_size1 = self.pixel_size if not uniform_pixel_sizes: pixel_size1 += 0.01 diff --git a/tests/test_covar2d.py b/tests/test_covar2d.py index 66e88c4ff8..5630606f30 100644 --- a/tests/test_covar2d.py +++ b/tests/test_covar2d.py @@ -79,22 +79,20 @@ def cov2d_fixture(volume, basis, ctf_enabled): n = 32 # Default CTF params - unique_filters = None + filters = None h_idx = None h_ctf_fb = None # Popluate CTF if ctf_enabled: - unique_filters = [ - RadialCTFFilter(200, defocus=d, Cs=2.0, alpha=0.1) - for d in np.linspace(1.5e4, 2.5e4, 7) - ] + filters = RadialCTFFilter( + 200, defocus=np.linspace(1.5e4, 2.5e4, 7), Cs=2.0, alpha=0.1 + ) # Copied from simulation defaults to match legacy test files. - h_idx = randi(len(unique_filters), n, seed=0) - 1 + h_idx = randi(len(filters), n, seed=0) - 1 h_ctf_fb = [ - basis.filter_to_basis_mat(f, pixel_size=volume.pixel_size) - for f in unique_filters + basis.filter_to_basis_mat(f, pixel_size=volume.pixel_size) for f in filters ] noise_adder = WhiteNoiseAdder(var=NOISE_VAR) @@ -102,7 +100,7 @@ def cov2d_fixture(volume, basis, ctf_enabled): sim = _LegacySimulation( n=n, vols=volume, - unique_filters=unique_filters, + filter_stack=filters, filter_indices=h_idx, offsets=0.0, amplitudes=1.0, diff --git a/tests/test_covar2d_denoiser.py b/tests/test_covar2d_denoiser.py index 984acdd342..46db43a3ee 100644 --- a/tests/test_covar2d_denoiser.py +++ b/tests/test_covar2d_denoiser.py @@ -15,12 +15,10 @@ noise_var = 0.1848 noise_adder = WhiteNoiseAdder(var=noise_var) pixel_size = 5 -filters = [ - CTFFilter( - 200, defocus_ang=np.pi / 3, defocus_u=d, defocus_v=d + 345, Cs=2.0, alpha=0.1 - ) - for d in np.linspace(1.5e4, 2.5e4, 7) -] +d = np.linspace(1.5e4, 2.5e4, 7) +filters = CTFFilter( + 200, defocus_ang=np.pi / 3, defocus_u=d, defocus_v=d + 345, Cs=2.0, alpha=0.1 +) # For (F)PSWFBasis2D we get off-block entries which are truncated # when converting to block-diagonal. We filter these warnings. @@ -61,7 +59,7 @@ def sim(): sim = Simulation( L=img_size, n=num_imgs, - unique_filters=filters, + filter_stack=filters, offsets=0.0, amplitudes=1.0, dtype=dtype, diff --git a/tests/test_covar3d.py b/tests/test_covar3d.py index 98d0fdea13..5b50ec2c56 100644 --- a/tests/test_covar3d.py +++ b/tests/test_covar3d.py @@ -28,9 +28,7 @@ def setUpClass(cls): cls.sim = _LegacySimulation( n=1024, vols=cls.vols, - unique_filters=[ - RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7) - ], + filter_stack=RadialCTFFilter(defocus=np.linspace(1.5e4, 2.5e4, 7)), dtype=cls.dtype, ) basis = FBBasis3D((8, 8, 8), dtype=cls.dtype) diff --git a/tests/test_downsample.py b/tests/test_downsample.py index 80a6f1e69a..b2da4806f3 100644 --- a/tests/test_downsample.py +++ b/tests/test_downsample.py @@ -234,17 +234,16 @@ def test_simulation_relion_downsample(): defocus_max = 25000 defocus_ct = 7 - ctf_filters = [ - RadialCTFFilter(defocus=d) - for d in np.linspace(defocus_min, defocus_max, defocus_ct) - ] + ctf_filters = RadialCTFFilter( + defocus=np.linspace(defocus_min, defocus_max, defocus_ct) + ) # Generate Simulation source and downsampled simulation. src = Simulation( L=64, n=10, C=1, - unique_filters=ctf_filters, + filter_stack=ctf_filters, noise_adder=WhiteNoiseAdder.from_snr(snr=1), pixel_size=1, ) diff --git a/tests/test_indexed_source.py b/tests/test_indexed_source.py index 80ada30003..1a0179d0cc 100644 --- a/tests/test_indexed_source.py +++ b/tests/test_indexed_source.py @@ -59,7 +59,7 @@ def test_repr(sim_fixture): @pytest.mark.expensive def test_filter_mapping(): """ - This test is designed to ensure that `unique_filters` and `filter_indices` + This test is designed to ensure that `filter_stack` and `filter_indices` are being remapped correctly upon slicing. Additionally it tests that a realistic preprocessing pipeline is equivalent @@ -81,18 +81,14 @@ def test_filter_mapping(): angles = Rotation(np.repeat(rots, 2, axis=0)).angles # Generate N//2 rotations and repeat indices - defoci = np.linspace(1000, 25000, N // 2) - ctf_filters = [ - CTFFilter( - 200, - defocus_u=defoci[d], - defocus_v=defoci[-d], - defocus_ang=np.pi / (N // 2) * d, - Cs=2.0, - alpha=0.1, - ) - for d in range(N // 2) - ] + ctf_filters = CTFFilter( + 200, + defocus_u=np.linspace(1000, 25000, N // 2), + defocus_v=np.linspace(1000, 25000, N // 2)[::-1], + defocus_ang=np.linspace(0, np.pi, N // 2), + Cs=2.0, + alpha=0.1, + ) ctf_indices = np.repeat(np.arange(N // 2), 2) # Construct the source @@ -101,7 +97,7 @@ def test_filter_mapping(): n=N, dtype=DT, seed=SEED, - unique_filters=ctf_filters, + filter_stack=ctf_filters, filter_indices=ctf_indices, angles=angles, offsets=0, diff --git a/tests/test_mean_estimator.py b/tests/test_mean_estimator.py index b56a06f4ce..db69ce5222 100644 --- a/tests/test_mean_estimator.py +++ b/tests/test_mean_estimator.py @@ -51,9 +51,7 @@ def sim(L, dtype): L=L, n=256, C=1, # single volume - unique_filters=[ - RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7) - ], + filter_stack=RadialCTFFilter(defocus=np.linspace(1.5e4, 2.5e4, 7)), dtype=dtype, seed=SEED, pixel_size=1.234, diff --git a/tests/test_micrograph_simulation.py b/tests/test_micrograph_simulation.py index 2e81d7326f..4620969cdf 100644 --- a/tests/test_micrograph_simulation.py +++ b/tests/test_micrograph_simulation.py @@ -304,7 +304,7 @@ def test_sim_save(): """ v = AsymmetricVolume(L=16, C=1, pixel_size=4, dtype=np.float64).generate() - ctfs = [RadialCTFFilter(voltage=200, defocus=15000, Cs=2.26, alpha=0.07, B=0)] + ctfs = RadialCTFFilter(voltage=200, defocus=15000, Cs=2.26, alpha=0.07, B=0) mg_sim = MicrographSimulation( volume=v, @@ -365,7 +365,7 @@ def test_save_overwrite(caplog): """ v = AsymmetricVolume(L=16, C=1, pixel_size=4, dtype=np.float64).generate() - ctfs = [RadialCTFFilter(voltage=200, defocus=15000, Cs=2.26, alpha=0.07, B=0)] + ctfs = RadialCTFFilter(voltage=200, defocus=15000, Cs=2.26, alpha=0.07, B=0) mg_sim = MicrographSimulation( volume=v, @@ -466,8 +466,5 @@ def test_bad_ctf(vol_fixture): particles_per_micrograph=1, micrograph_count=1, micrograph_size=512, - ctf_filters=[ - RadialCTFFilter(), - ] - * 2, # total particles == 1 + ctf_filters=RadialCTFFilter(defocus=[10000, 20000]), # total particles == 1 ) diff --git a/tests/test_preprocess_pipeline.py b/tests/test_preprocess_pipeline.py index 83cc070930..20716660ff 100644 --- a/tests/test_preprocess_pipeline.py +++ b/tests/test_preprocess_pipeline.py @@ -29,9 +29,7 @@ def get_sim_object(L, dtype): sim = Simulation( L=L, n=num_images, - unique_filters=[ - RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7) - ], + filter_stack=RadialCTFFilter(defocus=np.linspace(1.5e4, 2.5e4, 7)), noise_adder=noise_adder, pixel_size=1, dtype=dtype, diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 58bea32128..898479bf7a 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -128,9 +128,7 @@ def setUp(self): n=self.n, L=self.L, vols=self.vols, - unique_filters=[ - RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7) - ], + filter_stack=RadialCTFFilter(defocus=np.linspace(1.5e4, 2.5e4, 7)), noise_adder=WhiteNoiseAdder(var=1), dtype=self.dtype, ) @@ -176,9 +174,7 @@ def testSimulationCached(self): L=self.L, vols=self.vols, offsets=self.sim.offsets, - unique_filters=[ - RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7) - ], + filter_stack=RadialCTFFilter(defocus=np.linspace(1.5e4, 2.5e4, 7)), noise_adder=WhiteNoiseAdder(var=1), dtype=self.dtype, ) @@ -634,12 +630,10 @@ def test_simulation_save_optics_block(tmp_path): # Radial CTF Filters. Should make 3 distinct optics blocks kv_min, kv_max, kv_ct = 200, 300, 3 voltages = np.linspace(kv_min, kv_max, kv_ct) - ctf_filters = [RadialCTFFilter(voltage=kv) for kv in voltages] + ctf_filters = RadialCTFFilter(voltage=voltages) # Generate and save Simulation - sim = Simulation( - n=9, L=res, C=1, unique_filters=ctf_filters, pixel_size=1.34 - ).cache() + sim = Simulation(n=9, L=res, C=1, filter_stack=ctf_filters, pixel_size=1.34).cache() starpath = tmp_path / "sim.star" sim.save(starpath, overwrite=True) @@ -699,10 +693,10 @@ def test_simulation_slice_save_roundtrip(tmp_path): # Radial CTF Filters kv_min, kv_max, kv_ct = 200, 300, 3 voltages = np.linspace(kv_min, kv_max, kv_ct) - ctf_filters = [RadialCTFFilter(voltage=kv) for kv in voltages] + ctf_filters = RadialCTFFilter(voltage=voltages) # Generate and save slice of Simulation - sim = Simulation(n=9, L=16, C=1, unique_filters=ctf_filters, pixel_size=1.34) + sim = Simulation(n=9, L=16, C=1, filter_stack=ctf_filters, pixel_size=1.34) sliced_sim = sim[::2] save_path = tmp_path / "sliced_sim.star" sliced_sim.save(save_path, overwrite=True) @@ -787,14 +781,14 @@ def test_cached_image_accessors(): Test the behavior of image caching. """ # Create a CTF - ctf = [RadialCTFFilter()] + ctf = RadialCTFFilter() # Create a Simulation with noise and `ctf` src = Simulation( L=32, n=3, C=1, noise_adder=WhiteNoiseAdder(var=0.123), - unique_filters=ctf, + filter_stack=ctf, pixel_size=5, ) # Cache the simulation @@ -818,14 +812,14 @@ def test_projections_and_clean_images_downsample(): L = 32 L_ds = 21 px_sz = 1.23 - ctf = [RadialCTFFilter(1.5e4)] + ctf = RadialCTFFilter(1.5e4) src = Simulation( L=L, n=n, C=1, noise_adder=WhiteNoiseAdder(var=0.123), - unique_filters=ctf, + filter_stack=ctf, pixel_size=px_sz, ) @@ -921,7 +915,7 @@ def test_save_load_dummy_ctf_values(tmp_path, caplog): are present. These values should be detected upon reloading the source. """ star_path = tmp_path / "no_ctf.star" - sim = Simulation(n=8, L=16) # no unique_filters, ie. no CTF info + sim = Simulation(n=8, L=16) # no filter_stack, ie. no CTF info sim.save(star_path, overwrite=True) # STAR file should contain our fallback tag diff --git a/tests/test_simulation_metadata.py b/tests/test_simulation_metadata.py index c47efa1167..f5367038a6 100644 --- a/tests/test_simulation_metadata.py +++ b/tests/test_simulation_metadata.py @@ -20,9 +20,7 @@ def setUp(self): self.sim = MySimulation( n=1024, L=8, - unique_filters=[ - RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7) - ], + filter_stack=RadialCTFFilter(defocus=np.linspace(1.5e4, 2.5e4, 7)), ) def tearDown(self): diff --git a/tests/test_weighted_mean_estimator.py b/tests/test_weighted_mean_estimator.py index 96f11f2cab..b376891b0e 100644 --- a/tests/test_weighted_mean_estimator.py +++ b/tests/test_weighted_mean_estimator.py @@ -54,9 +54,7 @@ def sim(L, dtype): L=L, n=256, C=1, # single volume - unique_filters=[ - RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7) - ], + filter_stack=RadialCTFFilter(defocus=np.linspace(1.5e4, 2.5e4, 7)), pixel_size=1, dtype=dtype, seed=SEED, From ccaf5b791d1e7d5c199e227adedc2348885dc162 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 14 May 2026 07:39:14 -0400 Subject: [PATCH 29/50] got both filter (ds) and filter stack running [skip ci] --- src/aspire/basis/fle_2d.py | 13 +++++++++---- src/aspire/basis/steerable.py | 4 +++- src/aspire/operators/filters.py | 6 ++++++ 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/aspire/basis/fle_2d.py b/src/aspire/basis/fle_2d.py index e39e26b3a2..f6e9b5d905 100644 --- a/src/aspire/basis/fle_2d.py +++ b/src/aspire/basis/fle_2d.py @@ -771,6 +771,7 @@ def _radial_convolve_weights(self, b): b = xp.concatenate((b, bz), axis=1) b = fft.idct(b, axis=1, type=2) * 2 * b.shape[1] a = xp.zeros((b.shape[0], self.count), dtype=self.dtype) + for i in range(self.ell_p_max + 1): # Wierd mul transpose forced by A3 being CSR. # Can't reshape A3, but can broadcast over dims of b. @@ -860,11 +861,15 @@ def expand_radial_vec(self, radial_vec, **kwargs): # Convert to internal FLE indices ordering coefs = coefs[..., self._fb_to_fle_indices] - # squeeze should probably be addressed in consuming code, - # for now match old `filter_to_basis_mat` - coefs = xp.asnumpy(coefs).squeeze() + coefs = xp.asnumpy(coefs) + + # who needs this as a list? + if len(coefs) > 1: + coefs = [DiagMatrix(c) for c in coefs] + else: + coefs = DiagMatrix(coefs.flatten()) - return [DiagMatrix(c) for c in coefs] + return coefs def _radial_filter_to_vals(self, f, **kwargs): """ diff --git a/src/aspire/basis/steerable.py b/src/aspire/basis/steerable.py index 94cea0a96a..ddb7424b77 100644 --- a/src/aspire/basis/steerable.py +++ b/src/aspire/basis/steerable.py @@ -498,7 +498,9 @@ def filter_to_basis_mat(self, f, **kwargs): if optimized_expand and filter_is_radial and radial_method: # kwargs supports passing through pixel_size - h_vals = self._radial_filter_to_vals(f, **kwargs).reshape(-1, 1) + h_vals = self._radial_filter_to_vals( + f, **kwargs + ) # check dont need #.reshape(-1, 1) res = self.expand_radial_vec(h_vals) return res else: diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index 5eaf97f195..606ea35a50 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -301,6 +301,12 @@ def __str__(self): """ return f"ScaledFilter (scales {self._filter} by {self._scale})" + def __getitem__(self, item): + return ScaledFilter(self._filter[item], self._scale) + + def __len__(self): + return len(self._filter) + class ArrayFilter(Filter): def __init__(self, xfer_fn_array): From e06fd26121c10469ba2c7f2e91c75a5c27b1f494 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 14 May 2026 14:47:44 -0400 Subject: [PATCH 30/50] initial attempt extending multiplicative filter bcast [skip ci] --- src/aspire/operators/filters.py | 35 +++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index 606ea35a50..23a6e50ce5 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -149,6 +149,14 @@ def sign(self): """ return LambdaFilter(self, np.sign) + def __len__(self): + """ + Default filters are length 1. + + Some filters (eg CTFFilter) provide additional optimizations for stacks. + """ + return 1 + class DualFilter(Filter): """ @@ -162,6 +170,9 @@ def __init__(self, filter_in): def evaluate(self, omega, **kwargs): return self._filter.evaluate(-omega, **kwargs) + def __len__(self): + return len(self._filter) + class FunctionFilter(Filter): """ @@ -240,6 +251,9 @@ def evaluate_grid(self, L, *args, dtype=np.float32, **kwargs): return filter_vals**self._power + def __len__(self): + return len(self._filter) + class LambdaFilter(Filter): """ @@ -272,6 +286,24 @@ class MultiplicativeFilter(Filter): def __init__(self, *args): super().__init__(dim=args[0].dim, radial=all(c.radial for c in args)) self._components = args + self._init_size() + + def _init_size(self): + """ + Check sizes of _components are coherent and initialize resulting length. + """ + filter_lengths = [len(f) for f in self._components] + filter_lengths = np.unique(filter_lengths, sorted=True) + + # Code should be able to broadcast n_filters with n_filters, or n_filters with 1_filters. + # Any other combination is considered an error. + if len(filter_lengths) > 2 or ( + (len(filter_lengths) == 2) and (filter_lengths[0] != 1) + ): + raise RuntimeError(f"Incoherent filter lengths {filter_lengths}") + # filter_lengths is sorted, so this should be the larger of two values, + # or the single value in the 1_filter case + self._n = filter_lengths[-1] def _evaluate(self, omega, **kwargs): res = 1 @@ -279,6 +311,9 @@ def _evaluate(self, omega, **kwargs): res *= c.evaluate(omega, **kwargs) return res + def __len__(self): + return self._n + class ScaledFilter(Filter): """ From 5ac18254264c81b8a67fcecd5922331809cba823 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 18 May 2026 09:46:12 -0400 Subject: [PATCH 31/50] hacktastic ctf param passthrough [skip ci] --- src/aspire/operators/filters.py | 42 +++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index 23a6e50ce5..fc9328f1c6 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -157,6 +157,9 @@ def __len__(self): """ return 1 + def _ctf_params(self): + raise NotImplementedError(f"Not implemented for {self.__class__.__name__}") + class DualFilter(Filter): """ @@ -173,6 +176,12 @@ def evaluate(self, omega, **kwargs): def __len__(self): return len(self._filter) + def _ctf_params(self): + """ + Return n_filters-by-n_param array from prior filter. + """ + return self._filter._ctf_params() + class FunctionFilter(Filter): """ @@ -254,6 +263,12 @@ def evaluate_grid(self, L, *args, dtype=np.float32, **kwargs): def __len__(self): return len(self._filter) + def _ctf_params(self): + """ + Return n_filters-by-n_param array from prior filter. + """ + return self._filter._ctf_params() + class LambdaFilter(Filter): """ @@ -277,6 +292,12 @@ def __len__(self): def __getitem__(self, item): return LambdaFilter(self._filter[item], self._f) + def _ctf_params(self): + """ + Return n_filters-by-n_param array from prior filter. + """ + return self._filter._ctf_params() + class MultiplicativeFilter(Filter): """ @@ -314,6 +335,21 @@ def _evaluate(self, omega, **kwargs): def __len__(self): return self._n + def _ctf_params(self): + """ + Return n_filters-by-n_param array from prior filter. + + Raises error if multiple or none found. + """ + _params = [getattr(c, "_ctf_params", None) for c in self._components] + _params = list(filter(_params, None)) + if len(_params) > 1: + raise RuntimeError("Multiple filters with CTF parameters found.") + elif len(_params) == 0: + raise RuntimeError("No CTF parameters found.") + + return _params[0] + class ScaledFilter(Filter): """ @@ -342,6 +378,12 @@ def __getitem__(self, item): def __len__(self): return len(self._filter) + def _ctf_params(self): + """ + Return n_filters-by-n_param array from prior filter. + """ + return self._filter._ctf_params() + class ArrayFilter(Filter): def __init__(self, xfer_fn_array): From 497b75a97da3c07cc4f2cef92098b8354bb0fe40 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 21 May 2026 13:03:19 -0400 Subject: [PATCH 32/50] revert last approach in favor of using evaluate per dev meeting --- src/aspire/covariance/covar2d.py | 23 ++++++++++++----------- src/aspire/operators/filters.py | 29 ++++++++++++++++++++++++----- 2 files changed, 36 insertions(+), 16 deletions(-) diff --git a/src/aspire/covariance/covar2d.py b/src/aspire/covariance/covar2d.py index 6e833b1b08..9ca3098715 100644 --- a/src/aspire/covariance/covar2d.py +++ b/src/aspire/covariance/covar2d.py @@ -562,12 +562,15 @@ def filters_to_basis_mats(self): optimized_expand = callable( getattr(self.basis.__class__, "expand_radial_vec", None) ) - if optimized_expand and isinstance(self.src.filter_stack, CTFFilter): + if optimized_expand and self.src.filter_stack.radial: logger.info( - "Found all filters are CTF, and `basis.expand_radial_vec` available using, bulk basis mat eval" + "Found radial filter stack and `basis.expand_radial_vec` available." + " Using bulk basis mat eval." ) - return self._ctf_filters_to_basis_mats() + return self._filter_stack_to_basis_mats() else: + # Note, can come back and optmize the filter eval to bulk, just not radial + # For now use legacy path. logger.info("Using sequential basis mat eval") return self._filters_to_basis_mats() @@ -585,20 +588,18 @@ def _filters_to_basis_mats(self): ] return basis_mats - def _ctf_filters_to_basis_mats(self): - # lol - logger.info("Extracting CTF filter parameters and generating eval points") - params = self.src.filter_stack._ctf_params() - - logger.info("Computing CTF filters at eval points") + # todo, either rename _radial_filters_to_basis_mats or handle none radial + # same remark as `filters_to_basis_mats` + def _filter_stack_to_basis_mats(self): + logger.info("Generating filter eval points") _filter_pts = self.basis._filter_pts # if we have many filters, might be worth trip to GPU if len(self.src.filter_stack) >= 2048: params = xp.asarray(params) _filter_pts = xp.asarray(_filter_pts) - _filter_vals = CTFFilter.ctf_formula( - _filter_pts, self.src.pixel_size, *(params.T) + _filter_vals = self.src.filter_stack.evaluate( + _filter_pts, pixel_size=self.src.pixel_size ) logger.info("Computing basis radial expansion") diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index fc9328f1c6..ceb6048d8b 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -89,7 +89,16 @@ def evaluate(self, omega, **kwargs): h = self._evaluate(omega, **kwargs) if self.radial: - h = np.take(h, idx) + # The reshape and take axis gynmastics work to provide the + # legacy idx taking functionality for both singleton and + # stack cases. + # reshape (stack, vals) + h = h.reshape(len(self), -1) + # keep stack, take along vals axis + h = np.take(h, idx, axis=-1) + # squeeze off stack dim from singleton case (preserves legacy behavior) + if h.shape[0] == 1: # avoid error when len(h)>1 + h = np.squeeze(h, axis=0) return h @@ -158,7 +167,12 @@ def __len__(self): return 1 def _ctf_params(self): - raise NotImplementedError(f"Not implemented for {self.__class__.__name__}") + """ + Return n_filters-by-n_param array from prior filter. + """ + raise NotImplementedError( + f"_ctf_params not implemented for {self.__class__.__name__}" + ) class DualFilter(Filter): @@ -329,7 +343,7 @@ def _init_size(self): def _evaluate(self, omega, **kwargs): res = 1 for c in self._components: - res *= c.evaluate(omega, **kwargs) + res = res * c.evaluate(omega, **kwargs) return res def __len__(self): @@ -341,8 +355,13 @@ def _ctf_params(self): Raises error if multiple or none found. """ - _params = [getattr(c, "_ctf_params", None) for c in self._components] - _params = list(filter(_params, None)) + _params = [] + for c in self._components: + try: + _params.append(c._ctf_params()) + except NotImplementedError as e: + pass + if len(_params) > 1: raise RuntimeError("Multiple filters with CTF parameters found.") elif len(_params) == 0: From 96806109cce80e7939045c25bb92f2a6f17554f8 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 21 May 2026 13:29:27 -0400 Subject: [PATCH 33/50] rm unused var --- src/aspire/covariance/covar2d.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/aspire/covariance/covar2d.py b/src/aspire/covariance/covar2d.py index 9ca3098715..abbc2ecf65 100644 --- a/src/aspire/covariance/covar2d.py +++ b/src/aspire/covariance/covar2d.py @@ -595,7 +595,6 @@ def _filter_stack_to_basis_mats(self): _filter_pts = self.basis._filter_pts # if we have many filters, might be worth trip to GPU if len(self.src.filter_stack) >= 2048: - params = xp.asarray(params) _filter_pts = xp.asarray(_filter_pts) _filter_vals = self.src.filter_stack.evaluate( From ce309f85ce562a596815702c0a57ae83345b1f1d Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 22 May 2026 08:57:28 -0400 Subject: [PATCH 34/50] satisfy tox --- src/aspire/covariance/covar2d.py | 2 +- src/aspire/operators/filters.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/aspire/covariance/covar2d.py b/src/aspire/covariance/covar2d.py index abbc2ecf65..d6f2150acc 100644 --- a/src/aspire/covariance/covar2d.py +++ b/src/aspire/covariance/covar2d.py @@ -7,7 +7,7 @@ from aspire.basis import Coef, FFBBasis2D from aspire.numeric import xp -from aspire.operators import BlkDiagMatrix, CTFFilter, DiagMatrix +from aspire.operators import BlkDiagMatrix, DiagMatrix from aspire.optimization import conj_grad, fill_struct from aspire.utils import make_symmat, tqdm diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index ceb6048d8b..904a46fe04 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -359,7 +359,7 @@ def _ctf_params(self): for c in self._components: try: _params.append(c._ctf_params()) - except NotImplementedError as e: + except NotImplementedError: pass if len(_params) > 1: From 73a3cfd6622e6d9a0286c3a107fc8bb348e33adf Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 22 May 2026 12:02:34 -0400 Subject: [PATCH 35/50] stashing, got filter stack to basis mat eval working for ffb2d need to add legacy branch logic for other basis [skip ci] --- src/aspire/basis/ffb_2d.py | 77 ++++++++++++++++++++++++++++++++ src/aspire/covariance/covar2d.py | 35 +++++++++++---- src/aspire/operators/filters.py | 9 ++++ 3 files changed, 113 insertions(+), 8 deletions(-) diff --git a/src/aspire/basis/ffb_2d.py b/src/aspire/basis/ffb_2d.py index 239e945ec5..4af6ccd7b9 100644 --- a/src/aspire/basis/ffb_2d.py +++ b/src/aspire/basis/ffb_2d.py @@ -288,6 +288,7 @@ def _filter_to_basis_mat(self, f, **kwargs): omegay = k * np.sin(theta) omega = 2 * np.pi * np.vstack((omegax.flatten("C"), omegay.flatten("C"))) + # this should be either a stack h_vals2d = ( h_fun(omega, pixel_size=pixel_size).reshape(n_k, n_theta).astype(self.dtype) ) @@ -314,6 +315,82 @@ def _filter_to_basis_mat(self, f, **kwargs): return h_basis + def _filter_stack_to_basis_mats(self, f, **kwargs): + """ + See `SteerableBasis2D.filter_to_basis_mat`. + """ + # Note 'method' and 'truncate' not relevant for this specific FFB code. + # Method `radial` should have already been diverted. + expand_method = kwargs.get("expand_method", None) + if expand_method is not None: + raise NotImplementedError( + f"`FFBBasis2D.filter_to_basis_mat` expand_method '{expand_method}' not supported." + " Use `expand_method=None`." + ) + + pixel_size = kwargs.get("pixel_size", None) + + # These form a circular dependence, import locally until time to clean up. + from aspire.basis.basis_utils import lgwt + + # Get the filter's evaluate function. + h_fun = f.evaluate + + # Set same dimensions as basis object + n_k = self.n_r + n_theta = self.n_theta + radial = self._precomp["radial"] + + # get 2D grid in polar coordinate + # Confirm this lgwt call with Joakim (should it follow basis config self.kcut? same by default) + k_vals, wts = lgwt(n_k, 0, 0.5, dtype=self.dtype) + k, theta = np.meshgrid( + k_vals, np.arange(n_theta) * 2 * np.pi / (2 * n_theta), indexing="ij" + ) + + # Get function values in polar 2D grid and average out angle contribution + omegax = k * np.cos(theta) + omegay = k * np.sin(theta) + omega = 2 * np.pi * np.vstack((omegax.flatten("C"), omegay.flatten("C"))) + + # This should return either a single 2d array, or stack of 2d arrays + # Reshape singleton to stack of 1. + h_vals2d = ( + h_fun(omega, pixel_size=pixel_size) + .reshape(len(f), n_k, n_theta) + .astype(self.dtype) + ) + h_vals = np.sum(h_vals2d, axis=-1) / n_theta + + # Represent each 1D functions values in basis + h_basis = [ + BlkDiagMatrix.empty(2 * self.ell_max + 1, dtype=self.dtype) for _ in h_vals + ] + ind_ell = 0 + for ell in range(0, self.ell_max + 1): + k_max = self.k_max[ell] + # xxx todo, we can skip computing rmat, just need the shape + rmat = 2 * k_vals.reshape(n_k, 1) * self.r0[ell][0:k_max].T + basis_vals = np.zeros_like(rmat) + ind_radial = np.sum(self.k_max[0:ell]) + basis_vals[:, 0:k_max] = radial[ind_radial : ind_radial + k_max].T + h_basis_vals = basis_vals * h_vals.reshape( + len(f), n_k, 1 + ) # check bcast here + h_basis_ell = basis_vals.T @ ( + h_basis_vals * k_vals.reshape(n_k, 1) * wts.reshape(n_k, 1) + ) + # loop over assignment blocks. + for i in range(len(f)): + h_basis[i][ind_ell] = h_basis_ell[i] + ind_ell += 1 + if ell > 0: + for i in range(len(f)): + h_basis[i][ind_ell] = h_basis[i][ind_ell - 1] + ind_ell += 1 + + return h_basis + def expand_radial_vec(self, radial_vec, force_diag=False): """ Expands radial vector or stack of vetors `radial_vec` to basis matrix. diff --git a/src/aspire/covariance/covar2d.py b/src/aspire/covariance/covar2d.py index d6f2150acc..830b957823 100644 --- a/src/aspire/covariance/covar2d.py +++ b/src/aspire/covariance/covar2d.py @@ -578,14 +578,33 @@ def _filters_to_basis_mats(self): """ old code, should work with all basis and filters. slow. """ - basis_mats = [ - self.basis.filter_to_basis_mat( - f, pixel_size=self.src.pixel_size, expand_method=self.expand_method - ) - for f in tqdm( - self.src.filter_stack, desc="Converting filters to basis mats" - ) - ] + basis_mats = self.basis._filter_stack_to_basis_mats( + self.src.filter_stack, + pixel_size=self.src.pixel_size, + expand_method=self.expand_method, + ) + + # ## Legacy + # old_basis_mats = [ + # self.basis.filter_to_basis_mat( + # f, pixel_size=self.src.pixel_size, expand_method=self.expand_method + # ) + # for f in tqdm( + # self.src.filter_stack, desc="Converting filters to basis mats" + # ) + # ] + + # from tqdm import trange + # diff = 0 + # for i in trange(len(basis_mats)): + # a = basis_mats[i] + # ref= old_basis_mats[i] + # for j in range(len(ref)): + # diff += np.sum(a[j]-ref[j]) + + # print("sum of diff across all filters", diff) + # breakpoint() + return basis_mats # todo, either rename _radial_filters_to_basis_mats or handle none radial diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index 904a46fe04..c732c3951f 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -283,6 +283,9 @@ def _ctf_params(self): """ return self._filter._ctf_params() + def __getitem__(self, item): + return PowerFilter(self._filter[item], power=self._power, epsilon=self._epsilon) + class LambdaFilter(Filter): """ @@ -369,6 +372,9 @@ def _ctf_params(self): return _params[0] + def __getitem__(self, item): + return MultiplicativeFilter(*list(c[item] for c in self._components)) + class ScaledFilter(Filter): """ @@ -511,6 +517,9 @@ def __repr__(self): def _evaluate(self, omega, **kwargs): return self.value * np.ones_like(omega) + def __getitem__(self, item): + return self + class ZeroFilter(ScalarFilter): def __init__(self, dim=None): From cf839d6b678283511a5335c061097977409cdbda Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 26 May 2026 09:29:56 -0400 Subject: [PATCH 36/50] begin filter_basis_mat cleanup --- src/aspire/basis/fb_2d.py | 6 +++--- src/aspire/basis/ffb_2d.py | 11 ++++++++++- src/aspire/basis/fle_2d.py | 11 ++++++++++- src/aspire/basis/fpswf_2d.py | 6 +++--- src/aspire/basis/fspca.py | 2 +- src/aspire/basis/pswf_2d.py | 6 +++--- src/aspire/basis/steerable.py | 32 ++++++++++++++++++++++---------- src/aspire/covariance/covar2d.py | 25 +++++++++---------------- 8 files changed, 61 insertions(+), 38 deletions(-) diff --git a/src/aspire/basis/fb_2d.py b/src/aspire/basis/fb_2d.py index 3c02414535..f6c246b8bc 100644 --- a/src/aspire/basis/fb_2d.py +++ b/src/aspire/basis/fb_2d.py @@ -289,8 +289,8 @@ def calculate_bispectrum( freq_cutoff=freq_cutoff, ) - def _filter_to_basis_mat(self, *args, **kwargs): + def filter_to_basis_mat(self, *args, **kwargs): """ - See `SteerableBasis2D.filter_to_basis_mat`. + See `SteerableBasis2D.filter_stack_to_basis_mat`. """ - return super()._filter_to_basis_mat(*args, **kwargs) + return super().filter_to_basis_mat(*args, **kwargs) diff --git a/src/aspire/basis/ffb_2d.py b/src/aspire/basis/ffb_2d.py index 4af6ccd7b9..0015a29eae 100644 --- a/src/aspire/basis/ffb_2d.py +++ b/src/aspire/basis/ffb_2d.py @@ -250,7 +250,8 @@ def _evaluate_t(self, x): return xp.asnumpy(v) - def _filter_to_basis_mat(self, f, **kwargs): + # XXX for testing comparison + def _legacy_filter_to_basis_mat(self, f, **kwargs): """ See `SteerableBasis2D.filter_to_basis_mat`. """ @@ -391,6 +392,14 @@ def _filter_stack_to_basis_mats(self, f, **kwargs): return h_basis + def filter_to_basis_mat(self, f, **kwargs): + """ + See `SteerableBasis2D.filter_stack_to_basis_mats`. + """ + if len(f) != 1: + raise RuntimeError("Unexpected filter length.") + return self._filter_stack_to_basis_mats(f, **kwargs)[0] + def expand_radial_vec(self, radial_vec, force_diag=False): """ Expands radial vector or stack of vetors `radial_vec` to basis matrix. diff --git a/src/aspire/basis/fle_2d.py b/src/aspire/basis/fle_2d.py index f6e9b5d905..92e241b9c2 100644 --- a/src/aspire/basis/fle_2d.py +++ b/src/aspire/basis/fle_2d.py @@ -780,7 +780,16 @@ def _radial_convolve_weights(self, b): return a - def _filter_to_basis_mat(self, f, **kwargs): + # def filter_to_basis_mat(self, f, **kwargs): + # """ + # See `SteerableBasis2D.filter_stack_to_basis_mats`. + # """ + # if len(f) != 1: + # raise RuntimeError("Unexpected filter length.") + # return self._filter_stack_to_basis_mats(f, **kwargs)[0] + + # XXX TODO, convert to _filter_stack_to_basis_mats via broadcasting. + def filter_to_basis_mat(self, f, **kwargs): """ See `SteerableBasis2D.filter_to_basis_mat`. diff --git a/src/aspire/basis/fpswf_2d.py b/src/aspire/basis/fpswf_2d.py index 4c498757e7..e1215103d0 100644 --- a/src/aspire/basis/fpswf_2d.py +++ b/src/aspire/basis/fpswf_2d.py @@ -367,8 +367,8 @@ def _pswf_integration(self, images_nufft): return coef_vec_quad - def _filter_to_basis_mat(self, *args, **kwargs): + def filter_to_basis_mat(self, *args, **kwargs): """ - See `SteerableBasis2D.filter_to_basis_mat`. + See `SteerableBasis2D.filter_stack_to_basis_mat`. """ - return super()._filter_to_basis_mat(*args, **kwargs) + return super().filter_to_basis_mat(*args, **kwargs) diff --git a/src/aspire/basis/fspca.py b/src/aspire/basis/fspca.py index 053780b27a..d2f78b2ff7 100644 --- a/src/aspire/basis/fspca.py +++ b/src/aspire/basis/fspca.py @@ -617,7 +617,7 @@ def shift(self, coef, shifts): self.evaluate_to_image_basis(coef).shift(shifts) ) - def _filter_to_basis_mat(self, f, **kwargs): + def filter_to_basis_mat(self, f, **kwargs): """ Convert a filter into a basis representation. diff --git a/src/aspire/basis/pswf_2d.py b/src/aspire/basis/pswf_2d.py index 632aab14a3..d03c483a23 100644 --- a/src/aspire/basis/pswf_2d.py +++ b/src/aspire/basis/pswf_2d.py @@ -398,8 +398,8 @@ def _pswf_2d_minor_computations(self, big_n, n, bandlimit, phi_approximate_error range_array = np.arange(approx_length, dtype=self.dtype) return d_vec, approx_length, range_array - def _filter_to_basis_mat(self, *args, **kwargs): + def filter_to_basis_mat(self, *args, **kwargs): """ - See `SteerableBasis2D.filter_to_basis_mat`. + See `SteerableBasis2D.filter_stack_to_basis_mat`. """ - return super()._filter_to_basis_mat(*args, **kwargs) + return super().filter_to_basis_mat(*args, **kwargs) diff --git a/src/aspire/basis/steerable.py b/src/aspire/basis/steerable.py index ddb7424b77..03586a7a19 100644 --- a/src/aspire/basis/steerable.py +++ b/src/aspire/basis/steerable.py @@ -6,7 +6,7 @@ from aspire.basis import Basis, Coef, ComplexCoef from aspire.operators import BlkDiagMatrix -from aspire.utils import LogFilterByCount, complex_type, real_type, trange +from aspire.utils import LogFilterByCount, complex_type, real_type, tqdm, trange logger = logging.getLogger(__name__) @@ -476,17 +476,18 @@ def to_complex(self, coef): return ComplexCoef(self, complex_coef) - def filter_to_basis_mat(self, f, **kwargs): + def filter_stack_to_basis_mats(self, f, **kwargs): """ - Convert a filter into a basis operator representation. + Convert a filter stack into a list of basis operator representations. - See `_filter_to_basis_mat` here and in subclasses for available **kwargs. + See `_filter_stack_to_basis_mats` and `filter_to_basis_mat` + here and in subclasses for available **kwargs. - :param f: `Filter` object, usually a `CTFFilter`. - :param radial_optimization: Optionally attempt radial approximation if available. + :param f: `Filter` object, for example a `CTFFilter`. - :return: Representation of filter as `basis` operator. - Return type will be based on the class's `matrix_type`. + :return: List containing representations of Filter as `basis` operators. + Return type of list elements will be based on the class's `matrix_type`, + typically `BlkDiagMatrix` or `DiagMatrix`. """ # does the basis have optimized expand for radial vectors? @@ -496,6 +497,7 @@ def filter_to_basis_mat(self, f, **kwargs): # did user request the special radial expansion method? radial_method = kwargs.get("expand_method", None) == "radial" + # xxx, do we need this block anymore? (i dont think so, I think it was just bridge code?)... if optimized_expand and filter_is_radial and radial_method: # kwargs supports passing through pixel_size h_vals = self._radial_filter_to_vals( @@ -505,14 +507,24 @@ def filter_to_basis_mat(self, f, **kwargs): return res else: # use generic (legacy) filter path/code (may return DiagMatrix) - return self._filter_to_basis_mat(f, **kwargs) + return self._filter_stack_to_basis_mats(f, **kwargs) + + def _filter_stack_to_basis_mats(self, f, **kwargs): + """ + Helper function for sequentially evaluating filters in a basis that does not provide optimized filter_stack_to_basis_mats. + """ + basis_mats = [None] * len(f) + for i, _f in enumerate(tqdm(f, desc="Converting filters to basis mats")): + basis_mats[i] = self.filter_to_basis_mat(_f, **kwargs) + return basis_mats # `abstractmethod` enforces when a new subclass of # `SteerableBasis2D` is created that this method is explicitly # implemented. This is intended to encourage future basis authors # to consider this method for their application. + # When possible, they should prefer to create an optimized _filter_stack_to_basis_mats. @abc.abstractmethod - def _filter_to_basis_mat(self, f, expand_method=None, truncate=True, **kwargs): + def filter_to_basis_mat(self, f, expand_method=None, truncate=True, **kwargs): """ Convert a filter into a basis operator representation. diff --git a/src/aspire/covariance/covar2d.py b/src/aspire/covariance/covar2d.py index 830b957823..8ca0729795 100644 --- a/src/aspire/covariance/covar2d.py +++ b/src/aspire/covariance/covar2d.py @@ -9,7 +9,7 @@ from aspire.numeric import xp from aspire.operators import BlkDiagMatrix, DiagMatrix from aspire.optimization import conj_grad, fill_struct -from aspire.utils import make_symmat, tqdm +from aspire.utils import make_symmat logger = logging.getLogger(__name__) @@ -559,15 +559,20 @@ def _build(self): logger.info("Representing filters in basis complete") def filters_to_basis_mats(self): + """ + Dispatch between various methods for converting filter stacks to basis matrices. + """ + # Does the basis provide radially optimized expansion? optimized_expand = callable( getattr(self.basis.__class__, "expand_radial_vec", None) ) + if optimized_expand and self.src.filter_stack.radial: logger.info( "Found radial filter stack and `basis.expand_radial_vec` available." " Using bulk basis mat eval." ) - return self._filter_stack_to_basis_mats() + return self._radial_filter_stack_to_basis_mats() else: # Note, can come back and optmize the filter eval to bulk, just not radial # For now use legacy path. @@ -578,22 +583,12 @@ def _filters_to_basis_mats(self): """ old code, should work with all basis and filters. slow. """ - basis_mats = self.basis._filter_stack_to_basis_mats( + basis_mats = self.basis.filter_stack_to_basis_mats( self.src.filter_stack, pixel_size=self.src.pixel_size, expand_method=self.expand_method, ) - # ## Legacy - # old_basis_mats = [ - # self.basis.filter_to_basis_mat( - # f, pixel_size=self.src.pixel_size, expand_method=self.expand_method - # ) - # for f in tqdm( - # self.src.filter_stack, desc="Converting filters to basis mats" - # ) - # ] - # from tqdm import trange # diff = 0 # for i in trange(len(basis_mats)): @@ -607,9 +602,7 @@ def _filters_to_basis_mats(self): return basis_mats - # todo, either rename _radial_filters_to_basis_mats or handle none radial - # same remark as `filters_to_basis_mats` - def _filter_stack_to_basis_mats(self): + def _radial_filter_stack_to_basis_mats(self): logger.info("Generating filter eval points") _filter_pts = self.basis._filter_pts # if we have many filters, might be worth trip to GPU From 6725c17e4d9d391c1d99a6f3659a1e0d13063b48 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 26 May 2026 14:52:40 -0400 Subject: [PATCH 37/50] continue filter_basis_mat cleanup --- src/aspire/basis/fle_2d.py | 39 ++++++++++++++++------------ src/aspire/covariance/covar2d.py | 44 ++++++++++---------------------- 2 files changed, 36 insertions(+), 47 deletions(-) diff --git a/src/aspire/basis/fle_2d.py b/src/aspire/basis/fle_2d.py index 92e241b9c2..ac78051314 100644 --- a/src/aspire/basis/fle_2d.py +++ b/src/aspire/basis/fle_2d.py @@ -780,16 +780,15 @@ def _radial_convolve_weights(self, b): return a - # def filter_to_basis_mat(self, f, **kwargs): - # """ - # See `SteerableBasis2D.filter_stack_to_basis_mats`. - # """ - # if len(f) != 1: - # raise RuntimeError("Unexpected filter length.") - # return self._filter_stack_to_basis_mats(f, **kwargs)[0] - - # XXX TODO, convert to _filter_stack_to_basis_mats via broadcasting. def filter_to_basis_mat(self, f, **kwargs): + """ + See `SteerableBasis2D.filter_stack_to_basis_mats`. + """ + if len(f) != 1: + raise RuntimeError("Unexpected filter length.") + return self._filter_stack_to_basis_mats(f, **kwargs) + + def _filter_stack_to_basis_mats(self, f, **kwargs): """ See `SteerableBasis2D.filter_to_basis_mat`. @@ -840,20 +839,28 @@ def filter_to_basis_mat(self, f, **kwargs): h_vals2d = ( xp.asarray(h_fun(omega, pixel_size=pixel_size)) - .reshape(n_k, n_theta) + .reshape(len(f), n_k, n_theta) .astype(self.dtype, copy=False) ) - h_vals = xp.sum(h_vals2d, axis=1) / n_theta + h_vals = xp.sum(h_vals2d, axis=-1) / n_theta - h_basis = xp.zeros(self.count, dtype=self.dtype) - # For now we just need to handle 1D (stack of one ctf) + h_basis = xp.zeros((len(f), self.count), dtype=self.dtype) + # shape gymnastics to get a broadcast with csr A3 + h_vals = h_vals.T for j in range(self.ell_p_max + 1): - h_basis[self.idx_list[j]] = self.A3[j] @ h_vals + h_basis[:, self.idx_list[j]] = (self.A3[j] @ h_vals).T # Convert from internal FLE ordering to FB convention - h_basis = h_basis[self._fle_to_fb_indices] + h_basis = h_basis[:, self._fle_to_fb_indices] + # who needs this as a list? + + coefs = xp.asnumpy(h_basis) + if len(coefs) > 1: + coefs = [DiagMatrix(c) for c in coefs] + else: + coefs = DiagMatrix(coefs.flatten()) - return DiagMatrix(xp.asnumpy(h_basis)) + return coefs def expand_radial_vec(self, radial_vec, **kwargs): """ diff --git a/src/aspire/covariance/covar2d.py b/src/aspire/covariance/covar2d.py index 8ca0729795..74fc63e6e1 100644 --- a/src/aspire/covariance/covar2d.py +++ b/src/aspire/covariance/covar2d.py @@ -567,40 +567,22 @@ def filters_to_basis_mats(self): getattr(self.basis.__class__, "expand_radial_vec", None) ) + # Are the filters radial? + if self.src.filter_stack.radial: + logger.info("Found radial filter stack.") + else: + logger.info("Found non-radial filter stack.") + if optimized_expand and self.src.filter_stack.radial: - logger.info( - "Found radial filter stack and `basis.expand_radial_vec` available." - " Using bulk basis mat eval." - ) + logger.info("Using optimized `basis.expand_radial_vec`.") return self._radial_filter_stack_to_basis_mats() else: - # Note, can come back and optmize the filter eval to bulk, just not radial - # For now use legacy path. - logger.info("Using sequential basis mat eval") - return self._filters_to_basis_mats() - - def _filters_to_basis_mats(self): - """ - old code, should work with all basis and filters. slow. - """ - basis_mats = self.basis.filter_stack_to_basis_mats( - self.src.filter_stack, - pixel_size=self.src.pixel_size, - expand_method=self.expand_method, - ) - - # from tqdm import trange - # diff = 0 - # for i in trange(len(basis_mats)): - # a = basis_mats[i] - # ref= old_basis_mats[i] - # for j in range(len(ref)): - # diff += np.sum(a[j]-ref[j]) - - # print("sum of diff across all filters", diff) - # breakpoint() - - return basis_mats + logger.info("Using basis.filter_stack_to_basis_mats.") + return self.basis.filter_stack_to_basis_mats( + self.src.filter_stack, + pixel_size=self.src.pixel_size, + expand_method=self.expand_method, + ) def _radial_filter_stack_to_basis_mats(self): logger.info("Generating filter eval points") From fd4a92aee432e4a18219cfd3a696b850186d82c1 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 27 May 2026 13:07:33 -0400 Subject: [PATCH 38/50] initial documentation updates --- .../save_simulation_relion_reconstruct.py | 4 ++-- .../experiments/simulated_abinitio_pipeline.py | 12 +++++++----- gallery/tutorials/aspire_introduction.py | 7 ++----- gallery/tutorials/pipeline_demo.py | 8 +++----- gallery/tutorials/tutorials/cov2d_simulation.py | 17 +++++++++-------- gallery/tutorials/tutorials/cov3d_simulation.py | 2 +- gallery/tutorials/tutorials/ctf.py | 7 +++---- .../tutorials/tutorials/micrograph_source.py | 4 +--- .../tutorials/tutorials/orient3d_simulation.py | 13 ++++++++----- .../tutorials/tutorials/preprocess_imgs_sim.py | 13 ++++++++----- src/aspire/operators/filters.py | 5 +++++ src/aspire/source/image.py | 6 ++++-- src/aspire/utils/units.py | 5 ++--- 13 files changed, 55 insertions(+), 48 deletions(-) diff --git a/gallery/experiments/save_simulation_relion_reconstruct.py b/gallery/experiments/save_simulation_relion_reconstruct.py index e52f6bc2a8..a7ddfd3aae 100644 --- a/gallery/experiments/save_simulation_relion_reconstruct.py +++ b/gallery/experiments/save_simulation_relion_reconstruct.py @@ -51,7 +51,7 @@ # that RELION will recover as optics groups. vol = emdb_2660() -ctf_filters = [RadialCTFFilter(defocus=d) for d in defocus] +ctf_filters = RadialCTFFilter(defocus=defocus) # %% @@ -64,7 +64,7 @@ sim = Simulation( n=n_particles, vols=vol, - unique_filters=ctf_filters, + filter_stack=ctf_filters, noise_adder=WhiteNoiseAdder.from_snr(snr), ) sim.save(star_path, overwrite=True) diff --git a/gallery/experiments/simulated_abinitio_pipeline.py b/gallery/experiments/simulated_abinitio_pipeline.py index adf228499f..527e69e417 100644 --- a/gallery/experiments/simulated_abinitio_pipeline.py +++ b/gallery/experiments/simulated_abinitio_pipeline.py @@ -83,17 +83,19 @@ def noise_function(x, y): alpha = 0.1 # Amplitude contrast # Create filters -ctf_filters = [ - RadialCTFFilter(pixel_size, voltage, defocus=d, Cs=2.0, alpha=0.1) - for d in np.linspace(defocus_min, defocus_max, defocus_ct) -] +ctf_filters = RadialCTFFilter( + voltage, + defocus=np.linspace(defocus_min, defocus_max, defocus_ct), + Cs=2.0, + alpha=0.1, +) # Finally create the Simulation src = Simulation( n=num_imgs, vols=og_v, noise_adder=custom_noise, - unique_filters=ctf_filters, + filter_stack=ctf_filters, ) # Downsample diff --git a/gallery/tutorials/aspire_introduction.py b/gallery/tutorials/aspire_introduction.py index b0013335f5..cf67b76d03 100644 --- a/gallery/tutorials/aspire_introduction.py +++ b/gallery/tutorials/aspire_introduction.py @@ -570,10 +570,7 @@ def noise_function(x, y): defocus_ct = 7 # Generate several CTFs. -ctf_filters = [ - RadialCTFFilter(defocus=d) - for d in np.linspace(defocus_min, defocus_max, defocus_ct) -] +ctf_filters = RadialCTFFilter(defocus=np.linspace(defocus_min, defocus_max, defocus_ct)) # %% # Combining into a Simulation @@ -586,7 +583,7 @@ def noise_function(x, y): amplitudes=1, offsets=0, noise_adder=white_noise_adder, - unique_filters=ctf_filters, + filter_stack=ctf_filters, seed=42, ) diff --git a/gallery/tutorials/pipeline_demo.py b/gallery/tutorials/pipeline_demo.py index 0a192ecaf1..104f8f7c03 100644 --- a/gallery/tutorials/pipeline_demo.py +++ b/gallery/tutorials/pipeline_demo.py @@ -66,10 +66,8 @@ defocus_max = 25000 defocus_ct = 7 -ctf_filters = [ - RadialCTFFilter(defocus=d) - for d in np.linspace(defocus_min, defocus_max, defocus_ct) -] +ctf_filters = RadialCTFFilter(defocus=np.linspace(defocus_min, defocus_max, defocus_ct)) + # %% # Initialize Simulation Object @@ -96,7 +94,7 @@ n=2500, # number of projections vols=original_vol, # volume source offsets=0, # Default: images are randomly shifted - unique_filters=ctf_filters, + filter_stack=ctf_filters, noise_adder=WhiteNoiseAdder(var=0.0002), # desired noise variance ).cache() diff --git a/gallery/tutorials/tutorials/cov2d_simulation.py b/gallery/tutorials/tutorials/cov2d_simulation.py index 5d9b469b04..f7368e19cb 100644 --- a/gallery/tutorials/tutorials/cov2d_simulation.py +++ b/gallery/tutorials/tutorials/cov2d_simulation.py @@ -67,10 +67,13 @@ print("Initialize simulation object and CTF filters.") # Create filters -ctf_filters = [ - RadialCTFFilter(voltage, defocus=d, Cs=2.0, alpha=0.1) - for d in np.linspace(defocus_min, defocus_max, defocus_ct) -] +ctf_filters = RadialCTFFilter( + voltage, + defocus=np.linspace(defocus_min, defocus_max, defocus_ct), + Cs=2.0, + alpha=0.1, +) + # Load the map file of a 70S Ribosome print( @@ -89,7 +92,7 @@ L=img_size, n=num_imgs, vols=vols, - unique_filters=ctf_filters, + filter_stack=ctf_filters, offsets=0.0, amplitudes=1.0, dtype=dtype, @@ -109,9 +112,7 @@ h_idx = sim.filter_indices # Evaluate CTF in the 8X8 FB basis -h_ctf_fb = [ - ffbbasis.filter_to_basis_mat(filt, pixel_size=pixel_size) for filt in ctf_filters -] +h_ctf_fb = ffbbasis.filter_stack_to_basis_mats(ctf_filters, pixel_size=pixel_size) # Get clean images from projections of 3D map. print("Apply CTF filters to clean images.") diff --git a/gallery/tutorials/tutorials/cov3d_simulation.py b/gallery/tutorials/tutorials/cov3d_simulation.py index 4e46596a36..58d8e9d5cf 100644 --- a/gallery/tutorials/tutorials/cov3d_simulation.py +++ b/gallery/tutorials/tutorials/cov3d_simulation.py @@ -44,7 +44,7 @@ L=img_size, n=num_imgs, vols=vols, - unique_filters=[RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7)], + filter_stack=RadialCTFFilter(defocus=np.linspace(1.5e4, 2.5e4, 7)), dtype=dtype, ) diff --git a/gallery/tutorials/tutorials/ctf.py b/gallery/tutorials/tutorials/ctf.py index 218612dc8c..a269e3c35b 100644 --- a/gallery/tutorials/tutorials/ctf.py +++ b/gallery/tutorials/tutorials/ctf.py @@ -154,9 +154,8 @@ def generate_example_image(L, noise_variance=0.1): # Construct a range of CTF filters. defoci = [2500, 5000, 10000, 20000] -ctf_filters = [ - RadialCTFFilter(voltage=200, defocus=d, Cs=2.26, alpha=0.07, B=0) for d in defoci -] +ctf_filters = RadialCTFFilter(voltage=200, defocus=defoci, Cs=2.26, alpha=0.07, B=0) + # %% # Generate CTF corrupted Images @@ -334,7 +333,7 @@ def generate_example_image(L, noise_variance=0.1): from aspire.source import Simulation # Create the Source. ``ctf_filters`` are re-used from earlier section. -src = Simulation(L=64, n=4, unique_filters=ctf_filters, pixel_size=1) +src = Simulation(L=64, n=4, filter_stack=ctf_filters, pixel_size=1) src.images[:4].show() # %% diff --git a/gallery/tutorials/tutorials/micrograph_source.py b/gallery/tutorials/tutorials/micrograph_source.py index 26f2429925..4c6cae10dd 100644 --- a/gallery/tutorials/tutorials/micrograph_source.py +++ b/gallery/tutorials/tutorials/micrograph_source.py @@ -181,9 +181,7 @@ # Create our CTF Filter and add it to a list. # This configuration will apply the same CTF to all particles. -ctfs = [ - RadialCTFFilter(voltage=200, defocus=15000, Cs=2.26, alpha=0.07, B=0), -] +ctfs = RadialCTFFilter(voltage=200, defocus=15000, Cs=2.26, alpha=0.07, B=0) src = MicrographSimulation( vol, diff --git a/gallery/tutorials/tutorials/orient3d_simulation.py b/gallery/tutorials/tutorials/orient3d_simulation.py index 142fe177f4..a41f29d70b 100644 --- a/gallery/tutorials/tutorials/orient3d_simulation.py +++ b/gallery/tutorials/tutorials/orient3d_simulation.py @@ -49,10 +49,13 @@ print("Initialize simulation object and CTF filters.") # Create CTF filters -filters = [ - RadialCTFFilter(voltage, defocus=d, Cs=2.0, alpha=0.1) - for d in np.linspace(defocus_min, defocus_max, defocus_ct) -] +filters = RadialCTFFilter( + voltage, + defocus=np.linspace(defocus_min, defocus_max, defocus_ct), + Cs=2.0, + alpha=0.1, +) + # %% # Downsampling @@ -74,7 +77,7 @@ # Create a simulation object with specified filters and the downsampled 3D map print("Use downsampled map to creat simulation object.") sim = Simulation( - L=img_size, n=num_imgs, vols=vols, unique_filters=filters, pixel_size=5, dtype=dtype + L=img_size, n=num_imgs, vols=vols, filter_stack=filters, pixel_size=5, dtype=dtype ) print("Get true rotation angles generated randomly by the simulation object.") diff --git a/gallery/tutorials/tutorials/preprocess_imgs_sim.py b/gallery/tutorials/tutorials/preprocess_imgs_sim.py index 4e70c5f4cc..7c0023b5db 100644 --- a/gallery/tutorials/tutorials/preprocess_imgs_sim.py +++ b/gallery/tutorials/tutorials/preprocess_imgs_sim.py @@ -53,10 +53,13 @@ print("Initialize simulation object and CTF filters.") # Create CTF filters -ctf_filters = [ - RadialCTFFilter(voltage, defocus=d, Cs=2.0, alpha=0.1) - for d in np.linspace(defocus_min, defocus_max, defocus_ct) -] +ctf_filters = RadialCTFFilter( + voltage, + defocus=np.linspace(defocus_min, defocus_max, defocus_ct), + Cs=2.0, + alpha=0.1, +) + # Load the map file of a 70S ribosome and downsample the 3D map to desired image size. print("Load 3D map from mrc file") @@ -73,7 +76,7 @@ L=img_size, n=num_imgs, vols=vols, - unique_filters=ctf_filters, + filter_stack=ctf_filters, noise_adder=noise_adder, pixel_size=pixel_size, ) diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index c732c3951f..f0f8dedfd3 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -505,6 +505,11 @@ def evaluate_grid(self, L, *args, dtype=np.float32, **kwargs): res = super().evaluate_grid(L, *args, dtype=dtype, **kwargs) return res + def __getitem__(self, item): + # Note, could extend to a stack dimension and lookup. + # For now, we have no use case for that. + return self + class ScalarFilter(Filter): def __init__(self, dim=None, value=1): diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index 2483358578..22fd08b473 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -960,8 +960,10 @@ def whiten(self, noise_estimate=None, epsilon=None): whiten_filter = PowerFilter(noise_filter, power=-0.5, epsilon=epsilon) logger.info("Transforming all CTF Filters into Multiplicative Filters") - # XXXX - self.filter_stack = MultiplicativeFilter(self.filter_stack, whiten_filter) + if self.filter_stack is not None: + self.filter_stack = MultiplicativeFilter(self.filter_stack, whiten_filter) + else: + self.filter_stack = whiten_filter logger.info("Adding Whitening Filter Xform to end of generation pipeline") self.generation_pipeline.add_xform(FilterXform(whiten_filter)) diff --git a/src/aspire/utils/units.py b/src/aspire/utils/units.py index fd40864795..97d668f579 100644 --- a/src/aspire/utils/units.py +++ b/src/aspire/utils/units.py @@ -3,7 +3,6 @@ """ import logging -import math import numpy as np @@ -43,7 +42,7 @@ def voltage_to_wavelength(voltage): a = float(12.264259661581491) b = float(0.9784755917869367) - return a / math.sqrt(voltage * 1e3 + b * voltage**2) + return a / np.sqrt(voltage * 1e3 + b * voltage**2) def wavelength_to_voltage(wavelength): @@ -56,4 +55,4 @@ def wavelength_to_voltage(wavelength): a = float(12.264259661581491) b = float(0.9784755917869367) - return (-1e3 + math.sqrt(1e6 + 4 * a**2 * b / wavelength**2)) / (2 * b) + return (-1e3 + np.sqrt(1e6 + 4 * a**2 * b / wavelength**2)) / (2 * b) From 7bafc0e4959591a43025bcf517e5466799094704 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 27 May 2026 13:43:14 -0400 Subject: [PATCH 39/50] make np.unique call compat with older numpy --- src/aspire/operators/filters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index f0f8dedfd3..8aebd4eb50 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -331,7 +331,7 @@ def _init_size(self): Check sizes of _components are coherent and initialize resulting length. """ filter_lengths = [len(f) for f in self._components] - filter_lengths = np.unique(filter_lengths, sorted=True) + filter_lengths = np.unique(filter_lengths) # defaults to sorted=True # Code should be able to broadcast n_filters with n_filters, or n_filters with 1_filters. # Any other combination is considered an error. From 8905ddfd5d86faf7fe2d29e1b664ea6ed7ef830e Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 2 Jun 2026 08:10:36 -0400 Subject: [PATCH 40/50] should been better, was not --- src/aspire/source/image.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index 22fd08b473..21c2848b4a 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -22,7 +22,13 @@ Pipeline, ) from aspire.noise import LegacyNoiseEstimator, NoiseEstimator, WhiteNoiseEstimator -from aspire.operators import CTFFilter, Filter, MultiplicativeFilter, PowerFilter +from aspire.operators import ( + ArrayFilter, + CTFFilter, + Filter, + MultiplicativeFilter, + PowerFilter, +) from aspire.storage import MrcStats, StarFile from aspire.utils import ( Rotation, @@ -959,12 +965,12 @@ def whiten(self, noise_estimate=None, epsilon=None): logger.info("Whitening source object") whiten_filter = PowerFilter(noise_filter, power=-0.5, epsilon=epsilon) - logger.info("Transforming all CTF Filters into Multiplicative Filters") + logger.info(f"Extending filter stack by whitening filter") if self.filter_stack is not None: self.filter_stack = MultiplicativeFilter(self.filter_stack, whiten_filter) else: self.filter_stack = whiten_filter - logger.info("Adding Whitening Filter Xform to end of generation pipeline") + logger.info("Adding whitening FilterXform to end of generation pipeline") self.generation_pipeline.add_xform(FilterXform(whiten_filter)) @_as_copy @@ -998,6 +1004,16 @@ def legacy_whiten(self, noise_response=None, delta=None, batch_size=512): if delta is None: delta = np.finfo(np.float32).eps + # # XXX This "should be better" but totally breaks things. + # # First guess would be to check the strange normalization. + # # Can't fix everything at once. + # logger.info(f"Extending filter stack by legacy whitening Filter") + # whiten_filter = ArrayFilter(psd) + # if self.filter_stack is not None: + # self.filter_stack = MultiplicativeFilter(self.filter_stack, whiten_filter) + # else: + # self.filter_stack = whiten_filter + logger.info("Adding LegacyWhiten Filter Xform to end of generation pipeline.") self.generation_pipeline.add_xform(LegacyWhiten(psd, delta)) From 841cc8defcc03cd51052e34a169a064322c2e145 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 2 Jun 2026 08:22:03 -0400 Subject: [PATCH 41/50] first round cleanup --- src/aspire/basis/ffb_2d.py | 66 ----------------------------------- src/aspire/basis/steerable.py | 16 ++++----- 2 files changed, 8 insertions(+), 74 deletions(-) diff --git a/src/aspire/basis/ffb_2d.py b/src/aspire/basis/ffb_2d.py index 0015a29eae..d643df57a5 100644 --- a/src/aspire/basis/ffb_2d.py +++ b/src/aspire/basis/ffb_2d.py @@ -250,72 +250,6 @@ def _evaluate_t(self, x): return xp.asnumpy(v) - # XXX for testing comparison - def _legacy_filter_to_basis_mat(self, f, **kwargs): - """ - See `SteerableBasis2D.filter_to_basis_mat`. - """ - # Note 'method' and 'truncate' not relevant for this specific FFB code. - # Method `radial` should have already been diverted. - expand_method = kwargs.get("expand_method", None) - if expand_method is not None: - raise NotImplementedError( - f"`FFBBasis2D.filter_to_basis_mat` expand_method '{expand_method}' not supported." - " Use `expand_method=None`." - ) - - pixel_size = kwargs.get("pixel_size", None) - - # These form a circular dependence, import locally until time to clean up. - from aspire.basis.basis_utils import lgwt - - # Get the filter's evaluate function. - h_fun = f.evaluate - - # Set same dimensions as basis object - n_k = self.n_r - n_theta = self.n_theta - radial = self._precomp["radial"] - - # get 2D grid in polar coordinate - # Confirm this lgwt call with Joakim (should it follow basis config self.kcut? same by default) - k_vals, wts = lgwt(n_k, 0, 0.5, dtype=self.dtype) - k, theta = np.meshgrid( - k_vals, np.arange(n_theta) * 2 * np.pi / (2 * n_theta), indexing="ij" - ) - - # Get function values in polar 2D grid and average out angle contribution - omegax = k * np.cos(theta) - omegay = k * np.sin(theta) - omega = 2 * np.pi * np.vstack((omegax.flatten("C"), omegay.flatten("C"))) - - # this should be either a stack - h_vals2d = ( - h_fun(omega, pixel_size=pixel_size).reshape(n_k, n_theta).astype(self.dtype) - ) - h_vals = np.sum(h_vals2d, axis=1) / n_theta - - # Represent 1D function values in basis - h_basis = BlkDiagMatrix.empty(2 * self.ell_max + 1, dtype=self.dtype) - ind_ell = 0 - for ell in range(0, self.ell_max + 1): - k_max = self.k_max[ell] - rmat = 2 * k_vals.reshape(n_k, 1) * self.r0[ell][0:k_max].T - basis_vals = np.zeros_like(rmat) - ind_radial = np.sum(self.k_max[0:ell]) - basis_vals[:, 0:k_max] = radial[ind_radial : ind_radial + k_max].T - h_basis_vals = basis_vals * h_vals.reshape(n_k, 1) - h_basis_ell = basis_vals.T @ ( - h_basis_vals * k_vals.reshape(n_k, 1) * wts.reshape(n_k, 1) - ) - h_basis[ind_ell] = h_basis_ell - ind_ell += 1 - if ell > 0: - h_basis[ind_ell] = h_basis[ind_ell - 1] - ind_ell += 1 - - return h_basis - def _filter_stack_to_basis_mats(self, f, **kwargs): """ See `SteerableBasis2D.filter_to_basis_mat`. diff --git a/src/aspire/basis/steerable.py b/src/aspire/basis/steerable.py index 03586a7a19..f9474e81d4 100644 --- a/src/aspire/basis/steerable.py +++ b/src/aspire/basis/steerable.py @@ -490,28 +490,28 @@ def filter_stack_to_basis_mats(self, f, **kwargs): typically `BlkDiagMatrix` or `DiagMatrix`. """ - # does the basis have optimized expand for radial vectors? + # does the basis have optimized expansion for radial vectors? optimized_expand = callable(getattr(self.__class__, "expand_radial_vec", None)) # is the filter radial? filter_is_radial = f.radial # did user request the special radial expansion method? radial_method = kwargs.get("expand_method", None) == "radial" - # xxx, do we need this block anymore? (i dont think so, I think it was just bridge code?)... if optimized_expand and filter_is_radial and radial_method: # kwargs supports passing through pixel_size - h_vals = self._radial_filter_to_vals( - f, **kwargs - ) # check dont need #.reshape(-1, 1) + h_vals = self._radial_filter_to_vals(f, **kwargs) res = self.expand_radial_vec(h_vals) - return res else: # use generic (legacy) filter path/code (may return DiagMatrix) - return self._filter_stack_to_basis_mats(f, **kwargs) + res = self._filter_stack_to_basis_mats(f, **kwargs) + + return res def _filter_stack_to_basis_mats(self, f, **kwargs): """ - Helper function for sequentially evaluating filters in a basis that does not provide optimized filter_stack_to_basis_mats. + Helper function for sequentially evaluating filters in a basis. + + This is a crude fall back for basis that do not provide optimized `filter_stack_to_basis_mats`. """ basis_mats = [None] * len(f) for i, _f in enumerate(tqdm(f, desc="Converting filters to basis mats")): From 50695bcf837569044b463a22fb69ccb0c64beacd Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 2 Jun 2026 09:10:22 -0400 Subject: [PATCH 42/50] cleanup unused rmat and move reshapes out of loop --- src/aspire/basis/ffb_2d.py | 34 +++++++++++++++------------------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/src/aspire/basis/ffb_2d.py b/src/aspire/basis/ffb_2d.py index d643df57a5..e4db2bf4a3 100644 --- a/src/aspire/basis/ffb_2d.py +++ b/src/aspire/basis/ffb_2d.py @@ -302,19 +302,18 @@ def _filter_stack_to_basis_mats(self, f, **kwargs): BlkDiagMatrix.empty(2 * self.ell_max + 1, dtype=self.dtype) for _ in h_vals ] ind_ell = 0 + # Reshapes for broadcasting + k_vals = k_vals.reshape(n_k, 1) + wts = wts.reshape(n_k, 1) + h_vals = h_vals.reshape(len(f), n_k, 1) for ell in range(0, self.ell_max + 1): k_max = self.k_max[ell] - # xxx todo, we can skip computing rmat, just need the shape - rmat = 2 * k_vals.reshape(n_k, 1) * self.r0[ell][0:k_max].T - basis_vals = np.zeros_like(rmat) + basis_vals = np.zeros((n_k, k_max), dtype=self.dtype) ind_radial = np.sum(self.k_max[0:ell]) basis_vals[:, 0:k_max] = radial[ind_radial : ind_radial + k_max].T - h_basis_vals = basis_vals * h_vals.reshape( - len(f), n_k, 1 - ) # check bcast here - h_basis_ell = basis_vals.T @ ( - h_basis_vals * k_vals.reshape(n_k, 1) * wts.reshape(n_k, 1) - ) + h_basis_vals = basis_vals * h_vals + h_basis_ell = basis_vals.T @ (h_basis_vals * k_vals * wts) + # loop over assignment blocks. for i in range(len(f)): h_basis[i][ind_ell] = h_basis_ell[i] @@ -353,8 +352,6 @@ def expand_radial_vec(self, radial_vec, force_diag=False): n_k = self.n_r radial = self._precomp["radial"] - # hrrmm, ask Joakim can we always use the basis precomp, or do we need to use lgwt as in the old filter_to_basis_mat? - # This is doing opposite logic (same result) by default. Joy. k_vals = xp.asarray(self._precomp["gl_nodes"]) wts = xp.asarray(self._precomp["gl_weights"]) @@ -365,20 +362,19 @@ def expand_radial_vec(self, radial_vec, force_diag=False): ] ind_ell = 0 + # Reshapes for broadcasting + radial_vec = radial_vec.reshape(len(h_basis), n_k, 1) + k_vals = k_vals.reshape(1, n_k, 1) + wts = wts.reshape(1, n_k, 1) for ell in range(0, self.ell_max + 1): k_max = self.k_max[ell] - rmat = ( - 2 * xp.asnumpy(k_vals.reshape(n_k, 1)) * self.r0[ell][0:k_max].T - ) # WHAT IN THE WORLD IS GOING ON HERE - basis_vals = xp.zeros_like(rmat) + basis_vals = np.zeros((n_k, k_max), dtype=self.dtype) ind_radial = np.sum(self.k_max[0:ell]) basis_vals[:, 0:k_max] = xp.asarray( radial[ind_radial : ind_radial + k_max] ).T - h_basis_vals = basis_vals * radial_vec.reshape(len(h_basis), n_k, 1) - h_basis_ell = basis_vals.T @ ( - h_basis_vals * k_vals.reshape(1, n_k, 1) * wts.reshape(1, n_k, 1) - ) + h_basis_vals = basis_vals * radial_vec + h_basis_ell = basis_vals.T @ (h_basis_vals * k_vals * wts) h_basis_ell = xp.asnumpy(h_basis_ell) for _filter in range(len(radial_vec)): _tmp = h_basis[_filter][ind_ell] = h_basis_ell[_filter] From 966d1bf4e19f343e5bd2fd88e787bbc4d70a4d62 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 2 Jun 2026 09:29:23 -0400 Subject: [PATCH 43/50] cleanup additional ffb2d test --- tests/test_FFBbasis2D.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/tests/test_FFBbasis2D.py b/tests/test_FFBbasis2D.py index cd07b2ff84..b21ecb9c11 100644 --- a/tests/test_FFBbasis2D.py +++ b/tests/test_FFBbasis2D.py @@ -168,7 +168,8 @@ def test_bulk_expand_radial_vec(): """ For a given stack of radial vectors (such as from RadialCTFFilters) `expand_radial_vec` should return equivalent - result as calling filter_to_basis_mat on each filter. + result as calling filter_stack_to_basis_mats on filter_stack, + and calling filter_to_basis_mat on each filter. """ L = 32 @@ -176,21 +177,20 @@ def test_bulk_expand_radial_vec(): basis = FFBBasis2D(L, dtype=dtype) pixel_size = 1.23 - filters = [RadialCTFFilter(defocus=d) for d in np.linspace(10000, 15000, 3)] - + filters = RadialCTFFilter(defocus=np.linspace(10000, 15000, 3)) + stack_references = basis.filter_stack_to_basis_mats(filters, pixel_size=pixel_size) references = [basis.filter_to_basis_mat(f, pixel_size=pixel_size) for f in filters] - # from cov code - params = np.empty((len(filters), 7), dtype=dtype) - for i, f in enumerate(filters): - # TODO xxx fix param dump, same as sim/source - params[i] = np.array(f._ctf_params()).flatten() - + # Stack of all filter params + params = filters._ctf_params() + # Stack of filter values _filter_vals = RadialCTFFilter.ctf_formula( basis._filter_pts, pixel_size, *(params.T) ) + # Stack result results = basis.expand_radial_vec(_filter_vals) + # Sequential result results2 = [basis.expand_radial_vec(f)[0] for f in _filter_vals] # expand_radial_vec should be same as itself called sequentially @@ -198,9 +198,16 @@ def test_bulk_expand_radial_vec(): for res, ref in zip(results2, results): np.testing.assert_allclose(res.dense(), ref.dense()) - # and should be equivalent to calling filter_to_basis_mat + # and should be equivalent to calling filter_to_basis_mat (for this radial case) assert len(results) == len(references) for res, ref in zip(results, references): np.testing.assert_allclose( res.dense(), ref.dense(), atol=utest_tolerance(dtype) ) + + # The list from filter_to_basis_mat should be equivalent to list filter_stack_to_basis_mats + assert len(references) == len(stack_references) + for res, ref in zip(stack_references, references): + np.testing.assert_allclose( + res.dense(), ref.dense(), atol=utest_tolerance(dtype) + ) From 61a48394a08d45c13d508e2f82462ff8b3d474ac Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 2 Jun 2026 10:11:56 -0400 Subject: [PATCH 44/50] cleanup some filter eval concerns --- src/aspire/operators/filters.py | 10 +++++++--- src/aspire/source/image.py | 14 +++++++------- src/aspire/source/simulation.py | 3 ++- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index 8aebd4eb50..90039469ee 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -31,14 +31,18 @@ def evaluate_src_filters_on_grid(src, indices=None): grid2d = grid_2d(src.L, indexing="yx", dtype=src.dtype) omega = np.pi * np.vstack((grid2d["x"].flatten(), grid2d["y"].flatten())) - # xxx filter opt (eval in bulk instead of loop here), remove branch # Initialize h as ones to mimic an IdentityFilter when src.filter_stack is None. h = np.ones((omega.shape[-1], len(indices)), dtype=src.dtype) + #### XXX I believe this might be what Tony reported ^ + if src.filter_stack is not None: - for i, filt in enumerate(src.filter_stack): + # Evaluate all filters in bulk + filter_stack_values = src.filter_stack.evaluate( + omega, pixel_size=src.pixel_size + ) + for i, filter_values in enumerate(filter_stack_values): idx_k = np.where(src.filter_indices[indices] == i)[0] if len(idx_k) > 0: - filter_values = filt.evaluate(omega, pixel_size=src.pixel_size) # convert filter_values row vector to column vector and tile broadcast filter_values = filter_values.reshape(-1, 1) h[:, idx_k] = np.tile(filter_values, len(idx_k)) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index 21c2848b4a..bf370279c0 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -792,7 +792,7 @@ def _apply_filters( return im # else evaluate filters - # XXXX broadcast filter eval + # TODO broadcast filter eval for i, filt in enumerate(filters): idx_k = np.where(indices == i)[0] if len(idx_k) > 0: @@ -876,7 +876,6 @@ def downsample(self, L, zero_nyquist=True, centered_fft=True): ) ) - # XXXX sigh ds_factor = self.L / L if self.filter_stack is not None: self.filter_stack = self.filter_stack.scale(ds_factor) @@ -1004,7 +1003,7 @@ def legacy_whiten(self, noise_response=None, delta=None, batch_size=512): if delta is None: delta = np.finfo(np.float32).eps - # # XXX This "should be better" but totally breaks things. + # # TODO This "should be better" but totally breaks things. # # First guess would be to check the strange normalization. # # Can't fix everything at once. # logger.info(f"Extending filter stack by legacy whitening Filter") @@ -1027,7 +1026,6 @@ def phase_flip(self): logger.info("Perform phase flip on source object") if self.filter_stack is not None: - # XXXX unique_xforms = FilterXform(self.filter_stack.sign) logger.info("Adding Phase Flip Xform to end of generation pipeline") @@ -1811,9 +1809,11 @@ def __init__(self, src, indices, memory=None): _unq, _inv = np.unique(_filter_indices, return_inverse=True) # Repack filter_stack self.filter_indices = _inv - self.filter_stack = copy.copy( - src.filter_stack[_unq] - ) # xxx, this might just work by slicing... + # This would work by slicing with current code, + # but if future code mutated the filter objects, that would be a problem. + # Copy for safety/intent. + # Deep copy may be required if future code mutates underlying objects. + self.filter_stack = copy.copy(src.filter_stack[_unq]) else: # Pass through the None case self.filter_stack = src.filter_stack diff --git a/src/aspire/source/simulation.py b/src/aspire/source/simulation.py index 997f7f8890..7c13366865 100644 --- a/src/aspire/source/simulation.py +++ b/src/aspire/source/simulation.py @@ -228,7 +228,8 @@ def _populate_ctf_metadata(self, filter_indices): ] # Unpack the `filter_stack` params across images using `filter_indices` mapping - # Note this does not include the B factor term (unique to ASPIRE?,xxx should we add to star if used?) + # Note this does not include the B factor term (hence the truncation) + # B factor term looks unique to ASPIRE, should we add to star if used? filter_stack_params = self.filter_stack._ctf_params()[ :, :6 ] # params per filter From 2e6dc56ad44b21f8d56fd68d7b6725df302f720b Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 2 Jun 2026 10:13:25 -0400 Subject: [PATCH 45/50] tox cleanup --- src/aspire/operators/filters.py | 2 +- src/aspire/source/image.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index 90039469ee..d333148be5 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -33,7 +33,7 @@ def evaluate_src_filters_on_grid(src, indices=None): # Initialize h as ones to mimic an IdentityFilter when src.filter_stack is None. h = np.ones((omega.shape[-1], len(indices)), dtype=src.dtype) - #### XXX I believe this might be what Tony reported ^ + # ### XXX I believe this might be what Tony reported ^ if src.filter_stack is not None: # Evaluate all filters in bulk diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index bf370279c0..493c12d7c8 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -23,7 +23,6 @@ ) from aspire.noise import LegacyNoiseEstimator, NoiseEstimator, WhiteNoiseEstimator from aspire.operators import ( - ArrayFilter, CTFFilter, Filter, MultiplicativeFilter, @@ -964,7 +963,7 @@ def whiten(self, noise_estimate=None, epsilon=None): logger.info("Whitening source object") whiten_filter = PowerFilter(noise_filter, power=-0.5, epsilon=epsilon) - logger.info(f"Extending filter stack by whitening filter") + logger.info("Extending filter stack by whitening filter") if self.filter_stack is not None: self.filter_stack = MultiplicativeFilter(self.filter_stack, whiten_filter) else: From a0cf7ba0f3f30a690b788678623d23372dafcb0c Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 4 Jun 2026 08:42:25 -0400 Subject: [PATCH 46/50] tox cleanup --- src/aspire/source/image.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index 493c12d7c8..8113ae7073 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -22,12 +22,7 @@ Pipeline, ) from aspire.noise import LegacyNoiseEstimator, NoiseEstimator, WhiteNoiseEstimator -from aspire.operators import ( - CTFFilter, - Filter, - MultiplicativeFilter, - PowerFilter, -) +from aspire.operators import CTFFilter, Filter, MultiplicativeFilter, PowerFilter from aspire.storage import MrcStats, StarFile from aspire.utils import ( Rotation, From 37ad83ed2fc5613e73d71886301b27f1895c28de Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 4 Jun 2026 10:39:27 -0400 Subject: [PATCH 47/50] missing xp --- src/aspire/basis/ffb_2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aspire/basis/ffb_2d.py b/src/aspire/basis/ffb_2d.py index e4db2bf4a3..976031f7d7 100644 --- a/src/aspire/basis/ffb_2d.py +++ b/src/aspire/basis/ffb_2d.py @@ -368,7 +368,7 @@ def expand_radial_vec(self, radial_vec, force_diag=False): wts = wts.reshape(1, n_k, 1) for ell in range(0, self.ell_max + 1): k_max = self.k_max[ell] - basis_vals = np.zeros((n_k, k_max), dtype=self.dtype) + basis_vals = xp.zeros((n_k, k_max), dtype=self.dtype) ind_radial = np.sum(self.k_max[0:ell]) basis_vals[:, 0:k_max] = xp.asarray( radial[ind_radial : ind_radial + k_max] From 07dbf201516e3ac4357824a981786e393edaf381 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 10 Jun 2026 14:08:04 -0400 Subject: [PATCH 48/50] add large covar pytest file and some minor changes for one of the cases --- src/aspire/basis/ffb_2d.py | 2 +- src/aspire/basis/fle_2d.py | 20 +++-- src/aspire/operators/filters.py | 29 ++++++- tests/test_large_covar.py | 132 ++++++++++++++++++++++++++++++++ tox.ini | 3 +- 5 files changed, 175 insertions(+), 11 deletions(-) create mode 100644 tests/test_large_covar.py diff --git a/src/aspire/basis/ffb_2d.py b/src/aspire/basis/ffb_2d.py index 976031f7d7..2ec2b14017 100644 --- a/src/aspire/basis/ffb_2d.py +++ b/src/aspire/basis/ffb_2d.py @@ -295,7 +295,7 @@ def _filter_stack_to_basis_mats(self, f, **kwargs): .reshape(len(f), n_k, n_theta) .astype(self.dtype) ) - h_vals = np.sum(h_vals2d, axis=-1) / n_theta + h_vals = h_vals2d.sum(axis=-1) / n_theta # Represent each 1D functions values in basis h_basis = [ diff --git a/src/aspire/basis/fle_2d.py b/src/aspire/basis/fle_2d.py index ac78051314..12fe173be8 100644 --- a/src/aspire/basis/fle_2d.py +++ b/src/aspire/basis/fle_2d.py @@ -17,6 +17,9 @@ logger = logging.getLogger(__name__) +# Number of elements in filter_to_basis_mat before breaking into batches +MAX_ELEM_COUNT = 2e9 + def _cleanup(): """ @@ -837,12 +840,17 @@ def _filter_stack_to_basis_mats(self, f, **kwargs): omegay = k * xp.sin(theta) omega = 2 * xp.pi * xp.vstack((omegax.flatten("C"), omegay.flatten("C"))) - h_vals2d = ( - xp.asarray(h_fun(omega, pixel_size=pixel_size)) - .reshape(len(f), n_k, n_theta) - .astype(self.dtype, copy=False) - ) - h_vals = xp.sum(h_vals2d, axis=-1) / n_theta + # For high non-radial filter counts at higher pixel counts + # h_vals2d requires a large amount of memory and is too large + # to fit on a GPU + # In the smaller cases, the code attepts using GPU. + h_vals2d = h_fun(omega, pixel_size=pixel_size) + if len(f) * xp.size(omega) < MAX_ELEM_COUNT: + h_vals2d = xp.asarray(h_vals2d) + + h_vals2d = h_vals2d.reshape(len(f), n_k, n_theta).astype(self.dtype, copy=False) + h_vals = h_vals2d.sum(axis=-1) / n_theta + h_vals = xp.asarray(h_vals) # no-op if already fit on GPU h_basis = xp.zeros((len(f), self.count), dtype=self.dtype) # shape gymnastics to get a broadcast with csr A3 diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index d333148be5..b71cef0ce1 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -7,7 +7,7 @@ from aspire import config from aspire.numeric import xp -from aspire.utils import cart2pol, grid_2d, voltage_to_wavelength +from aspire.utils import cart2pol, grid_2d, trange, voltage_to_wavelength logger = logging.getLogger(__name__) @@ -33,7 +33,6 @@ def evaluate_src_filters_on_grid(src, indices=None): # Initialize h as ones to mimic an IdentityFilter when src.filter_stack is None. h = np.ones((omega.shape[-1], len(indices)), dtype=src.dtype) - # ### XXX I believe this might be what Tony reported ^ if src.filter_stack is not None: # Evaluate all filters in bulk @@ -53,6 +52,9 @@ def evaluate_src_filters_on_grid(src, indices=None): # TODO: filters should probably be dtyped... class Filter: + max_size = 4e9 # Max element count for a single evaluate batch + batch_size = 512 # Batch size in elements + def __init__(self, dim=None, radial=False): self.dim = dim self.radial = radial @@ -90,7 +92,28 @@ def evaluate(self, omega, **kwargs): omega, idx = np.unique(omega, return_inverse=True) omega = np.vstack((omega, np.zeros_like(omega))) - h = self._evaluate(omega, **kwargs) + # Batch over large problems that may not fit in memory/GPU + filter_ind_count = len(self) + indices = np.arange(filter_ind_count) + if kwargs.get("indices", None) is not None: + indices = kwargs["indices"] + filter_ind_count = len(indices) + if (omega.shape[-1] * filter_ind_count) > self.max_size: + # Create an empty result array + # For large (2D) problems this will not fit on a GPU + h = np.empty((filter_ind_count, omega.shape[-1]), dtype=np.float64) + # Batch over filter indices + for i in trange( + 0, filter_ind_count, self.batch_size, desc="Filter evaluation" + ): + # Form the subset of filter indices + s = slice(i, min(filter_ind_count, i + self.batch_size)) + kwargs["indices"] = indices[s] + # Evaluate filter for the subset of indices and assign + h[s] = xp.asnumpy(self._evaluate(omega, **kwargs)) + else: + # Compute as one problem + h = self._evaluate(omega, **kwargs) if self.radial: # The reshape and take axis gynmastics work to provide the diff --git a/tests/test_large_covar.py b/tests/test_large_covar.py new file mode 100644 index 0000000000..7ac2196b11 --- /dev/null +++ b/tests/test_large_covar.py @@ -0,0 +1,132 @@ +import os +import socket + +import numpy as np +import pytest + +from aspire.basis import FFBBasis2D, FLEBasis2D +from aspire.covariance import BatchedRotCov2D +from aspire.source import RelionSource + +DTYPES = [ + np.float32, + np.float64, +] + + +@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module") +def dtype(request): + return request.param + + +IMG_SIZES = [ + 128, + 179, +] + + +@pytest.fixture(params=IMG_SIZES, ids=lambda x: f"img_size={x}", scope="module") +def img_size(request): + return request.param + + +BASI = [ + FFBBasis2D, + FLEBasis2D, +] + + +@pytest.fixture(params=BASI, ids=lambda x: f"basis={x}", scope="module") +def basis(request, img_size, dtype): + return request.param(img_size, dtype=dtype) + + +RADIAL = [ + False, + True, +] + + +@pytest.fixture(params=RADIAL, ids=lambda x: f"force_radial={x}", scope="module") +def force_radial(request): + return request.param + + +MOLECULES = { + 10028: "10028/data/shiny_2sets_fixed9.star", + 11618: "11618/data/particles/J43_particles.star", +} + + +@pytest.fixture(params=MOLECULES.keys(), ids=lambda x: f"molecule={x}", scope="module") +def molecule(request): + return request.param + + +def _raw_data_path(): + """ + Attempt getting a working path to raw EMPIAR data location + """ + # Try to populate a path to raw data + raw_data_path = None + + # Check if we're on a known testing platform. + known_hosts = [ + "caf.math.princeton.edu.private", + "decaf.math.princeton.edu.private", + ] + # Default to their expected location + if socket.gethostname() in known_hosts: + raw_data_path = "/scratch/ExperimentalData/raw" + + # Check if a user has provided or overides the location + raw_data_path = os.environ.get("ASPIRE_RAW_DATA_PATH", raw_data_path) + + if raw_data_path is None: + raise RuntimeError("Must provide path to raw data") + if not os.path.exists(raw_data_path): + raise RuntimeError(f"Provided path {raw_data_path} does not exist.") + + return raw_data_path + + +@pytest.fixture(scope="module") +def preprocessed_src(img_size, molecule, force_radial, dtype): + starfile_path = os.path.join(_raw_data_path(), MOLECULES[molecule]) + if not os.path.exists(starfile_path): + raise RuntimeError(f"Expected starfile path {starfile_path} does not exist.") + + src = RelionSource(starfile_path, dtype=dtype) + + # preprocess + src = src.downsample(img_size).cache() + src = src.phase_flip().cache() + src = src.normalize_background().cache() + src = src.whiten().cache() + src = src.invert_contrast() + + # To run radially optimized code we need + # i) radial filters + # ii) set radial expand mode in cov2d + if force_radial: + src.filter_stack = src.filter_stack.to_radial() + + return src + + +@pytest.mark.covar +def test_covar(preprocessed_src, basis, force_radial): + + # To run radially optimized code we need + # i) radial filters + # ii) set radial expand mode in cov2d + expand_method = None # default for cov2d + if force_radial: + assert ( + preprocessed_src.filter_stack.radial + ), "Expected radial filters under `force_radial=True`" + expand_method = "radial" + + cov2d = BatchedRotCov2D(preprocessed_src, basis, expand_method=expand_method) + # smoke test + _ = cov2d.get_covar() diff --git a/tox.ini b/tox.ini index 79d2566183..37b5bcfcbe 100644 --- a/tox.ini +++ b/tox.ini @@ -92,8 +92,9 @@ line_length = 88 testpaths = tests markers = expensive: mark a test as a long running test. + covar: extended (long) tests for covar components scheduled: tests that should only run in the scheduled workflow -addopts = -m "not expensive and not scheduled" +addopts = -m "not expensive and not scheduled and not covar" [gh-actions] python = From 4a2f961248ceae9357d3764ffc07dc4bc72c616d Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 11 Jun 2026 13:50:57 -0400 Subject: [PATCH 49/50] add a top of file docstring to test --- tests/test_large_covar.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/test_large_covar.py b/tests/test_large_covar.py index 7ac2196b11..c698fc0a8c 100644 --- a/tests/test_large_covar.py +++ b/tests/test_large_covar.py @@ -1,3 +1,12 @@ +""" +This file contains a collection of parameterized source setups and calls +to the covariance component applied to real problem sizes and real data. + +It can time/excercise the code paths used for CWF denoising and class +averaging under different basis and CTF filter assumptions (ie radial +optimizations). +""" + import os import socket @@ -115,7 +124,7 @@ def preprocessed_src(img_size, molecule, force_radial, dtype): @pytest.mark.covar -def test_covar(preprocessed_src, basis, force_radial): +def test_covar2d(preprocessed_src, basis, force_radial): # To run radially optimized code we need # i) radial filters From c743b13f3fb81c348d9314716d8e4b3fd9abbe79 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 11 Jun 2026 14:10:00 -0400 Subject: [PATCH 50/50] more cleanup --- src/aspire/basis/ffb_2d.py | 10 +--------- src/aspire/basis/fle_2d.py | 1 - src/aspire/image/xform.py | 2 -- 3 files changed, 1 insertion(+), 12 deletions(-) diff --git a/src/aspire/basis/ffb_2d.py b/src/aspire/basis/ffb_2d.py index 2ec2b14017..96eca8000c 100644 --- a/src/aspire/basis/ffb_2d.py +++ b/src/aspire/basis/ffb_2d.py @@ -68,19 +68,12 @@ def _build(self): ) # Generate radial filter point set for radial optimized eval - # Weights appear a little sensitive to dtype ... + # Weights appear a little sensitive to dtype, otherwise could use self._precomp["gl_nodes"] k_vals, _ = lgwt(self.n_r, 0, 0.5, dtype=np.float64) self._filter_pts = np.pad( 2 * np.pi * k_vals.reshape(1, -1), ((0, 1), (0, 0)) ).astype(self.dtype) - # Ask Joakim about this... - # Why does filter_to_basis_mat hard code lgwt instead of following basis self.kcut - # they are the same by default. - # self._filter_pts = np.pad( - # 2 * np.pi * self._precomp["gl_nodes"].reshape(1, -1), ((0, 1), (0, 0)) - # ) - def _precomp(self): """ Precomute the basis functions on a polar Fourier grid @@ -277,7 +270,6 @@ def _filter_stack_to_basis_mats(self, f, **kwargs): radial = self._precomp["radial"] # get 2D grid in polar coordinate - # Confirm this lgwt call with Joakim (should it follow basis config self.kcut? same by default) k_vals, wts = lgwt(n_k, 0, 0.5, dtype=self.dtype) k, theta = np.meshgrid( k_vals, np.arange(n_theta) * 2 * np.pi / (2 * n_theta), indexing="ij" diff --git a/src/aspire/basis/fle_2d.py b/src/aspire/basis/fle_2d.py index 12fe173be8..e9e7ea54de 100644 --- a/src/aspire/basis/fle_2d.py +++ b/src/aspire/basis/fle_2d.py @@ -881,7 +881,6 @@ def expand_radial_vec(self, radial_vec, **kwargs): coefs = self._radial_convolve_weights(radial_vec) - # check... # Convert to internal FLE indices ordering coefs = coefs[..., self._fb_to_fle_indices] diff --git a/src/aspire/image/xform.py b/src/aspire/image/xform.py index 7582688b16..ac642914a9 100644 --- a/src/aspire/image/xform.py +++ b/src/aspire/image/xform.py @@ -409,8 +409,6 @@ def __init__(self, unique_xforms, indices=None): # A list of references to individual Xform objects, with possibly multiple references pointing to # the same Xform object. - # Crap, im stuck - # self.xforms = [unique_xforms[i] for i in indices] self.xforms = unique_xforms def _indexed_operation(self, im, indices, which):