Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 0 additions & 24 deletions httomolibgpu/memory_estimator_helpers.py

This file was deleted.

6 changes: 3 additions & 3 deletions httomolibgpu/prep/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import numpy as np
from httomolibgpu import cupywrapper
from httomolibgpu.memory_estimator_helpers import _DeviceMemStack
from tomobar.supp.memory_estimator_helpers import DeviceMemStack

cp = cupywrapper.cp
cupy_run = cupywrapper.cupy_run
Expand Down Expand Up @@ -91,7 +91,7 @@ def paganin_filter(
cp.ndarray
The 3D array of Paganin phase-filtered projection images.
"""
mem_stack = _DeviceMemStack() if calc_peak_gpu_mem else None
mem_stack = DeviceMemStack() if calc_peak_gpu_mem else None
# Check the input data is valid
if not mem_stack and tomo.ndim != 3:
raise ValueError(
Expand Down Expand Up @@ -301,7 +301,7 @@ def _pad_projections(
"next_power_of_2", "next_fast_length", "use_pad_x_y"
],
pad_x_y: Optional[list],
mem_stack: Optional[_DeviceMemStack],
mem_stack: Optional[DeviceMemStack],
) -> Tuple[cp.ndarray, Tuple[int, int]]:
"""
Performs padding of each projection to a size optimal for FFT.
Expand Down
41 changes: 9 additions & 32 deletions httomolibgpu/prep/stripe.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from unittest.mock import Mock

if cupy_run:
from tomobar.supp.memory_estimator_helpers import DeviceMemStack
from cupyx.scipy.ndimage import median_filter, binary_dilation, uniform_filter1d
from cupyx.scipy.fft import fft2, ifft2, fftshift
from cupyx.scipy.fftpack import get_fft_plan
Expand Down Expand Up @@ -204,32 +205,8 @@ def _reflect(x: np.ndarray, minx: float, maxx: float) -> np.ndarray:
return np.array(out, dtype=x.dtype)


class _DeviceMemStack:
def __init__(self) -> None:
self.allocations = []
self.current = 0
self.highwater = 0

def malloc(self, bytes):
self.allocations.append(bytes)
allocated = self._round_up(bytes)
self.current += allocated
self.highwater = max(self.current, self.highwater)

def free(self, bytes):
assert bytes in self.allocations
self.allocations.remove(bytes)
self.current -= self._round_up(bytes)
assert self.current >= 0

def _round_up(self, size):
ALLOCATION_UNIT_SIZE = 512
size = (size + ALLOCATION_UNIT_SIZE - 1) // ALLOCATION_UNIT_SIZE
return size * ALLOCATION_UNIT_SIZE


def _mypad(
x: cp.ndarray, pad: Tuple[int, int, int, int], mem_stack: Optional[_DeviceMemStack]
x: cp.ndarray, pad: Tuple[int, int, int, int], mem_stack: Optional[DeviceMemStack]
) -> cp.ndarray:
"""Function to do numpy like padding on Arrays. Only works for 2-D
padding.
Expand Down Expand Up @@ -272,7 +249,7 @@ def _conv2d(
w: np.ndarray,
stride: Tuple[int, int],
groups: int,
mem_stack: Optional[_DeviceMemStack],
mem_stack: Optional[DeviceMemStack],
) -> cp.ndarray:
"""Convolution (equivalent pytorch.conv2d)"""
b, ci, hi, wi = x.shape if not mem_stack else x
Expand Down Expand Up @@ -355,7 +332,7 @@ def _conv_transpose2d(
stride: Tuple[int, int],
pad: Tuple[int, int],
groups: int,
mem_stack: Optional[_DeviceMemStack],
mem_stack: Optional[DeviceMemStack],
) -> cp.ndarray:
"""Transposed convolution (equivalent pytorch.conv_transpose2d)"""
b, co, ho, wo = x.shape if not mem_stack else x
Expand Down Expand Up @@ -428,7 +405,7 @@ def _afb1d(
h0: np.ndarray,
h1: np.ndarray,
dim: int,
mem_stack: Optional[_DeviceMemStack],
mem_stack: Optional[DeviceMemStack],
) -> cp.ndarray:
"""1D analysis filter bank (along one dimension only) of an image

Expand Down Expand Up @@ -476,7 +453,7 @@ def _sfb1d(
g0: np.ndarray,
g1: np.ndarray,
dim: int,
mem_stack: Optional[_DeviceMemStack],
mem_stack: Optional[DeviceMemStack],
) -> cp.ndarray:
"""1D synthesis filter bank of an image Array"""

Expand Down Expand Up @@ -520,7 +497,7 @@ def __init__(self, wave: str):
self.h1_row = np.array(h1_row).astype("float32")[::-1].reshape((1, 1, 1, -1))

def apply(
self, x: cp.ndarray, mem_stack: Optional[_DeviceMemStack] = None
self, x: cp.ndarray, mem_stack: Optional[DeviceMemStack] = None
) -> Tuple[cp.ndarray, cp.ndarray]:
"""Forward pass of the DWT.

Expand Down Expand Up @@ -582,7 +559,7 @@ def __init__(self, wave: str):
def apply(
self,
coeffs: Tuple[cp.ndarray, cp.ndarray],
mem_stack: Optional[_DeviceMemStack] = None,
mem_stack: Optional[DeviceMemStack] = None,
) -> cp.ndarray:
"""
Args:
Expand Down Expand Up @@ -672,7 +649,7 @@ def remove_stripe_fw(
sli_shape = [nz, 1, nproj_pad, ni]

if calc_peak_gpu_mem:
mem_stack = _DeviceMemStack()
mem_stack = DeviceMemStack()
# A data copy is assumed when invoking the function
mem_stack.malloc(np.prod(data) * np.float32().itemsize)
mem_stack.malloc(np.prod(sli_shape) * np.float32().itemsize)
Expand Down
Loading
Loading