B
B
Bogdan Zhuvak2021-01-06 20:09:04
Algorithms
Bogdan Zhuvak, 2021-01-06 20:09:04

Why is the Strassen-Winegrad algorithm slow?

The algorithm, compared to the classical one, is very slow and consumes a lot of memory. For example, the classic for 256 by 256 matrices does a multiplication in 8 ms. And I have Strassen for 26000 ms

#include <iostream>
#include <random>

void generate(int32_t** A, int32_t N)
{
    std::random_device rnd_device;

    std::mt19937 mersenne_engine{ rnd_device() };
    std::uniform_int_distribution<int32_t> dist{ 0, 100 };

    auto gen = [&dist, &mersenne_engine]() {
        return dist(mersenne_engine);
    };

    for (int32_t i = 0; i < N; i++)
    {
        for (int32_t j = 0; j < N; j++)
        {
            A[i][j] = gen();
        }
    }
}

int32_t** initializeMatrix(int32_t n) {
    int32_t** temp = new int32_t* [n];
    for (int32_t i = 0; i < n; i++)
        temp[i] = new int32_t[n];
    return temp;
}

int32_t** normalMultiply(int32_t** A, int32_t** B, int32_t n)
{

    int32_t** C = initializeMatrix(n);
    for (int32_t i = 0; i < n; i++)
    {
        for (int32_t j = 0; j < n; j++)
        {
            C[i][j] = 0;
            for (int32_t k = 0; k < n; k++)
            {
                C[i][j] += A[i][k] * B[k][j];
            }
        }
    }
    return C;
}

void input(int32_t** M, int32_t n) {
    std::cout << "Enter matrix: " << std::endl;
    for (int32_t i = 0; i < n; i++)
        for (int32_t j = 0; j < n; j++)
            std::cin >> M[i][j];
    std::cout << std::endl;
}

void printMatrix(int32_t** M, int32_t n) {
    for (int32_t i = 0; i < n; i++) {
        for (int32_t j = 0; j < n; j++)
            std::cout << M[i][j] << " ";
        std::cout << std::endl;
    }
    std::cout << std::endl;
}

int32_t** add(int32_t** M1, int32_t** M2, int32_t n) {
    int32_t** temp = initializeMatrix(n);
    for (int32_t i = 0; i < n; i++)
        for (int32_t j = 0; j < n; j++)
            temp[i][j] = M1[i][j] + M2[i][j];
    return temp;
}

int32_t** subtract(int32_t** M1, int32_t** M2, int32_t n) {
    int32_t** temp = initializeMatrix(n);
    for (int32_t i = 0; i < n; i++)
        for (int32_t j = 0; j < n; j++)
            temp[i][j] = M1[i][j] - M2[i][j];
    return temp;
}

int32_t** strassenMultiply(int32_t** A, int32_t** B, int32_t n) {
    if (n == 1) {
        int32_t** C = initializeMatrix(1);
        C[0][0] = A[0][0] * B[0][0];
        return C;
    }
    int32_t** C = initializeMatrix(n);
    int32_t k = n / 2;

    int32_t** A11 = initializeMatrix(k);
    int32_t** A12 = initializeMatrix(k);
    int32_t** A21 = initializeMatrix(k);
    int32_t** A22 = initializeMatrix(k);
    int32_t** B11 = initializeMatrix(k);
    int32_t** B12 = initializeMatrix(k);
    int32_t** B21 = initializeMatrix(k);
    int32_t** B22 = initializeMatrix(k);

    for (int32_t i = 0; i < k; i++)
        for (int32_t j = 0; j < k; j++) {
            A11[i][j] = A[i][j];
            A12[i][j] = A[i][k + j];
            A21[i][j] = A[k + i][j];
            A22[i][j] = A[k + i][k + j];
            B11[i][j] = B[i][j];
            B12[i][j] = B[i][k + j];
            B21[i][j] = B[k + i][j];
            B22[i][j] = B[k + i][k + j];
        }
    int32_t** S1 = add(A21, A22, k);
    int32_t** S2 = subtract(S1, A11, k);
    int32_t** S3 = subtract(A11, A21, k);
    int32_t** S4 = subtract(A12, S2, k);
    int32_t** S5 = subtract(B12, B11, k);
    int32_t** S6 = subtract(B22, S5, k);
    int32_t** S7 = subtract(B22, B12, k);
    int32_t** S8 = subtract(S6, B21, k);


    int32_t** P1 = strassenMultiply(S2, S6, k);
    int32_t** P2 = strassenMultiply(A11, B11, k);
    int32_t** P3 = strassenMultiply(A12, B21, k);
    int32_t** P4 = strassenMultiply(S3, S7, k);
    int32_t** P5 = strassenMultiply(S1, S5, k);
    int32_t** P6 = strassenMultiply(S4, B22, k);
    int32_t** P7 = strassenMultiply(A22, S8, k);

    int32_t** T1 = add(P1, P2, k);
    int32_t** T2 = add(T1, P4, k);

    int32_t** C11 = add(P2, P3, k);
    int32_t** C12 = add(add(P5, P6, k), T1, k);
    int32_t** C21 = subtract(T2, P7, k);
    int32_t** C22 = add(T2, P5, k);

    for (int32_t i = 0; i < k; i++)
        for (int32_t j = 0; j < k; j++) {
            C[i][j] = C11[i][j];
            C[i][j + k] = C12[i][j];
            C[k + i][j] = C21[i][j];
            C[k + i][k + j] = C22[i][j];
        }

    for (int32_t i = 0; i < k; i++) {
        delete[] A11[i];
        delete[] A12[i];
        delete[] A21[i];
        delete[] A22[i];
        delete[] B11[i];
        delete[] B12[i];
        delete[] B21[i];
        delete[] B22[i];
        delete[] S1[i];
        delete[] S2[i];
        delete[] S3[i];
        delete[] S4[i];
        delete[] S5[i];
        delete[] S6[i];
        delete[] S7[i];
        delete[] S8[i];
        delete[] P1[i];
        delete[] P2[i];
        delete[] P3[i];
        delete[] P4[i];
        delete[] P5[i];
        delete[] P6[i];
        delete[] P7[i];
        delete[] T1[i];
        delete[] T2[i];
        delete[] C11[i];
        delete[] C12[i];
        delete[] C21[i];
        delete[] C22[i];
    }

    delete[] A11;
    delete[] A12;
    delete[] A21;
    delete[] A22;
    delete[] B11;
    delete[] B12;
    delete[] B21;
    delete[] B22;
    delete[] S1;
    delete[] S2;
    delete[] S3;
    delete[] S4;
    delete[] S5;
    delete[] S6;
    delete[] S7;
    delete[] S8;
    delete[] P1;
    delete[] P2;
    delete[] P3;
    delete[] P4;
    delete[] P5;
    delete[] P6;
    delete[] P7;
    delete[] T1;
    delete[] T2;
    delete[] C11;
    delete[] C12;
    delete[] C21;
    delete[] C22;

    return C;
}

int32_t main() {
    std::cout << "Enter size of the matrix: ";
    int32_t n;
    std::cin >> n;

    int32_t** A = initializeMatrix(n);
    int32_t** B = initializeMatrix(n);
    //input(A, n);
    generate(A, n);
    //std::cout << "Matrix A:" << std::endl;
    //printMatrix(A, n);
    //input(B, n);
    generate(B, n);
    //std::cout << "Matrix B:" << std::endl;
    //printMatrix(B, n);

    int32_t** C = initializeMatrix(n);
    int32_t time = clock();
    //C = strassenMultiply(A, B, n);
    C = normalMultiply(A, B, n);
    time = clock() - time;
    std::cout << time;
    //std::cout << "Multipliction result:" << std::endl;
    //printMatrix(C, n);

    for (int32_t i = 0; i < n; i++)
        delete[] A[i];
    delete[] A;

    for (int32_t i = 0; i < n; i++)
        delete[] B[i];
    delete[] B;

    for (int32_t i = 0; i < n; i++)
        delete[] C[i];
    delete[] C;

    return 0;
}

Answer the question

In order to leave comments, you need to log in

1 answer(s)
A
Andrey Ezhgurov, 2021-01-06
@eandr_67

For example, because the wild number of allocations / releases of memory. And it's never fast.
And also memory leaks: add(add(P5, P6, k), T1, k); - the memory allocated in add(P5, P6, k) is not freed.

Didn't find what you were looking for?

Ask your question

Ask a Question

731 491 924 answers to any question