From 09e247bccd7bb495f2efe9c6beda872244b48925 Mon Sep 17 00:00:00 2001 From: dvezinet Date: Mon, 10 Mar 2025 22:45:33 +0000 Subject: [PATCH 1/4] [#202] First operational version of _check_all_broadcastable() --- datastock/_generic_check.py | 89 ++++++++++++++++++++++++++++++++++++- 1 file changed, 88 insertions(+), 1 deletion(-) diff --git a/datastock/_generic_check.py b/datastock/_generic_check.py index c864655..30d72ff 100644 --- a/datastock/_generic_check.py +++ b/datastock/_generic_check.py @@ -550,6 +550,93 @@ def _obj_key(d0=None, short=None, key=None, ndigits=None): ) +# ############################################################################# +# ############################################################################# +# Utilities for plotting +# ############################################################################# + + +def _check_all_broadcastable(**kwdargs): + + # ------------------- + # Preliminary check + # ------------------- + + dout = {} + dfail = {} + for k0, v0 in kwdargs.items(): + try: + dout[k0] = np.atleast_1d(v0) + except Exception as err: + dfail[k0] = f"Not convertible to np.ndarray! - {v0}" + + # Raise Exception + if len(dfail) > 0: + lstr = [f"\t- {k0}: {v0}" for k0, v0 in dfail.items()] + msg = ( + "The following kwdargs are non-conform:\n" + + "\n".join(lstr) + ) + raise Exception(msg) + + # ------------------- + # check ndim + # ------------------- + + dndim = {k0: v0.ndim for k0, v0 in dout.items() if v0.shape != (1,)} + lndim = list(set(dndim.values())) + + if len(lndim) == 0: + # all scalar + return {k0: v0[0] for k0, v0 in dout.items()}, None + + elif len(lndim) == 1: + ndim = lndim[0] + + else: + lstr = [f"-t {k0}: {v0}" for k0, v0 in dndim.items()] + msg = ( + "Some keyword args have non-compatible dimensions:\n" + + "\n".join(lstr) + ) + raise Exception(msg) + + # ------------------- + # check shapes + # ------------------- + + dfail = {} + shapef = np.ones((ndim,), dtype=int) + for k0, v0 in dout.items(): + + if v0.shape == (1,): + dout[k0] = v0[0] + continue + + for ii in range(ndim): + if v0.shape[ii] == 1: + pass + elif shapef[ii] == 1: + shapef[ii] = v0.shape[ii] + elif v0.shape[ii] == shapef[ii]: + pass + else: + dfail[k0] = f"Non-compatible shape = {v0.shape} (ii = {ii})" + continue + + # raise Exception if needed + if len(dfail) > 1: + lstr = [f"\t- {k0}: {v0}" for k0, v0 in dfail.items()] + msg = ( + "The following keywords args have non-compatible shape:\n" + + "\n".join(lstr) + + f"Reference shape: {shapef}\n" + ) + raise Exception(msg) + + return dout, shapef + + # ############################################################################# # ############################################################################# # Utilities for plotting @@ -929,4 +1016,4 @@ def _check_cmap_vminvmax(data=None, cmap=None, vmin=None, vmax=None): else: vmax = nanmax - return cmap, vmin, vmax \ No newline at end of file + return cmap, vmin, vmax From b7ccef080653ef675928ea9260e5a3dadcfbd438 Mon Sep 17 00:00:00 2001 From: dvezinet Date: Tue, 11 Mar 2025 12:00:09 +0000 Subject: [PATCH 2/4] [#202] Added unit test --- datastock/tests/test_01_DataStock.py | 34 +++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/datastock/tests/test_01_DataStock.py b/datastock/tests/test_01_DataStock.py index 59da29a..56dde3b 100644 --- a/datastock/tests/test_01_DataStock.py +++ b/datastock/tests/test_01_DataStock.py @@ -14,6 +14,7 @@ import matplotlib.pyplot as plt # datastock-specific +from .._generic_check import _check_all_broadcastable from .._class import DataStock from .._saveload import load @@ -228,6 +229,37 @@ def test02_add_data(self): def test03_add_obj(self): _add_obj(st=self.st, nc=self.nc) + # ------------------------ + # Tools + # ------------------------ + + def test04_check_all_broadcastable(self): + # all scalar + dout, shape = _check_all_broadcastable(a=1, b=2) + + # scalar + arrays + dout, shape = _check_all_broadcastable(a=1, b=(1, 2, 3)) + + # all arrays + dout, shape = _check_all_broadcastable( + a=(1, 2, 3), + b=(1, 2, 3), + ) + + # all arrays - 2d + dout, shape = _check_all_broadcastable( + a=np.r_[1, 2, 3][:, None], + b=np.r_[10, 20][None, :], + ) + + # check flag + err = False + try: + dout, shape = _check_all_broadcastable(a=(1, 2), b=(1, 2, 3)) + except Exception as err: + err = True + assert err is True + ####################################################### # @@ -639,4 +671,4 @@ def test26_saveload_coll(self, verb=False): msg = st2.__eq__(self.st, returnas=str) if msg is not True: raise Exception(msg) - os.remove(pfe) \ No newline at end of file + os.remove(pfe) From 53fdc64e242065a8aa428ca3dbddd1709613b07d Mon Sep 17 00:00:00 2001 From: dvezinet Date: Tue, 11 Mar 2025 12:07:30 +0000 Subject: [PATCH 3/4] [#202] Unit test debug 1 --- datastock/_generic_check.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/datastock/_generic_check.py b/datastock/_generic_check.py index 30d72ff..ee31ee6 100644 --- a/datastock/_generic_check.py +++ b/datastock/_generic_check.py @@ -624,13 +624,15 @@ def _check_all_broadcastable(**kwdargs): dfail[k0] = f"Non-compatible shape = {v0.shape} (ii = {ii})" continue + shapef = tuple(shapef) + # raise Exception if needed - if len(dfail) > 1: + if len(dfail) > 0: lstr = [f"\t- {k0}: {v0}" for k0, v0 in dfail.items()] msg = ( "The following keywords args have non-compatible shape:\n" + "\n".join(lstr) - + f"Reference shape: {shapef}\n" + + f"\nReference shape: {shapef}\n" ) raise Exception(msg) From 152ed4bd95d3eddb86588b25f785ace3dd48b2c3 Mon Sep 17 00:00:00 2001 From: dvezinet Date: Tue, 11 Mar 2025 12:17:44 +0000 Subject: [PATCH 4/4] [#202] Unit test debug 2 --- datastock/tests/test_01_DataStock.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datastock/tests/test_01_DataStock.py b/datastock/tests/test_01_DataStock.py index 56dde3b..3e39bd0 100644 --- a/datastock/tests/test_01_DataStock.py +++ b/datastock/tests/test_01_DataStock.py @@ -256,8 +256,9 @@ def test04_check_all_broadcastable(self): err = False try: dout, shape = _check_all_broadcastable(a=(1, 2), b=(1, 2, 3)) - except Exception as err: + except Exception: err = True + assert err is True