Skip to content

Commit aaaaa8d

Browse files
authored
Merge pull request #430 from ev-br/rm_complex128
MAINT: remove explicit xp.complex128 imports
2 parents cf7bc9f + 7b7713f commit aaaaa8d

File tree

7 files changed

+80
-85
lines changed

7 files changed

+80
-85
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,10 +281,23 @@ def __contains__(self, other):
281281
)
282282

283283

284+
# complex128 if available else complex64
285+
widest_complex_dtype = max(
286+
[(dt, dtype_nbits[dt]) for dt in complex_dtypes], key=lambda x: x[1]
287+
)[0]
288+
289+
290+
# float64 if available else float32
291+
widest_real_dtype = max(
292+
[(dt, dtype_nbits[dt]) for dt in real_float_dtypes], key=lambda x: x[1]
293+
)[0]
294+
295+
284296
dtype_components = _make_dtype_mapping_from_names(
285297
{"complex64": xp.float32, "complex128": xp.float64}
286298
)
287299

300+
288301
def as_real_dtype(dtype):
289302
"""
290303
Return the corresponding real dtype for a given floating-point dtype.

array_api_tests/hypothesis_helpers.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from . import xps
2121
from ._array_module import _UndefinedStub
2222
from ._array_module import bool as bool_dtype
23-
from ._array_module import broadcast_to, eye, float32, float64, full, complex64, complex128
23+
from ._array_module import broadcast_to, eye, full
2424
from .stubs import category_to_funcs
2525
from .pytest_helpers import nargs
2626
from .typing import Array, DataType, Scalar, Shape
@@ -465,26 +465,21 @@ def scalars(draw, dtypes, finite=False, **kwds):
465465
m, M = dh.dtype_ranges[dtype]
466466
min_value = kwds.get('min_value', m)
467467
max_value = kwds.get('max_value', M)
468-
469468
return draw(integers(min_value, max_value))
469+
470470
elif dtype == bool_dtype:
471471
return draw(booleans())
472-
elif dtype == float64:
473-
if finite:
474-
return draw(floats(allow_nan=False, allow_infinity=False, **kwds))
475-
return draw(floats(), **kwds)
476-
elif dtype == float32:
477-
if finite:
478-
return draw(floats(width=32, allow_nan=False, allow_infinity=False, **kwds))
479-
return draw(floats(width=32, **kwds))
480-
elif dtype == complex64:
481-
if finite:
482-
return draw(complex_numbers(width=32, allow_nan=False, allow_infinity=False))
483-
return draw(complex_numbers(width=32))
484-
elif dtype == complex128:
485-
if finite:
486-
return draw(complex_numbers(allow_nan=False, allow_infinity=False))
487-
return draw(complex_numbers())
472+
473+
elif dtype in dh.real_float_dtypes:
474+
f_kwds = dict(allow_nan=False, allow_infinity=False) if finite else dict()
475+
width = dh.dtype_nbits[dtype] # 32 or 64
476+
return draw(floats(width=width, **f_kwds, **kwds))
477+
478+
elif dtype in dh.complex_dtypes:
479+
f_kwds = dict(allow_nan=False, allow_infinity=False) if finite else dict()
480+
width = dh.dtype_nbits[dtype] # 64 or 128
481+
return draw(complex_numbers(width=width, **f_kwds, **kwds))
482+
488483
else:
489484
raise ValueError(f"Unrecognized dtype {dtype}")
490485

array_api_tests/pytest_helpers.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,8 @@ def assert_dtype(
161161
def assert_float_to_complex_dtype(
162162
func_name: str, *, in_dtype: DataType, out_dtype: DataType
163163
):
164-
if in_dtype == xp.float32:
165-
expected = xp.complex64
166-
else:
167-
assert in_dtype == xp.float64 # sanity check
168-
expected = xp.complex128
164+
assert in_dtype in dh.real_float_dtypes # sanity check
165+
expected = dh.complex_dtype_for(in_dtype)
169166
assert_dtype(
170167
func_name, in_dtype=in_dtype, out_dtype=out_dtype, expected=expected
171168
)
@@ -174,13 +171,8 @@ def assert_float_to_complex_dtype(
174171
def assert_complex_to_float_dtype(
175172
func_name: str, *, in_dtype: DataType, out_dtype: DataType, repr_name: str = "out.dtype"
176173
):
177-
if in_dtype == xp.complex64:
178-
expected = xp.float32
179-
elif in_dtype == xp.complex128:
180-
expected = xp.float64
181-
else:
182-
assert in_dtype in (xp.float32, xp.float64) # sanity check
183-
expected = in_dtype
174+
assert in_dtype in dh.all_float_dtypes
175+
expected = dh.real_dtype_for(in_dtype)
184176
assert_dtype(
185177
func_name, in_dtype=in_dtype, out_dtype=out_dtype, expected=expected, repr_name=repr_name
186178
)

array_api_tests/test_creation_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def test_arange(dtype, data):
198198
), f"out[0]={out[0]}, but should be {_start} {f_func}"
199199
except Exception as exc:
200200
ph.add_note(exc, repro_snippet)
201-
raise
201+
raise
202202

203203

204204
@given(shape=hh.shapes(min_side=1), data=st.data())

array_api_tests/test_data_type_functions.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -215,12 +215,7 @@ def test_finfo_dtype(dtype):
215215
try:
216216
out = xp.finfo(dtype)
217217

218-
if dtype == xp.complex64:
219-
assert out.dtype == xp.float32
220-
elif dtype == xp.complex128:
221-
assert out.dtype == xp.float64
222-
else:
223-
assert out.dtype == dtype
218+
assert out.dtype == dh.real_dtype_for(dtype)
224219

225220
# Guard vs. numpy.dtype.__eq__ lax comparison
226221
assert not isinstance(out.dtype, str)

array_api_tests/test_signatures.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,11 @@ def make_pretty_func(func_name: str, *args: Any, **kwargs: Any) -> str:
130130
{
131131
"stack": {"arrays": "[xp.ones((5,)), xp.ones((5,))]"},
132132
"iinfo": {"type": "xp.int64"},
133-
"finfo": {"type": "xp.float64"},
134-
"cholesky": {"x": "xp.asarray([[1, 0], [0, 1]], dtype=xp.float64)"},
135-
"inv": {"x": "xp.asarray([[1, 2], [3, 4]], dtype=xp.float64)"},
133+
"finfo": {"type": "xp.float32"},
134+
"cholesky": {"x": "xp.asarray([[1, 0], [0, 1]], dtype=xp.float32)"},
135+
"inv": {"x": "xp.asarray([[1, 2], [3, 4]], dtype=xp.float32)"},
136136
"solve": {
137-
a: "xp.asarray([[1, 2], [3, 4]], dtype=xp.float64)" for a in ["x1", "x2"]
137+
a: "xp.asarray([[1, 2], [3, 4]], dtype=xp.float32)" for a in ["x1", "x2"]
138138
},
139139
"outer": {"x1": "xp.ones((5,))", "x2": "xp.ones((5,))"},
140140
},

0 commit comments

Comments
 (0)