Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions divide_and_conquer/strassen_matrix_multiplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@

def default_matrix_multiplication(a: list, b: list) -> list:
"""
Multiplication only for 2x2 matrices
Standard multiplication for 2x2 matrices (base case).

Used as the base case for Strassen's algorithm when the matrix
cannot be subdivided further. Uses 8 multiplications.

Time complexity: O(1) — fixed size input.

>>> default_matrix_multiplication([[1, 2], [3, 4]], [[5, 6], [7, 8]])
[[19, 22], [43, 50]]
"""
if len(a) != 2 or len(a[0]) != 2 or len(b) != 2 or len(b[0]) != 2:
raise Exception("Matrices are not 2x2")
Expand All @@ -17,13 +25,15 @@ def default_matrix_multiplication(a: list, b: list) -> list:


def matrix_addition(matrix_a: list, matrix_b: list):
"""Element-wise addition of two matrices of equal dimensions."""
return [
[matrix_a[row][col] + matrix_b[row][col] for col in range(len(matrix_a[row]))]
for row in range(len(matrix_a))
]


def matrix_subtraction(matrix_a: list, matrix_b: list):
"""Element-wise subtraction of two matrices of equal dimensions."""
return [
[matrix_a[row][col] - matrix_b[row][col] for col in range(len(matrix_a[row]))]
for row in range(len(matrix_a))
Expand Down Expand Up @@ -64,6 +74,7 @@ def split_matrix(a: list) -> tuple[list, list, list, list]:


def matrix_dimensions(matrix: list) -> tuple[int, int]:
"""Return (rows, columns) of a matrix."""
return len(matrix), len(matrix[0])


Expand All @@ -73,8 +84,22 @@ def print_matrix(matrix: list) -> None:

def actual_strassen(matrix_a: list, matrix_b: list) -> list:
"""
Recursive function to calculate the product of two matrices, using the Strassen
Algorithm. It only supports square matrices of any size that is a power of 2.
Recursive function to calculate the product of two matrices using Strassen's
algorithm. Only supports square matrices whose dimensions are a power of 2.

Strassen's algorithm reduces matrix multiplication from 8 recursive
multiplications (naive divide-and-conquer) to 7, at the cost of more
additions and subtractions. This gives a better asymptotic complexity:

- Naive matrix multiplication: O(n^3)
- Naive divide-and-conquer: O(n^3) — 8 multiplications of n/2 size
- Strassen's algorithm: O(n^2.807) — 7 multiplications of n/2 size

The 7 intermediate products (t1-t7) are combined to form the four
quadrants of the result matrix using only additions and subtractions.

Reference: Strassen, V. (1969). Gaussian elimination is not optimal.
Numerische Mathematik, 13(4), 354-356.
"""
if matrix_dimensions(matrix_a) == (2, 2):
return default_matrix_multiplication(matrix_a, matrix_b)
Expand Down Expand Up @@ -106,6 +131,26 @@ def actual_strassen(matrix_a: list, matrix_b: list) -> list:

def strassen(matrix1: list, matrix2: list) -> list:
"""
Multiply two matrices of arbitrary dimensions using Strassen's algorithm.

Handles non-square and non-power-of-2 matrices by padding with zeros
to the next power of 2, running Strassen's algorithm, then removing
the padding from the result.

Time complexity: O(n^2.807) where n is the padded dimension.
Space complexity: O(n^2) for the padded matrices.

Args:
matrix1: First matrix (m x n).
matrix2: Second matrix (n x p). Number of columns in matrix1
must equal number of rows in matrix2.

Returns:
Result matrix (m x p).

Raises:
Exception: If matrix dimensions are incompatible for multiplication.

>>> strassen([[2,1,3],[3,4,6],[1,4,2],[7,6,7]], [[4,2,3,4],[2,1,1,1],[8,6,4,2]])
[[34, 23, 19, 15], [68, 46, 37, 28], [28, 18, 15, 12], [96, 62, 55, 48]]
>>> strassen([[3,7,5,6,9],[1,5,3,7,8],[1,4,4,5,7]], [[2,4],[5,2],[1,7],[5,5],[7,8]])
Expand Down
Loading