I
I
ItsTipTop2012-05-13 22:13:28
Search Engine Optimization
ItsTipTop, 2012-05-13 22:13:28

Strassen's algorithm. Optimization

Hey Habr.
There is a Strassen algorithm for matrix multiplication, but it does not shine with speed. I would be glad to see your comments on code optimization.

For 512*512 matrices it gave the following results:
Normal mode 1538911
Strassen mode 784078

I hope the code is readable :)

int Strassen(int N, int **MatrixA, int **MatrixB, int **MatrixC)
{
    
    int HalfSize = N/2;
    int newSize = N/2;
    
    if ( N <= 32 )//choosing the threshold is extremely important, try N<=2 to see the result
    {
        MUL(MatrixA,MatrixB,MatrixC,N);
    }
    else
    {
        int** A11;
        int** A12;
        int** A21;
        int** A22;
        
        int** B11;
        int** B12;
        int** B21;
        int** B22;
        
        int** C11;
        int** C12;
        int** C21;
        int** C22;
        
        int** M1;
        int** M2;
        int** M3;
        int** M4;
        int** M5;
        int** M6;
        int** M7;
        int** AResult;
        int** BResult;
        
        //making a 1 diminsional pointer based array.
        A11 = new int *[newSize];
        A12 = new int *[newSize];
        A21 = new int *[newSize];
        A22 = new int *[newSize];
        
        B11 = new int *[newSize];
        B12 = new int *[newSize];
        B21 = new int *[newSize];
        B22 = new int *[newSize];
        
        C11 = new int *[newSize];
        C12 = new int *[newSize];
        C21 = new int *[newSize];
        C22 = new int *[newSize];
        
        M1 = new int *[newSize];
        M2 = new int *[newSize];
        M3 = new int *[newSize];
        M4 = new int *[newSize];
        M5 = new int *[newSize];
        M6 = new int *[newSize];
        M7 = new int *[newSize];
        
        AResult = new int *[newSize];
        BResult = new int *[newSize];
        
        int newLength = newSize;
        
        //making that 1 diminsional pointer based array , a 2D pointer based array
        for ( int i = 0; i < newSize; i++)
        {
            A11[i] = new int[newLength];
            A12[i] = new int[newLength];
            A21[i] = new int[newLength];
            A22[i] = new int[newLength];
            
            B11[i] = new int[newLength];
            B12[i] = new int[newLength];
            B21[i] = new int[newLength];
            B22[i] = new int[newLength];
            
            C11[i] = new int[newLength];
            C12[i] = new int[newLength];
            C21[i] = new int[newLength];
            C22[i] = new int[newLength];
            
            M1[i] = new int[newLength];
            M2[i] = new int[newLength];
            M3[i] = new int[newLength];
            M4[i] = new int[newLength];
            M5[i] = new int[newLength];
            M6[i] = new int[newLength];
            M7[i] = new int[newLength];
            
            AResult[i] = new int[newLength];
            BResult[i] = new int[newLength];
            
            
        }
        //splitting input Matrixes, into 4 submatrices each.
        for (int i = 0; i < N / 2; i++)
        {
            for (int j = 0; j < N / 2; j++)
            {
                A11[i][j] = MatrixA[i][j];
                A12[i][j] = MatrixA[i][j + N / 2];
                A21[i][j] = MatrixA[i + N / 2][j];
                A22[i][j] = MatrixA[i + N / 2][j + N / 2];
                
                B11[i][j] = MatrixB[i][j];
                B12[i][j] = MatrixB[i][j + N / 2];
                B21[i][j] = MatrixB[i + N / 2][j];
                B22[i][j] = MatrixB[i + N / 2][j + N / 2];
                
            }
        }
        
        //here we calculate M1..M7 matrices .
        //M1[][]
        ADD( A11,A22,AResult, HalfSize);
        ADD( B11,B22,BResult, HalfSize);
        Strassen( HalfSize, AResult, BResult, M1 ); //now that we need to multiply this , we use the strassen itself .
        
        
        //M2[][]
        ADD( A21,A22,AResult, HalfSize);              //M2=(A21+A22)B11
        Strassen(HalfSize, AResult, B11, M2);       //Mul(AResult,B11,M2);
        
        //M3[][]
        SUB( B12,B22,BResult, HalfSize);              //M3=A11(B12-B22)
        Strassen(HalfSize, A11, BResult, M3);       //Mul(A11,BResult,M3);
        
        //M4[][]
        SUB( B21, B11, BResult, HalfSize);           //M4=A22(B21-B11)
        Strassen(HalfSize, A22, BResult, M4);       //Mul(A22,BResult,M4);
        
        //M5[][]
        ADD( A11, A12, AResult, HalfSize);           //M5=(A11+A12)B22
        Strassen(HalfSize, AResult, B22, M5);       //Mul(AResult,B22,M5);
        
        
        //M6[][]
        SUB( A21, A11, AResult, HalfSize);
        ADD( B11, B12, BResult, HalfSize);             //M6=(A21-A11)(B11+B12)
        Strassen( HalfSize, AResult, BResult, M6);    //Mul(AResult,BResult,M6);
        
        //M7[][]
        SUB(A12, A22, AResult, HalfSize);
        ADD(B21, B22, BResult, HalfSize);             //M7=(A12-A22)(B21+B22)
        Strassen(HalfSize, AResult, BResult, M7);     //Mul(AResult,BResult,M7);
        
        //C11 = M1 + M4 - M5 + M7;
        ADD( M1, M4, AResult, HalfSize);
        SUB( M7, M5, BResult, HalfSize);
        ADD( AResult, BResult, C11, HalfSize);
        
        //C12 = M3 + M5;
        ADD( M3, M5, C12, HalfSize);
        
        //C21 = M2 + M4;
        ADD( M2, M4, C21, HalfSize);
        
        //C22 = M1 + M3 - M2 + M6;
        ADD( M1, M3, AResult, HalfSize);
        SUB( M6, M2, BResult, HalfSize);
        ADD( AResult, BResult, C22, HalfSize);
        
        
        //at this point , we have calculated the c11..c22 matrices, and now we are going to
        //put them together and make a unit matrix which would describe our resulting Matrix.
        for (int i = 0; i < N/2 ; i++)
        {
            for (int j = 0 ; j < N/2 ; j++)
            {
                MatrixC[i][j] = C11[i][j];
                MatrixC[i][j + N / 2] = C12[i][j];
                MatrixC[i + N / 2][j] = C21[i][j];
                MatrixC[i + N / 2][j + N / 2] = C22[i][j];
            }
        }
        
        // dont forget to free the space we alocated for matrices,
        for (int i = 0; i < newLength; i++)
        {
            delete[] A11[i];delete[] A12[i];delete[] A21[i];
            delete[] A22[i];
            
            delete[] B11[i];delete[] B12[i];delete[] B21[i];
            delete[] B22[i];
            delete[] C11[i];delete[] C12[i];delete[] C21[i];
            delete[] C22[i];
            delete[] M1[i];delete[] M2[i];delete[] M3[i];delete[] M4[i];
            delete[] M5[i];delete[] M6[i];delete[] M7[i];
            delete[] AResult[i];delete[] BResult[i] ;
        }
        delete[] A11;delete[] A12;delete[] A21;delete[] A22;
        delete[] B11;delete[] B12;delete[] B21;delete[] B22;
        delete[] C11;delete[] C12;delete[] C21;delete[] C22;
        delete[] M1;delete[] M2;delete[] M3;delete[] M4;delete[] M5;
        delete[] M6;delete[] M7;
        delete[] AResult;
        delete[] BResult ;
        
        
    }//end of else
    
    
  return 0;
}

int ADD(int** MatrixA, int** MatrixB, int** MatrixResult, int MatrixSize )
{
    for ( int i = 0; i < MatrixSize; i++)
    {
        for ( int j = 0; j < MatrixSize; j++)
        {
            MatrixResult[i][j] =  MatrixA[i][j] + MatrixB[i][j];
        }
    }
  return 0;
}

int SUB(int** MatrixA, int** MatrixB, int** MatrixResult, int MatrixSize )
{
    for ( int i = 0; i < MatrixSize; i++)
    {
        for ( int j = 0; j < MatrixSize; j++)
        {
            MatrixResult[i][j] =  MatrixA[i][j] - MatrixB[i][j];
        }
    }
  return 0;
}

int MUL( int** MatrixA, int** MatrixB, int** MatrixResult, int MatrixSize )
{
    for (int i=0;i<MatrixSize ;i++)
    {
        for (int j=0;j<MatrixSize ;j++)
        {
            MatrixResult[i][j]=0;
            for (int k=0;k<MatrixSize ;k++)
            {
                MatrixResult[i][j]=MatrixResult[i][j]+MatrixA[i][k]*MatrixB[k][j];
            }
        }
    }
  return 0;
}



Thank you!

Answer the question

In order to leave comments, you need to log in

7 answer(s)
G
Gribozavr, 2012-05-14
@ItsTipTop

First, convert all matrices to flat arrays int[N * N], and the inversions as M[j * N + i].

B
BrainHacker, 2012-05-13
@BrainHacker

Damn, you have so much dynamic memory allocation that it's getting old! At least try to allocate memory in a big batch. And it is better to allocate all the necessary memory at the start of the algorithm. Once.

K
kefiijrw, 2012-05-14
@kefiijrw

I think you have a bug in the MUL function. Remember the algorithm too lazy, so the assumption.
In line

     MatrixResult[i][j]=MatrixResult[i][j]+MatrixA[i][k]*MatrixB[k][j];

meaning instead of = should be += or something like that.
Well, can it be parallelized?

B
bayandin, 2012-05-14
@bayandin

You can read zealint.ru/fast-matrix-multiplication-results.html.
Let me quote a little from this post:
"First, I first wrote the dumbest version that multiplies row by column in a triple loop (C [ i ] [ j ] += A [ i ] [ k ] * B [ k ] [ j ], the loops go in order ijk). These were the same 800 s. (you have to start somewhere). Here you can also play around by rearranging the cycles: the fastest is obtained if you choose ikj, in this case all three matrices are scanned row by row. This version already fit in 3 minutes (which is almost 5 times faster). You can simply transpose the second matrix directly on the input and swap the indices k and j when accessing matrix B. This is even faster (150 s.). What else can be done without the intervention of the assembler? - change int (32 bits) to short (16 bits) for incoming matrices and compile with native processor compiler with maximum acceleration options. This gives 94 seconds. Up to 55 seconds can be reached, if you expand the matrices into an array. Then records of the form A [ i * n + k ] will be obtained, but all such multiplications must either be taken out of the inner loop, or generally replaced by adding pointers with the number n in the right places. In this case, the compiler optimizes much better. That's it, without assembler, you can only do one more thing: break the matrix into blocks that fit into the cache, but I decided that I would do it right away in assembler.

E
Eddy_Em, 2012-05-13
@Eddy_Em

Do you need to implement your method (i.e. is it some kind of laboratory) or are you reinventing the wheel?
If the latter, then I advise you to use, for example, the BLAS or GSL libraries.

K
Konstantin, 2012-05-14
@Norraxx

have you tried pastebin.com/ for code?

I
IlVin, 2012-05-16
@IlVin

Read Alexey Tutubalin's optimizations - back in shaggy 2007, he "felt" processors by multiplying matrices: blog.lexa.ru/2007/01/04/o_peremnozhenii_matric_i_prochix_arxitekturnix_zamorochkax.html

Didn't find what you were looking for?

Ask your question

Ask a Question

731 491 924 answers to any question