Strassens Algorithm for Matrix Multiplication¶
This algorithm follows divide and conquer strategy. We divide the matrices into equal sizes of square matrices and then further divide them until their size is 2x2.
These are the matrices on which we will be performing matrix multiplication
In [1]:
# 4x4 and 4x3 matrices
mat3 = [
[5.0,2.0,3.0,1.0],
[7.0,6.0,4.0,2.0],
[5.0,6.0,7.0,5.0],
[5.0,7.0,9.0,7.0]
]
mat4 = [
[5.0,2.0,3.0],
[7.0,6.0,4.0],
[5.0,6.0,7.0],
[5.0,7.0,9.0]
]
# 3x3 and 3x3 matrices
mat5 = [
[5.0,2.0,3.0],
[7.0,6.0,4.0],
[5.0,6.0,7.0],
]
mat6 = [
[5.0,2.0,3.0],
[7.0,6.0,5.0],
[4.0,6.0,3.0],
]
# 5x5 and 5x4 matrices
mat7 = [
[2.0,3.0,4.0,5.0,6.0],
[3.0,4.0,5.0,6.0,7.0],
[4.0,5.0,6.0,7.0,8.0],
[5.0,6.0,7.0,8.0,9.0],
[6.0,7.0,8.0,9.0,10.0],
]
mat8 = [
[3.0,5.0,7.0,9.0],
[5.0,8.0,11.0,14.0],
[7.0,11.0,15.0,19.0],
[9.0,14.0,19.0,24.0],
[11.0,17.0,23.0,29.0],
]
# 5x8 matrix
mat9 = [
[3.0,5.0,7.0,9.0,11.0,13.0,15.0],
[5.0,8.0,11.0,14.0,17.0,20.0,23.0],
[7.0,11.0,15.0,19.0,23.0,27.0,31.0],
[9.0,14.0,19.0,24.0,29.0,34.0,39.0],
[11.0,17.0,23.0,29.0,35.0,41.0,47.0],
]
Function to add and display matrices¶
In [2]:
def disp(mat):
for i in mat:
for j in i:
print(int(j),end=" ")
print()
def addMatrices(a,b):
x = [row[:] for row in a]
for i in range(len(b)):
for j in range(len(b[i])):
x[i][j] += b[i][j]
return x
This function makes the matrix ready for the strassen function to work¶
In [3]:
def tokenize(matA,matB):
# for matrix 1
if len(matA[0])&1:
r = [0.0]
for i in range(len(matA)):
matA[i] = matA[i]+r
if len(matA)&1:
r = [0.0 for i in range(len(matA[0]))]
matA.append(r)
# now we must make the number of rows and columns equal
ln = len(matA)-len(matA[0])
if ln>0:
r = [0.0 for i in range(ln)]
for i in range(len(matA)):
matA[i] = matA[i]+r
if ln<0:
r = [0.0 for i in range(len(matA[0]))]
for i in range(-1*ln):
matA.append(r)
# for matrix 2
# first we need to make the number of columns of matrix-1 equal to number of rows of matrix-2
ln = len(matA[0])-len(matB)
if ln:
r = [0 for i in range(len(matB[0]))]
for i in range(ln):
matB.append(r)
rem = len(matB[0])%len(matB)
if(rem):
r = [0.0 for i in range(len(matB)-rem)]
for i in range(len(matB)):
matB[i] = matB[i]+r
# end of function
This is the main function that is responsible for running the algorithm and does most of the heavy lifting¶
In [4]:
def strassen(mat1,mat2):
# base condition
if (len(mat1) == 2 and len(mat1[0]) == 2
and len(mat2) == 2 and len(mat2[0]) == 2):
a11 = mat1[0][0]*mat2[0][0]+mat1[0][1]*mat2[1][0]
a12 = mat1[0][0]*mat2[0][1]+mat1[0][1]*mat2[1][1]
a21 = mat1[1][0]*mat2[0][0]+mat1[1][1]*mat2[1][0]
a22 = mat1[1][0]*mat2[0][1]+mat1[1][1]*mat2[1][1]
return [[a11,a12],[a21,a22]]
tokenize(mat1,mat2)
# number of equal parts of the matrices
n = len(mat1)
n = (n>>1) + (n&1)
# dividing matrix-1
mat11 = [[0.0 for i in range(n)] for j in range(n)]
mat12 = [[0.0 for i in range(n)] for j in range(n)]
mat13 = [[0.0 for i in range(n)] for j in range(n)]
mat14 = [[0.0 for i in range(n)] for j in range(n)]
for i in range(n):
for j in range(n):
mat11[i][j] = mat1[i][j]
mat12[i][j] = mat1[i][j+n]
mat13[i][j] = mat1[i+n][j]
mat14[i][j] = mat1[i+n][j+n]
# dividing matrix-2
mat21 = [[0.0 for i in range(n)] for j in range(n)]
mat22 = [[0.0 for i in range(n)] for j in range(n)]
mat23 = [[0.0 for i in range(n)] for j in range(n)]
mat24 = [[0.0 for i in range(n)] for j in range(n)]
for i in range(n):
for j in range(n):
mat21[i][j] = mat2[i][j]
mat22[i][j] = mat2[i][j+n]
mat23[i][j] = mat2[i+n][j]
mat24[i][j] = mat2[i+n][j+n]
m11 = addMatrices(strassen(mat11,mat21),strassen(mat12,mat23))
m12 = addMatrices(strassen(mat11,mat22),strassen(mat12,mat24))
m21 = addMatrices(strassen(mat13,mat21),strassen(mat14,mat23))
m22 = addMatrices(strassen(mat13,mat22),strassen(mat14,mat24))
# adding the matrices to make the result
res = []
for i in range(n):
tempArr = m11[i][:]+m12[i][:]
res.append(tempArr[:])
for i in range(n):
tempArr = m21[i][:]+m22[i][:]
res.append(tempArr[:])
return res
This function removes extra rows and columns of 0s¶
In [5]:
def rem0(res):
resX = len(res[0])
resY = len(res)
# first for the columns
listResX = []
for i in range(resX):
if res[0][i] == 0.0:
isTrue = False
for j in range(resY):
isTrue = (res[j][i] == 0.0)
if not isTrue:
break
if isTrue:
listResX.append(i)
# similarly for the rows
listResY = []
for i in range(resY):
if res[i][0] == 0.0:
isTrue = False
for j in range(resX):
isTrue = (res[i][j] == 0.0)
if not isTrue:
break
if isTrue:
listResY.append(i)
# now we remove the elements
for index in sorted(listResY, reverse=True):
del res[index]
resY = len(res)
for index in sorted(listResX, reverse=True):
for j in range(resY):
del res[j][index]
return res
This is the Driver Function¶
In [6]:
def matMul(matX,matY):
# if the matrices cannot be multiplied, we throw an exception
if len(matX[0]) != len(matY):
raise ArithmeticError("Matrices cannot be multiplied")
matA = [row[:] for row in matX]
matB = [row[:] for row in matY]
tokenize(matA,matB)
# now we must divide matrix-2 into equal parts where each matrix is a square matrix
# for those matrices which is not a square matrix, we adjust them by adding columns of 0s
# we multiply each matrices by matrix-1 and add the result
sizeY = len(matB)
sizeX = len(matB[0]) # this variable is going to be used a lot
parts = sizeX/sizeY
parts = int(parts) # just to be safe
res = []
for p in range(0,parts):
mat = [] # this is where we store the temporary matrix
for i in range(sizeY):
temp = []
for j in range(sizeY):
temp.append(matB[i][j+p*sizeY])
mat.append(temp[:])
if p == 0:
res = strassen(matA,mat)
else: # join the matrices
tmp = strassen(matA,mat)
for i in range(sizeY):
res[i] = res[i]+tmp[i]
return rem0(res)
Driver Code¶
In [7]:
if __name__=="__main__":
res1 = matMul(mat3,mat4)
res2 = matMul(mat5,mat6)
res3 = matMul(mat7,mat8)
res4 = matMul(mat7,mat9)
disp(res1)
print()
disp(res2)
print()
disp(res3)
print()
disp(res4)
59 47 53 107 88 91 127 123 133 154 155 169 51 40 34 93 74 63 95 88 66 160 250 340 430 195 305 415 525 230 360 490 620 265 415 565 715 300 470 640 810 160 250 340 430 520 610 700 195 305 415 525 635 745 855 230 360 490 620 750 880 1010 265 415 565 715 865 1015 1165 300 470 640 810 980 1150 1320
Now we'll check if our calculations are correct, using numpy¶
In [8]:
import numpy as np
# converting into numpy array
nmat3 = np.array(mat3)
nmat4 = np.array(mat4)
nmat5 = np.array(mat5)
nmat6 = np.array(mat6)
nmat7 = np.array(mat7)
nmat8 = np.array(mat8)
nmat9 = np.array(mat9)
disp(np.dot(nmat3,nmat4))
print()
disp(np.dot(nmat5,nmat6))
print()
disp(np.dot(nmat7,nmat8))
print()
disp(np.dot(nmat7,nmat9))
59 47 53 107 88 91 127 123 133 154 155 169 51 40 34 93 74 63 95 88 66 160 250 340 430 195 305 415 525 230 360 490 620 265 415 565 715 300 470 640 810 160 250 340 430 520 610 700 195 305 415 525 635 745 855 230 360 490 620 750 880 1010 265 415 565 715 865 1015 1165 300 470 640 810 980 1150 1320