From 1fae65b7273123b0b8681b7a37a8d6959b93c75f Mon Sep 17 00:00:00 2001 From: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> Date: Mon, 18 May 2026 22:40:18 +0530 Subject: [PATCH] Adding default_xp context manager for xp_assert functions Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> --- src/array_api_extra/_lib/_testing.py | 60 +++++++++++++++++++++++++--- tests/test_testing.py | 13 +++++- 2 files changed, 67 insertions(+), 6 deletions(-) diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index 9447d5c3..b071d968 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -8,6 +8,9 @@ from __future__ import annotations import math +from collections.abc import Generator +from contextlib import contextmanager +from contextvars import ContextVar from types import ModuleType from typing import Any, cast @@ -30,6 +33,45 @@ __all__ = ["as_numpy_array", "xp_assert_close", "xp_assert_equal", "xp_assert_less"] +_default_xp_ctxvar: ContextVar[ModuleType] = ContextVar("_default_xp") + + +@contextmanager +def default_xp(xp: ModuleType) -> Generator[None, None, None]: + """In all ``xp_assert_*`` function calls executed within this + context manager, test by default that the array namespace is + the provided across all arrays, unless one explicitly passes the ``xp=`` + parameter. + + Without this context manager, the default value for `xp` is the namespace + for the desired array (the second parameter of the tests). + """ + token = _default_xp_ctxvar.set(xp) + try: + yield + finally: + _default_xp_ctxvar.reset(token) + + +def _assert_matching_namespace(actual: Array, desired: Array, xp: ModuleType) -> None: + desired_arr_space = array_namespace(desired) + _msg = ( + "Namespace of desired array does not match expectations " + "set by the `default_xp` context manager or by the `xp`" + "pytest fixture.\n" + f"Desired array's space: {desired_arr_space.__name__}\n" + f"Expected namespace: {xp.__name__}" + ) + assert desired_arr_space == xp, _msg + + actual_arr_space = array_namespace(actual) + _msg = ( + "Namespace of actual and desired arrays do not match.\n" + f"Actual: {actual_arr_space.__name__}\n" + f"Desired: {xp.__name__}" + ) + assert actual_arr_space == xp, _msg + def _check_ns_shape_dtype( actual: Array, @@ -37,6 +79,7 @@ def _check_ns_shape_dtype( check_dtype: bool, check_shape: bool, check_scalar: bool, + xp: ModuleType | None = None, ) -> tuple[Array, Array, ModuleType]: # numpydoc ignore=RT03 """ Assert that namespace, shape and dtype of the two arrays match. @@ -60,8 +103,12 @@ def _check_ns_shape_dtype( actual_xp = array_namespace(actual) # Raises on Python scalars and lists desired_xp = array_namespace(desired) - msg = f"namespaces do not match: {actual_xp} != f{desired_xp}" - assert actual_xp == desired_xp, msg + if xp is None: + try: + xp = _default_xp_ctxvar.get() + except LookupError: + xp = array_namespace(desired) + _assert_matching_namespace(actual, desired, xp) if is_numpy_namespace(actual_xp) and check_scalar: # only NumPy distinguishes between scalars and arrays; we do if check_scalar. @@ -148,6 +195,7 @@ def xp_assert_equal( check_dtype: bool = True, check_shape: bool = True, check_scalar: bool = False, + xp: ModuleType | None = None, ) -> None: """ Array-API compatible version of `np.testing.assert_array_equal`. @@ -174,7 +222,7 @@ def xp_assert_equal( numpy.testing.assert_array_equal : Similar function for NumPy arrays. """ actual, desired, xp = _check_ns_shape_dtype( - actual, desired, check_dtype, check_shape, check_scalar + actual, desired, check_dtype, check_shape, check_scalar, xp ) if not _is_materializable(actual): return @@ -194,6 +242,7 @@ def xp_assert_less( check_dtype: bool = True, check_shape: bool = True, check_scalar: bool = False, + xp: ModuleType | None = None, ) -> None: """ Array-API compatible version of `np.testing.assert_array_less`. @@ -217,7 +266,7 @@ def xp_assert_less( xp_assert_close : Similar function for inexact equality checks. numpy.testing.assert_array_equal : Similar function for NumPy arrays. """ - x, y, xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar) + x, y, xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar, xp) if not _is_materializable(x): return x_np = as_numpy_array(x, xp=xp) @@ -237,6 +286,7 @@ def xp_assert_close( check_dtype: bool = True, check_shape: bool = True, check_scalar: bool = False, + xp: ModuleType | None = None, ) -> None: """ Array-API compatible version of `np.testing.assert_allclose`. @@ -276,7 +326,7 @@ def xp_assert_close( Array arguments to `atol` and `rtol` must be valid input to :py:func:`float`. """ actual, desired, xp = _check_ns_shape_dtype( - actual, desired, check_dtype, check_shape, check_scalar + actual, desired, check_dtype, check_shape, check_scalar, xp ) if not _is_materializable(actual): return diff --git a/tests/test_testing.py b/tests/test_testing.py index b70de734..aac2284f 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -9,6 +9,7 @@ from array_api_extra._lib._backends import Backend from array_api_extra._lib._testing import ( as_numpy_array, + default_xp, xp_assert_close, xp_assert_equal, xp_assert_less, @@ -63,12 +64,22 @@ def test_shape_dtype(self, xp: ModuleType, func: Callable[..., None]): ) @pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less]) def test_namespace(self, xp: ModuleType, func: Callable[..., None]): - with pytest.raises(AssertionError, match="namespaces do not match"): + with pytest.raises( + AssertionError, match="Namespace of actual and desired arrays do not match" + ): func(xp.asarray(0), np.asarray(0)) with pytest.raises(TypeError, match=r"array_namespace requires .* array input"): func(xp.asarray(0), 0) with pytest.raises(TypeError, match="list is not a supported array type"): func(xp.asarray([0]), [0]) + with ( + default_xp(np), + pytest.raises( + AssertionError, + match="Namespace of desired array does not match expectations", + ), + ): + func(xp.asarray(0), xp.asarray(0)) @pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less]) def test_check_shape(self, xp: ModuleType, func: Callable[..., None]):