diff --git a/divide_and_conquer/strassen_matrix_multiplication.py b/divide_and_conquer/strassen_matrix_multiplication.py index f529a255d2ef..525fe8e3c649 100644 --- a/divide_and_conquer/strassen_matrix_multiplication.py +++ b/divide_and_conquer/strassen_matrix_multiplication.py @@ -5,7 +5,15 @@ def default_matrix_multiplication(a: list, b: list) -> list: """ - Multiplication only for 2x2 matrices + Standard multiplication for 2x2 matrices (base case). + + Used as the base case for Strassen's algorithm when the matrix + cannot be subdivided further. Uses 8 multiplications. + + Time complexity: O(1) — fixed size input. + + >>> default_matrix_multiplication([[1, 2], [3, 4]], [[5, 6], [7, 8]]) + [[19, 22], [43, 50]] """ if len(a) != 2 or len(a[0]) != 2 or len(b) != 2 or len(b[0]) != 2: raise Exception("Matrices are not 2x2") @@ -17,6 +25,7 @@ def default_matrix_multiplication(a: list, b: list) -> list: def matrix_addition(matrix_a: list, matrix_b: list): + """Element-wise addition of two matrices of equal dimensions.""" return [ [matrix_a[row][col] + matrix_b[row][col] for col in range(len(matrix_a[row]))] for row in range(len(matrix_a)) @@ -24,6 +33,7 @@ def matrix_addition(matrix_a: list, matrix_b: list): def matrix_subtraction(matrix_a: list, matrix_b: list): + """Element-wise subtraction of two matrices of equal dimensions.""" return [ [matrix_a[row][col] - matrix_b[row][col] for col in range(len(matrix_a[row]))] for row in range(len(matrix_a)) @@ -64,6 +74,7 @@ def split_matrix(a: list) -> tuple[list, list, list, list]: def matrix_dimensions(matrix: list) -> tuple[int, int]: + """Return (rows, columns) of a matrix.""" return len(matrix), len(matrix[0]) @@ -73,8 +84,22 @@ def print_matrix(matrix: list) -> None: def actual_strassen(matrix_a: list, matrix_b: list) -> list: """ - Recursive function to calculate the product of two matrices, using the Strassen - Algorithm. It only supports square matrices of any size that is a power of 2. + Recursive function to calculate the product of two matrices using Strassen's + algorithm. Only supports square matrices whose dimensions are a power of 2. + + Strassen's algorithm reduces matrix multiplication from 8 recursive + multiplications (naive divide-and-conquer) to 7, at the cost of more + additions and subtractions. This gives a better asymptotic complexity: + + - Naive matrix multiplication: O(n^3) + - Naive divide-and-conquer: O(n^3) — 8 multiplications of n/2 size + - Strassen's algorithm: O(n^2.807) — 7 multiplications of n/2 size + + The 7 intermediate products (t1-t7) are combined to form the four + quadrants of the result matrix using only additions and subtractions. + + Reference: Strassen, V. (1969). Gaussian elimination is not optimal. + Numerische Mathematik, 13(4), 354-356. """ if matrix_dimensions(matrix_a) == (2, 2): return default_matrix_multiplication(matrix_a, matrix_b) @@ -106,6 +131,26 @@ def actual_strassen(matrix_a: list, matrix_b: list) -> list: def strassen(matrix1: list, matrix2: list) -> list: """ + Multiply two matrices of arbitrary dimensions using Strassen's algorithm. + + Handles non-square and non-power-of-2 matrices by padding with zeros + to the next power of 2, running Strassen's algorithm, then removing + the padding from the result. + + Time complexity: O(n^2.807) where n is the padded dimension. + Space complexity: O(n^2) for the padded matrices. + + Args: + matrix1: First matrix (m x n). + matrix2: Second matrix (n x p). Number of columns in matrix1 + must equal number of rows in matrix2. + + Returns: + Result matrix (m x p). + + Raises: + Exception: If matrix dimensions are incompatible for multiplication. + >>> strassen([[2,1,3],[3,4,6],[1,4,2],[7,6,7]], [[4,2,3,4],[2,1,1,1],[8,6,4,2]]) [[34, 23, 19, 15], [68, 46, 37, 28], [28, 18, 15, 12], [96, 62, 55, 48]] >>> strassen([[3,7,5,6,9],[1,5,3,7,8],[1,4,4,5,7]], [[2,4],[5,2],[1,7],[5,5],[7,8]])