Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 55 additions & 5 deletions src/array_api_extra/_lib/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from __future__ import annotations

import math
from collections.abc import Generator
from contextlib import contextmanager
from contextvars import ContextVar
from types import ModuleType
from typing import Any, cast

Expand All @@ -30,13 +33,53 @@

__all__ = ["as_numpy_array", "xp_assert_close", "xp_assert_equal", "xp_assert_less"]

_default_xp_ctxvar: ContextVar[ModuleType] = ContextVar("_default_xp")


@contextmanager
def default_xp(xp: ModuleType) -> Generator[None, None, None]:
"""In all ``xp_assert_*`` function calls executed within this
context manager, test by default that the array namespace is
the provided across all arrays, unless one explicitly passes the ``xp=``
parameter.

Without this context manager, the default value for `xp` is the namespace
for the desired array (the second parameter of the tests).
"""
token = _default_xp_ctxvar.set(xp)
try:
yield
finally:
_default_xp_ctxvar.reset(token)


def _assert_matching_namespace(actual: Array, desired: Array, xp: ModuleType) -> None:
desired_arr_space = array_namespace(desired)
_msg = (
"Namespace of desired array does not match expectations "
"set by the `default_xp` context manager or by the `xp`"
"pytest fixture.\n"
f"Desired array's space: {desired_arr_space.__name__}\n"
f"Expected namespace: {xp.__name__}"
)
assert desired_arr_space == xp, _msg

actual_arr_space = array_namespace(actual)
_msg = (
"Namespace of actual and desired arrays do not match.\n"
f"Actual: {actual_arr_space.__name__}\n"
f"Desired: {xp.__name__}"
)
assert actual_arr_space == xp, _msg


def _check_ns_shape_dtype(
actual: Array,
desired: Array,
check_dtype: bool,
check_shape: bool,
check_scalar: bool,
xp: ModuleType | None = None,
) -> tuple[Array, Array, ModuleType]: # numpydoc ignore=RT03
"""
Assert that namespace, shape and dtype of the two arrays match.
Expand All @@ -60,8 +103,12 @@ def _check_ns_shape_dtype(
actual_xp = array_namespace(actual) # Raises on Python scalars and lists
desired_xp = array_namespace(desired)

msg = f"namespaces do not match: {actual_xp} != f{desired_xp}"
assert actual_xp == desired_xp, msg
if xp is None:
try:
xp = _default_xp_ctxvar.get()
except LookupError:
xp = array_namespace(desired)
Comment on lines +106 to +110
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove the default_xp context manager code from scipy and keep it here (and import from here), but we would still do this before converting the inputs to asarrays in scipy. On another thought, we can wrap this around a function and expose it.

_assert_matching_namespace(actual, desired, xp)

if is_numpy_namespace(actual_xp) and check_scalar:
# only NumPy distinguishes between scalars and arrays; we do if check_scalar.
Expand Down Expand Up @@ -148,6 +195,7 @@ def xp_assert_equal(
check_dtype: bool = True,
check_shape: bool = True,
check_scalar: bool = False,
xp: ModuleType | None = None,
) -> None:
"""
Array-API compatible version of `np.testing.assert_array_equal`.
Expand All @@ -174,7 +222,7 @@ def xp_assert_equal(
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
"""
actual, desired, xp = _check_ns_shape_dtype(
actual, desired, check_dtype, check_shape, check_scalar
actual, desired, check_dtype, check_shape, check_scalar, xp
)
if not _is_materializable(actual):
return
Expand All @@ -194,6 +242,7 @@ def xp_assert_less(
check_dtype: bool = True,
check_shape: bool = True,
check_scalar: bool = False,
xp: ModuleType | None = None,
) -> None:
"""
Array-API compatible version of `np.testing.assert_array_less`.
Expand All @@ -217,7 +266,7 @@ def xp_assert_less(
xp_assert_close : Similar function for inexact equality checks.
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
"""
x, y, xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar)
x, y, xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar, xp)
if not _is_materializable(x):
return
x_np = as_numpy_array(x, xp=xp)
Expand All @@ -237,6 +286,7 @@ def xp_assert_close(
check_dtype: bool = True,
check_shape: bool = True,
check_scalar: bool = False,
xp: ModuleType | None = None,
) -> None:
"""
Array-API compatible version of `np.testing.assert_allclose`.
Expand Down Expand Up @@ -276,7 +326,7 @@ def xp_assert_close(
Array arguments to `atol` and `rtol` must be valid input to :py:func:`float`.
"""
actual, desired, xp = _check_ns_shape_dtype(
actual, desired, check_dtype, check_shape, check_scalar
actual, desired, check_dtype, check_shape, check_scalar, xp
)
if not _is_materializable(actual):
return
Expand Down
13 changes: 12 additions & 1 deletion tests/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from array_api_extra._lib._backends import Backend
from array_api_extra._lib._testing import (
as_numpy_array,
default_xp,
xp_assert_close,
xp_assert_equal,
xp_assert_less,
Expand Down Expand Up @@ -63,12 +64,22 @@ def test_shape_dtype(self, xp: ModuleType, func: Callable[..., None]):
)
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
def test_namespace(self, xp: ModuleType, func: Callable[..., None]):
with pytest.raises(AssertionError, match="namespaces do not match"):
with pytest.raises(
AssertionError, match="Namespace of actual and desired arrays do not match"
):
func(xp.asarray(0), np.asarray(0))
with pytest.raises(TypeError, match=r"array_namespace requires .* array input"):
func(xp.asarray(0), 0)
with pytest.raises(TypeError, match="list is not a supported array type"):
func(xp.asarray([0]), [0])
with (
default_xp(np),
pytest.raises(
AssertionError,
match="Namespace of desired array does not match expectations",
),
):
func(xp.asarray(0), xp.asarray(0))

@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
def test_check_shape(self, xp: ModuleType, func: Callable[..., None]):
Expand Down