Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
2fa4bd0
add timing printouts within cov2d
garrettwrong Feb 26, 2026
ad8c109
add fetch and eval_t times
garrettwrong Feb 26, 2026
9e4b88e
stashing initial rad ctf port, trying compare to our filter
garrettwrong Mar 19, 2026
2bffae3
stashing to_radial and freq pt scaling patches
garrettwrong Mar 31, 2026
8f2a553
cleanup debugging logic a bit
garrettwrong Mar 31, 2026
b8ab9d9
continue cleanup
garrettwrong Mar 31, 2026
88089dc
use existing expand_method, add radial, instead of new flag
garrettwrong Apr 9, 2026
710bcf5
use existing expand_method, add radial, instead of new flag
garrettwrong Apr 9, 2026
ef830a4
cleanup
garrettwrong Apr 9, 2026
052acac
cleanup
garrettwrong Apr 9, 2026
85c96fa
remove warnings
garrettwrong Apr 9, 2026
8d2c1da
fix logic error
garrettwrong Apr 10, 2026
979def8
refix logic error
garrettwrong Apr 10, 2026
c2232fb
add tqdm to cov2d filter to basis mat
garrettwrong Apr 10, 2026
8106ecf
optimal fle basis comp
garrettwrong Apr 14, 2026
3cc687a
stub in bulk ctf code
garrettwrong Apr 16, 2026
b40e4f6
tests passing except FFB opts
garrettwrong Apr 21, 2026
5253d29
ffb equiv checkpoint
garrettwrong Apr 21, 2026
53026c2
cleanup
garrettwrong Apr 21, 2026
922536d
cleanup
garrettwrong Apr 22, 2026
2e34d5f
cleanup FFB doc strings
garrettwrong Apr 22, 2026
e9c61df
dtype sensitivity
garrettwrong Apr 22, 2026
33a7f3c
more cleanup, maybe CI passing
garrettwrong Apr 22, 2026
39d8211
more cleanup, maybe docs runs
garrettwrong Apr 22, 2026
83ac188
cleanup extra loop
garrettwrong Apr 23, 2026
751eb5e
add force diag option to cov2d
garrettwrong Apr 29, 2026
c998bf0
ctf stack unit test patches
garrettwrong May 6, 2026
9180224
minimal xform/pipeline patches
garrettwrong May 12, 2026
ccaf5b7
got both filter (ds) and filter stack running
garrettwrong May 14, 2026
e06fd26
initial attempt extending multiplicative filter bcast
garrettwrong May 14, 2026
5ac1825
hacktastic ctf param passthrough
garrettwrong May 18, 2026
497b75a
revert last approach in favor of using evaluate per dev meeting
garrettwrong May 21, 2026
9680610
rm unused var
garrettwrong May 21, 2026
ce309f8
satisfy tox
garrettwrong May 22, 2026
73a3cfd
stashing, got filter stack to basis mat eval working for ffb2d
garrettwrong May 22, 2026
cf839d6
begin filter_basis_mat cleanup
garrettwrong May 26, 2026
6725c17
continue filter_basis_mat cleanup
garrettwrong May 26, 2026
fd4a92a
initial documentation updates
garrettwrong May 27, 2026
7bafc0e
make np.unique call compat with older numpy
garrettwrong May 27, 2026
8905ddf
should been better, was not
garrettwrong Jun 2, 2026
841cc8d
first round cleanup
garrettwrong Jun 2, 2026
50695bc
cleanup unused rmat and move reshapes out of loop
garrettwrong Jun 2, 2026
966d1bf
cleanup additional ffb2d test
garrettwrong Jun 2, 2026
61a4839
cleanup some filter eval concerns
garrettwrong Jun 2, 2026
2e6dc56
tox cleanup
garrettwrong Jun 2, 2026
a0cf7ba
tox cleanup
garrettwrong Jun 4, 2026
37ad83e
missing xp
garrettwrong Jun 4, 2026
07dbf20
add large covar pytest file and some minor changes for one of the cases
garrettwrong Jun 10, 2026
4a2f961
add a top of file docstring to test
garrettwrong Jun 11, 2026
c743b13
more cleanup
garrettwrong Jun 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions gallery/experiments/save_simulation_relion_reconstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
# that RELION will recover as optics groups.

vol = emdb_2660()
ctf_filters = [RadialCTFFilter(defocus=d) for d in defocus]
ctf_filters = RadialCTFFilter(defocus=defocus)


# %%
Expand All @@ -64,7 +64,7 @@
sim = Simulation(
n=n_particles,
vols=vol,
unique_filters=ctf_filters,
filter_stack=ctf_filters,
noise_adder=WhiteNoiseAdder.from_snr(snr),
)
sim.save(star_path, overwrite=True)
Expand Down
12 changes: 7 additions & 5 deletions gallery/experiments/simulated_abinitio_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,19 @@ def noise_function(x, y):
alpha = 0.1 # Amplitude contrast

# Create filters
ctf_filters = [
RadialCTFFilter(pixel_size, voltage, defocus=d, Cs=2.0, alpha=0.1)
for d in np.linspace(defocus_min, defocus_max, defocus_ct)
]
ctf_filters = RadialCTFFilter(
voltage,
defocus=np.linspace(defocus_min, defocus_max, defocus_ct),
Cs=2.0,
alpha=0.1,
)

# Finally create the Simulation
src = Simulation(
n=num_imgs,
vols=og_v,
noise_adder=custom_noise,
unique_filters=ctf_filters,
filter_stack=ctf_filters,
)

# Downsample
Expand Down
7 changes: 2 additions & 5 deletions gallery/tutorials/aspire_introduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,10 +570,7 @@ def noise_function(x, y):
defocus_ct = 7

# Generate several CTFs.
ctf_filters = [
RadialCTFFilter(defocus=d)
for d in np.linspace(defocus_min, defocus_max, defocus_ct)
]
ctf_filters = RadialCTFFilter(defocus=np.linspace(defocus_min, defocus_max, defocus_ct))

# %%
# Combining into a Simulation
Expand All @@ -586,7 +583,7 @@ def noise_function(x, y):
amplitudes=1,
offsets=0,
noise_adder=white_noise_adder,
unique_filters=ctf_filters,
filter_stack=ctf_filters,
seed=42,
)

Expand Down
8 changes: 3 additions & 5 deletions gallery/tutorials/pipeline_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,8 @@
defocus_max = 25000
defocus_ct = 7

ctf_filters = [
RadialCTFFilter(defocus=d)
for d in np.linspace(defocus_min, defocus_max, defocus_ct)
]
ctf_filters = RadialCTFFilter(defocus=np.linspace(defocus_min, defocus_max, defocus_ct))


# %%
# Initialize Simulation Object
Expand All @@ -96,7 +94,7 @@
n=2500, # number of projections
vols=original_vol, # volume source
offsets=0, # Default: images are randomly shifted
unique_filters=ctf_filters,
filter_stack=ctf_filters,
noise_adder=WhiteNoiseAdder(var=0.0002), # desired noise variance
).cache()

Expand Down
17 changes: 9 additions & 8 deletions gallery/tutorials/tutorials/cov2d_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,13 @@

print("Initialize simulation object and CTF filters.")
# Create filters
ctf_filters = [
RadialCTFFilter(voltage, defocus=d, Cs=2.0, alpha=0.1)
for d in np.linspace(defocus_min, defocus_max, defocus_ct)
]
ctf_filters = RadialCTFFilter(
voltage,
defocus=np.linspace(defocus_min, defocus_max, defocus_ct),
Cs=2.0,
alpha=0.1,
)


# Load the map file of a 70S Ribosome
print(
Expand All @@ -89,7 +92,7 @@
L=img_size,
n=num_imgs,
vols=vols,
unique_filters=ctf_filters,
filter_stack=ctf_filters,
offsets=0.0,
amplitudes=1.0,
dtype=dtype,
Expand All @@ -109,9 +112,7 @@
h_idx = sim.filter_indices

# Evaluate CTF in the 8X8 FB basis
h_ctf_fb = [
ffbbasis.filter_to_basis_mat(filt, pixel_size=pixel_size) for filt in ctf_filters
]
h_ctf_fb = ffbbasis.filter_stack_to_basis_mats(ctf_filters, pixel_size=pixel_size)

# Get clean images from projections of 3D map.
print("Apply CTF filters to clean images.")
Expand Down
2 changes: 1 addition & 1 deletion gallery/tutorials/tutorials/cov3d_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
L=img_size,
n=num_imgs,
vols=vols,
unique_filters=[RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7)],
filter_stack=RadialCTFFilter(defocus=np.linspace(1.5e4, 2.5e4, 7)),
dtype=dtype,
)

Expand Down
7 changes: 3 additions & 4 deletions gallery/tutorials/tutorials/ctf.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,8 @@ def generate_example_image(L, noise_variance=0.1):

# Construct a range of CTF filters.
defoci = [2500, 5000, 10000, 20000]
ctf_filters = [
RadialCTFFilter(voltage=200, defocus=d, Cs=2.26, alpha=0.07, B=0) for d in defoci
]
ctf_filters = RadialCTFFilter(voltage=200, defocus=defoci, Cs=2.26, alpha=0.07, B=0)


# %%
# Generate CTF corrupted Images
Expand Down Expand Up @@ -334,7 +333,7 @@ def generate_example_image(L, noise_variance=0.1):
from aspire.source import Simulation

# Create the Source. ``ctf_filters`` are re-used from earlier section.
src = Simulation(L=64, n=4, unique_filters=ctf_filters, pixel_size=1)
src = Simulation(L=64, n=4, filter_stack=ctf_filters, pixel_size=1)
src.images[:4].show()

# %%
Expand Down
4 changes: 1 addition & 3 deletions gallery/tutorials/tutorials/micrograph_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,7 @@

# Create our CTF Filter and add it to a list.
# This configuration will apply the same CTF to all particles.
ctfs = [
RadialCTFFilter(voltage=200, defocus=15000, Cs=2.26, alpha=0.07, B=0),
]
ctfs = RadialCTFFilter(voltage=200, defocus=15000, Cs=2.26, alpha=0.07, B=0)

src = MicrographSimulation(
vol,
Expand Down
13 changes: 8 additions & 5 deletions gallery/tutorials/tutorials/orient3d_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,13 @@

print("Initialize simulation object and CTF filters.")
# Create CTF filters
filters = [
RadialCTFFilter(voltage, defocus=d, Cs=2.0, alpha=0.1)
for d in np.linspace(defocus_min, defocus_max, defocus_ct)
]
filters = RadialCTFFilter(
voltage,
defocus=np.linspace(defocus_min, defocus_max, defocus_ct),
Cs=2.0,
alpha=0.1,
)


# %%
# Downsampling
Expand All @@ -74,7 +77,7 @@
# Create a simulation object with specified filters and the downsampled 3D map
print("Use downsampled map to creat simulation object.")
sim = Simulation(
L=img_size, n=num_imgs, vols=vols, unique_filters=filters, pixel_size=5, dtype=dtype
L=img_size, n=num_imgs, vols=vols, filter_stack=filters, pixel_size=5, dtype=dtype
)

print("Get true rotation angles generated randomly by the simulation object.")
Expand Down
13 changes: 8 additions & 5 deletions gallery/tutorials/tutorials/preprocess_imgs_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,13 @@

print("Initialize simulation object and CTF filters.")
# Create CTF filters
ctf_filters = [
RadialCTFFilter(voltage, defocus=d, Cs=2.0, alpha=0.1)
for d in np.linspace(defocus_min, defocus_max, defocus_ct)
]
ctf_filters = RadialCTFFilter(
voltage,
defocus=np.linspace(defocus_min, defocus_max, defocus_ct),
Cs=2.0,
alpha=0.1,
)


# Load the map file of a 70S ribosome and downsample the 3D map to desired image size.
print("Load 3D map from mrc file")
Expand All @@ -73,7 +76,7 @@
L=img_size,
n=num_imgs,
vols=vols,
unique_filters=ctf_filters,
filter_stack=ctf_filters,
noise_adder=noise_adder,
pixel_size=pixel_size,
)
Expand Down
2 changes: 1 addition & 1 deletion src/aspire/basis/fb_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,6 @@ def calculate_bispectrum(

def filter_to_basis_mat(self, *args, **kwargs):
"""
See `SteerableBasis2D.filter_to_basis_mat`.
See `SteerableBasis2D.filter_stack_to_basis_mat`.
"""
return super().filter_to_basis_mat(*args, **kwargs)
118 changes: 101 additions & 17 deletions src/aspire/basis/ffb_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ def _build(self):
self._precomp["gl_nodes"]
)

# Generate radial filter point set for radial optimized eval
# Weights appear a little sensitive to dtype, otherwise could use self._precomp["gl_nodes"]
k_vals, _ = lgwt(self.n_r, 0, 0.5, dtype=np.float64)
self._filter_pts = np.pad(
2 * np.pi * k_vals.reshape(1, -1), ((0, 1), (0, 0))
).astype(self.dtype)

def _precomp(self):
"""
Precomute the basis functions on a polar Fourier grid
Expand Down Expand Up @@ -236,15 +243,17 @@ def _evaluate_t(self, x):

return xp.asnumpy(v)

def filter_to_basis_mat(self, f, **kwargs):
def _filter_stack_to_basis_mats(self, f, **kwargs):
"""
See `SteerableBasis2D.filter_to_basis_mat`.
"""
# Note 'method' and 'truncate' not relevant for this optimized FFB code.
if kwargs.get("method", None) is not None:
# Note 'method' and 'truncate' not relevant for this specific FFB code.
# Method `radial` should have already been diverted.
expand_method = kwargs.get("expand_method", None)
if expand_method is not None:
raise NotImplementedError(
"`FFBBasis2D.filter_to_basis_mat` method {method} not supported."
" Use `method=None`."
f"`FFBBasis2D.filter_to_basis_mat` expand_method '{expand_method}' not supported."
" Use `expand_method=None`."
)

pixel_size = kwargs.get("pixel_size", None)
Expand All @@ -271,28 +280,103 @@ def filter_to_basis_mat(self, f, **kwargs):
omegay = k * np.sin(theta)
omega = 2 * np.pi * np.vstack((omegax.flatten("C"), omegay.flatten("C")))

# This should return either a single 2d array, or stack of 2d arrays
# Reshape singleton to stack of 1.
h_vals2d = (
h_fun(omega, pixel_size=pixel_size).reshape(n_k, n_theta).astype(self.dtype)
h_fun(omega, pixel_size=pixel_size)
.reshape(len(f), n_k, n_theta)
.astype(self.dtype)
)
h_vals = np.sum(h_vals2d, axis=1) / n_theta
h_vals = h_vals2d.sum(axis=-1) / n_theta

# Represent 1D function values in basis
h_basis = BlkDiagMatrix.empty(2 * self.ell_max + 1, dtype=self.dtype)
# Represent each 1D functions values in basis
h_basis = [
BlkDiagMatrix.empty(2 * self.ell_max + 1, dtype=self.dtype) for _ in h_vals
]
ind_ell = 0
# Reshapes for broadcasting
k_vals = k_vals.reshape(n_k, 1)
wts = wts.reshape(n_k, 1)
h_vals = h_vals.reshape(len(f), n_k, 1)
for ell in range(0, self.ell_max + 1):
k_max = self.k_max[ell]
rmat = 2 * k_vals.reshape(n_k, 1) * self.r0[ell][0:k_max].T
basis_vals = np.zeros_like(rmat)
basis_vals = np.zeros((n_k, k_max), dtype=self.dtype)
ind_radial = np.sum(self.k_max[0:ell])
basis_vals[:, 0:k_max] = radial[ind_radial : ind_radial + k_max].T
h_basis_vals = basis_vals * h_vals.reshape(n_k, 1)
h_basis_ell = basis_vals.T @ (
h_basis_vals * k_vals.reshape(n_k, 1) * wts.reshape(n_k, 1)
)
h_basis[ind_ell] = h_basis_ell
h_basis_vals = basis_vals * h_vals
h_basis_ell = basis_vals.T @ (h_basis_vals * k_vals * wts)

# loop over assignment blocks.
for i in range(len(f)):
h_basis[i][ind_ell] = h_basis_ell[i]
ind_ell += 1
if ell > 0:
h_basis[ind_ell] = h_basis[ind_ell - 1]
for i in range(len(f)):
h_basis[i][ind_ell] = h_basis[i][ind_ell - 1]
ind_ell += 1

return h_basis

def filter_to_basis_mat(self, f, **kwargs):
"""
See `SteerableBasis2D.filter_stack_to_basis_mats`.
"""
if len(f) != 1:
raise RuntimeError("Unexpected filter length.")
return self._filter_stack_to_basis_mats(f, **kwargs)[0]

def expand_radial_vec(self, radial_vec, force_diag=False):
"""
Expands radial vector or stack of vetors `radial_vec` to basis matrix.

:param radial_vec: Array holding radial vector,
shaped (n_radial_pts) or (n_vectors, n_radial_pts)
:force_diag: Optionally flush off-diagonal elements to zero and return `DiagMatrix`
:return: List of `BlkDiagMatrix`, or list of `DiagMatrix`
"""
# Convert vector to (1,...)
if radial_vec.ndim == 1:
radial_vec = radial_vec.reshape(1, *radial_vec.shape)
# Optionally transfer to GPU
radial_vec = xp.asarray(radial_vec)

# Set same dimensions as basis object
n_k = self.n_r
radial = self._precomp["radial"]

k_vals = xp.asarray(self._precomp["gl_nodes"])
wts = xp.asarray(self._precomp["gl_weights"])

# Represent 1D function values in basis
h_basis = [
BlkDiagMatrix.empty(2 * self.ell_max + 1, dtype=self.dtype)
for _ in radial_vec
]

ind_ell = 0
# Reshapes for broadcasting
radial_vec = radial_vec.reshape(len(h_basis), n_k, 1)
k_vals = k_vals.reshape(1, n_k, 1)
wts = wts.reshape(1, n_k, 1)
for ell in range(0, self.ell_max + 1):
k_max = self.k_max[ell]
basis_vals = xp.zeros((n_k, k_max), dtype=self.dtype)
ind_radial = np.sum(self.k_max[0:ell])
basis_vals[:, 0:k_max] = xp.asarray(
radial[ind_radial : ind_radial + k_max]
).T
h_basis_vals = basis_vals * radial_vec
h_basis_ell = basis_vals.T @ (h_basis_vals * k_vals * wts)
h_basis_ell = xp.asnumpy(h_basis_ell)
for _filter in range(len(radial_vec)):
_tmp = h_basis[_filter][ind_ell] = h_basis_ell[_filter]
if ell > 0:
h_basis[_filter][ind_ell + 1] = _tmp
if _filter == len(radial_vec) - 1:
ind_ell += 1
if ell > 0:
ind_ell += 1
if force_diag:
h_basis = [h.diag() for h in h_basis]

return h_basis
Loading