diff --git a/CHANGELOG.md b/CHANGELOG.md index 61cde1ddfef..25b58b7bc56 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ Also, that release drops support for Python 3.9, making Python 3.10 the minimum * Added implementation of `dpnp.divmod` [#2674](https://github.com/IntelPython/dpnp/pull/2674) * Added implementation of `dpnp.isin` function [#2595](https://github.com/IntelPython/dpnp/pull/2595) * Added implementation of `dpnp.scipy.linalg.lu` (SciPy-compatible) [#2787](https://github.com/IntelPython/dpnp/pull/2787) +* Added support for ndarray subclassing via `dpnp.ndarray.view` method with `type` parameter [#2815](https://github.com/IntelPython/dpnp/issues/2815) ### Changed diff --git a/dpnp/dpnp_array.py b/dpnp/dpnp_array.py index dad67fc1b58..951f782c300 100644 --- a/dpnp/dpnp_array.py +++ b/dpnp/dpnp_array.py @@ -644,6 +644,136 @@ def _create_from_usm_ndarray(usm_ary: dpt.usm_ndarray): res._array_obj._set_namespace(dpnp) return res + def _create_view(self, array_class, shape, dtype, strides): + """ + Create a view of an array with the specified class. + + The method handles subclass instantiation by creating a usm_ndarray + view and then wrapping it in the appropriate class. + + Parameters + ---------- + array_class : type + The class to instantiate (dpnp_array or a subclass). + shape : tuple + Shape of the view. + dtype : dtype + Data type of the view (can be None to keep source's dtype). + strides : tuple + Strides of the view. + + Returns + ------- + view : array_class instance + A view of the array as the specified class. + + """ + + if dtype is None: + dtype = self.dtype + + # create the underlying usm_ndarray view + usm_view = dpt.usm_ndarray( + shape, + dtype=dtype, + buffer=self._array_obj, + strides=tuple(s // dpnp.dtype(dtype).itemsize for s in strides), + ) + + # wrap the view into the appropriate class + if array_class is dpnp_array: + res = dpnp_array._create_from_usm_ndarray(usm_view) + else: + # for subclasses, create using __new__ and set up manually + res = array_class.__new__(array_class) + res._array_obj = usm_view + res._array_obj._set_namespace(dpnp) + + if hasattr(res, "__array_finalize__"): + res.__array_finalize__(self) + + return res + + def _view_impl(self, dtype=None, array_class=None): + """ + Internal implementation of view method to avoid an issue where + `type` parameter in ndarray.view method shadowing builtin type. + + """ + + # check if dtype is actually a type + if dtype is not None: + if isinstance(dtype, type) and issubclass(dtype, dpnp_array): + if array_class is not None: + raise ValueError("Cannot specify output type twice") + array_class = dtype + dtype = None + + # validate array_class parameter + if not ( + array_class is None + or isinstance(array_class, type) + and issubclass(array_class, dpnp_array) + ): + raise ValueError("Type must be a sub-type of ndarray type") + + if array_class is None: + # it's a view on dpnp.ndarray + array_class = self.__class__ + + old_sh = self.shape + old_strides = self.strides + + if dtype is None: + return self._create_view(array_class, old_sh, None, old_strides) + + new_dt = dpnp.dtype(dtype) + new_dt = dtu._to_device_supported_dtype(new_dt, self.sycl_device) + + new_itemsz = new_dt.itemsize + old_itemsz = self.dtype.itemsize + if new_itemsz == old_itemsz: + return self._create_view(array_class, old_sh, new_dt, old_strides) + + ndim = self.ndim + if ndim == 0: + raise ValueError( + "Changing the dtype of a 0d array is only supported " + "if the itemsize is unchanged" + ) + + # resize on last axis only + axis = ndim - 1 + if ( + old_sh[axis] != 1 + and self.size != 0 + and old_strides[axis] != old_itemsz + ): + raise ValueError( + "To change to a dtype of a different size, " + "the last axis must be contiguous" + ) + + # normalize strides whenever itemsize changes + new_strides = tuple( + old_strides[i] if i != axis else new_itemsz for i in range(ndim) + ) + + new_dim = old_sh[axis] * old_itemsz + if new_dim % new_itemsz != 0: + raise ValueError( + "When changing to a larger dtype, its size must be a divisor " + "of the total size in bytes of the last axis of the array" + ) + + # normalize shape whenever itemsize changes + new_sh = tuple( + old_sh[i] if i != axis else new_dim // new_itemsz + for i in range(ndim) + ) + + return self._create_view(array_class, new_sh, new_dt, new_strides) + def all(self, axis=None, *, out=None, keepdims=False, where=True): """ Return ``True`` if all elements evaluate to ``True``. @@ -2322,10 +2452,18 @@ def view(self, /, dtype=None, *, type=None): Parameters ---------- - dtype : {None, str, dtype object}, optional + dtype : {None, str, dtype object, type}, optional The desired data type of the returned view, e.g. :obj:`dpnp.float32` - or :obj:`dpnp.int16`. By default, it results in the view having the - same data type. + or :obj:`dpnp.int16`. Omitting it results in the view having the + same data type. Can also be a subclass of :class:`dpnp.ndarray` to + create a view of that type (this is equivalent to setting the `type` + parameter). + + Default: ``None``. + type : {None, type}, optional + Type of the returned view, e.g. a subclass of :class:`dpnp.ndarray`. + If specified, the returned array will be an instance of `type`. + Omitting it results in type preservation. Default: ``None``. @@ -2340,11 +2478,6 @@ def view(self, /, dtype=None, *, type=None): Only the last axis has to be contiguous. - Limitations - ----------- - Parameter `type` is supported only with default value ``None``. - Otherwise, the function raises ``NotImplementedError`` exception. - Examples -------- >>> import dpnp as np @@ -2368,73 +2501,17 @@ def view(self, /, dtype=None, *, type=None): [[2312, 2826], [5396, 5910]]], dtype=int16) - """ - - if type is not None: - raise NotImplementedError( - "Keyword argument `type` is supported only with " - f"default value ``None``, but got {type}." - ) - - old_sh = self.shape - old_strides = self.strides - - if dtype is None: - return dpnp_array(old_sh, buffer=self, strides=old_strides) - - new_dt = dpnp.dtype(dtype) - new_dt = dtu._to_device_supported_dtype(new_dt, self.sycl_device) - - new_itemsz = new_dt.itemsize - old_itemsz = self.dtype.itemsize - if new_itemsz == old_itemsz: - return dpnp_array( - old_sh, dtype=new_dt, buffer=self, strides=old_strides - ) - - ndim = self.ndim - if ndim == 0: - raise ValueError( - "Changing the dtype of a 0d array is only supported " - "if the itemsize is unchanged" - ) - - # resize on last axis only - axis = ndim - 1 - if ( - old_sh[axis] != 1 - and self.size != 0 - and old_strides[axis] != old_itemsz - ): - raise ValueError( - "To change to a dtype of a different size, " - "the last axis must be contiguous" - ) + Creating a view with a custom ndarray subclass: - # normalize strides whenever itemsize changes - new_strides = tuple( - old_strides[i] if i != axis else new_itemsz for i in range(ndim) - ) - - new_dim = old_sh[axis] * old_itemsz - if new_dim % new_itemsz != 0: - raise ValueError( - "When changing to a larger dtype, its size must be a divisor " - "of the total size in bytes of the last axis of the array" - ) - - # normalize shape whenever itemsize changes - new_sh = tuple( - old_sh[i] if i != axis else new_dim // new_itemsz - for i in range(ndim) - ) + >>> class MyArray(np.ndarray): + ... pass + >>> x = np.array([1, 2, 3]) + >>> y = x.view(MyArray) + >>> type(y) + - return dpnp_array( - new_sh, - dtype=new_dt, - buffer=self, - strides=new_strides, - ) + """ + return self._view_impl(dtype=dtype, array_class=type) @property def usm_type(self): diff --git a/dpnp/tests/test_ndarray.py b/dpnp/tests/test_ndarray.py index 4e4e42bbc85..6ce8645a11d 100644 --- a/dpnp/tests/test_ndarray.py +++ b/dpnp/tests/test_ndarray.py @@ -228,10 +228,87 @@ def test_python_types(self, dt): expected = a.view(dt) assert_allclose(result, expected) - def test_type_error(self): - x = dpnp.ones(4, dtype="i4") - with pytest.raises(NotImplementedError): - x.view("i2", type=dpnp.ndarray) + def test_subclass_basic(self): + class MyArray(dpnp.ndarray): + pass + + x = dpnp.array([1, 2, 3]) + view = x.view(type=MyArray) + + assert isinstance(view, MyArray) + assert type(view) is MyArray + assert (view == x).all() + + def test_dtype_type_subclass(self): + class MyArray(dpnp.ndarray): + pass + + x = dpnp.array([1, 2, 3]) + + # All three syntaxes should work identically + view1 = x.view(type=MyArray) + view2 = x.view(MyArray) + view3 = x.view(dtype=MyArray) + + assert type(view1) is MyArray + assert type(view2) is MyArray + assert type(view3) is MyArray + + def test_subclass_array_finalize(self): + class ArrayWithInfo(dpnp.ndarray): + def __array_finalize__(self, obj): + self.info = getattr(obj, "info", "default") + + x = dpnp.array([1, 2, 3]).view(type=ArrayWithInfo) + x.info = "metadata" + + # Create a view - __array_finalize__ should be called + view = x.view() + assert hasattr(view, "info") + assert view.info == "metadata" + assert type(view) is ArrayWithInfo + + def test_subclass_self_class_preservation(self): + class MyArray(dpnp.ndarray): + pass + + x = dpnp.array([1, 2, 3]).view(type=MyArray) + + # View without type parameter should preserve MyArray + view = x.view() + assert type(view) is MyArray + + def test_subclass_with_dtype_change(self): + class MyArray(dpnp.ndarray): + pass + + x = dpnp.array([1.0, 2.0], dtype=dpnp.float32) + view = x.view(dtype=dpnp.int32, type=MyArray) + + assert type(view) is MyArray + assert view.dtype == dpnp.int32 + + @pytest.mark.parametrize("xp", [dpnp, numpy]) + def test_subclass_invalid_type(self, xp): + x = xp.array([1, 2, 3]) + with pytest.raises( + ValueError, match="Type must be a sub-type of ndarray type" + ): + x.view(type=list) + + @pytest.mark.parametrize("xp", [dpnp, numpy]) + def test_subclass_double_type_specification(self, xp): + class MyArray(xp.ndarray): + pass + + class OtherArray(xp.ndarray): + pass + + x = xp.array([1, 2, 3]) + with pytest.raises( + ValueError, match="Cannot specify output type twice" + ): + x.view(dtype=MyArray, type=OtherArray) @pytest.mark.parametrize( diff --git a/dpnp/tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py b/dpnp/tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py index 7b503f1997a..5df4322ba0b 100644 --- a/dpnp/tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py +++ b/dpnp/tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py @@ -466,7 +466,6 @@ def __array_finalize__(self, obj): self.info = getattr(obj, "info", None) -@pytest.mark.skip("subclass array is not supported") class TestSubclassArrayView: def test_view_casting(self):