diff --git a/xrspatial/focal.py b/xrspatial/focal.py index 5cadede3..53c12ec4 100644 --- a/xrspatial/focal.py +++ b/xrspatial/focal.py @@ -27,7 +27,7 @@ class cupy(object): ndarray = False -from xrspatial.convolution import convolve_2d, custom_kernel +from xrspatial.convolution import convolve_2d, custom_kernel, _convolve_2d_numpy from xrspatial.utils import ArrayTypeFunctionMapping, cuda_args, ngjit, not_implemented_func from xrspatial.dataset_support import supports_dataset @@ -938,32 +938,50 @@ def _hotspots_numpy(raster, kernel): def _hotspots_dask_numpy(raster, kernel): - data = raster.data.astype(np.float32) + data = raster.data + if not np.issubdtype(data.dtype, np.floating): + data = data.astype(np.float32) + + # Pass 1: eagerly compute global statistics (two scalars). + # This reads all chunks once, produces 16 bytes, then frees all + # intermediate state -- no barrier that would force materialization + # of the full convolution output. + global_mean, global_std = da.compute(da.nanmean(data), da.nanstd(data)) + global_mean = np.float32(global_mean) + global_std = np.float32(global_std) - # apply kernel to raster values - mean_array = convolve_2d(data, kernel / kernel.sum()) - - # calculate z-scores - global_mean = da.nanmean(data) - global_std = da.nanstd(data) - - # commented out to avoid early compute to check if global_std is zero - # if global_std == 0: - # raise ZeroDivisionError( - # "Standard deviation of the input raster values is 0." - # ) + if global_std == 0: + raise ZeroDivisionError( + "Standard deviation of the input raster values is 0." + ) - z_array = (mean_array - global_mean) / global_std + norm_kernel = (kernel / kernel.sum()).astype(np.float32) + pad_h = norm_kernel.shape[0] // 2 + pad_w = norm_kernel.shape[1] // 2 + + # Pass 2: fuse convolution + z-score + classification into one + # map_overlap call. Each chunk reads source + halo, produces int8 + # output, and frees all intermediates immediately. No spill needed. + _func = partial( + _hotspots_chunk, + kernel=norm_kernel, + global_mean=global_mean, + global_std=global_std, + ) + out = data.map_overlap( + _func, + depth=(pad_h, pad_w), + boundary=np.nan, + meta=np.array((), dtype=np.int8), + ) + return out - _func = partial(_calc_hotspots_numpy) - pad_h = kernel.shape[0] // 2 - pad_w = kernel.shape[1] // 2 - out = z_array.map_overlap(_func, - depth=(pad_h, pad_w), - boundary=np.nan, - meta=np.array(())) - return out +def _hotspots_chunk(chunk, kernel, global_mean, global_std): + """Fused per-chunk: convolve -> z-score -> classify.""" + convolved = _convolve_2d_numpy(chunk, kernel) + z = (convolved - global_mean) / global_std + return _calc_hotspots_numpy(z) @nb.cuda.jit(device=True)