Skip to content

[BUG] Fixing matmul to support leading dimensions > 1#88

Merged
ngoldbaum merged 4 commits into
numpy:mainfrom
SwayamInSync:matmul-nd
May 21, 2026
Merged

[BUG] Fixing matmul to support leading dimensions > 1#88
ngoldbaum merged 4 commits into
numpy:mainfrom
SwayamInSync:matmul-nd

Conversation

@SwayamInSync
Copy link
Copy Markdown
Member

@SwayamInSync SwayamInSync commented May 12, 2026

closes #87
As per the title

Note: This also found a separate bug of GEMM dispatching not supporting fortran ordered arrays, the related tests are added here and marked as xfail will be discussed in a different issue

@SwayamInSync SwayamInSync changed the title fixing matmul N-D batch issue [BUG] Fixing matmul to support leading dimensions > 1 May 12, 2026
@SwayamInSync
Copy link
Copy Markdown
Member Author

Interesting, this might be the race condition issue from the NumPy side on the lazy attribute loader?
In this PR the fix might be simple to pre-import the rec

click to expand
 ==================================== ERRORS ====================================
  _____________________ ERROR at call of test_pandas_strrep ______________________
  
      def test_pandas_strrep():
          """Test that we can construct a pandas data frame with quad precision columns
      
          Make sure the string representation can be generated
          """
          import pandas as pd
      
          BIG_NUMBER=123456789098765432123456789
          x = np.arange(500, dtype=np.float64) * BIG_NUMBER
          y = np.arange(500, dtype=QuadPrecDType()) * BIG_NUMBER
          df = pd.DataFrame({"col1": x, "col2": y})
  >       assert isinstance(str(df), str) # Make sure this doesn't fail
                            ^^^^^^^
  
  /Users/runner/work/numpy-quaddtype/numpy-quaddtype/tests/test_quaddtype.py:6041: 
  _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
  ../venv-test-arm64/lib/python3.14t/site-packages/pandas/core/frame.py:1201: in __repr__
      return self.to_string(**repr_params)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  ../venv-test-arm64/lib/python3.14t/site-packages/pandas/core/frame.py:1380: in to_string
      return fmt.DataFrameRenderer(formatter).to_string(
  ../venv-test-arm64/lib/python3.14t/site-packages/pandas/io/formats/format.py:973: in to_string
      string = string_formatter.to_string()
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  ../venv-test-arm64/lib/python3.14t/site-packages/pandas/io/formats/string.py:30: in to_string
      text = self._get_string_representation()
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  ../venv-test-arm64/lib/python3.14t/site-packages/pandas/io/formats/string.py:45: in _get_string_representation
      strcols = self._get_strcols()
                ^^^^^^^^^^^^^^^^^^^
  ../venv-test-arm64/lib/python3.14t/site-packages/pandas/io/formats/string.py:36: in _get_strcols
      strcols = self.fmt.get_strcols()
                ^^^^^^^^^^^^^^^^^^^^^^
  ../venv-test-arm64/lib/python3.14t/site-packages/pandas/io/formats/format.py:476: in get_strcols
      strcols = self._get_strcols_without_index()
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  ../venv-test-arm64/lib/python3.14t/site-packages/pandas/io/formats/format.py:729: in _get_strcols_without_index
      str_columns = self._get_formatted_column_labels(self.tr_frame)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  ../venv-test-arm64/lib/python3.14t/site-packages/pandas/io/formats/format.py:788: in _get_formatted_column_labels
      fmt_columns = columns._format_flat(include_name=False)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  ../venv-test-arm64/lib/python3.14t/site-packages/numpy/__init__.py:745: in __getattr__
      import numpy.rec as rec
  ../venv-test-arm64/lib/python3.14t/site-packages/numpy/__init__.py:745: in __getattr__
      import numpy.rec as rec
  ../venv-test-arm64/lib/python3.14t/site-packages/numpy/__init__.py:745: in __getattr__
      import numpy.rec as rec
  ../venv-test-arm64/lib/python3.14t/site-packages/numpy/__init__.py:745: in __getattr__
      import numpy.rec as rec
  ../venv-test-arm64/lib/python3.14t/site-packages/numpy/__init__.py:745: in __getattr__
      import numpy.rec as rec
  ../venv-test-arm64/lib/python3.14t/site-packages/numpy/__init__.py:745: in __getattr__
      import numpy.rec as rec
  ../venv-test-arm64/lib/python3.14t/site-packages/numpy/__init__.py:745: in __getattr__
      import numpy.rec as rec
  ../venv-test-arm64/lib/python3.14t/site-packages/numpy/__init__.py:745: in __getattr__
      import numpy.rec as rec
  ../venv-test-arm64/lib/python3.14t/site-packages/numpy/__init__.py:745: in __getattr__
      import numpy.rec as rec
  _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
  
  attr = 'rec'
  
      def __getattr__(attr):
          # Warn for expired attributes
          import warnings
      
          if attr == "linalg":
              import numpy.linalg as linalg
              return linalg
          elif attr == "fft":
              import numpy.fft as fft
              return fft
          elif attr == "dtypes":
              import numpy.dtypes as dtypes
              return dtypes
          elif attr == "random":
              import numpy.random as random
              return random
          elif attr == "polynomial":
              import numpy.polynomial as polynomial
              return polynomial
          elif attr == "ma":
              import numpy.ma as ma
              return ma
          elif attr == "ctypeslib":
              import numpy.ctypeslib as ctypeslib
              return ctypeslib
          elif attr == "exceptions":
              import numpy.exceptions as exceptions
              return exceptions
          elif attr == "testing":
              import numpy.testing as testing
              return testing
          elif attr == "matlib":
              import numpy.matlib as matlib
              return matlib
          elif attr == "f2py":
              import numpy.f2py as f2py
              return f2py
          elif attr == "typing":
              import numpy.typing as typing
              return typing
          elif attr == "rec":
  >           import numpy.rec as rec
  E           RecursionError: maximum recursion depth exceeded
  
  ../venv-test-arm64/lib/python3.14t/site-packages/numpy/__init__.py:745: RecursionError
  !!! Recursion error detected, but an error occurred locating the origin of recursion.
    The following exception happened when comparing locals in the stack frame:
      ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
    Displaying first and last 10 stack frames out of 996.

@SwayamInSync
Copy link
Copy Markdown
Member Author

Oh that's from cpython, fix is straightforward, will raise the Issue and PR there

@SwayamInSync
Copy link
Copy Markdown
Member Author

As a fix here, adding conftest.py (to pre-import the public submodules) in a separate PR, will merge it then re-run the workflows here

Copy link
Copy Markdown
Member

@ngoldbaum ngoldbaum left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I asked Claude to review this for C++ correctness and it spotted some issues, see below. Tests and overall implementation look good.

Comment thread tests/test_dot.py Outdated
A_f = rng.standard_normal((batch, m, k))
B_f = rng.standard_normal((batch, k, n))
_assert_matmul_matches_float64(_qnd(A_f), _qnd(B_f), A_f, B_f,
rtol=1e-13, atol=1e-13) No newline at end of file
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No newline at end of file. Maybe add a lint or pre-commit hook for this? See e.g. for why this matters.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add a pre-commit workflow to ensure this in future

Comment thread src/csrc/umath/matmul.cpp
case MATMUL_GEMM:
temp_A_buffer = new Sleef_quad[m * n];
temp_B_buffer = new Sleef_quad[n * p];
temp_C_buffer = new Sleef_quad[m * p];
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude points out that using new like this isn't exception-safe. To keep the allocation as-is you can either use std::unique_ptr to ensure RAII cleanup if an exception happens or new (std::nothrow) to disable exceptions for the allocation.

That said, IMO in an extension it's probably better to use PyMem_RawMalloc because it'll integrate with the interpreter better and scale on multithreaded parallelism better on the free-threaded build where it will use CPython's mimalloc.

{
if (!alpha || !A || !x || !beta || !y || m == 0 || n == 0) {
if (m == 0 || n == 0) {
return 0;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be consistent with qblas_dot, shouldn't you write zero to y? Similarly qblas_gemm should probably do the same to C.

Comment thread src/csrc/umath/matmul.cpp
case MATMUL_DOT: {
size_t incx = A_col_stride / sizeof(Sleef_quad);
size_t incy = B_row_stride / sizeof(Sleef_quad);
switch (op_type) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commenting here but the same applies to all the switch statements in your PR: add a default case that e.g. aborts to ensure you don't accidentally add code later that relies on falling through for an invalid value.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, good catch!

Comment thread src/csrc/umath/matmul.cpp
case MATMUL_GEMM:
temp_A_buffer = new Sleef_quad[m * n];
temp_B_buffer = new Sleef_quad[n * p];
temp_C_buffer = new Sleef_quad[m * p];
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also here and below for all signed integer multiplication: you need to do overflow checking. Since signed integer overflow is UB and the compiler is free to optimize this code away or other badness if it detects a possible UB here.

@SwayamInSync
Copy link
Copy Markdown
Member Author

Thanks @ngoldbaum all the reviews (except the ones I commented on), were already planned for a different PR.
I can perform all of them here if you feel right, but that might go out of scope for this PR and issue.

Let me know what you feel right?

@SwayamInSync
Copy link
Copy Markdown
Member Author

The pre-commit workflow can also come in a different PR because I am guessing it might flag unrelated positions to fix

@ngoldbaum
Copy link
Copy Markdown
Member

Sure, let's do the cleanups in future PRs.

@ngoldbaum ngoldbaum merged commit 60b1222 into numpy:main May 21, 2026
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] matmul produces incorrect results for batched / N-D inputs

2 participants