diff --git a/CHANGELOG.md b/CHANGELOG.md index 7c67db4d..cde1f742 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,17 @@ ----------- +### Version 0.5.3 (unreleased) +- Fixes #134: Add `xr.Dataset` as input type for appropriate modules. + Most public functions now transparently accept `xr.Dataset` in addition + to `xr.DataArray`. Single-input functions (slope, aspect, curvature, + hillshade, focal.mean, all classification functions, proximity/allocation/ + direction) iterate over data variables and return a Dataset. Multi-input + functions (all multispectral indices) accept a Dataset with band-name + keyword arguments. `zonal.stats` computes per-variable statistics and + returns a merged DataFrame with prefixed columns. + + ### Version 0.5.2 - 2025-12-18 - Make dask optional (#835) - Fixes 832 update citation info in readme (#834) diff --git a/README.md b/README.md index 5170e392..881566f4 100644 --- a/README.md +++ b/README.md @@ -258,6 +258,21 @@ my_dataarray = xr.DataArray(...) hillshaded_dataarray = hillshade(my_dataarray) ``` +##### Dataset Support + +Most functions also accept an `xr.Dataset`. Single-input functions (surface, classification, focal, proximity) apply the operation to each data variable and return a Dataset. Multi-input functions (multispectral indices) accept a Dataset with band-name keyword arguments. + +```python +# Single-input: returns a Dataset with slope computed for each variable +slope_ds = slope(my_dataset) + +# Multi-input: map Dataset variables to band parameters +ndvi_result = ndvi(my_dataset, nir='band_5', red='band_4') + +# Zonal stats: columns prefixed by variable name +stats_df = zonal.stats(zones, my_dataset) # → elevation_mean, temperature_mean, ... +``` + Check out the user guide [here](/examples/user_guide/). ------ diff --git a/docs/source/getting_started/usage.rst b/docs/source/getting_started/usage.rst index 2e86ea33..008002fe 100644 --- a/docs/source/getting_started/usage.rst +++ b/docs/source/getting_started/usage.rst @@ -15,6 +15,27 @@ Basic Pattern my_dataarray = xr.DataArray(...) hillshaded_dataarray = hillshade(my_dataarray) + +Dataset Support +================ + +Most functions also accept an ``xr.Dataset``. Single-input functions apply +the operation to each data variable and return a Dataset. Multi-input +functions (multispectral indices) accept a Dataset with band-name keyword +arguments. + +.. code-block:: python + + from xrspatial import slope + from xrspatial.multispectral import ndvi + + # Single-input: returns a Dataset with slope for each variable + slope_ds = slope(my_dataset) + + # Multi-input: map Dataset variables to band parameters + ndvi_result = ndvi(my_dataset, nir='band_5', red='band_4') + + Check out the user guide `here `_. diff --git a/docs/source/user_guide/data_types.rst b/docs/source/user_guide/data_types.rst index 0d9da66c..e6775f6c 100644 --- a/docs/source/user_guide/data_types.rst +++ b/docs/source/user_guide/data_types.rst @@ -227,11 +227,49 @@ Best Practices combined = (ndvi_result + savi_result) / 2 +Dataset Input Support +===================== + +Most functions accept an ``xr.Dataset`` in addition to ``xr.DataArray``. +When a Dataset is passed, the operation is applied to each data variable +independently and the result is returned as a new Dataset. + +Single-input functions (surface, classification, focal, proximity): + +.. code-block:: python + + from xrspatial import slope + + # Apply slope to every variable in the Dataset + slope_ds = slope(my_dataset) + # Returns an xr.Dataset with the same variable names + +Multi-input functions (multispectral indices) accept a Dataset with keyword +arguments that map band aliases to variable names: + +.. code-block:: python + + from xrspatial.multispectral import ndvi + + # Map Dataset variables to band parameters + ndvi_result = ndvi(my_dataset, nir='band_5', red='band_4') + +``zonal.stats`` also accepts a Dataset for the ``values`` parameter, returning +a merged DataFrame with columns prefixed by variable name: + +.. code-block:: python + + from xrspatial.zonal import stats + + df = stats(zones, my_dataset) + # Columns: zone, elevation_mean, elevation_max, ..., temperature_mean, ... + + Summary ======= -- **Input**: xarray-spatial accepts any numeric data type (int or float) +- **Input**: xarray-spatial accepts any numeric data type (int or float), as either ``xr.DataArray`` or ``xr.Dataset`` - **Processing**: All calculations are performed in float32 precision -- **Output**: Results are returned as float32 DataArrays +- **Output**: Results are returned as float32 DataArrays (or a Dataset of float32 DataArrays when a Dataset is passed) - **Consistency**: This behavior is consistent across NumPy, Dask, and CuPy backends - **Rationale**: Float32 provides adequate precision for geospatial analysis while using half the memory of float64 diff --git a/docs/source/user_guide/multispectral.ipynb b/docs/source/user_guide/multispectral.ipynb index 86aa2446..f736de73 100644 --- a/docs/source/user_guide/multispectral.ipynb +++ b/docs/source/user_guide/multispectral.ipynb @@ -10,13 +10,7 @@ { "cell_type": "markdown", "metadata": {}, - "source": [ - "Xarray-spatial's Multispectral tools provide a range of functions pertaining to remote sensing data such as satellite imagery. A range of functions are available to calculate various vegetation and environmental parameters from the range of band data available for an area. These functions accept and output data in the form of xarray.DataArray rasters.\n", - "\n", - "- [Generate terrain](#Generate-Terrain-Data) \n", - "- [Bump](#Bump) \n", - "- [NDVI](#NDVI) " - ] + "source": "Xarray-spatial's Multispectral tools provide a range of functions pertaining to remote sensing data such as satellite imagery. A range of functions are available to calculate various vegetation and environmental parameters from the range of band data available for an area. These functions accept and output data in the form of xarray.DataArray rasters. They also accept an xr.Dataset as the first argument with band-name keyword arguments to map variables to bands (e.g. `ndvi(ds, nir='B5', red='B4')`).\n\n- [Generate terrain](#Generate-Terrain-Data) \n- [Bump](#Bump) \n- [NDVI](#NDVI) " }, { "cell_type": "markdown", @@ -64,11 +58,7 @@ { "cell_type": "markdown", "metadata": {}, - "source": [ - "The following functions apply to image data with bands in different parts of the UV/Visible/IR spectrum (multispectral), so we'll bring in some multispectral satellite image data to work with.\n", - "\n", - "Below, we loaded all of the images and transformed them into the form of an xarray DataArray to use in the Xarray-spatial functions." - ] + "source": "The following functions apply to image data with bands in different parts of the UV/Visible/IR spectrum (multispectral), so we'll bring in some multispectral satellite image data to work with.\n\nBelow, we loaded all of the images and transformed them into the form of an xarray DataArray to use in the Xarray-spatial functions. Note: you can also load bands into an `xr.Dataset` and pass it directly to multispectral functions with band-name keyword arguments (e.g. `ndvi(ds, nir='nir', red='red')`)." }, { "cell_type": "code", @@ -1408,4 +1398,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/docs/source/user_guide/surface.ipynb b/docs/source/user_guide/surface.ipynb index 99dece31..eb77f66c 100644 --- a/docs/source/user_guide/surface.ipynb +++ b/docs/source/user_guide/surface.ipynb @@ -10,17 +10,7 @@ { "cell_type": "markdown", "metadata": {}, - "source": [ - "With the Surface tools, you can quantify and visualize a terrain landform represented by a digital elevation model.\n", - "\n", - "Starting with a raster elevation surface that represented as an Xarray DataArray, these tools help you in identifying some specific patterns that were not readily apparent in the original surface. Return of each function is also an Xarray DataArray.\n", - "\n", - "- [Hillshade](#Hillshade): Creates a shaded relief from a surface raster by considering the illumination source angle and shadows.\n", - "- [Slope](#Slope): Identifies the slope from each cell of a raster.\n", - "- [Curvature](#Curvature): Calculates the curvature of a raster surface.\n", - "- [Aspect](#Aspect): Derives the aspect from each cell of a raster surface.\n", - "- [Viewshed](#Viewshed): Determines visible locations in the input raster surface from a viewpoint with some optional observer features." - ] + "source": "With the Surface tools, you can quantify and visualize a terrain landform represented by a digital elevation model.\n\nStarting with a raster elevation surface that represented as an Xarray DataArray (or an Xarray Dataset containing multiple elevation variables), these tools help you in identifying some specific patterns that were not readily apparent in the original surface. When a DataArray is passed, the return is a DataArray. When a Dataset is passed, the function is applied to each variable independently and the return is a Dataset.\n\n- [Hillshade](#Hillshade): Creates a shaded relief from a surface raster by considering the illumination source angle and shadows.\n- [Slope](#Slope): Identifies the slope from each cell of a raster.\n- [Curvature](#Curvature): Calculates the curvature of a raster surface.\n- [Aspect](#Aspect): Derives the aspect from each cell of a raster surface.\n- [Viewshed](#Viewshed): Determines visible locations in the input raster surface from a viewpoint with some optional observer features." }, { "cell_type": "markdown", @@ -932,4 +922,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/examples/user_guide/1_Surface.ipynb b/examples/user_guide/1_Surface.ipynb index d31ef942..2990e1e1 100644 --- a/examples/user_guide/1_Surface.ipynb +++ b/examples/user_guide/1_Surface.ipynb @@ -3,26 +3,7 @@ { "cell_type": "markdown", "metadata": {}, - "source": [ - "# Xarray-spatial\n", - "### User Guide: Surface tools\n", - "-----\n", - "With the Surface tools, you can quantify and visualize a terrain landform represented by a digital elevation model.\n", - "\n", - "Starting with a raster elevation surface, represented as an Xarray DataArray, these tools can help you identify some specific patterns that may not be readily apparent in the original surface. The return of each function is also an Xarray DataArray.\n", - "\n", - "[Hillshade](#Hillshade): Creates a shaded relief from a surface raster by considering the illumination source angle and shadows.\n", - "\n", - "[Slope](#Slope): Identifies the slope for each cell of a raster.\n", - "\n", - "[Curvature](#Curvature): Calculates the curvature of a raster surface.\n", - "\n", - "[Aspect](#Aspect): Derives the aspect for each cell of a raster surface.\n", - "\n", - "[Viewshed](#Viewshed): Determines visible locations in the input raster surface from a viewpoint with an optional observer height.\n", - "\n", - "-----------\n" - ] + "source": "# Xarray-spatial\n### User Guide: Surface tools\n-----\nWith the Surface tools, you can quantify and visualize a terrain landform represented by a digital elevation model.\n\nStarting with a raster elevation surface, represented as an Xarray DataArray (or an Xarray Dataset containing multiple elevation variables), these tools can help you identify some specific patterns that may not be readily apparent in the original surface. When a DataArray is passed, the return is a DataArray. When a Dataset is passed, the function is applied to each variable independently and the return is a Dataset.\n\n[Hillshade](#Hillshade): Creates a shaded relief from a surface raster by considering the illumination source angle and shadows.\n\n[Slope](#Slope): Identifies the slope for each cell of a raster.\n\n[Curvature](#Curvature): Calculates the curvature of a raster surface.\n\n[Aspect](#Aspect): Derives the aspect for each cell of a raster surface.\n\n[Viewshed](#Viewshed): Determines visible locations in the input raster surface from a viewpoint with an optional observer height.\n\n-----------\n" }, { "cell_type": "markdown", @@ -524,4 +505,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/examples/user_guide/6_Remote_Sensing.ipynb b/examples/user_guide/6_Remote_Sensing.ipynb index 95ff0a8c..bb0be943 100644 --- a/examples/user_guide/6_Remote_Sensing.ipynb +++ b/examples/user_guide/6_Remote_Sensing.ipynb @@ -3,24 +3,7 @@ { "cell_type": "markdown", "metadata": {}, - "source": [ - "# Xarray-spatial\n", - "### User Guide: Remote Sensing tools\n", - "-----\n", - "\n", - "Xarray-spatial's Remote Sensing tools provide a range of functions pertaining to remote sensing data such as satellite imagery. A range of functions are available to calculate various vegetation and environmental parameters from the range of band data available for an area. These functions accept and output data in the form of xarray.DataArray rasters.\n", - "\n", - "[True Color](#True-Color) \n", - "[Vegetation Index](#Vegetation-Index): [NDVI](#NDVI), [SAVI](#SAVI), [ARVI](#ARVI), [EVI](#EVI) \n", - "[Green Chlorophyll Index - GCI](#Green-Chlorophyll-Index-(GCI)) \n", - "[Normalized Burn Ratio](#Normalized-Burn-Ratio): [NBR](#NBR), [NBR2](#NBR2) \n", - "[Normalized Difference Moisture Index - NDMI](#Normalized-Difference-Moisture-Index-(NDMI)) \n", - "[Structure Insensitive Pigment Index - SIPI](#Structure-Insensitive-Pigment-Index-(SIPI)) \n", - "[Enhanced Built-Up and Bareness Index - EBBI](#Enhanced-Built-Up-and-Bareness-Index-(EBBI)) \n", - "[Bump Mapping](#Bump-Mapping) \n", - "\n", - "-----------\n" - ] + "source": "# Xarray-spatial\n### User Guide: Remote Sensing tools\n-----\n\nXarray-spatial's Remote Sensing tools provide a range of functions pertaining to remote sensing data such as satellite imagery. A range of functions are available to calculate various vegetation and environmental parameters from the range of band data available for an area. These functions accept and output data in the form of xarray.DataArray rasters. They also accept an xr.Dataset as the first argument with band-name keyword arguments to map variables to bands (e.g. `ndvi(ds, nir='B5', red='B4')`).\n\n[True Color](#True-Color) \n[Vegetation Index](#Vegetation-Index): [NDVI](#NDVI), [SAVI](#SAVI), [ARVI](#ARVI), [EVI](#EVI) \n[Green Chlorophyll Index - GCI](#Green-Chlorophyll-Index-(GCI)) \n[Normalized Burn Ratio](#Normalized-Burn-Ratio): [NBR](#NBR), [NBR2](#NBR2) \n[Normalized Difference Moisture Index - NDMI](#Normalized-Difference-Moisture-Index-(NDMI)) \n[Structure Insensitive Pigment Index - SIPI](#Structure-Insensitive-Pigment-Index-(SIPI)) \n[Enhanced Built-Up and Bareness Index - EBBI](#Enhanced-Built-Up-and-Bareness-Index-(EBBI)) \n[Bump Mapping](#Bump-Mapping) \n\n-----------\n" }, { "cell_type": "markdown", @@ -68,11 +51,7 @@ { "cell_type": "markdown", "metadata": {}, - "source": [ - "The following functions apply to image data with bands in different parts of the UV/Visible/IR spectrum (multispectral), so we'll bring in some multispectral satellite image data to work with.\n", - "\n", - "Below, we loaded all of the images and transformed them into the form of an xarray DataArray to use in the Xarray-spatial functions." - ] + "source": "The following functions apply to image data with bands in different parts of the UV/Visible/IR spectrum (multispectral), so we'll bring in some multispectral satellite image data to work with.\n\nBelow, we loaded all of the images and transformed them into the form of an xarray DataArray to use in the Xarray-spatial functions. Note: you can also load bands into an `xr.Dataset` and pass it directly to multispectral functions with band-name keyword arguments (e.g. `ndvi(ds, nir='nir', red='red')`)." }, { "cell_type": "code", @@ -695,4 +674,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/xrspatial/aspect.py b/xrspatial/aspect.py index bc9bf080..28832b1c 100644 --- a/xrspatial/aspect.py +++ b/xrspatial/aspect.py @@ -19,6 +19,7 @@ from xrspatial.utils import _extract_latlon_coords from xrspatial.utils import cuda_args from xrspatial.utils import ngjit +from xrspatial.dataset_support import supports_dataset def _geodesic_cuda_dims(shape): @@ -270,6 +271,7 @@ def _run_dask_cupy_geodesic(data, lat_2d, lon_2d, a2, b2, z_factor): # Public API # ===================================================================== +@supports_dataset def aspect(agg: xr.DataArray, name: Optional[str] = 'aspect', method: str = 'planar', @@ -296,9 +298,11 @@ def aspect(agg: xr.DataArray, Parameters ---------- - agg : xarray.DataArray + agg : xarray.DataArray or xr.Dataset 2D NumPy, CuPy, or Dask with NumPy-backed xarray DataArray of elevation values. + If a Dataset is passed, the operation is applied to each + data variable independently. name : str, default='aspect' Name of ouput DataArray. method : str, default='planar' @@ -313,7 +317,10 @@ def aspect(agg: xr.DataArray, Returns ------- - aspect_agg : xarray.DataArray of the same type as `agg` + aspect_agg : xarray.DataArray or xr.Dataset + If `agg` is a DataArray, returns a DataArray of the same type. + If `agg` is a Dataset, returns a Dataset with aspect computed + for each data variable. 2D aggregate array of calculated aspect values. All other input attributes are preserved. diff --git a/xrspatial/classify.py b/xrspatial/classify.py index 5551beec..5867e040 100644 --- a/xrspatial/classify.py +++ b/xrspatial/classify.py @@ -25,6 +25,7 @@ class cupy(object): import numpy as np from xrspatial.utils import ArrayTypeFunctionMapping, cuda_args, ngjit, not_implemented_func +from xrspatial.dataset_support import supports_dataset @ngjit @@ -83,6 +84,7 @@ def _run_dask_cupy_binary(data, values_cupy): return out +@supports_dataset def binary(agg, values, name='binary'): """ Binarize a data array based on a set of values. Data that equals to a value in the set will be @@ -91,7 +93,7 @@ def binary(agg, values, name='binary'): Parameters ---------- - agg : xarray.DataArray + agg : xr.DataArray or xr.Dataset 2D NumPy, CuPy, NumPy-backed Dask, or Cupy-backed Dask array of values to be reclassified. values : array-like object @@ -101,9 +103,11 @@ def binary(agg, values, name='binary'): Returns ------- - binarized_agg : xarray.DataArray, of the same type as `agg` + binarized_agg : xr.DataArray or xr.Dataset 2D aggregate array of binarized data array. All other input attributes are preserved. + If `agg` is a Dataset, returns a Dataset with each variable + classified independently. Examples -------- @@ -266,6 +270,7 @@ def _bin(agg, bins, new_values): return out +@supports_dataset def reclassify(agg: xr.DataArray, bins: List[int], new_values: List[int], @@ -276,7 +281,7 @@ def reclassify(agg: xr.DataArray, Parameters ---------- - agg : xarray.DataArray + agg : xr.DataArray or xr.Dataset 2D NumPy, CuPy, NumPy-backed Dask, or Cupy-backed Dask array of values to be reclassified. bins : array-like object @@ -288,9 +293,11 @@ def reclassify(agg: xr.DataArray, Returns ------- - reclass_agg : xarray.DataArray, of the same type as `agg` + reclass_agg : xr.DataArray or xr.Dataset 2D aggregate array of reclassified allocations. All other input attributes are preserved. + If `agg` is a Dataset, returns a Dataset with each variable + classified independently. References ---------- @@ -416,6 +423,7 @@ def _quantile(agg, k): return out +@supports_dataset def quantile(agg: xr.DataArray, k: int = 4, name: Optional[str] = 'quantile') -> xr.DataArray: @@ -425,7 +433,7 @@ def quantile(agg: xr.DataArray, Parameters ---------- - agg : xarray.DataArray + agg : xr.DataArray or xr.Dataset 2D NumPy, CuPy, NumPy-backed Dask, or Cupy-backed Dask array of values to be reclassified. k : int, default=4 @@ -435,9 +443,11 @@ def quantile(agg: xr.DataArray, Returns ------- - quantile_agg : xarray.DataArray, of the same type as `agg` + quantile_agg : xr.DataArray or xr.Dataset 2D aggregate array, of quantile allocations. All other input attributes are preserved. + If `agg` is a Dataset, returns a Dataset with each variable + classified independently. Notes ----- @@ -723,6 +733,7 @@ def _run_dask_cupy_natural_break(agg, num_sample, k): return out +@supports_dataset def natural_breaks(agg: xr.DataArray, num_sample: Optional[int] = 20000, name: Optional[str] = 'natural_breaks', @@ -735,7 +746,7 @@ def natural_breaks(agg: xr.DataArray, Parameters ---------- - agg : xarray.DataArray + agg : xr.DataArray or xr.Dataset 2D NumPy, CuPy, NumPy-backed Dask, or CuPy-backed Dask array of values to be reclassified. num_sample : int, default=20000 @@ -751,9 +762,11 @@ def natural_breaks(agg: xr.DataArray, Returns ------- - natural_breaks_agg : xarray.DataArray of the same type as `agg` + natural_breaks_agg : xr.DataArray or xr.Dataset 2D aggregate array of natural break allocations. All other input attributes are preserved. + If `agg` is a Dataset, returns a Dataset with each variable + classified independently. References ---------- @@ -854,6 +867,7 @@ def _run_equal_interval(agg, k, module): return out +@supports_dataset def equal_interval(agg: xr.DataArray, k: int = 5, name: Optional[str] = 'equal_interval') -> xr.DataArray: @@ -863,7 +877,7 @@ def equal_interval(agg: xr.DataArray, Parameters ---------- - agg : xarray.DataArray + agg : xr.DataArray or xr.Dataset 2D NumPy, CuPy, NumPy-backed Dask, or Cupy-backed Dask array of values to be reclassified. k : int, default=5 @@ -873,9 +887,11 @@ def equal_interval(agg: xr.DataArray, Returns ------- - equal_interval_agg : xarray.DataArray of the same type as `agg` + equal_interval_agg : xr.DataArray or xr.Dataset 2D aggregate array of equal interval allocations. All other input attributes are preserved. + If `agg` is a Dataset, returns a Dataset with each variable + classified independently. References ---------- @@ -952,6 +968,7 @@ def _run_std_mean(agg, module): return out +@supports_dataset def std_mean(agg: xr.DataArray, name: Optional[str] = 'std_mean') -> xr.DataArray: """ @@ -961,7 +978,7 @@ def std_mean(agg: xr.DataArray, Parameters ---------- - agg : xarray.DataArray + agg : xr.DataArray or xr.Dataset 2D NumPy, CuPy, NumPy-backed Dask, or CuPy-backed Dask array of values to be classified. name : str, default='std_mean' @@ -969,9 +986,11 @@ def std_mean(agg: xr.DataArray, Returns ------- - std_mean_agg : xarray.DataArray, of the same type as `agg` + std_mean_agg : xr.DataArray or xr.Dataset 2D aggregate array of standard deviation classifications. All other input attributes are preserved. + If `agg` is a Dataset, returns a Dataset with each variable + classified independently. References ---------- @@ -1044,6 +1063,7 @@ def _run_dask_head_tail_breaks(agg): return out +@supports_dataset def head_tail_breaks(agg: xr.DataArray, name: Optional[str] = 'head_tail_breaks') -> xr.DataArray: """ @@ -1055,7 +1075,7 @@ def head_tail_breaks(agg: xr.DataArray, Parameters ---------- - agg : xarray.DataArray + agg : xr.DataArray or xr.Dataset 2D NumPy, CuPy, NumPy-backed Dask, or CuPy-backed Dask array of values to be classified. name : str, default='head_tail_breaks' @@ -1063,9 +1083,11 @@ def head_tail_breaks(agg: xr.DataArray, Returns ------- - head_tail_agg : xarray.DataArray, of the same type as `agg` + head_tail_agg : xr.DataArray or xr.Dataset 2D aggregate array of head/tail break classifications. All other input attributes are preserved. + If `agg` is a Dataset, returns a Dataset with each variable + classified independently. References ---------- @@ -1096,6 +1118,7 @@ def _run_dask_cupy_percentiles(data, pct): return _run_percentiles(data_cpu, pct, da) +@supports_dataset def percentiles(agg: xr.DataArray, pct: Optional[List] = None, name: Optional[str] = 'percentiles') -> xr.DataArray: @@ -1104,7 +1127,7 @@ def percentiles(agg: xr.DataArray, Parameters ---------- - agg : xarray.DataArray + agg : xr.DataArray or xr.Dataset 2D NumPy, CuPy, NumPy-backed Dask, or CuPy-backed Dask array of values to be classified. pct : list of float, default=[1, 10, 50, 90, 99] @@ -1114,9 +1137,11 @@ def percentiles(agg: xr.DataArray, Returns ------- - percentiles_agg : xarray.DataArray, of the same type as `agg` + percentiles_agg : xr.DataArray or xr.Dataset 2D aggregate array of percentile classifications. All other input attributes are preserved. + If `agg` is a Dataset, returns a Dataset with each variable + classified independently. References ---------- @@ -1212,6 +1237,7 @@ def _run_dask_cupy_maximum_breaks(agg, k): return out +@supports_dataset def maximum_breaks(agg: xr.DataArray, k: int = 5, name: Optional[str] = 'maximum_breaks') -> xr.DataArray: @@ -1223,7 +1249,7 @@ def maximum_breaks(agg: xr.DataArray, Parameters ---------- - agg : xarray.DataArray + agg : xr.DataArray or xr.Dataset 2D NumPy, CuPy, NumPy-backed Dask, or CuPy-backed Dask array of values to be classified. k : int, default=5 @@ -1233,9 +1259,11 @@ def maximum_breaks(agg: xr.DataArray, Returns ------- - max_breaks_agg : xarray.DataArray, of the same type as `agg` + max_breaks_agg : xr.DataArray or xr.Dataset 2D aggregate array of maximum break classifications. All other input attributes are preserved. + If `agg` is a Dataset, returns a Dataset with each variable + classified independently. References ---------- @@ -1312,6 +1340,7 @@ def _run_dask_cupy_box_plot(agg, hinge): return out +@supports_dataset def box_plot(agg: xr.DataArray, hinge: float = 1.5, name: Optional[str] = 'box_plot') -> xr.DataArray: @@ -1323,7 +1352,7 @@ def box_plot(agg: xr.DataArray, Parameters ---------- - agg : xarray.DataArray + agg : xr.DataArray or xr.Dataset 2D NumPy, CuPy, NumPy-backed Dask, or CuPy-backed Dask array of values to be classified. hinge : float, default=1.5 @@ -1333,9 +1362,11 @@ def box_plot(agg: xr.DataArray, Returns ------- - box_plot_agg : xarray.DataArray, of the same type as `agg` + box_plot_agg : xr.DataArray or xr.Dataset 2D aggregate array of box plot classifications. All other input attributes are preserved. + If `agg` is a Dataset, returns a Dataset with each variable + classified independently. References ---------- diff --git a/xrspatial/curvature.py b/xrspatial/curvature.py index 97eff8d0..f460b925 100644 --- a/xrspatial/curvature.py +++ b/xrspatial/curvature.py @@ -25,6 +25,7 @@ class cupy(object): from xrspatial.utils import cuda_args from xrspatial.utils import get_dataarray_resolution from xrspatial.utils import ngjit +from xrspatial.dataset_support import supports_dataset @ngjit @@ -107,6 +108,7 @@ def _run_dask_cupy(data: da.Array, return out +@supports_dataset def curvature(agg: xr.DataArray, name: Optional[str] = 'curvature') -> xr.DataArray: """ @@ -121,15 +123,20 @@ def curvature(agg: xr.DataArray, Parameters ---------- - agg : xarray.DataArray + agg : xarray.DataArray or xr.Dataset 2D NumPy, CuPy, NumPy-backed Dask xarray DataArray of elevation values. Must contain `res` attribute. + If a Dataset is passed, the operation is applied to each + data variable independently. name : str, default='curvature' Name of output DataArray. Returns ------- - curvature_agg : xarray.DataArray, of the same type as `agg` + curvature_agg : xarray.DataArray or xr.Dataset + If `agg` is a DataArray, returns a DataArray of the same type. + If `agg` is a Dataset, returns a Dataset with curvature computed + for each data variable. 2D aggregate array of curvature values. All other input attributes are preserved. diff --git a/xrspatial/dataset_support.py b/xrspatial/dataset_support.py new file mode 100644 index 00000000..fc758d80 --- /dev/null +++ b/xrspatial/dataset_support.py @@ -0,0 +1,80 @@ +"""Decorators for transparent xr.Dataset support on xr.DataArray functions.""" + +from __future__ import annotations + +import functools +import inspect + +import xarray as xr + + +def supports_dataset(func): + """Decorator that lets single-input DataArray functions accept a Dataset. + + When a Dataset is passed as the first argument, the wrapped function + is called on each data variable and the results are collected into + a new Dataset. + """ + sig = inspect.signature(func) + has_name_param = 'name' in sig.parameters + + @functools.wraps(func) + def wrapper(agg, *args, **kwargs): + if isinstance(agg, xr.Dataset): + results = {} + for var_name in agg.data_vars: + kw = dict(kwargs) + if has_name_param: + kw['name'] = var_name + results[var_name] = func(agg[var_name], *args, **kw) + return xr.Dataset(results, attrs=agg.attrs) + return func(agg, *args, **kwargs) + + return wrapper + + +def supports_dataset_bands(**band_param_map): + """Decorator for multi-input functions that take separate band DataArrays. + + Enables passing a single Dataset with keyword arguments that map + band aliases to Dataset variable names. + + Example:: + + @supports_dataset_bands(nir='nir_agg', red='red_agg') + def ndvi(nir_agg, red_agg, name='ndvi'): ... + + # Enables: + ndvi(ds, nir='band_8', red='band_4') + """ + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if args and isinstance(args[0], xr.Dataset): + ds = args[0] + func_kwargs = {} + used = set() + for alias, param in band_param_map.items(): + if alias not in kwargs: + raise TypeError( + f"'{alias}' keyword required when passing a Dataset" + ) + var_name = kwargs[alias] + if var_name not in ds.data_vars: + raise ValueError( + f"'{var_name}' not in Dataset. " + f"Available: {list(ds.data_vars)}" + ) + func_kwargs[param] = ds[var_name] + used.add(alias) + # Pass through remaining kwargs (name, soil_factor, etc.) + for k, v in kwargs.items(): + if k not in used: + func_kwargs[k] = v + return func(**func_kwargs) + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/xrspatial/focal.py b/xrspatial/focal.py index a198c9ac..5cadede3 100644 --- a/xrspatial/focal.py +++ b/xrspatial/focal.py @@ -29,6 +29,7 @@ class cupy(object): from xrspatial.convolution import convolve_2d, custom_kernel from xrspatial.utils import ArrayTypeFunctionMapping, cuda_args, ngjit, not_implemented_func +from xrspatial.dataset_support import supports_dataset # TODO: Make convolution more generic with numba first-class functions. @@ -158,6 +159,7 @@ def _mean(data, excludes): return out +@supports_dataset def mean(agg, passes=1, excludes=[np.nan], name='mean'): """ Returns Mean filtered array using a 3x3 window. @@ -165,8 +167,10 @@ def mean(agg, passes=1, excludes=[np.nan], name='mean'): Parameters ---------- - agg : xarray.DataArray + agg : xarray.DataArray or xr.Dataset 2D array of input values to be filtered. + If a Dataset is passed, the operation is applied to each + data variable independently. passes : int, default=1 Number of times to run mean. name : str, default='mean' @@ -174,7 +178,10 @@ def mean(agg, passes=1, excludes=[np.nan], name='mean'): Returns ------- - mean_agg : xarray.DataArray of same type as `agg` + mean_agg : xarray.DataArray or xr.Dataset + If `agg` is a DataArray, returns a DataArray of the same type. + If `agg` is a Dataset, returns a Dataset with mean computed + for each data variable. 2D aggregate array of filtered values. Examples diff --git a/xrspatial/hillshade.py b/xrspatial/hillshade.py index 137be95a..a13cb8c9 100644 --- a/xrspatial/hillshade.py +++ b/xrspatial/hillshade.py @@ -14,6 +14,7 @@ from .gpu_rtx import has_rtx from .utils import calc_cuda_dims, has_cuda_and_cupy, is_cupy_array, is_cupy_backed +from .dataset_support import supports_dataset def _run_numpy(data, azimuth=225, angle_altitude=25): @@ -99,6 +100,7 @@ def _run_cupy(d_data, azimuth, angle_altitude): return output +@supports_dataset def hillshade(agg: xr.DataArray, azimuth: int = 225, angle_altitude: int = 25, @@ -111,9 +113,11 @@ def hillshade(agg: xr.DataArray, Parameters ---------- - agg : xarray.DataArray + agg : xarray.DataArray or xr.Dataset 2D NumPy, CuPy, NumPy-backed Dask, or Cupy-backed Dask array of elevation values. + If a Dataset is passed, the operation is applied to each + data variable independently. angle_altitude : int, default=25 Altitude angle of the sun specified in degrees. azimuth : int, default=225 @@ -129,7 +133,10 @@ def hillshade(agg: xr.DataArray, Returns ------- - hillshade_agg : xarray.DataArray, of same type as `agg` + hillshade_agg : xarray.DataArray or xr.Dataset + If `agg` is a DataArray, returns a DataArray of the same type. + If `agg` is a Dataset, returns a Dataset with hillshade computed + for each data variable. 2D aggregate array of illumination values. References diff --git a/xrspatial/multispectral.py b/xrspatial/multispectral.py index c0852866..a6db0e7c 100644 --- a/xrspatial/multispectral.py +++ b/xrspatial/multispectral.py @@ -11,6 +11,7 @@ from xrspatial.utils import (ArrayTypeFunctionMapping, cuda_args, ngjit, not_implemented_func, validate_arrays) +from xrspatial.dataset_support import supports_dataset_bands # 3rd-party try: @@ -75,6 +76,7 @@ def _arvi_dask_cupy(nir_data, red_data, blue_data): return out +@supports_dataset_bands(nir='nir_agg', red='red_agg', blue='blue_agg') def arvi(nir_agg: xr.DataArray, red_agg: xr.DataArray, blue_agg: xr.DataArray, @@ -95,6 +97,12 @@ def arvi(nir_agg: xr.DataArray, name : str, default='arvi' Name of output DataArray. + Alternatively, a single ``xr.Dataset`` may be passed as the first + argument with keyword arguments mapping band names to Dataset + variables. For example:: + + arvi(ds, nir='B8', red='B4', blue='B2') + Returns ------- arvi_agg : xarray.DataArray of the same type as inputs. @@ -215,6 +223,7 @@ def _evi_dask_cupy(nir_data, red_data, blue_data, c1, c2, soil_factor, gain): return out +@supports_dataset_bands(nir='nir_agg', red='red_agg', blue='blue_agg') def evi(nir_agg: xr.DataArray, red_agg: xr.DataArray, blue_agg: xr.DataArray, @@ -247,6 +256,12 @@ def evi(nir_agg: xr.DataArray, name : str, default='evi' Name of output DataArray. + Alternatively, a single ``xr.Dataset`` may be passed as the first + argument with keyword arguments mapping band names to Dataset + variables. For example:: + + evi(ds, nir='B8', red='B4', blue='B2') + Returns ------- evi_agg : xarray.DataArray of same type as inputs @@ -374,6 +389,7 @@ def _gci_dask_cupy(nir_data, green_data): return out +@supports_dataset_bands(nir='nir_agg', green='green_agg') def gci(nir_agg: xr.DataArray, green_agg: xr.DataArray, name='gci'): @@ -391,6 +407,12 @@ def gci(nir_agg: xr.DataArray, name : str, default='gci' Name of output DataArray. + Alternatively, a single ``xr.Dataset`` may be passed as the first + argument with keyword arguments mapping band names to Dataset + variables. For example:: + + gci(ds, nir='B8', green='B3') + Returns ------- gci_agg : xarray.DataArray of the same type as inputs @@ -451,6 +473,7 @@ def gci(nir_agg: xr.DataArray, # NBR ---------- +@supports_dataset_bands(nir='nir_agg', swir2='swir2_agg') def nbr(nir_agg: xr.DataArray, swir2_agg: xr.DataArray, name='nbr'): @@ -469,6 +492,12 @@ def nbr(nir_agg: xr.DataArray, name : str, default='nbr' Name of output DataArray. + Alternatively, a single ``xr.Dataset`` may be passed as the first + argument with keyword arguments mapping band names to Dataset + variables. For example:: + + nbr(ds, nir='B8', swir2='B12') + Returns ------- nbr_agg : xr.DataArray of the same type as inputs @@ -529,6 +558,7 @@ def nbr(nir_agg: xr.DataArray, attrs=nir_agg.attrs) +@supports_dataset_bands(swir1='swir1_agg', swir2='swir2_agg') def nbr2(swir1_agg: xr.DataArray, swir2_agg: xr.DataArray, name='nbr2'): @@ -552,6 +582,12 @@ def nbr2(swir1_agg: xr.DataArray, name : str default='nbr2' Name of output DataArray. + Alternatively, a single ``xr.Dataset`` may be passed as the first + argument with keyword arguments mapping band names to Dataset + variables. For example:: + + nbr2(ds, swir1='B11', swir2='B12') + Returns ------- nbr2_agg : xr.DataArray of same type as inputs. @@ -614,6 +650,7 @@ def nbr2(swir1_agg: xr.DataArray, # NDVI ---------- +@supports_dataset_bands(nir='nir_agg', red='red_agg') def ndvi(nir_agg: xr.DataArray, red_agg: xr.DataArray, name='ndvi'): @@ -630,6 +667,12 @@ def ndvi(nir_agg: xr.DataArray, name : str default='ndvi' Name of output DataArray. + Alternatively, a single ``xr.Dataset`` may be passed as the first + argument with keyword arguments mapping band names to Dataset + variables. For example:: + + ndvi(ds, nir='B8', red='B4') + Returns ------- ndvi_agg : xarray.DataArray of same type as inputs @@ -691,6 +734,7 @@ def ndvi(nir_agg: xr.DataArray, # NDMI ---------- +@supports_dataset_bands(nir='nir_agg', swir1='swir1_agg') def ndmi(nir_agg: xr.DataArray, swir1_agg: xr.DataArray, name='ndmi'): @@ -711,6 +755,12 @@ def ndmi(nir_agg: xr.DataArray, name: str, default='ndmi' Name of output DataArray. + Alternatively, a single ``xr.Dataset`` may be passed as the first + argument with keyword arguments mapping band names to Dataset + variables. For example:: + + ndmi(ds, nir='B8', swir1='B11') + Returns ------- ndmi_agg : xr.DataArray of same type as inputs @@ -874,6 +924,7 @@ def _savi_dask_cupy(nir_data, red_data, soil_factor): # SAVI ---------- +@supports_dataset_bands(nir='nir_agg', red='red_agg') def savi(nir_agg: xr.DataArray, red_agg: xr.DataArray, soil_factor: float = 1.0, @@ -895,6 +946,12 @@ def savi(nir_agg: xr.DataArray, name : str, default='savi' Name of output DataArray. + Alternatively, a single ``xr.Dataset`` may be passed as the first + argument with keyword arguments mapping band names to Dataset + variables. For example:: + + savi(ds, nir='B8', red='B4') + Returns ------- savi_agg : xr.DataArray of same type as inputs @@ -1006,6 +1063,7 @@ def _sipi_dask_cupy(nir_data, red_data, blue_data): return out +@supports_dataset_bands(nir='nir_agg', red='red_agg', blue='blue_agg') def sipi(nir_agg: xr.DataArray, red_agg: xr.DataArray, blue_agg: xr.DataArray, @@ -1025,6 +1083,12 @@ def sipi(nir_agg: xr.DataArray, name: str, default='sipi' Name of output DataArray. + Alternatively, a single ``xr.Dataset`` may be passed as the first + argument with keyword arguments mapping band names to Dataset + variables. For example:: + + sipi(ds, nir='B8', red='B4', blue='B2') + Returns ------- sipi_agg : xr.DataArray of same type as inputs @@ -1142,6 +1206,7 @@ def _ebbi_dask_cupy(red_data, swir_data, tir_data): return out +@supports_dataset_bands(red='red_agg', swir='swir_agg', tir='tir_agg') def ebbi(red_agg: xr.DataArray, swir_agg: xr.DataArray, tir_agg: xr.DataArray, @@ -1161,6 +1226,12 @@ def ebbi(red_agg: xr.DataArray, name: str, default='ebbi' Name of output DataArray. + Alternatively, a single ``xr.Dataset`` may be passed as the first + argument with keyword arguments mapping band names to Dataset + variables. For example:: + + ebbi(ds, red='B4', swir='B11', tir='B10') + Returns ------- ebbi_agg = xr.DataArray of same type as inputs diff --git a/xrspatial/proximity.py b/xrspatial/proximity.py index 4960df2a..44602308 100644 --- a/xrspatial/proximity.py +++ b/xrspatial/proximity.py @@ -10,6 +10,7 @@ from numba import prange from xrspatial.utils import get_dataarray_resolution, ngjit +from xrspatial.dataset_support import supports_dataset EUCLIDEAN = 0 GREAT_CIRCLE = 1 @@ -648,6 +649,7 @@ def _process_dask(raster, xs, ys): # ported from # https://github.com/OSGeo/gdal/blob/master/gdal/alg/gdalproximity.cpp +@supports_dataset def proximity( raster: xr.DataArray, x: str = "x", @@ -682,8 +684,10 @@ def proximity( Parameters ---------- - raster : xr.DataArray + raster : xr.DataArray or xr.Dataset 2D array image with `raster.shape` = (height, width). + If a Dataset is passed, the function is applied to each + data variable independently, returning a Dataset. x : str, default='x' Name of x-coordinates. @@ -722,7 +726,10 @@ def proximity( Returns ------- - proximity_agg: xr.DataArray of same type as `raster` + xr.DataArray or xr.Dataset + If ``raster`` is a DataArray, returns a DataArray. + If ``raster`` is a Dataset, returns a Dataset with each + variable processed independently. 2D array of proximity values. All other input attributes are preserved. @@ -783,6 +790,7 @@ def proximity( return result +@supports_dataset def allocation( raster: xr.DataArray, x: str = "x", @@ -816,8 +824,10 @@ def allocation( Parameters ---------- - raster : xr.DataArray + raster : xr.DataArray or xr.Dataset 2D array of target data. + If a Dataset is passed, the function is applied to each + data variable independently, returning a Dataset. x : str, default='x' Name of x-coordinates. @@ -855,7 +865,10 @@ def allocation( Returns ------- - allocation_agg: xr.DataArray of same type as `raster` + xr.DataArray or xr.Dataset + If ``raster`` is a DataArray, returns a DataArray. + If ``raster`` is a Dataset, returns a Dataset with each + variable processed independently. 2D array of allocation values. All other input attributes are preserved. @@ -915,6 +928,7 @@ def allocation( return result +@supports_dataset def direction( raster: xr.DataArray, x: str = "x", @@ -952,8 +966,10 @@ def direction( Parameters ---------- - raster : xr.DataArray + raster : xr.DataArray or xr.Dataset 2D array image with `raster.shape` = (height, width). + If a Dataset is passed, the function is applied to each + data variable independently, returning a Dataset. x : str, default='x' Name of x-coordinates. @@ -992,7 +1008,10 @@ def direction( Returns ------- - direction_agg: xr.DataArray of same type as `raster` + xr.DataArray or xr.Dataset + If ``raster`` is a DataArray, returns a DataArray. + If ``raster`` is a Dataset, returns a Dataset with each + variable processed independently. 2D array of direction values. All other input attributes are preserved. diff --git a/xrspatial/slope.py b/xrspatial/slope.py index 753d905c..0f57baca 100644 --- a/xrspatial/slope.py +++ b/xrspatial/slope.py @@ -28,6 +28,7 @@ class cupy(object): from xrspatial.utils import cuda_args from xrspatial.utils import get_dataarray_resolution from xrspatial.utils import ngjit +from xrspatial.dataset_support import supports_dataset def _geodesic_cuda_dims(shape): @@ -267,6 +268,7 @@ def _run_dask_cupy_geodesic(data, lat_2d, lon_2d, a2, b2, z_factor): # Public API # ===================================================================== +@supports_dataset def slope(agg: xr.DataArray, name: str = 'slope', method: str = 'planar', @@ -276,8 +278,10 @@ def slope(agg: xr.DataArray, Parameters ---------- - agg : xr.DataArray + agg : xr.DataArray or xr.Dataset 2D array of elevation data. + If a Dataset is passed, the operation is applied to each + data variable independently. name : str, default='slope' Name of output DataArray. method : str, default='planar' @@ -292,7 +296,10 @@ def slope(agg: xr.DataArray, Returns ------- - slope_agg : xr.DataArray of same type as `agg` + slope_agg : xr.DataArray or xr.Dataset + If `agg` is a DataArray, returns a DataArray of the same type. + If `agg` is a Dataset, returns a Dataset with slope computed + for each data variable. 2D array of slope values. All other input attributes are preserved. diff --git a/xrspatial/tests/test_dataset_support.py b/xrspatial/tests/test_dataset_support.py new file mode 100644 index 00000000..06b305e0 --- /dev/null +++ b/xrspatial/tests/test_dataset_support.py @@ -0,0 +1,211 @@ +"""Tests for xr.Dataset support (issue #134).""" + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from xrspatial import slope, aspect +from xrspatial.classify import quantile +from xrspatial.focal import mean as focal_mean +from xrspatial.multispectral import ndvi, evi +from xrspatial.zonal import stats as zonal_stats + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def elevation_dataset(): + """Dataset with two elevation-like variables.""" + np.random.seed(42) + y = np.linspace(0, 1, 20) + x = np.linspace(0, 1, 20) + dem1 = xr.DataArray( + np.random.rand(20, 20).astype(np.float64) * 1000, + dims=['y', 'x'], coords={'y': y, 'x': x}, + attrs={'res': (y[1] - y[0], x[1] - x[0])}, + ) + dem2 = xr.DataArray( + np.random.rand(20, 20).astype(np.float64) * 500, + dims=['y', 'x'], coords={'y': y, 'x': x}, + attrs={'res': (y[1] - y[0], x[1] - x[0])}, + ) + return xr.Dataset({'dem1': dem1, 'dem2': dem2}, attrs={'source': 'test'}) + + +@pytest.fixture +def spectral_dataset(): + """Dataset mimicking multi-band satellite imagery.""" + np.random.seed(123) + data = lambda: np.random.rand(30, 30).astype(np.float64) * 0.5 + 0.1 + dims = ['y', 'x'] + return xr.Dataset({ + 'nir': xr.DataArray(data(), dims=dims), + 'red': xr.DataArray(data(), dims=dims), + 'blue': xr.DataArray(data(), dims=dims), + }) + + +@pytest.fixture +def zones_and_values(): + """Zones raster + values Dataset for zonal stats tests.""" + np.random.seed(7) + zones_data = np.zeros((10, 10), dtype=np.float64) + zones_data[:5, :] = 1.0 + zones_data[5:, :] = 2.0 + zones = xr.DataArray(zones_data, dims=['y', 'x']) + + vals_a = np.random.rand(10, 10).astype(np.float64) * 100 + vals_b = np.random.rand(10, 10).astype(np.float64) * 50 + ds = xr.Dataset({ + 'elevation': xr.DataArray(vals_a, dims=['y', 'x']), + 'temperature': xr.DataArray(vals_b, dims=['y', 'x']), + }) + return zones, ds + + +# =================================================================== +# A. Single-input decorator (supports_dataset) +# =================================================================== + +class TestSupportsDataset: + + def test_slope_dataset_returns_dataset(self, elevation_dataset): + result = slope(elevation_dataset) + assert isinstance(result, xr.Dataset) + assert set(result.data_vars) == {'dem1', 'dem2'} + + def test_slope_dataset_matches_individual(self, elevation_dataset): + ds = elevation_dataset + result = slope(ds) + for var in ds.data_vars: + expected = slope(ds[var]) + xr.testing.assert_allclose(result[var], expected) + + def test_slope_dataset_preserves_attrs(self, elevation_dataset): + result = slope(elevation_dataset) + assert result.attrs == elevation_dataset.attrs + + def test_slope_dataarray_unchanged(self, elevation_dataset): + """Existing DataArray path is a passthrough.""" + da = elevation_dataset['dem1'] + result = slope(da) + assert isinstance(result, xr.DataArray) + + def test_classify_quantile_dataset(self, elevation_dataset): + result = quantile(elevation_dataset, k=4) + assert isinstance(result, xr.Dataset) + assert set(result.data_vars) == {'dem1', 'dem2'} + for var in result.data_vars: + expected = quantile(elevation_dataset[var], k=4) + xr.testing.assert_equal(result[var], expected) + + def test_single_var_dataset(self, elevation_dataset): + ds = elevation_dataset[['dem1']] + result = slope(ds) + assert isinstance(result, xr.Dataset) + assert set(result.data_vars) == {'dem1'} + + def test_dataset_error_propagation(self): + """Bad data in one variable raises immediately.""" + ds = xr.Dataset({ + 'ok': xr.DataArray(np.random.rand(5, 5), dims=['y', 'x']), + 'bad': xr.DataArray(np.array(['a', 'b', 'c', 'd', 'e']), dims=['z']), + }) + with pytest.raises(Exception): + slope(ds) + + def test_aspect_dataset(self, elevation_dataset): + result = aspect(elevation_dataset) + assert isinstance(result, xr.Dataset) + for var in elevation_dataset.data_vars: + expected = aspect(elevation_dataset[var]) + xr.testing.assert_allclose(result[var], expected) + + def test_focal_mean_dataset(self, elevation_dataset): + result = focal_mean(elevation_dataset) + assert isinstance(result, xr.Dataset) + for var in elevation_dataset.data_vars: + expected = focal_mean(elevation_dataset[var]) + xr.testing.assert_allclose(result[var], expected) + + +# =================================================================== +# B. Multi-input decorator (supports_dataset_bands) +# =================================================================== + +class TestSupportsDatasetBands: + + def test_ndvi_dataset_band_kwargs(self, spectral_dataset): + result = ndvi(spectral_dataset, nir='nir', red='red') + assert isinstance(result, xr.DataArray) + + def test_ndvi_dataset_matches_individual(self, spectral_dataset): + ds = spectral_dataset + from_ds = ndvi(ds, nir='nir', red='red') + from_da = ndvi(ds['nir'], ds['red']) + xr.testing.assert_allclose(from_ds, from_da) + + def test_ndvi_missing_band_kwarg(self, spectral_dataset): + with pytest.raises(TypeError, match="'red' keyword required"): + ndvi(spectral_dataset, nir='nir') # missing red= + + def test_ndvi_invalid_var_name(self, spectral_dataset): + with pytest.raises(ValueError, match="'nonexistent' not in Dataset"): + ndvi(spectral_dataset, nir='nonexistent', red='red') + + def test_evi_extra_kwargs_passthrough(self, spectral_dataset): + """Extra kwargs like soil_factor, gain are passed through.""" + ds = spectral_dataset + result = evi(ds, nir='nir', red='red', blue='blue', + soil_factor=0.5, gain=3.0) + assert isinstance(result, xr.DataArray) + expected = evi(ds['nir'], ds['red'], ds['blue'], + soil_factor=0.5, gain=3.0) + xr.testing.assert_allclose(result, expected) + + def test_ndvi_dataarray_unchanged(self, spectral_dataset): + """Existing positional DataArray call still works.""" + ds = spectral_dataset + result = ndvi(ds['nir'], ds['red']) + assert isinstance(result, xr.DataArray) + + +# =================================================================== +# C. Zonal stats Dataset +# =================================================================== + +class TestZonalStatsDataset: + + def test_zonal_stats_dataset_column_naming(self, zones_and_values): + zones, ds = zones_and_values + result = zonal_stats(zones, ds) + assert isinstance(result, pd.DataFrame) + assert 'zone' in result.columns + # Check prefixed columns exist + for var in ds.data_vars: + assert f'{var}_mean' in result.columns + assert f'{var}_max' in result.columns + + def test_zonal_stats_dataset_matches_individual(self, zones_and_values): + zones, ds = zones_and_values + merged = zonal_stats(zones, ds) + for var in ds.data_vars: + individual = zonal_stats(zones, ds[var]) + for stat_col in individual.columns: + if stat_col == 'zone': + continue + prefixed = f'{var}_{stat_col}' + assert prefixed in merged.columns + pd.testing.assert_series_equal( + merged[prefixed].reset_index(drop=True), + individual[stat_col].reset_index(drop=True), + check_names=False, + ) + + def test_zonal_stats_dataset_return_type_error(self, zones_and_values): + zones, ds = zones_and_values + with pytest.raises(ValueError, match="return_type must be 'pandas.DataFrame'"): + zonal_stats(zones, ds, return_type='xarray.DataArray') diff --git a/xrspatial/zonal.py b/xrspatial/zonal.py index 614c5e1f..c04d4d52 100644 --- a/xrspatial/zonal.py +++ b/xrspatial/zonal.py @@ -455,12 +455,15 @@ def stats( the shape, values, and locations of the zones. An integer field in the input `zones` DataArray defines a zone. - values : xr.DataArray + values : xr.DataArray or xr.Dataset values is a 2D xarray DataArray of numeric values (integers or floats). The input `values` raster contains the input values used in calculating the output statistic for each zone. In dask case, the chunksizes of `zones` and `values` should be matching. If not, `values` will be rechunked to be the same as of `zones`. + When a Dataset is passed, stats are computed for each variable + and columns are prefixed with the variable name + (e.g. ``elevation_mean``). zone_ids : list of ints, or floats List of zones to be included in calculation. If no zone_ids provided, @@ -492,6 +495,10 @@ def stats( stats_df : Union[pandas.DataFrame, dask.dataframe.DataFrame] A pandas DataFrame, or a dask DataFrame where each column is a statistic and each row is a zone with zone id. + When ``values`` is a Dataset, the returned DataFrame has + columns prefixed by the variable name (e.g. ``elevation_mean``, + ``elevation_max``), and ``return_type`` must be + ``'pandas.DataFrame'``. Examples -------- @@ -566,6 +573,27 @@ def stats( 3 30 77.0 99 55 1925 14.21267 202.0 25 """ + # Dataset support: run stats per variable and merge into one DataFrame + if isinstance(values, xr.Dataset): + if return_type != 'pandas.DataFrame': + raise ValueError( + "return_type must be 'pandas.DataFrame' when values is a Dataset" + ) + dfs = [] + for var_name in values.data_vars: + df = stats( + zones, values[var_name], zone_ids, stats_funcs, + nodata_values, 'pandas.DataFrame', + ) + df = df.rename( + columns={c: f'{var_name}_{c}' for c in df.columns if c != 'zone'} + ) + dfs.append(df) + result = dfs[0] + for df in dfs[1:]: + result = result.merge(df, on='zone', how='outer') + return result + validate_arrays(zones, values) if not ( @@ -591,12 +619,12 @@ def stats( if isinstance(stats_funcs, list): # create a dict of stats stats_funcs_dict = {} - for stats in stats_funcs: - func = _DEFAULT_STATS.get(stats, None) + for stat_name in stats_funcs: + func = _DEFAULT_STATS.get(stat_name, None) if func is None: - err_str = f"Invalid stat name. {stats} option not supported." + err_str = f"Invalid stat name. {stat_name} option not supported." raise ValueError(err_str) - stats_funcs_dict[stats] = func + stats_funcs_dict[stat_name] = func elif isinstance(stats_funcs, dict): stats_funcs_dict = stats_funcs.copy()