diff --git a/gallery/experiments/save_simulation_relion_reconstruct.py b/gallery/experiments/save_simulation_relion_reconstruct.py index e52f6bc2a8..a7ddfd3aae 100644 --- a/gallery/experiments/save_simulation_relion_reconstruct.py +++ b/gallery/experiments/save_simulation_relion_reconstruct.py @@ -51,7 +51,7 @@ # that RELION will recover as optics groups. vol = emdb_2660() -ctf_filters = [RadialCTFFilter(defocus=d) for d in defocus] +ctf_filters = RadialCTFFilter(defocus=defocus) # %% @@ -64,7 +64,7 @@ sim = Simulation( n=n_particles, vols=vol, - unique_filters=ctf_filters, + filter_stack=ctf_filters, noise_adder=WhiteNoiseAdder.from_snr(snr), ) sim.save(star_path, overwrite=True) diff --git a/gallery/experiments/simulated_abinitio_pipeline.py b/gallery/experiments/simulated_abinitio_pipeline.py index adf228499f..527e69e417 100644 --- a/gallery/experiments/simulated_abinitio_pipeline.py +++ b/gallery/experiments/simulated_abinitio_pipeline.py @@ -83,17 +83,19 @@ def noise_function(x, y): alpha = 0.1 # Amplitude contrast # Create filters -ctf_filters = [ - RadialCTFFilter(pixel_size, voltage, defocus=d, Cs=2.0, alpha=0.1) - for d in np.linspace(defocus_min, defocus_max, defocus_ct) -] +ctf_filters = RadialCTFFilter( + voltage, + defocus=np.linspace(defocus_min, defocus_max, defocus_ct), + Cs=2.0, + alpha=0.1, +) # Finally create the Simulation src = Simulation( n=num_imgs, vols=og_v, noise_adder=custom_noise, - unique_filters=ctf_filters, + filter_stack=ctf_filters, ) # Downsample diff --git a/gallery/tutorials/aspire_introduction.py b/gallery/tutorials/aspire_introduction.py index b0013335f5..cf67b76d03 100644 --- a/gallery/tutorials/aspire_introduction.py +++ b/gallery/tutorials/aspire_introduction.py @@ -570,10 +570,7 @@ def noise_function(x, y): defocus_ct = 7 # Generate several CTFs. -ctf_filters = [ - RadialCTFFilter(defocus=d) - for d in np.linspace(defocus_min, defocus_max, defocus_ct) -] +ctf_filters = RadialCTFFilter(defocus=np.linspace(defocus_min, defocus_max, defocus_ct)) # %% # Combining into a Simulation @@ -586,7 +583,7 @@ def noise_function(x, y): amplitudes=1, offsets=0, noise_adder=white_noise_adder, - unique_filters=ctf_filters, + filter_stack=ctf_filters, seed=42, ) diff --git a/gallery/tutorials/pipeline_demo.py b/gallery/tutorials/pipeline_demo.py index 0a192ecaf1..104f8f7c03 100644 --- a/gallery/tutorials/pipeline_demo.py +++ b/gallery/tutorials/pipeline_demo.py @@ -66,10 +66,8 @@ defocus_max = 25000 defocus_ct = 7 -ctf_filters = [ - RadialCTFFilter(defocus=d) - for d in np.linspace(defocus_min, defocus_max, defocus_ct) -] +ctf_filters = RadialCTFFilter(defocus=np.linspace(defocus_min, defocus_max, defocus_ct)) + # %% # Initialize Simulation Object @@ -96,7 +94,7 @@ n=2500, # number of projections vols=original_vol, # volume source offsets=0, # Default: images are randomly shifted - unique_filters=ctf_filters, + filter_stack=ctf_filters, noise_adder=WhiteNoiseAdder(var=0.0002), # desired noise variance ).cache() diff --git a/gallery/tutorials/tutorials/cov2d_simulation.py b/gallery/tutorials/tutorials/cov2d_simulation.py index 5d9b469b04..f7368e19cb 100644 --- a/gallery/tutorials/tutorials/cov2d_simulation.py +++ b/gallery/tutorials/tutorials/cov2d_simulation.py @@ -67,10 +67,13 @@ print("Initialize simulation object and CTF filters.") # Create filters -ctf_filters = [ - RadialCTFFilter(voltage, defocus=d, Cs=2.0, alpha=0.1) - for d in np.linspace(defocus_min, defocus_max, defocus_ct) -] +ctf_filters = RadialCTFFilter( + voltage, + defocus=np.linspace(defocus_min, defocus_max, defocus_ct), + Cs=2.0, + alpha=0.1, +) + # Load the map file of a 70S Ribosome print( @@ -89,7 +92,7 @@ L=img_size, n=num_imgs, vols=vols, - unique_filters=ctf_filters, + filter_stack=ctf_filters, offsets=0.0, amplitudes=1.0, dtype=dtype, @@ -109,9 +112,7 @@ h_idx = sim.filter_indices # Evaluate CTF in the 8X8 FB basis -h_ctf_fb = [ - ffbbasis.filter_to_basis_mat(filt, pixel_size=pixel_size) for filt in ctf_filters -] +h_ctf_fb = ffbbasis.filter_stack_to_basis_mats(ctf_filters, pixel_size=pixel_size) # Get clean images from projections of 3D map. print("Apply CTF filters to clean images.") diff --git a/gallery/tutorials/tutorials/cov3d_simulation.py b/gallery/tutorials/tutorials/cov3d_simulation.py index 4e46596a36..58d8e9d5cf 100644 --- a/gallery/tutorials/tutorials/cov3d_simulation.py +++ b/gallery/tutorials/tutorials/cov3d_simulation.py @@ -44,7 +44,7 @@ L=img_size, n=num_imgs, vols=vols, - unique_filters=[RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7)], + filter_stack=RadialCTFFilter(defocus=np.linspace(1.5e4, 2.5e4, 7)), dtype=dtype, ) diff --git a/gallery/tutorials/tutorials/ctf.py b/gallery/tutorials/tutorials/ctf.py index 218612dc8c..a269e3c35b 100644 --- a/gallery/tutorials/tutorials/ctf.py +++ b/gallery/tutorials/tutorials/ctf.py @@ -154,9 +154,8 @@ def generate_example_image(L, noise_variance=0.1): # Construct a range of CTF filters. defoci = [2500, 5000, 10000, 20000] -ctf_filters = [ - RadialCTFFilter(voltage=200, defocus=d, Cs=2.26, alpha=0.07, B=0) for d in defoci -] +ctf_filters = RadialCTFFilter(voltage=200, defocus=defoci, Cs=2.26, alpha=0.07, B=0) + # %% # Generate CTF corrupted Images @@ -334,7 +333,7 @@ def generate_example_image(L, noise_variance=0.1): from aspire.source import Simulation # Create the Source. ``ctf_filters`` are re-used from earlier section. -src = Simulation(L=64, n=4, unique_filters=ctf_filters, pixel_size=1) +src = Simulation(L=64, n=4, filter_stack=ctf_filters, pixel_size=1) src.images[:4].show() # %% diff --git a/gallery/tutorials/tutorials/micrograph_source.py b/gallery/tutorials/tutorials/micrograph_source.py index 26f2429925..4c6cae10dd 100644 --- a/gallery/tutorials/tutorials/micrograph_source.py +++ b/gallery/tutorials/tutorials/micrograph_source.py @@ -181,9 +181,7 @@ # Create our CTF Filter and add it to a list. # This configuration will apply the same CTF to all particles. -ctfs = [ - RadialCTFFilter(voltage=200, defocus=15000, Cs=2.26, alpha=0.07, B=0), -] +ctfs = RadialCTFFilter(voltage=200, defocus=15000, Cs=2.26, alpha=0.07, B=0) src = MicrographSimulation( vol, diff --git a/gallery/tutorials/tutorials/orient3d_simulation.py b/gallery/tutorials/tutorials/orient3d_simulation.py index 142fe177f4..a41f29d70b 100644 --- a/gallery/tutorials/tutorials/orient3d_simulation.py +++ b/gallery/tutorials/tutorials/orient3d_simulation.py @@ -49,10 +49,13 @@ print("Initialize simulation object and CTF filters.") # Create CTF filters -filters = [ - RadialCTFFilter(voltage, defocus=d, Cs=2.0, alpha=0.1) - for d in np.linspace(defocus_min, defocus_max, defocus_ct) -] +filters = RadialCTFFilter( + voltage, + defocus=np.linspace(defocus_min, defocus_max, defocus_ct), + Cs=2.0, + alpha=0.1, +) + # %% # Downsampling @@ -74,7 +77,7 @@ # Create a simulation object with specified filters and the downsampled 3D map print("Use downsampled map to creat simulation object.") sim = Simulation( - L=img_size, n=num_imgs, vols=vols, unique_filters=filters, pixel_size=5, dtype=dtype + L=img_size, n=num_imgs, vols=vols, filter_stack=filters, pixel_size=5, dtype=dtype ) print("Get true rotation angles generated randomly by the simulation object.") diff --git a/gallery/tutorials/tutorials/preprocess_imgs_sim.py b/gallery/tutorials/tutorials/preprocess_imgs_sim.py index 4e70c5f4cc..7c0023b5db 100644 --- a/gallery/tutorials/tutorials/preprocess_imgs_sim.py +++ b/gallery/tutorials/tutorials/preprocess_imgs_sim.py @@ -53,10 +53,13 @@ print("Initialize simulation object and CTF filters.") # Create CTF filters -ctf_filters = [ - RadialCTFFilter(voltage, defocus=d, Cs=2.0, alpha=0.1) - for d in np.linspace(defocus_min, defocus_max, defocus_ct) -] +ctf_filters = RadialCTFFilter( + voltage, + defocus=np.linspace(defocus_min, defocus_max, defocus_ct), + Cs=2.0, + alpha=0.1, +) + # Load the map file of a 70S ribosome and downsample the 3D map to desired image size. print("Load 3D map from mrc file") @@ -73,7 +76,7 @@ L=img_size, n=num_imgs, vols=vols, - unique_filters=ctf_filters, + filter_stack=ctf_filters, noise_adder=noise_adder, pixel_size=pixel_size, ) diff --git a/src/aspire/basis/fb_2d.py b/src/aspire/basis/fb_2d.py index 6477d23e5c..f6c246b8bc 100644 --- a/src/aspire/basis/fb_2d.py +++ b/src/aspire/basis/fb_2d.py @@ -291,6 +291,6 @@ def calculate_bispectrum( def filter_to_basis_mat(self, *args, **kwargs): """ - See `SteerableBasis2D.filter_to_basis_mat`. + See `SteerableBasis2D.filter_stack_to_basis_mat`. """ return super().filter_to_basis_mat(*args, **kwargs) diff --git a/src/aspire/basis/ffb_2d.py b/src/aspire/basis/ffb_2d.py index 2f871af9f5..96eca8000c 100644 --- a/src/aspire/basis/ffb_2d.py +++ b/src/aspire/basis/ffb_2d.py @@ -67,6 +67,13 @@ def _build(self): self._precomp["gl_nodes"] ) + # Generate radial filter point set for radial optimized eval + # Weights appear a little sensitive to dtype, otherwise could use self._precomp["gl_nodes"] + k_vals, _ = lgwt(self.n_r, 0, 0.5, dtype=np.float64) + self._filter_pts = np.pad( + 2 * np.pi * k_vals.reshape(1, -1), ((0, 1), (0, 0)) + ).astype(self.dtype) + def _precomp(self): """ Precomute the basis functions on a polar Fourier grid @@ -236,15 +243,17 @@ def _evaluate_t(self, x): return xp.asnumpy(v) - def filter_to_basis_mat(self, f, **kwargs): + def _filter_stack_to_basis_mats(self, f, **kwargs): """ See `SteerableBasis2D.filter_to_basis_mat`. """ - # Note 'method' and 'truncate' not relevant for this optimized FFB code. - if kwargs.get("method", None) is not None: + # Note 'method' and 'truncate' not relevant for this specific FFB code. + # Method `radial` should have already been diverted. + expand_method = kwargs.get("expand_method", None) + if expand_method is not None: raise NotImplementedError( - "`FFBBasis2D.filter_to_basis_mat` method {method} not supported." - " Use `method=None`." + f"`FFBBasis2D.filter_to_basis_mat` expand_method '{expand_method}' not supported." + " Use `expand_method=None`." ) pixel_size = kwargs.get("pixel_size", None) @@ -271,28 +280,103 @@ def filter_to_basis_mat(self, f, **kwargs): omegay = k * np.sin(theta) omega = 2 * np.pi * np.vstack((omegax.flatten("C"), omegay.flatten("C"))) + # This should return either a single 2d array, or stack of 2d arrays + # Reshape singleton to stack of 1. h_vals2d = ( - h_fun(omega, pixel_size=pixel_size).reshape(n_k, n_theta).astype(self.dtype) + h_fun(omega, pixel_size=pixel_size) + .reshape(len(f), n_k, n_theta) + .astype(self.dtype) ) - h_vals = np.sum(h_vals2d, axis=1) / n_theta + h_vals = h_vals2d.sum(axis=-1) / n_theta - # Represent 1D function values in basis - h_basis = BlkDiagMatrix.empty(2 * self.ell_max + 1, dtype=self.dtype) + # Represent each 1D functions values in basis + h_basis = [ + BlkDiagMatrix.empty(2 * self.ell_max + 1, dtype=self.dtype) for _ in h_vals + ] ind_ell = 0 + # Reshapes for broadcasting + k_vals = k_vals.reshape(n_k, 1) + wts = wts.reshape(n_k, 1) + h_vals = h_vals.reshape(len(f), n_k, 1) for ell in range(0, self.ell_max + 1): k_max = self.k_max[ell] - rmat = 2 * k_vals.reshape(n_k, 1) * self.r0[ell][0:k_max].T - basis_vals = np.zeros_like(rmat) + basis_vals = np.zeros((n_k, k_max), dtype=self.dtype) ind_radial = np.sum(self.k_max[0:ell]) basis_vals[:, 0:k_max] = radial[ind_radial : ind_radial + k_max].T - h_basis_vals = basis_vals * h_vals.reshape(n_k, 1) - h_basis_ell = basis_vals.T @ ( - h_basis_vals * k_vals.reshape(n_k, 1) * wts.reshape(n_k, 1) - ) - h_basis[ind_ell] = h_basis_ell + h_basis_vals = basis_vals * h_vals + h_basis_ell = basis_vals.T @ (h_basis_vals * k_vals * wts) + + # loop over assignment blocks. + for i in range(len(f)): + h_basis[i][ind_ell] = h_basis_ell[i] ind_ell += 1 if ell > 0: - h_basis[ind_ell] = h_basis[ind_ell - 1] + for i in range(len(f)): + h_basis[i][ind_ell] = h_basis[i][ind_ell - 1] ind_ell += 1 return h_basis + + def filter_to_basis_mat(self, f, **kwargs): + """ + See `SteerableBasis2D.filter_stack_to_basis_mats`. + """ + if len(f) != 1: + raise RuntimeError("Unexpected filter length.") + return self._filter_stack_to_basis_mats(f, **kwargs)[0] + + def expand_radial_vec(self, radial_vec, force_diag=False): + """ + Expands radial vector or stack of vetors `radial_vec` to basis matrix. + + :param radial_vec: Array holding radial vector, + shaped (n_radial_pts) or (n_vectors, n_radial_pts) + :force_diag: Optionally flush off-diagonal elements to zero and return `DiagMatrix` + :return: List of `BlkDiagMatrix`, or list of `DiagMatrix` + """ + # Convert vector to (1,...) + if radial_vec.ndim == 1: + radial_vec = radial_vec.reshape(1, *radial_vec.shape) + # Optionally transfer to GPU + radial_vec = xp.asarray(radial_vec) + + # Set same dimensions as basis object + n_k = self.n_r + radial = self._precomp["radial"] + + k_vals = xp.asarray(self._precomp["gl_nodes"]) + wts = xp.asarray(self._precomp["gl_weights"]) + + # Represent 1D function values in basis + h_basis = [ + BlkDiagMatrix.empty(2 * self.ell_max + 1, dtype=self.dtype) + for _ in radial_vec + ] + + ind_ell = 0 + # Reshapes for broadcasting + radial_vec = radial_vec.reshape(len(h_basis), n_k, 1) + k_vals = k_vals.reshape(1, n_k, 1) + wts = wts.reshape(1, n_k, 1) + for ell in range(0, self.ell_max + 1): + k_max = self.k_max[ell] + basis_vals = xp.zeros((n_k, k_max), dtype=self.dtype) + ind_radial = np.sum(self.k_max[0:ell]) + basis_vals[:, 0:k_max] = xp.asarray( + radial[ind_radial : ind_radial + k_max] + ).T + h_basis_vals = basis_vals * radial_vec + h_basis_ell = basis_vals.T @ (h_basis_vals * k_vals * wts) + h_basis_ell = xp.asnumpy(h_basis_ell) + for _filter in range(len(radial_vec)): + _tmp = h_basis[_filter][ind_ell] = h_basis_ell[_filter] + if ell > 0: + h_basis[_filter][ind_ell + 1] = _tmp + if _filter == len(radial_vec) - 1: + ind_ell += 1 + if ell > 0: + ind_ell += 1 + if force_diag: + h_basis = [h.diag() for h in h_basis] + + return h_basis diff --git a/src/aspire/basis/fle_2d.py b/src/aspire/basis/fle_2d.py index cc6ea04672..e9e7ea54de 100644 --- a/src/aspire/basis/fle_2d.py +++ b/src/aspire/basis/fle_2d.py @@ -17,6 +17,9 @@ logger = logging.getLogger(__name__) +# Number of elements in filter_to_basis_mat before breaking into batches +MAX_ELEM_COUNT = 2e9 + def _cleanup(): """ @@ -298,7 +301,7 @@ def _compute_nufft_points(self): nodes = ( self.greatest_lambda - self.smallest_lambda ) * nodes + self.smallest_lambda - nodes = nodes.reshape(self.num_radial_nodes, 1) + self.nodes = nodes.reshape(self.num_radial_nodes, 1) radius = self.nres / 2 h = 1 / radius @@ -314,9 +317,14 @@ def _compute_nufft_points(self): ) grid_xy[0] = xp.cos(phi) # x grid_xy[1] = xp.sin(phi) # y - grid_xy[:] = grid_xy * nodes * h + grid_xy[:] = grid_xy * self.nodes * h self.grid_xy = grid_xy.reshape(2, -1) + # Generate radial filter point set for radial optimized eval + self._filter_pts = ( + np.pad(xp.asnumpy(self.nodes).reshape(1, -1), ((0, 1), (0, 0))) * self.h + ) + def _build_interpolation_matrix(self): """ Create the matrix used in the third step of evaluate_t() and the first step of evaluate() @@ -738,7 +746,7 @@ def radial_convolve(self, coefs, radial_img): _coefs = coefs[k, :] z = self._step1_t(radial_img) b = self._step2_t(z) - weights = self._radial_convolve_weights(b) + weights = self._radial_convolve_weights(b[..., 0]) b = weights / (self.h**2) b = b.reshape(self.count) coefs_conv[k, :] = (self.c2r @ (b * (self.r2c @ _coefs).flatten())).real @@ -752,33 +760,63 @@ def radial_convolve(self, coefs, radial_img): def _radial_convolve_weights(self, b): """ Helper function for step 3 of convolving with a radial function. + + :param b: Radial vector or stack of radial vectors """ - b = xp.squeeze(b) - b = xp.array(b) # implies copy + # Developer note, this is equivalent `fle2d.expand_radial_vec` up to shapes. + # Convert vector to (1,...) + if b.ndim == 1: + b = b.reshape(1, *b.shape) + b = xp.asarray(b) # implies copy if self.num_interp > self.num_radial_nodes: - b = fft.dct(b, axis=0, type=2) / (2 * self.num_radial_nodes) - bz = xp.zeros(b.shape) - b = xp.concatenate((b, bz), axis=0) - b = fft.idct(b, axis=0, type=2) * 2 * b.shape[0] - a = xp.zeros(self.count, dtype=np.float64) - y = [None] * (self.ell_p_max + 1) - for i in range(self.ell_p_max + 1): - y[i] = (self.A3[i] @ b[:, 0]).flatten() + b = fft.dct(b, axis=1, type=2) / (2 * self.num_radial_nodes) + bz = xp.zeros(b.shape, dtype=self.dtype) + b = xp.concatenate((b, bz), axis=1) + b = fft.idct(b, axis=1, type=2) * 2 * b.shape[1] + a = xp.zeros((b.shape[0], self.count), dtype=self.dtype) + for i in range(self.ell_p_max + 1): - a[self.idx_list[i]] = y[i] + # Wierd mul transpose forced by A3 being CSR. + # Can't reshape A3, but can broadcast over dims of b. + # T b first to yield (cnt, num_img) and T result back to (num_img, cnt) + a[:, self.idx_list[i]] = (self.A3[i] @ b[:, :].T).T - return a.flatten() + return a def filter_to_basis_mat(self, f, **kwargs): """ - See `SteerableBasis2D.filter_to_basis_mat`. + See `SteerableBasis2D.filter_stack_to_basis_mats`. + """ + if len(f) != 1: + raise RuntimeError("Unexpected filter length.") + return self._filter_stack_to_basis_mats(f, **kwargs) + + def _filter_stack_to_basis_mats(self, f, **kwargs): """ - # Note 'method' and 'truncate' not relevant for this optimized FLE code. - if kwargs.get("method", None) is not None: + See `SteerableBasis2D.filter_to_basis_mat`. + + This code implements a radially averaged filter, returning a `DiagMatrix`. + It is meant to be used with non radial filter inputs. + For radial filters, an optimized code path may be available via `expand_method='radial'`. + """ + # Note 'expand_method' and 'truncate' not relevant for this optimized FLE code. + expand_method = kwargs.get("expand_method", None) + # This first check is to guide users to either + # the optimized purely radial (1d) calc, + # or the radially averaged (2d) calc. + # Errors requesting a (1d) calc on a non radial filter. + # Errors on unknown `expand_method`. + if expand_method == "radial" and not f.radial: raise NotImplementedError( - "`FLEBasis2D.filter_to_basis_mat` method {method} not supported." - " Use `method=None`." + f"`FLEBasis2D.filter_to_basis_mat` expand_method '{expand_method}' not supported for non radial {f}." + " Convert filter to radial for optimized radial expansion, or use `expand_method=None` for a radially averaged approximation." ) + elif expand_method is not None: + raise NotImplementedError( + f"`FLEBasis2D.filter_to_basis_mat` expand_method '{expand_method}' not supported." + " Try `expand_method=None` or 'radial'." + ) + pixel_size = kwargs.get("pixel_size", None) # Get the filter's evaluate function. @@ -802,19 +840,68 @@ def filter_to_basis_mat(self, f, **kwargs): omegay = k * xp.sin(theta) omega = 2 * xp.pi * xp.vstack((omegax.flatten("C"), omegay.flatten("C"))) - h_vals2d = ( - xp.asarray(h_fun(omega, pixel_size=pixel_size)) - .reshape(n_k, n_theta) - .astype(self.dtype, copy=False) - ) - h_vals = xp.sum(h_vals2d, axis=1) / n_theta - - h_basis = xp.zeros(self.count, dtype=self.dtype) - # For now we just need to handle 1D (stack of one ctf) + # For high non-radial filter counts at higher pixel counts + # h_vals2d requires a large amount of memory and is too large + # to fit on a GPU + # In the smaller cases, the code attepts using GPU. + h_vals2d = h_fun(omega, pixel_size=pixel_size) + if len(f) * xp.size(omega) < MAX_ELEM_COUNT: + h_vals2d = xp.asarray(h_vals2d) + + h_vals2d = h_vals2d.reshape(len(f), n_k, n_theta).astype(self.dtype, copy=False) + h_vals = h_vals2d.sum(axis=-1) / n_theta + h_vals = xp.asarray(h_vals) # no-op if already fit on GPU + + h_basis = xp.zeros((len(f), self.count), dtype=self.dtype) + # shape gymnastics to get a broadcast with csr A3 + h_vals = h_vals.T for j in range(self.ell_p_max + 1): - h_basis[self.idx_list[j]] = self.A3[j] @ h_vals + h_basis[:, self.idx_list[j]] = (self.A3[j] @ h_vals).T # Convert from internal FLE ordering to FB convention - h_basis = h_basis[self._fle_to_fb_indices] + h_basis = h_basis[:, self._fle_to_fb_indices] + # who needs this as a list? + + coefs = xp.asnumpy(h_basis) + if len(coefs) > 1: + coefs = [DiagMatrix(c) for c in coefs] + else: + coefs = DiagMatrix(coefs.flatten()) + + return coefs + + def expand_radial_vec(self, radial_vec, **kwargs): + """ + Expands radial vector or stack of vectors `radial_vec` to basis matrix. + + :param radial_vec: Array holding radial vector, + shaped (n_radial_pts) or (n_vectors, n_radial_pts) + :return: List of `DiagMatrix` + """ + + coefs = self._radial_convolve_weights(radial_vec) + + # Convert to internal FLE indices ordering + coefs = coefs[..., self._fb_to_fle_indices] + + coefs = xp.asnumpy(coefs) + + # who needs this as a list? + if len(coefs) > 1: + coefs = [DiagMatrix(c) for c in coefs] + else: + coefs = DiagMatrix(coefs.flatten()) + + return coefs + + def _radial_filter_to_vals(self, f, **kwargs): + """ + Unpack filter attributes and pass to Yunpeng code. + """ + + pts = xp.asnumpy(self.nodes) + + _filter_pts = np.pad(pts.reshape(1, -1), ((0, 1), (0, 0))) * self.h + h_vals = f.evaluate(_filter_pts, **kwargs) - return DiagMatrix(xp.asnumpy(h_basis)) + return h_vals diff --git a/src/aspire/basis/fpswf_2d.py b/src/aspire/basis/fpswf_2d.py index 6eac1a91db..e1215103d0 100644 --- a/src/aspire/basis/fpswf_2d.py +++ b/src/aspire/basis/fpswf_2d.py @@ -369,6 +369,6 @@ def _pswf_integration(self, images_nufft): def filter_to_basis_mat(self, *args, **kwargs): """ - See `SteerableBasis2D.filter_to_basis_mat`. + See `SteerableBasis2D.filter_stack_to_basis_mat`. """ return super().filter_to_basis_mat(*args, **kwargs) diff --git a/src/aspire/basis/pswf_2d.py b/src/aspire/basis/pswf_2d.py index c9795ec1bc..d03c483a23 100644 --- a/src/aspire/basis/pswf_2d.py +++ b/src/aspire/basis/pswf_2d.py @@ -400,6 +400,6 @@ def _pswf_2d_minor_computations(self, big_n, n, bandlimit, phi_approximate_error def filter_to_basis_mat(self, *args, **kwargs): """ - See `SteerableBasis2D.filter_to_basis_mat`. + See `SteerableBasis2D.filter_stack_to_basis_mat`. """ return super().filter_to_basis_mat(*args, **kwargs) diff --git a/src/aspire/basis/steerable.py b/src/aspire/basis/steerable.py index b5051065da..f9474e81d4 100644 --- a/src/aspire/basis/steerable.py +++ b/src/aspire/basis/steerable.py @@ -6,7 +6,7 @@ from aspire.basis import Basis, Coef, ComplexCoef from aspire.operators import BlkDiagMatrix -from aspire.utils import LogFilterByCount, complex_type, real_type, trange +from aspire.utils import LogFilterByCount, complex_type, real_type, tqdm, trange logger = logging.getLogger(__name__) @@ -476,17 +476,60 @@ def to_complex(self, coef): return ComplexCoef(self, complex_coef) + def filter_stack_to_basis_mats(self, f, **kwargs): + """ + Convert a filter stack into a list of basis operator representations. + + See `_filter_stack_to_basis_mats` and `filter_to_basis_mat` + here and in subclasses for available **kwargs. + + :param f: `Filter` object, for example a `CTFFilter`. + + :return: List containing representations of Filter as `basis` operators. + Return type of list elements will be based on the class's `matrix_type`, + typically `BlkDiagMatrix` or `DiagMatrix`. + """ + + # does the basis have optimized expansion for radial vectors? + optimized_expand = callable(getattr(self.__class__, "expand_radial_vec", None)) + # is the filter radial? + filter_is_radial = f.radial + # did user request the special radial expansion method? + radial_method = kwargs.get("expand_method", None) == "radial" + + if optimized_expand and filter_is_radial and radial_method: + # kwargs supports passing through pixel_size + h_vals = self._radial_filter_to_vals(f, **kwargs) + res = self.expand_radial_vec(h_vals) + else: + # use generic (legacy) filter path/code (may return DiagMatrix) + res = self._filter_stack_to_basis_mats(f, **kwargs) + + return res + + def _filter_stack_to_basis_mats(self, f, **kwargs): + """ + Helper function for sequentially evaluating filters in a basis. + + This is a crude fall back for basis that do not provide optimized `filter_stack_to_basis_mats`. + """ + basis_mats = [None] * len(f) + for i, _f in enumerate(tqdm(f, desc="Converting filters to basis mats")): + basis_mats[i] = self.filter_to_basis_mat(_f, **kwargs) + return basis_mats + # `abstractmethod` enforces when a new subclass of # `SteerableBasis2D` is created that this method is explicitly # implemented. This is intended to encourage future basis authors # to consider this method for their application. + # When possible, they should prefer to create an optimized _filter_stack_to_basis_mats. @abc.abstractmethod - def filter_to_basis_mat(self, f, method="evaluate_t", truncate=True, **kwargs): + def filter_to_basis_mat(self, f, expand_method=None, truncate=True, **kwargs): """ Convert a filter into a basis operator representation. :param f: `Filter` object, usually a `CTFFilter`. - :param method: `evaluate_t` or `expand`. + :param expand_method: `evaluate_t` or `expand`. Default `None` uses `evaluate_t`. :param truncate: Optionally, truncate dense matrix to BlkDiagMatrix. Defaults to True. @@ -494,13 +537,13 @@ def filter_to_basis_mat(self, f, method="evaluate_t", truncate=True, **kwargs): Return type will be based on the class's `matrix_type`. """ # evaluate_t is not as accurate, but much much faster... - if method == "evaluate_t": - expand_method = self.evaluate_t - elif method == "expand": - expand_method = self.expand + if expand_method == "evaluate_t" or expand_method is None: + expand_fun = self.evaluate_t + elif expand_method == "expand": + expand_fun = self.expand else: raise NotImplementedError( - "`filter_to_basis_mat` method {method} not supported." + f"`filter_to_basis_mat` expand_method '{expand_method}' not supported." " Try `evaluate_t` or `expand`." ) @@ -520,7 +563,7 @@ def filter_to_basis_mat(self, f, method="evaluate_t", truncate=True, **kwargs): with LogFilterByCount(logger, 1): for i in trange(self.count): try: - filt[i] = expand_method(img[i].filter(f)).asnumpy()[0] + filt[i] = expand_fun(img[i].filter(f)).asnumpy()[0] except Exception: logger.warning( f"Failed to expand basis vector {i} after filter {f}." diff --git a/src/aspire/covariance/covar2d.py b/src/aspire/covariance/covar2d.py index d0bb73b57e..74fc63e6e1 100644 --- a/src/aspire/covariance/covar2d.py +++ b/src/aspire/covariance/covar2d.py @@ -1,10 +1,12 @@ import logging +from time import perf_counter import numpy as np from numpy.linalg import eig, inv from scipy.linalg import solve, sqrtm from aspire.basis import Coef, FFBBasis2D +from aspire.numeric import xp from aspire.operators import BlkDiagMatrix, DiagMatrix from aspire.optimization import conj_grad, fill_struct from aspire.utils import make_symmat @@ -520,7 +522,9 @@ class BatchedRotCov2D(RotCov2D): scaling up to a larger value such as 8192 may yield better performance. """ - def __init__(self, src, basis=None, batch_size=512): + def __init__( + self, src, basis=None, expand_method=None, force_diag=False, batch_size=512 + ): self.src = src self.basis = basis self.batch_size = batch_size @@ -531,6 +535,8 @@ def __init__(self, src, basis=None, batch_size=512): self.A_mean = None self.A_covar = None self.M_covar = None + self.expand_method = expand_method + self.force_diag = force_diag self._build() @@ -540,20 +546,57 @@ def _build(self): if self.basis is None: self.basis = FFBBasis2D((src.L, src.L), dtype=self.dtype) - if not src.unique_filters: + if src.filter_stack is None: logger.info("CTF filters are not included in Cov2D denoising") # set all CTF filters to an identity filter self.ctf_idx = np.zeros(src.n, dtype=int) self.ctf_basis = [self._identity_mat()] else: - logger.info("Represent CTF filters in basis") - unique_filters = src.unique_filters + logger.info("Representing filters in basis") self.ctf_idx = src.filter_indices - self.ctf_basis = [ - self.basis.filter_to_basis_mat(f, pixel_size=self.src.pixel_size) - for f in unique_filters - ] + self.ctf_basis = self.filters_to_basis_mats() + logger.info("Representing filters in basis complete") + + def filters_to_basis_mats(self): + """ + Dispatch between various methods for converting filter stacks to basis matrices. + """ + # Does the basis provide radially optimized expansion? + optimized_expand = callable( + getattr(self.basis.__class__, "expand_radial_vec", None) + ) + + # Are the filters radial? + if self.src.filter_stack.radial: + logger.info("Found radial filter stack.") + else: + logger.info("Found non-radial filter stack.") + + if optimized_expand and self.src.filter_stack.radial: + logger.info("Using optimized `basis.expand_radial_vec`.") + return self._radial_filter_stack_to_basis_mats() + else: + logger.info("Using basis.filter_stack_to_basis_mats.") + return self.basis.filter_stack_to_basis_mats( + self.src.filter_stack, + pixel_size=self.src.pixel_size, + expand_method=self.expand_method, + ) + + def _radial_filter_stack_to_basis_mats(self): + logger.info("Generating filter eval points") + _filter_pts = self.basis._filter_pts + # if we have many filters, might be worth trip to GPU + if len(self.src.filter_stack) >= 2048: + _filter_pts = xp.asarray(_filter_pts) + + _filter_vals = self.src.filter_stack.evaluate( + _filter_pts, pixel_size=self.src.pixel_size + ) + + logger.info("Computing basis radial expansion") + return self.basis.expand_radial_vec(_filter_vals, force_diag=self.force_diag) def _calc_rhs(self): src = self.src @@ -568,11 +611,20 @@ def _calc_rhs(self): b_covar = BlkDiagMatrix.zeros(self.basis.blk_diag_cov_shape, dtype=self.dtype) + cumulative_image_fetch_time = 0 + cumulative_eval_t_time = 0 for start in range(0, src.n, self.batch_size): batch = np.arange(start, min(start + self.batch_size, src.n)) + t0 = perf_counter() im = src.images[batch[0] : batch[0] + len(batch)] + t1 = perf_counter() + cumulative_image_fetch_time += t1 - t0 + + t0 = perf_counter() coef = basis.evaluate_t(im).asnumpy() + t1 = perf_counter() + cumulative_eval_t_time += t1 - t0 for k in np.unique(ctf_idx[batch]): coef_k = coef[ctf_idx[batch] == k] @@ -602,6 +654,8 @@ def _calc_rhs(self): self.b_mean = b_mean self.b_covar = b_covar + logger.info(f"cumulative_image_fetch_time {cumulative_image_fetch_time}") + logger.info(f"cumulative_eval_t_time {cumulative_eval_t_time}") def _calc_op(self): src = self.src @@ -611,7 +665,12 @@ def _calc_op(self): A_mean = BlkDiagMatrix.zeros(self.basis.blk_diag_cov_shape, self.dtype) A_covar = [None for _ in ctf_basis] - M_covar = self._zeros_mat() + # If we're given all diag filters, A_covar and M can be diag + if all(isinstance(c, DiagMatrix) for c in ctf_basis): + M_covar = DiagMatrix.zeros(self.basis.count, dtype=self.dtype) + # otherwise, take the default `matrix_type` for the basis + else: + M_covar = self._zeros_mat() for k in np.unique(ctf_idx): weight = float(np.count_nonzero(ctf_idx == k) / src.n) @@ -677,12 +736,13 @@ def _noise_correct_covar_rhs(self, b_covar, b_noise, noise_var, shrinker): def _solve_covar(self, A_covar, b_covar, M, covar_est_opt): method = self._solve_covar_cg - if self.basis.matrix_type == DiagMatrix: + if all(isinstance(a, DiagMatrix) for a in A_covar): method = self._solve_covar_direct return method(A_covar, b_covar, M, covar_est_opt) def _solve_covar_direct(self, A_covar, b_covar, M, covar_est_opt): + t0 = perf_counter() # A_covar is a list of DiagMatrix, representing each ctf in self.basis. # b_covar is a BlkDiagMatrix # M is sum of weighted A squared. @@ -700,9 +760,13 @@ def _solve_covar_direct(self, A_covar, b_covar, M, covar_est_opt): # in Yunpeng's code. res = Minv @ b_covar @ Minv + t1 = perf_counter() + logger.info(f"_solve_covar_direct elapsed: {t1-t0}") return res def _solve_covar_cg(self, A_covar, b_covar, M, covar_est_opt): + t0 = perf_counter() + def precond_fun(S, x): p = np.size(S, 0) assert np.size(x) == p * p, "The sizes of S and x are not consistent." @@ -736,6 +800,8 @@ def apply(A, x): ) covar_coef[ell] = covar_coef_ell.reshape(p, p) + t1 = perf_counter() + logger.info(f"_solve_covar_cg elapsed: {t1-t0}") return covar_coef def get_mean(self): diff --git a/src/aspire/image/xform.py b/src/aspire/image/xform.py index 20a2e62c5c..ac642914a9 100644 --- a/src/aspire/image/xform.py +++ b/src/aspire/image/xform.py @@ -314,6 +314,15 @@ def _forward(self, im, indices): def __str__(self): return f"FilterXform ({self.filter})" + def __len__(self): + """ + Return the len of the underlying filter stack. + """ + return len(self.filter) + + def __getitem__(self, item): + return FilterXform(self.filter[item]) + class Add(Xform): """ @@ -400,7 +409,7 @@ def __init__(self, unique_xforms, indices=None): # A list of references to individual Xform objects, with possibly multiple references pointing to # the same Xform object. - self.xforms = [unique_xforms[i] for i in indices] + self.xforms = unique_xforms def _indexed_operation(self, im, indices, which): """ @@ -420,7 +429,7 @@ def _indexed_operation(self, im, indices, which): im_data = np.empty_like(im.asnumpy()) # For each individual transformation - for i, xform in enumerate(self.unique_xforms): + for i in range(len(self.unique_xforms)): # Get the indices corresponding to that transformation idx = np.flatnonzero(self.indices == i) # For the incoming Image object, find out which transformation indices are applicable @@ -429,7 +438,7 @@ def _indexed_operation(self, im, indices, which): im_data_indices = np.flatnonzero(np.isin(indices, idx)) # Apply the transformation to the selected indices in the Image object if len(im_data_indices) > 0: - fn_handle = getattr(xform, which) + fn_handle = getattr(self.unique_xforms[i], which) im_data[im_data_indices] = fn_handle(im[im_data_indices]).asnumpy() return Image(im_data, pixel_size=im.pixel_size) diff --git a/src/aspire/numeric/cupy.py b/src/aspire/numeric/cupy.py index f02aa393c2..2853cd7b7a 100644 --- a/src/aspire/numeric/cupy.py +++ b/src/aspire/numeric/cupy.py @@ -1,4 +1,5 @@ import cupy as cp +import numpy as np class Cupy: @@ -7,3 +8,15 @@ def __getattr__(self, item): Catch-all method to to allow a straight pass-through of any attribute that is not supported above. """ return getattr(cp, item) + + @staticmethod + def atleast_1d(x): + """ + Provide an agnostic `atleast_1d`. + + Returns same type as input. + """ + _fn = np.atleast_1d + if cp and isinstance(x, cp.ndarray): + _fn = cp.atleast_1d + return _fn(x) diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index 3c3f362fc3..b71cef0ce1 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -6,7 +6,8 @@ from scipy.interpolate import RegularGridInterpolator from aspire import config -from aspire.utils import cart2pol, grid_2d, voltage_to_wavelength +from aspire.numeric import xp +from aspire.utils import cart2pol, grid_2d, trange, voltage_to_wavelength logger = logging.getLogger(__name__) @@ -30,14 +31,20 @@ def evaluate_src_filters_on_grid(src, indices=None): grid2d = grid_2d(src.L, indexing="yx", dtype=src.dtype) omega = np.pi * np.vstack((grid2d["x"].flatten(), grid2d["y"].flatten())) - # Initialize h as ones to mimic an IdentityFilter when src.unique_filters is None. + # Initialize h as ones to mimic an IdentityFilter when src.filter_stack is None. h = np.ones((omega.shape[-1], len(indices)), dtype=src.dtype) - for i, filt in enumerate(src.unique_filters): - idx_k = np.where(src.filter_indices[indices] == i)[0] - if len(idx_k) > 0: - filter_values = filt.evaluate(omega, pixel_size=src.pixel_size) - h[:, idx_k] = np.column_stack((filter_values,) * len(idx_k)) + if src.filter_stack is not None: + # Evaluate all filters in bulk + filter_stack_values = src.filter_stack.evaluate( + omega, pixel_size=src.pixel_size + ) + for i, filter_values in enumerate(filter_stack_values): + idx_k = np.where(src.filter_indices[indices] == i)[0] + if len(idx_k) > 0: + # convert filter_values row vector to column vector and tile broadcast + filter_values = filter_values.reshape(-1, 1) + h[:, idx_k] = np.tile(filter_values, len(idx_k)) h = np.reshape(h, grid2d["x"].shape + (len(indices),)) return h @@ -45,6 +52,9 @@ def evaluate_src_filters_on_grid(src, indices=None): # TODO: filters should probably be dtyped... class Filter: + max_size = 4e9 # Max element count for a single evaluate batch + batch_size = 512 # Batch size in elements + def __init__(self, dim=None, radial=False): self.dim = dim self.radial = radial @@ -82,10 +92,40 @@ def evaluate(self, omega, **kwargs): omega, idx = np.unique(omega, return_inverse=True) omega = np.vstack((omega, np.zeros_like(omega))) - h = self._evaluate(omega, **kwargs) + # Batch over large problems that may not fit in memory/GPU + filter_ind_count = len(self) + indices = np.arange(filter_ind_count) + if kwargs.get("indices", None) is not None: + indices = kwargs["indices"] + filter_ind_count = len(indices) + if (omega.shape[-1] * filter_ind_count) > self.max_size: + # Create an empty result array + # For large (2D) problems this will not fit on a GPU + h = np.empty((filter_ind_count, omega.shape[-1]), dtype=np.float64) + # Batch over filter indices + for i in trange( + 0, filter_ind_count, self.batch_size, desc="Filter evaluation" + ): + # Form the subset of filter indices + s = slice(i, min(filter_ind_count, i + self.batch_size)) + kwargs["indices"] = indices[s] + # Evaluate filter for the subset of indices and assign + h[s] = xp.asnumpy(self._evaluate(omega, **kwargs)) + else: + # Compute as one problem + h = self._evaluate(omega, **kwargs) if self.radial: - h = np.take(h, idx) + # The reshape and take axis gynmastics work to provide the + # legacy idx taking functionality for both singleton and + # stack cases. + # reshape (stack, vals) + h = h.reshape(len(self), -1) + # keep stack, take along vals axis + h = np.take(h, idx, axis=-1) + # squeeze off stack dim from singleton case (preserves legacy behavior) + if h.shape[0] == 1: # avoid error when len(h)>1 + h = np.squeeze(h, axis=0) return h @@ -112,7 +152,8 @@ def scale(self, c=1): """ return ScaledFilter(self, c) - @lru_cache(maxsize=config["cache"]["filter_cache_size"].get()) # noqa: B019 + # this cache no longer appears to work? now unhashable? (fine but why not before?) + # @lru_cache(maxsize=config["cache"]["filter_cache_size"].get()) # noqa: B019 def evaluate_grid(self, L, *args, dtype=np.float32, **kwargs): """ Generates a two dimensional grid with prescribed dtype, @@ -144,6 +185,22 @@ def sign(self): """ return LambdaFilter(self, np.sign) + def __len__(self): + """ + Default filters are length 1. + + Some filters (eg CTFFilter) provide additional optimizations for stacks. + """ + return 1 + + def _ctf_params(self): + """ + Return n_filters-by-n_param array from prior filter. + """ + raise NotImplementedError( + f"_ctf_params not implemented for {self.__class__.__name__}" + ) + class DualFilter(Filter): """ @@ -157,6 +214,15 @@ def __init__(self, filter_in): def evaluate(self, omega, **kwargs): return self._filter.evaluate(-omega, **kwargs) + def __len__(self): + return len(self._filter) + + def _ctf_params(self): + """ + Return n_filters-by-n_param array from prior filter. + """ + return self._filter._ctf_params() + class FunctionFilter(Filter): """ @@ -235,6 +301,18 @@ def evaluate_grid(self, L, *args, dtype=np.float32, **kwargs): return filter_vals**self._power + def __len__(self): + return len(self._filter) + + def _ctf_params(self): + """ + Return n_filters-by-n_param array from prior filter. + """ + return self._filter._ctf_params() + + def __getitem__(self, item): + return PowerFilter(self._filter[item], power=self._power, epsilon=self._epsilon) + class LambdaFilter(Filter): """ @@ -249,6 +327,21 @@ def __init__(self, filter, f): def _evaluate(self, omega, **kwargs): return self._f(self._filter.evaluate(omega, **kwargs)) + def __len__(self): + """ + Return length of underlying filter stack + """ + return len(self._filter) + + def __getitem__(self, item): + return LambdaFilter(self._filter[item], self._f) + + def _ctf_params(self): + """ + Return n_filters-by-n_param array from prior filter. + """ + return self._filter._ctf_params() + class MultiplicativeFilter(Filter): """ @@ -258,13 +351,57 @@ class MultiplicativeFilter(Filter): def __init__(self, *args): super().__init__(dim=args[0].dim, radial=all(c.radial for c in args)) self._components = args + self._init_size() + + def _init_size(self): + """ + Check sizes of _components are coherent and initialize resulting length. + """ + filter_lengths = [len(f) for f in self._components] + filter_lengths = np.unique(filter_lengths) # defaults to sorted=True + + # Code should be able to broadcast n_filters with n_filters, or n_filters with 1_filters. + # Any other combination is considered an error. + if len(filter_lengths) > 2 or ( + (len(filter_lengths) == 2) and (filter_lengths[0] != 1) + ): + raise RuntimeError(f"Incoherent filter lengths {filter_lengths}") + # filter_lengths is sorted, so this should be the larger of two values, + # or the single value in the 1_filter case + self._n = filter_lengths[-1] def _evaluate(self, omega, **kwargs): res = 1 for c in self._components: - res *= c.evaluate(omega, **kwargs) + res = res * c.evaluate(omega, **kwargs) return res + def __len__(self): + return self._n + + def _ctf_params(self): + """ + Return n_filters-by-n_param array from prior filter. + + Raises error if multiple or none found. + """ + _params = [] + for c in self._components: + try: + _params.append(c._ctf_params()) + except NotImplementedError: + pass + + if len(_params) > 1: + raise RuntimeError("Multiple filters with CTF parameters found.") + elif len(_params) == 0: + raise RuntimeError("No CTF parameters found.") + + return _params[0] + + def __getitem__(self, item): + return MultiplicativeFilter(*list(c[item] for c in self._components)) + class ScaledFilter(Filter): """ @@ -287,6 +424,18 @@ def __str__(self): """ return f"ScaledFilter (scales {self._filter} by {self._scale})" + def __getitem__(self, item): + return ScaledFilter(self._filter[item], self._scale) + + def __len__(self): + return len(self._filter) + + def _ctf_params(self): + """ + Return n_filters-by-n_param array from prior filter. + """ + return self._filter._ctf_params() + class ArrayFilter(Filter): def __init__(self, xfer_fn_array): @@ -383,6 +532,11 @@ def evaluate_grid(self, L, *args, dtype=np.float32, **kwargs): res = super().evaluate_grid(L, *args, dtype=dtype, **kwargs) return res + def __getitem__(self, item): + # Note, could extend to a stack dimension and lookup. + # For now, we have no use case for that. + return self + class ScalarFilter(Filter): def __init__(self, dim=None, value=1): @@ -395,6 +549,9 @@ def __repr__(self): def _evaluate(self, omega, **kwargs): return self.value * np.ones_like(omega) + def __getitem__(self, item): + return self + class ZeroFilter(ScalarFilter): def __init__(self, dim=None): @@ -438,21 +595,85 @@ def __init__( :param alpha: Amplitude contrast phase in radians :param B: Envelope decay in inverse square angstrom (default 0) """ - super().__init__(dim=2, radial=defocus_u == defocus_v) - self.voltage = voltage - self.wavelength = voltage_to_wavelength(self.voltage) - self.defocus_u = defocus_u - self.defocus_v = defocus_v - self.defocus_ang = defocus_ang - self.Cs = Cs - self.alpha = alpha - self.B = B + super().__init__(dim=2, radial=np.all(defocus_u == defocus_v)) + voltage = np.atleast_1d(voltage) # maybe allow singleton here for V + defocus_u = np.atleast_1d(defocus_u) + defocus_v = np.atleast_1d(defocus_v) + defocus_ang = np.atleast_1d(defocus_ang) + Cs = np.atleast_1d(Cs) + alpha = np.atleast_1d(alpha) + B = np.atleast_1d(B) + + # TODO check all sizes match + self.n = max( + len(voltage), + len(defocus_u), + len(defocus_v), + len(defocus_ang), + len(Cs), + len(alpha), + len(B), + ) + + self.voltage = self._to_full(voltage) + self.defocus_u = self._to_full(defocus_u) + self.defocus_v = self._to_full(defocus_v) + self.defocus_ang = self._to_full(defocus_ang) + self.Cs = self._to_full(Cs) + self.alpha = self._to_full(alpha) + self.B = self._to_full(B) + + # derived value + # todo, check/fix broadcast in voltage_to_wavelength + self.wavelength = np.array([voltage_to_wavelength(v) for v in self.voltage]) + + def _to_full(self, vals): + if len(vals) == self.n: + return vals + elif len(vals) == 1: + return np.full(self.n, fill_value=vals[0]) + else: + raise RuntimeError("dont do that") + + def __getitem__(self, items): + return CTFFilter( + self.voltage[items], + # self.wavelength[items], + self.defocus_u[items], + self.defocus_v[items], + self.defocus_ang[items], + self.Cs[items], + self.alpha[items], + self.B[items], + ) - # Convert angstrom to nm and divide by 2 - self._defocus_mean_nm = 0.05 * (self.defocus_u + self.defocus_v) - self._defocus_diff_nm = 0.05 * (self.defocus_u - self.defocus_v) + def __len__(self): + """ + Return stack length + """ + return self.n + + def _ctf_params(self): + """ + Return n_filters-by-n_param array. + """ + return np.array( + [ + self.voltage, + self.defocus_u, + self.defocus_v, + self.defocus_ang, + self.Cs, + self.alpha, + self.B, + ] + ).T def _evaluate(self, omega, **kwargs): + indices = kwargs.get("indices", None) + if indices is None: + indices = np.arange(self.n) + # Ensure we have a pixel size, pixel_size = kwargs.get("pixel_size", None) if pixel_size is None: @@ -462,6 +683,22 @@ def _evaluate(self, omega, **kwargs): # and that it is a floating point value. pixel_size = float(pixel_size) + return self.ctf_formula( + omega, + pixel_size, + self.voltage[indices], + self.defocus_u[indices], + self.defocus_v[indices], + self.defocus_ang[indices], + self.Cs[indices], + self.alpha[indices], + self.B[indices], + ) + + @staticmethod + def ctf_formula( + omega, pixel_size, voltage, defocus_u_a, defocus_v_a, defocus_ang, Cs, alpha, B + ): # Reference MATLAB code, includes reference to paper # Mindell, J. A.; Grigorieff, N. (2003). # https://github.com/PrincetonUniversity/aspire/blob/760a43b35453e55ff2d9354339e9ffa109a25371/projections/cryo_CTF_Relion.m#L34 @@ -472,6 +709,26 @@ def _evaluate(self, omega, **kwargs): # and further rescale the radii `s` by half below. # # Additionally we upcast so downstream computations remain in doubles. + + # First prepare arrays for broadcasting. + voltage = xp.atleast_1d(voltage)[:, None] + defocus_u_a = xp.atleast_1d(defocus_u_a)[:, None] + defocus_v_a = xp.atleast_1d(defocus_v_a)[:, None] + defocus_ang = xp.atleast_1d(defocus_ang)[:, None] + Cs = xp.atleast_1d(Cs)[:, None] + alpha = xp.atleast_1d(alpha)[:, None] + B = xp.atleast_1d(B)[:, None] + + # If omega is on GPU, move the params over + if isinstance(omega, xp.ndarray) and not isinstance(omega, np.ndarray): + voltage = xp.asarray(voltage) + defocus_u_a = xp.array(defocus_u_a) + defocus_v_a = xp.array(defocus_v_a) + defocus_ang = xp.array(defocus_ang) + Cs = xp.array(Cs) + alpha = xp.array(alpha) + B = xp.array(B) + x, y = omega.astype(np.float64, copy=False) / np.pi # Returns radii such that when multiplied by the @@ -479,32 +736,54 @@ def _evaluate(self, omega, **kwargs): # corresponding to each pixel in our nxn grid. theta, s = cart2pol(x, y) s = s / 2 + theta = theta[None, :] + s = s[None, :] # Wavelength in nm. - lamb = 1.22639 / np.sqrt(self.voltage * 1000 + 0.97845 * self.voltage**2) + lamb = 1.22639 / np.sqrt(voltage * 1000 + 0.97845 * voltage**2) # Divide by 10 to make pixel size in nm. BW is the # bandwidth of the signal corresponding to the given pixel size. BW = 1 / (pixel_size / 10) s = s * BW - DFavg = self._defocus_mean_nm # (DefocusU+DefocusV)/2 - DFdiff = self._defocus_diff_nm # (DefocusU-DefocusV) + DFavg = 0.05 * (defocus_u_a + defocus_v_a) # (u+v)/2 * 1nm/10A + DFdiff = 0.05 * (defocus_u_a - defocus_v_a) # (u-v)/2 * 1nm/10A # Note division by 2 is pre-computed in _defocus_diff_nm - df = DFavg + DFdiff * np.cos(2 * (theta - self.defocus_ang)) - + df = DFavg + DFdiff * np.cos(2 * (theta - defocus_ang)) k2 = np.pi * lamb * df # 10*6 converts Cs from mm to nm. - k4 = np.pi / 2 * 10**6 * self.Cs * lamb**3 + k4 = np.pi / 2 * 10**6 * Cs * lamb**3 chi = k4 * s**4 - k2 * s**2 - h = np.sqrt(1 - self.alpha**2) * np.sin(chi) - self.alpha * np.cos(chi) + alpha = alpha + h = np.sqrt(1 - alpha**2) * np.sin(chi) - alpha * np.cos(chi) - if self.B: - h *= np.exp(-self.B * s**2) + if np.any(B): + h *= np.exp(-B * s**2) return h + def to_radial(self): + """ + Return a new `RadialCTFFilter` with the mean of astigmatic defocus values. + + :return: `RadialCTFFilter` + """ + mean_defocus = (self.defocus_u + self.defocus_v) / 2 + return RadialCTFFilter( + voltage=self.voltage, + defocus=mean_defocus, + Cs=self.Cs, + alpha=self.alpha, + B=self.B, + ) + + def __eq__(self, other): + if len(self) != len(other): + return False + return self._ctf_params() == other._ctf_params() + class RadialCTFFilter(CTFFilter): def __init__(self, voltage=200, defocus=15000, Cs=2.26, alpha=0.07, B=0): @@ -518,6 +797,14 @@ def __init__(self, voltage=200, defocus=15000, Cs=2.26, alpha=0.07, B=0): B=B, ) + def to_radial(self): + """ + `RadialCTFFilter` is already radial, Returns self. Supports code interop. + + :return: `RadialCTFFilter` + """ + return self + class BlueFilter(Filter): """ diff --git a/src/aspire/source/coordinates.py b/src/aspire/source/coordinates.py index 83d69b72b8..1eacfdc96c 100644 --- a/src/aspire/source/coordinates.py +++ b/src/aspire/source/coordinates.py @@ -176,7 +176,7 @@ def __init__( # set CTF metadata to defaults # this can be updated with import_ctf() self.filter_indices = np.zeros(self.n, dtype=int) - self.unique_filters = [IdentityFilter()] + self.filter_stack = IdentityFilter() self.set_metadata("__filter_indices", np.zeros(self.n, dtype=int)) # populate __mrc_filename and __mrc_index @@ -406,25 +406,15 @@ def _extract_ctf(self, data_block): ) # convert defocus_ang from degrees to radians - filter_params[:, 3] *= np.pi / 180.0 + filter_params[:, 3] = np.deg2rad(filter_params[:, 3]) # Warn if CTF pixel_sizes do match self.pixel_size ctf_pixel_sizes = np.unique(filter_params[:, 6]) check_pixel_size(ctf_pixel_sizes, self.pixel_size) # construct filters - self.unique_filters = [ - CTFFilter( - voltage=filter_params[i, 0], - defocus_u=filter_params[i, 1], - defocus_v=filter_params[i, 2], - defocus_ang=filter_params[i, 3], - Cs=filter_params[i, 4], - alpha=filter_params[i, 5], - B=self.B, - ) - for i in range(len(filter_params)) - ] + # drop pixel size column + self.filter_stack = CTFFilter(*(filter_params[:, :6]).T, B=self.B) # set metadata for mrc_idx, filter_index in enumerate(indices): diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index 0660f402c5..8113ae7073 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -22,13 +22,7 @@ Pipeline, ) from aspire.noise import LegacyNoiseEstimator, NoiseEstimator, WhiteNoiseEstimator -from aspire.operators import ( - CTFFilter, - Filter, - IdentityFilter, - MultiplicativeFilter, - PowerFilter, -) +from aspire.operators import CTFFilter, Filter, MultiplicativeFilter, PowerFilter from aspire.storage import MrcStats, StarFile from aspire.utils import ( Rotation, @@ -214,7 +208,7 @@ def __init__( self._populate_pixel_size(pixel_size) self._populate_symmetry_group(symmetry_group) - self.unique_filters = [] + self.filter_stack = None self.generation_pipeline = Pipeline(xforms=None, memory=memory) logger.info(f"Creating {self.__class__.__name__} with {len(self)} images.") @@ -443,7 +437,10 @@ def n_ctf_filters(self): """ Return the number of CTFFilters found in this Source. """ - return len([f for f in self.unique_filters if isinstance(f, CTFFilter)]) + n = 0 + if isinstance(self.filter_stack, CTFFilter): + n = len(self.filter_stack) + return n @property def states(self): @@ -785,6 +782,11 @@ def _apply_filters( im = im_orig.copy() + if filters is None: + return im + + # else evaluate filters + # TODO broadcast filter eval for i, filt in enumerate(filters): idx_k = np.where(indices == i)[0] if len(idx_k) > 0: @@ -795,7 +797,7 @@ def _apply_filters( def _apply_source_filters(self, im_orig, indices): return self._apply_filters( im_orig, - self.unique_filters, + self.filter_stack, self.filter_indices[indices], ) @@ -869,7 +871,8 @@ def downsample(self, L, zero_nyquist=True, centered_fft=True): ) ds_factor = self.L / L - self.unique_filters = [f.scale(ds_factor) for f in self.unique_filters] + if self.filter_stack is not None: + self.filter_stack = self.filter_stack.scale(ds_factor) if self.pixel_size is not None: self.pixel_size *= ds_factor @@ -955,11 +958,12 @@ def whiten(self, noise_estimate=None, epsilon=None): logger.info("Whitening source object") whiten_filter = PowerFilter(noise_filter, power=-0.5, epsilon=epsilon) - logger.info("Transforming all CTF Filters into Multiplicative Filters") - self.unique_filters = [ - MultiplicativeFilter(f, whiten_filter) for f in self.unique_filters - ] - logger.info("Adding Whitening Filter Xform to end of generation pipeline") + logger.info("Extending filter stack by whitening filter") + if self.filter_stack is not None: + self.filter_stack = MultiplicativeFilter(self.filter_stack, whiten_filter) + else: + self.filter_stack = whiten_filter + logger.info("Adding whitening FilterXform to end of generation pipeline") self.generation_pipeline.add_xform(FilterXform(whiten_filter)) @_as_copy @@ -993,6 +997,16 @@ def legacy_whiten(self, noise_response=None, delta=None, batch_size=512): if delta is None: delta = np.finfo(np.float32).eps + # # TODO This "should be better" but totally breaks things. + # # First guess would be to check the strange normalization. + # # Can't fix everything at once. + # logger.info(f"Extending filter stack by legacy whitening Filter") + # whiten_filter = ArrayFilter(psd) + # if self.filter_stack is not None: + # self.filter_stack = MultiplicativeFilter(self.filter_stack, whiten_filter) + # else: + # self.filter_stack = whiten_filter + logger.info("Adding LegacyWhiten Filter Xform to end of generation pipeline.") self.generation_pipeline.add_xform(LegacyWhiten(psd, delta)) @@ -1005,8 +1019,8 @@ def phase_flip(self): logger.info("Perform phase flip on source object") - if len(self.unique_filters) >= 1: - unique_xforms = [FilterXform(f.sign) for f in self.unique_filters] + if self.filter_stack is not None: + unique_xforms = FilterXform(self.filter_stack.sign) logger.info("Adding Phase Flip Xform to end of generation pipeline") self.generation_pipeline.add_xform( @@ -1781,18 +1795,22 @@ def __init__(self, src, indices, memory=None): pixel_size=src.pixel_size, ) - if src.unique_filters: + if src.filter_stack is not None: # Remap the filter indices to be unique. # Removes duplicates and filters that are unused in new source. _filter_indices = src.filter_indices[self.index_map] # _unq[_inv] reconstructs _filter_indices _unq, _inv = np.unique(_filter_indices, return_inverse=True) - # Repack unique_filters + # Repack filter_stack self.filter_indices = _inv - self.unique_filters = [copy.copy(src.unique_filters[i]) for i in _unq] + # This would work by slicing with current code, + # but if future code mutated the filter objects, that would be a problem. + # Copy for safety/intent. + # Deep copy may be required if future code mutates underlying objects. + self.filter_stack = copy.copy(src.filter_stack[_unq]) else: # Pass through the None case - self.unique_filters = src.unique_filters + self.filter_stack = src.filter_stack self.filter_indices = np.zeros(self.n, dtype=int) # Any further operations should not mutate this instance. @@ -2029,7 +2047,7 @@ def __init__( # Create filter indices, these are required to pass unharmed through filter eval code # that is potentially called by other methods later. self.filter_indices = np.zeros(self.n, dtype=int) - self.unique_filters = [IdentityFilter()] + self.filter_stack = None # Optionally populate angles/rotations. if angles is not None: diff --git a/src/aspire/source/micrograph.py b/src/aspire/source/micrograph.py index a0e11e61ae..6a1b386f90 100644 --- a/src/aspire/source/micrograph.py +++ b/src/aspire/source/micrograph.py @@ -393,10 +393,7 @@ def __init__( self.filter_indices = None if ctf_filters is not None: acceptable_lens = [1, self.micrograph_count, self.total_particle_count] - if ( - not isinstance(ctf_filters, list) - or len(ctf_filters) not in acceptable_lens - ): + if len(ctf_filters) not in acceptable_lens: raise TypeError( f"`ctf_filters` expects a list of len {acceptable_lens[0]}," f" {acceptable_lens[1]}, or {acceptable_lens[2]}." @@ -424,7 +421,7 @@ def __init__( offsets=0, amplitudes=self.particle_amplitudes, angles=self.projection_angles, - unique_filters=ctf_filters, + filter_stack=ctf_filters, filter_indices=self.filter_indices, pixel_size=self.pixel_size, dtype=self.dtype, @@ -699,7 +696,7 @@ def save(self, path, name_prefix="micrograph", overwrite=True): # CTF ctf_metadata = dict() - if self.simulation.unique_filters: + if self.simulation.filter_stack: ctf_metadata = self.simulation.get_metadata( metadata_fields=_meta_fields["ctf"], indices=self.get_particle_indices(m), diff --git a/src/aspire/source/relion.py b/src/aspire/source/relion.py index 963e7ae427..fd37e7d513 100644 --- a/src/aspire/source/relion.py +++ b/src/aspire/source/relion.py @@ -146,23 +146,12 @@ def __init__( return_inverse=True, axis=0, ) - filters = [] - # for each unique CTF configuration, create a CTFFilter object - for row in filter_params: - filters.append( - CTFFilter( - voltage=row[0], - defocus_u=row[1], - defocus_v=row[2], - defocus_ang=row[3] * np.pi / 180, # degrees to radians - Cs=row[4], - alpha=row[5], - B=B, - ) - ) - self.unique_filters = filters + # Convert `defocus_ang` from degrees to radians + filter_params[:, 3] = np.deg2rad(filter_params[:, 3]) + # Create a CTFFilter stack + self.filter_stack = CTFFilter(*filter_params.T, B=B) # filter_indices stores, for each particle index, the index in - # self.unique_filters of the filter that should be applied + # self.filter_stack of the filter that should be applied self.filter_indices = filter_indices # If we detect ASPIRE added dummy variables, log and initialize identity filter @@ -170,7 +159,7 @@ def __init__( logger.info( "Detected ASPIRE-generated dummy optics; initializing identity filters." ) - self.unique_filters = [IdentityFilter()] + self.filter_stack = IdentityFilter() self.filter_indices = np.zeros(self.n, dtype=int) # We have provided some, but not all the required params @@ -182,7 +171,7 @@ def __init__( # If no CTF info in STAR, we initialize the filter values of metadata with default values else: - self.unique_filters = [IdentityFilter()] + self.filter_stack = IdentityFilter() self.filter_indices = np.zeros(self.n, dtype=int) logger.info(f"Populated {self.n_ctf_filters} CTFFilters from '{filepath}'") diff --git a/src/aspire/source/simulation.py b/src/aspire/source/simulation.py index b5de869a67..7c13366865 100644 --- a/src/aspire/source/simulation.py +++ b/src/aspire/source/simulation.py @@ -22,7 +22,7 @@ class Simulation(ImageSource): `metadata`. The images are generated via projections of a supplied `Volume` object, `vols`, over orientations define by the Euler angles, `angles`. Various types of corruption, such as noise and CTF effects, can be added to the images by supplying a `Filter` object to the `noise_filter` or - `unique_filters` arguments. + `filter_stack` arguments. """ def __init__( @@ -31,7 +31,7 @@ def __init__( n=1024, vols=None, states=None, - unique_filters=None, + filter_stack=None, filter_indices=None, offsets=None, amplitudes=None, @@ -54,9 +54,9 @@ def __init__( Default is generated with `volume.volume_synthesis.AsymmetricVolume`. :param states: A 1d array of n integers in the interval [0, C). The i'th integer indicates the volume stack index used to produce the i'th projection image. Default is a random set. - :param unique_filters: A list of Filter objects to be applied to projection images. + :param filter_stack: A Filter object to be applied to projection images. :param filter_indices: A 1d array of n integers indicating the `unique_filter` indices associated - with each image. Default is a random set of filter indices, .ie the filters from `unique_filters` + with each image. Default is a random set of filter indices, .ie the filters from `filter_stack` are randomly assigned to the stack of images. :param offsets: A n-by-2 array of coordinates to offset the images. Default is a normally distributed set of offsets. Set `offsets = 0` to disable offsets. @@ -158,17 +158,15 @@ def __init__( self.angles = self._init_angles(angles) - if unique_filters is None: - unique_filters = [] - self.unique_filters = unique_filters + self.filter_stack = filter_stack # sim_filters must be a deep copy so that it is not changed - # when unique_filters is changed - self.sim_filters = copy.deepcopy(unique_filters) + # when filter_stack is changed + self.sim_filters = copy.deepcopy(filter_stack) # Create filter indices and fill the metadata based on unique filters - if unique_filters: + if filter_stack is not None: if filter_indices is None: - filter_indices = randi(len(unique_filters), n, seed=seed) - 1 + filter_indices = randi(len(filter_stack), n, seed=seed) - 1 self._populate_ctf_metadata(filter_indices) self.filter_indices = filter_indices else: @@ -220,34 +218,32 @@ def _populate_ctf_metadata(self, filter_indices): # for these columns # # class attributes of CTFFilter: - CTFFilter_attributes = ( - "voltage", - "defocus_u", - "defocus_v", - "defocus_ang", - "Cs", - "alpha", - ) - - # get the CTF parameters, if they exist, for each filter - # and for each image (indexed by filter_indices) - filter_values = np.zeros((len(filter_indices), len(CTFFilter_attributes))) - for i, filt in enumerate(self.unique_filters): - filter_values[filter_indices == i] = [ - getattr(filt, att, np.nan) for att in CTFFilter_attributes - ] + CTFFilter_attributes = [ + "_rlnVoltage", + "_rlnDefocusU", + "_rlnDefocusV", + "_rlnDefocusAngle", + "_rlnSphericalAberration", + "_rlnAmplitudeContrast", + ] + + # Unpack the `filter_stack` params across images using `filter_indices` mapping + # Note this does not include the B factor term (hence the truncation) + # B factor term looks unique to ASPIRE, should we add to star if used? + filter_stack_params = self.filter_stack._ctf_params()[ + :, :6 + ] # params per filter + image_filter_values = np.zeros( + (len(filter_indices), len(CTFFilter_attributes)) + ) # params per image + for i, params in enumerate(filter_stack_params): + # assign `params` to all matching images + image_filter_values[filter_indices == i] = params # set the corresponding Relion metadata values that we would expect # from a STAR file self.set_metadata( - [ - "_rlnVoltage", - "_rlnDefocusU", - "_rlnDefocusV", - "_rlnDefocusAngle", - "_rlnSphericalAberration", - "_rlnAmplitudeContrast", - ], - filter_values, + CTFFilter_attributes, + image_filter_values, ) @property diff --git a/src/aspire/utils/units.py b/src/aspire/utils/units.py index fd40864795..97d668f579 100644 --- a/src/aspire/utils/units.py +++ b/src/aspire/utils/units.py @@ -3,7 +3,6 @@ """ import logging -import math import numpy as np @@ -43,7 +42,7 @@ def voltage_to_wavelength(voltage): a = float(12.264259661581491) b = float(0.9784755917869367) - return a / math.sqrt(voltage * 1e3 + b * voltage**2) + return a / np.sqrt(voltage * 1e3 + b * voltage**2) def wavelength_to_voltage(wavelength): @@ -56,4 +55,4 @@ def wavelength_to_voltage(wavelength): a = float(12.264259661581491) b = float(0.9784755917869367) - return (-1e3 + math.sqrt(1e6 + 4 * a**2 * b / wavelength**2)) / (2 * b) + return (-1e3 + np.sqrt(1e6 + 4 * a**2 * b / wavelength**2)) / (2 * b) diff --git a/tests/test_FFBbasis2D.py b/tests/test_FFBbasis2D.py index e0132971cf..b21ecb9c11 100644 --- a/tests/test_FFBbasis2D.py +++ b/tests/test_FFBbasis2D.py @@ -7,8 +7,9 @@ from aspire.basis import Coef, FFBBasis2D from aspire.nufft import all_backends +from aspire.operators import RadialCTFFilter from aspire.source import Simulation -from aspire.utils.misc import grid_2d +from aspire.utils import grid_2d, utest_tolerance from aspire.volume import Volume from ._basis_util import Steerable2DMixin, UniversalBasisMixin, basis_params_2d @@ -161,3 +162,52 @@ def testHighResFFBBasis2D(L, dtype): np.testing.assert_allclose( im_ffb.asnumpy()[0][mask], im.asnumpy()[0][mask], rtol=1e-05, atol=1e-4 ) + + +def test_bulk_expand_radial_vec(): + """ + For a given stack of radial vectors (such as from + RadialCTFFilters) `expand_radial_vec` should return equivalent + result as calling filter_stack_to_basis_mats on filter_stack, + and calling filter_to_basis_mat on each filter. + """ + + L = 32 + dtype = np.float32 + basis = FFBBasis2D(L, dtype=dtype) + pixel_size = 1.23 + + filters = RadialCTFFilter(defocus=np.linspace(10000, 15000, 3)) + stack_references = basis.filter_stack_to_basis_mats(filters, pixel_size=pixel_size) + references = [basis.filter_to_basis_mat(f, pixel_size=pixel_size) for f in filters] + + # Stack of all filter params + params = filters._ctf_params() + # Stack of filter values + _filter_vals = RadialCTFFilter.ctf_formula( + basis._filter_pts, pixel_size, *(params.T) + ) + + # Stack result + results = basis.expand_radial_vec(_filter_vals) + # Sequential result + results2 = [basis.expand_radial_vec(f)[0] for f in _filter_vals] + + # expand_radial_vec should be same as itself called sequentially + assert len(results2) == len(results) + for res, ref in zip(results2, results): + np.testing.assert_allclose(res.dense(), ref.dense()) + + # and should be equivalent to calling filter_to_basis_mat (for this radial case) + assert len(results) == len(references) + for res, ref in zip(results, references): + np.testing.assert_allclose( + res.dense(), ref.dense(), atol=utest_tolerance(dtype) + ) + + # The list from filter_to_basis_mat should be equivalent to list filter_stack_to_basis_mats + assert len(references) == len(stack_references) + for res, ref in zip(stack_references, references): + np.testing.assert_allclose( + res.dense(), ref.dense(), atol=utest_tolerance(dtype) + ) diff --git a/tests/test_anisotropic_noise.py b/tests/test_anisotropic_noise.py index ee699cb20f..ccfc824e8d 100644 --- a/tests/test_anisotropic_noise.py +++ b/tests/test_anisotropic_noise.py @@ -19,9 +19,7 @@ def setUp(self): self.sim = _LegacySimulation( n=1024, vols=self.vol, - unique_filters=[ - RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7) - ], + filter_stack=RadialCTFFilter(defocus=np.linspace(1.5e4, 2.5e4, 7)), dtype=self.dtype, ) diff --git a/tests/test_array_image_source.py b/tests/test_array_image_source.py index a2c3ee4f0c..fc2735d15c 100644 --- a/tests/test_array_image_source.py +++ b/tests/test_array_image_source.py @@ -10,7 +10,6 @@ from aspire.basis import FBBasis3D from aspire.image import Image -from aspire.operators import IdentityFilter from aspire.reconstruction import MeanEstimator from aspire.source import ArrayImageSource, RelionSource, Simulation from aspire.utils import Rotation, utest_tolerance @@ -31,7 +30,6 @@ def setUp(self): self.sim = sim = Simulation( n=self.n, L=self.resolution, - unique_filters=[IdentityFilter()], seed=0, dtype=self.dtype, # We'll use random angles diff --git a/tests/test_batched_covar2d.py b/tests/test_batched_covar2d.py index f239354700..70906425d7 100644 --- a/tests/test_batched_covar2d.py +++ b/tests/test_batched_covar2d.py @@ -34,7 +34,7 @@ def setUp(self): self.src = Simulation( L, n, - unique_filters=self.filters, + filter_stack=self.filters, pixel_size=5, dtype=self.dtype, noise_adder=noise_adder, @@ -253,10 +253,9 @@ class BatchedRotCov2DTestCaseCTF(BatchedRotCov2DTestCase): @property def filters(self): - return [ - RadialCTFFilter(200, defocus=d, Cs=2.0, alpha=0.1) - for d in np.linspace(1.5e4, 2.5e4, 7) - ] + return RadialCTFFilter( + 200, defocus=np.linspace(1.5e4, 2.5e4, 7), Cs=2.0, alpha=0.1 + ) @property def ctf_idx(self): @@ -266,5 +265,5 @@ def ctf_idx(self): def ctf_basis(self): return [ self.basis.filter_to_basis_mat(f, pixel_size=self.src.pixel_size) - for f in self.src.unique_filters + for f in self.src.filter_stack ] diff --git a/tests/test_coordinate_source.py b/tests/test_coordinate_source.py index 3b0f29f1ee..4838f75a9d 100644 --- a/tests/test_coordinate_source.py +++ b/tests/test_coordinate_source.py @@ -621,42 +621,39 @@ def testImportCtfFromRelionLegacy(self): def _testCtfFilters(self, src, uniform_pixel_sizes=True): # there are two micrographs and two CTF files, so there should be two # unique CTF filters - self.assertEqual(len(src.unique_filters), 2) + self.assertEqual(len(src.filter_stack), 2) # test the properties of the CTF filters # based on the arbitrary values we added to the CTF files # note these values are not realistic - filter0 = src.unique_filters[0] - self.assertTrue( - np.allclose( - np.array( - [ - 1000.0, - 900.0, - 800.0 * np.pi / 180.0, - 700.0, - 600.0, - 500.0, - ], - dtype=src.dtype, - ), - np.array( - [ - filter0.defocus_u, - filter0.defocus_v, - filter0.defocus_ang, - filter0.Cs, - filter0.alpha, - filter0.voltage, - ] - ), - ) - ) - filter1 = src.unique_filters[1] + filter0 = src.filter_stack[0] + np.testing.assert_allclose( + np.array( + [ + 1000.0, + 900.0, + 800.0 * np.pi / 180.0, + 700.0, + 600.0, + 500.0, + ], + dtype=src.dtype, + ), + np.array( + [ + filter0.defocus_u, + filter0.defocus_v, + filter0.defocus_ang, + filter0.Cs, + filter0.alpha, + filter0.voltage, + ] + ).flatten(), + ) + filter1 = src.filter_stack[1] pixel_size1 = self.pixel_size if not uniform_pixel_sizes: pixel_size1 += 0.01 - self.assertTrue( - np.allclose( + np.testing.assert_allclose( np.array( [ 1001.0, @@ -677,9 +674,8 @@ def _testCtfFilters(self, src, uniform_pixel_sizes=True): filter1.alpha, filter1.voltage, ] - ), + ).flatten(), ) - ) # the first 200 particles should correspond to the first filter # since they came from the first micrograph self.assertTrue( diff --git a/tests/test_covar2d.py b/tests/test_covar2d.py index 66e88c4ff8..5630606f30 100644 --- a/tests/test_covar2d.py +++ b/tests/test_covar2d.py @@ -79,22 +79,20 @@ def cov2d_fixture(volume, basis, ctf_enabled): n = 32 # Default CTF params - unique_filters = None + filters = None h_idx = None h_ctf_fb = None # Popluate CTF if ctf_enabled: - unique_filters = [ - RadialCTFFilter(200, defocus=d, Cs=2.0, alpha=0.1) - for d in np.linspace(1.5e4, 2.5e4, 7) - ] + filters = RadialCTFFilter( + 200, defocus=np.linspace(1.5e4, 2.5e4, 7), Cs=2.0, alpha=0.1 + ) # Copied from simulation defaults to match legacy test files. - h_idx = randi(len(unique_filters), n, seed=0) - 1 + h_idx = randi(len(filters), n, seed=0) - 1 h_ctf_fb = [ - basis.filter_to_basis_mat(f, pixel_size=volume.pixel_size) - for f in unique_filters + basis.filter_to_basis_mat(f, pixel_size=volume.pixel_size) for f in filters ] noise_adder = WhiteNoiseAdder(var=NOISE_VAR) @@ -102,7 +100,7 @@ def cov2d_fixture(volume, basis, ctf_enabled): sim = _LegacySimulation( n=n, vols=volume, - unique_filters=unique_filters, + filter_stack=filters, filter_indices=h_idx, offsets=0.0, amplitudes=1.0, diff --git a/tests/test_covar2d_denoiser.py b/tests/test_covar2d_denoiser.py index 8a4660c5c6..46db43a3ee 100644 --- a/tests/test_covar2d_denoiser.py +++ b/tests/test_covar2d_denoiser.py @@ -4,7 +4,7 @@ from aspire.basis import FBBasis2D, FFBBasis2D, FLEBasis2D, FPSWFBasis2D, PSWFBasis2D from aspire.denoising import DenoisedSource, DenoiserCov2D from aspire.noise import WhiteNoiseAdder -from aspire.operators import IdentityFilter, RadialCTFFilter +from aspire.operators import CTFFilter, IdentityFilter, RadialCTFFilter from aspire.source import Simulation from aspire.utils import utest_tolerance @@ -15,10 +15,10 @@ noise_var = 0.1848 noise_adder = WhiteNoiseAdder(var=noise_var) pixel_size = 5 -filters = [ - RadialCTFFilter(200, defocus=d, Cs=2.0, alpha=0.1) - for d in np.linspace(1.5e4, 2.5e4, 7) -] +d = np.linspace(1.5e4, 2.5e4, 7) +filters = CTFFilter( + 200, defocus_ang=np.pi / 3, defocus_u=d, defocus_v=d + 345, Cs=2.0, alpha=0.1 +) # For (F)PSWFBasis2D we get off-block entries which are truncated # when converting to block-diagonal. We filter these warnings. @@ -59,7 +59,7 @@ def sim(): sim = Simulation( L=img_size, n=num_imgs, - unique_filters=filters, + filter_stack=filters, offsets=0.0, amplitudes=1.0, dtype=dtype, @@ -222,15 +222,17 @@ def test_filter_to_basis_mat_id_expand(coef, basis): # IdentityFilter should produce id filt = IdentityFilter() - # Some basis do not provide alternative `method`s + # Some basis do not provide alternative `expand_method`s if isinstance(basis, FFBBasis2D) or isinstance(basis, FLEBasis2D): with pytest.raises(NotImplementedError, match=r".*not supported.*"): - _ = basis.filter_to_basis_mat(filt, method="expand") + _ = basis.filter_to_basis_mat(filt, expand_method="expand") return # Apply the basis filter operator. # Note transpose because `apply` expects and returns column vectors. - coef_ftbm = (basis.filter_to_basis_mat(filt, method="expand") @ coef.asnumpy().T).T + coef_ftbm = ( + basis.filter_to_basis_mat(filt, expand_method="expand") @ coef.asnumpy().T + ).T # Apply evaluate->filter->expand manually imgs = coef.evaluate() @@ -250,4 +252,4 @@ def test_filter_to_basis_mat_id_expand(coef, basis): def test_filter_to_basis_mat_bad(coef, basis): filt = IdentityFilter() with pytest.raises(NotImplementedError, match=r".*not supported.*"): - _ = basis.filter_to_basis_mat(filt, method="bad_method") + _ = basis.filter_to_basis_mat(filt, expand_method="bad_method") diff --git a/tests/test_covar3d.py b/tests/test_covar3d.py index 98d0fdea13..5b50ec2c56 100644 --- a/tests/test_covar3d.py +++ b/tests/test_covar3d.py @@ -28,9 +28,7 @@ def setUpClass(cls): cls.sim = _LegacySimulation( n=1024, vols=cls.vols, - unique_filters=[ - RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7) - ], + filter_stack=RadialCTFFilter(defocus=np.linspace(1.5e4, 2.5e4, 7)), dtype=cls.dtype, ) basis = FBBasis3D((8, 8, 8), dtype=cls.dtype) diff --git a/tests/test_downsample.py b/tests/test_downsample.py index 80a6f1e69a..b2da4806f3 100644 --- a/tests/test_downsample.py +++ b/tests/test_downsample.py @@ -234,17 +234,16 @@ def test_simulation_relion_downsample(): defocus_max = 25000 defocus_ct = 7 - ctf_filters = [ - RadialCTFFilter(defocus=d) - for d in np.linspace(defocus_min, defocus_max, defocus_ct) - ] + ctf_filters = RadialCTFFilter( + defocus=np.linspace(defocus_min, defocus_max, defocus_ct) + ) # Generate Simulation source and downsampled simulation. src = Simulation( L=64, n=10, C=1, - unique_filters=ctf_filters, + filter_stack=ctf_filters, noise_adder=WhiteNoiseAdder.from_snr(snr=1), pixel_size=1, ) diff --git a/tests/test_indexed_source.py b/tests/test_indexed_source.py index 80ada30003..1a0179d0cc 100644 --- a/tests/test_indexed_source.py +++ b/tests/test_indexed_source.py @@ -59,7 +59,7 @@ def test_repr(sim_fixture): @pytest.mark.expensive def test_filter_mapping(): """ - This test is designed to ensure that `unique_filters` and `filter_indices` + This test is designed to ensure that `filter_stack` and `filter_indices` are being remapped correctly upon slicing. Additionally it tests that a realistic preprocessing pipeline is equivalent @@ -81,18 +81,14 @@ def test_filter_mapping(): angles = Rotation(np.repeat(rots, 2, axis=0)).angles # Generate N//2 rotations and repeat indices - defoci = np.linspace(1000, 25000, N // 2) - ctf_filters = [ - CTFFilter( - 200, - defocus_u=defoci[d], - defocus_v=defoci[-d], - defocus_ang=np.pi / (N // 2) * d, - Cs=2.0, - alpha=0.1, - ) - for d in range(N // 2) - ] + ctf_filters = CTFFilter( + 200, + defocus_u=np.linspace(1000, 25000, N // 2), + defocus_v=np.linspace(1000, 25000, N // 2)[::-1], + defocus_ang=np.linspace(0, np.pi, N // 2), + Cs=2.0, + alpha=0.1, + ) ctf_indices = np.repeat(np.arange(N // 2), 2) # Construct the source @@ -101,7 +97,7 @@ def test_filter_mapping(): n=N, dtype=DT, seed=SEED, - unique_filters=ctf_filters, + filter_stack=ctf_filters, filter_indices=ctf_indices, angles=angles, offsets=0, diff --git a/tests/test_large_covar.py b/tests/test_large_covar.py new file mode 100644 index 0000000000..c698fc0a8c --- /dev/null +++ b/tests/test_large_covar.py @@ -0,0 +1,141 @@ +""" +This file contains a collection of parameterized source setups and calls +to the covariance component applied to real problem sizes and real data. + +It can time/excercise the code paths used for CWF denoising and class +averaging under different basis and CTF filter assumptions (ie radial +optimizations). +""" + +import os +import socket + +import numpy as np +import pytest + +from aspire.basis import FFBBasis2D, FLEBasis2D +from aspire.covariance import BatchedRotCov2D +from aspire.source import RelionSource + +DTYPES = [ + np.float32, + np.float64, +] + + +@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module") +def dtype(request): + return request.param + + +IMG_SIZES = [ + 128, + 179, +] + + +@pytest.fixture(params=IMG_SIZES, ids=lambda x: f"img_size={x}", scope="module") +def img_size(request): + return request.param + + +BASI = [ + FFBBasis2D, + FLEBasis2D, +] + + +@pytest.fixture(params=BASI, ids=lambda x: f"basis={x}", scope="module") +def basis(request, img_size, dtype): + return request.param(img_size, dtype=dtype) + + +RADIAL = [ + False, + True, +] + + +@pytest.fixture(params=RADIAL, ids=lambda x: f"force_radial={x}", scope="module") +def force_radial(request): + return request.param + + +MOLECULES = { + 10028: "10028/data/shiny_2sets_fixed9.star", + 11618: "11618/data/particles/J43_particles.star", +} + + +@pytest.fixture(params=MOLECULES.keys(), ids=lambda x: f"molecule={x}", scope="module") +def molecule(request): + return request.param + + +def _raw_data_path(): + """ + Attempt getting a working path to raw EMPIAR data location + """ + # Try to populate a path to raw data + raw_data_path = None + + # Check if we're on a known testing platform. + known_hosts = [ + "caf.math.princeton.edu.private", + "decaf.math.princeton.edu.private", + ] + # Default to their expected location + if socket.gethostname() in known_hosts: + raw_data_path = "/scratch/ExperimentalData/raw" + + # Check if a user has provided or overides the location + raw_data_path = os.environ.get("ASPIRE_RAW_DATA_PATH", raw_data_path) + + if raw_data_path is None: + raise RuntimeError("Must provide path to raw data") + if not os.path.exists(raw_data_path): + raise RuntimeError(f"Provided path {raw_data_path} does not exist.") + + return raw_data_path + + +@pytest.fixture(scope="module") +def preprocessed_src(img_size, molecule, force_radial, dtype): + starfile_path = os.path.join(_raw_data_path(), MOLECULES[molecule]) + if not os.path.exists(starfile_path): + raise RuntimeError(f"Expected starfile path {starfile_path} does not exist.") + + src = RelionSource(starfile_path, dtype=dtype) + + # preprocess + src = src.downsample(img_size).cache() + src = src.phase_flip().cache() + src = src.normalize_background().cache() + src = src.whiten().cache() + src = src.invert_contrast() + + # To run radially optimized code we need + # i) radial filters + # ii) set radial expand mode in cov2d + if force_radial: + src.filter_stack = src.filter_stack.to_radial() + + return src + + +@pytest.mark.covar +def test_covar2d(preprocessed_src, basis, force_radial): + + # To run radially optimized code we need + # i) radial filters + # ii) set radial expand mode in cov2d + expand_method = None # default for cov2d + if force_radial: + assert ( + preprocessed_src.filter_stack.radial + ), "Expected radial filters under `force_radial=True`" + expand_method = "radial" + + cov2d = BatchedRotCov2D(preprocessed_src, basis, expand_method=expand_method) + # smoke test + _ = cov2d.get_covar() diff --git a/tests/test_mean_estimator.py b/tests/test_mean_estimator.py index b56a06f4ce..db69ce5222 100644 --- a/tests/test_mean_estimator.py +++ b/tests/test_mean_estimator.py @@ -51,9 +51,7 @@ def sim(L, dtype): L=L, n=256, C=1, # single volume - unique_filters=[ - RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7) - ], + filter_stack=RadialCTFFilter(defocus=np.linspace(1.5e4, 2.5e4, 7)), dtype=dtype, seed=SEED, pixel_size=1.234, diff --git a/tests/test_micrograph_simulation.py b/tests/test_micrograph_simulation.py index 2e81d7326f..4620969cdf 100644 --- a/tests/test_micrograph_simulation.py +++ b/tests/test_micrograph_simulation.py @@ -304,7 +304,7 @@ def test_sim_save(): """ v = AsymmetricVolume(L=16, C=1, pixel_size=4, dtype=np.float64).generate() - ctfs = [RadialCTFFilter(voltage=200, defocus=15000, Cs=2.26, alpha=0.07, B=0)] + ctfs = RadialCTFFilter(voltage=200, defocus=15000, Cs=2.26, alpha=0.07, B=0) mg_sim = MicrographSimulation( volume=v, @@ -365,7 +365,7 @@ def test_save_overwrite(caplog): """ v = AsymmetricVolume(L=16, C=1, pixel_size=4, dtype=np.float64).generate() - ctfs = [RadialCTFFilter(voltage=200, defocus=15000, Cs=2.26, alpha=0.07, B=0)] + ctfs = RadialCTFFilter(voltage=200, defocus=15000, Cs=2.26, alpha=0.07, B=0) mg_sim = MicrographSimulation( volume=v, @@ -466,8 +466,5 @@ def test_bad_ctf(vol_fixture): particles_per_micrograph=1, micrograph_count=1, micrograph_size=512, - ctf_filters=[ - RadialCTFFilter(), - ] - * 2, # total particles == 1 + ctf_filters=RadialCTFFilter(defocus=[10000, 20000]), # total particles == 1 ) diff --git a/tests/test_preprocess_pipeline.py b/tests/test_preprocess_pipeline.py index 83cc070930..20716660ff 100644 --- a/tests/test_preprocess_pipeline.py +++ b/tests/test_preprocess_pipeline.py @@ -29,9 +29,7 @@ def get_sim_object(L, dtype): sim = Simulation( L=L, n=num_images, - unique_filters=[ - RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7) - ], + filter_stack=RadialCTFFilter(defocus=np.linspace(1.5e4, 2.5e4, 7)), noise_adder=noise_adder, pixel_size=1, dtype=dtype, diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 58bea32128..898479bf7a 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -128,9 +128,7 @@ def setUp(self): n=self.n, L=self.L, vols=self.vols, - unique_filters=[ - RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7) - ], + filter_stack=RadialCTFFilter(defocus=np.linspace(1.5e4, 2.5e4, 7)), noise_adder=WhiteNoiseAdder(var=1), dtype=self.dtype, ) @@ -176,9 +174,7 @@ def testSimulationCached(self): L=self.L, vols=self.vols, offsets=self.sim.offsets, - unique_filters=[ - RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7) - ], + filter_stack=RadialCTFFilter(defocus=np.linspace(1.5e4, 2.5e4, 7)), noise_adder=WhiteNoiseAdder(var=1), dtype=self.dtype, ) @@ -634,12 +630,10 @@ def test_simulation_save_optics_block(tmp_path): # Radial CTF Filters. Should make 3 distinct optics blocks kv_min, kv_max, kv_ct = 200, 300, 3 voltages = np.linspace(kv_min, kv_max, kv_ct) - ctf_filters = [RadialCTFFilter(voltage=kv) for kv in voltages] + ctf_filters = RadialCTFFilter(voltage=voltages) # Generate and save Simulation - sim = Simulation( - n=9, L=res, C=1, unique_filters=ctf_filters, pixel_size=1.34 - ).cache() + sim = Simulation(n=9, L=res, C=1, filter_stack=ctf_filters, pixel_size=1.34).cache() starpath = tmp_path / "sim.star" sim.save(starpath, overwrite=True) @@ -699,10 +693,10 @@ def test_simulation_slice_save_roundtrip(tmp_path): # Radial CTF Filters kv_min, kv_max, kv_ct = 200, 300, 3 voltages = np.linspace(kv_min, kv_max, kv_ct) - ctf_filters = [RadialCTFFilter(voltage=kv) for kv in voltages] + ctf_filters = RadialCTFFilter(voltage=voltages) # Generate and save slice of Simulation - sim = Simulation(n=9, L=16, C=1, unique_filters=ctf_filters, pixel_size=1.34) + sim = Simulation(n=9, L=16, C=1, filter_stack=ctf_filters, pixel_size=1.34) sliced_sim = sim[::2] save_path = tmp_path / "sliced_sim.star" sliced_sim.save(save_path, overwrite=True) @@ -787,14 +781,14 @@ def test_cached_image_accessors(): Test the behavior of image caching. """ # Create a CTF - ctf = [RadialCTFFilter()] + ctf = RadialCTFFilter() # Create a Simulation with noise and `ctf` src = Simulation( L=32, n=3, C=1, noise_adder=WhiteNoiseAdder(var=0.123), - unique_filters=ctf, + filter_stack=ctf, pixel_size=5, ) # Cache the simulation @@ -818,14 +812,14 @@ def test_projections_and_clean_images_downsample(): L = 32 L_ds = 21 px_sz = 1.23 - ctf = [RadialCTFFilter(1.5e4)] + ctf = RadialCTFFilter(1.5e4) src = Simulation( L=L, n=n, C=1, noise_adder=WhiteNoiseAdder(var=0.123), - unique_filters=ctf, + filter_stack=ctf, pixel_size=px_sz, ) @@ -921,7 +915,7 @@ def test_save_load_dummy_ctf_values(tmp_path, caplog): are present. These values should be detected upon reloading the source. """ star_path = tmp_path / "no_ctf.star" - sim = Simulation(n=8, L=16) # no unique_filters, ie. no CTF info + sim = Simulation(n=8, L=16) # no filter_stack, ie. no CTF info sim.save(star_path, overwrite=True) # STAR file should contain our fallback tag diff --git a/tests/test_simulation_metadata.py b/tests/test_simulation_metadata.py index c47efa1167..f5367038a6 100644 --- a/tests/test_simulation_metadata.py +++ b/tests/test_simulation_metadata.py @@ -20,9 +20,7 @@ def setUp(self): self.sim = MySimulation( n=1024, L=8, - unique_filters=[ - RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7) - ], + filter_stack=RadialCTFFilter(defocus=np.linspace(1.5e4, 2.5e4, 7)), ) def tearDown(self): diff --git a/tests/test_weighted_mean_estimator.py b/tests/test_weighted_mean_estimator.py index 96f11f2cab..b376891b0e 100644 --- a/tests/test_weighted_mean_estimator.py +++ b/tests/test_weighted_mean_estimator.py @@ -54,9 +54,7 @@ def sim(L, dtype): L=L, n=256, C=1, # single volume - unique_filters=[ - RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7) - ], + filter_stack=RadialCTFFilter(defocus=np.linspace(1.5e4, 2.5e4, 7)), pixel_size=1, dtype=dtype, seed=SEED, diff --git a/tox.ini b/tox.ini index 79d2566183..37b5bcfcbe 100644 --- a/tox.ini +++ b/tox.ini @@ -92,8 +92,9 @@ line_length = 88 testpaths = tests markers = expensive: mark a test as a long running test. + covar: extended (long) tests for covar components scheduled: tests that should only run in the scheduled workflow -addopts = -m "not expensive and not scheduled" +addopts = -m "not expensive and not scheduled and not covar" [gh-actions] python =