Answer the question
In order to leave comments, you need to log in
Why is the Strassen-Winegrad algorithm slow?
The algorithm, compared to the classical one, is very slow and consumes a lot of memory. For example, the classic for 256 by 256 matrices does a multiplication in 8 ms. And I have Strassen for 26000 ms
#include <iostream>
#include <random>
void generate(int32_t** A, int32_t N)
{
std::random_device rnd_device;
std::mt19937 mersenne_engine{ rnd_device() };
std::uniform_int_distribution<int32_t> dist{ 0, 100 };
auto gen = [&dist, &mersenne_engine]() {
return dist(mersenne_engine);
};
for (int32_t i = 0; i < N; i++)
{
for (int32_t j = 0; j < N; j++)
{
A[i][j] = gen();
}
}
}
int32_t** initializeMatrix(int32_t n) {
int32_t** temp = new int32_t* [n];
for (int32_t i = 0; i < n; i++)
temp[i] = new int32_t[n];
return temp;
}
int32_t** normalMultiply(int32_t** A, int32_t** B, int32_t n)
{
int32_t** C = initializeMatrix(n);
for (int32_t i = 0; i < n; i++)
{
for (int32_t j = 0; j < n; j++)
{
C[i][j] = 0;
for (int32_t k = 0; k < n; k++)
{
C[i][j] += A[i][k] * B[k][j];
}
}
}
return C;
}
void input(int32_t** M, int32_t n) {
std::cout << "Enter matrix: " << std::endl;
for (int32_t i = 0; i < n; i++)
for (int32_t j = 0; j < n; j++)
std::cin >> M[i][j];
std::cout << std::endl;
}
void printMatrix(int32_t** M, int32_t n) {
for (int32_t i = 0; i < n; i++) {
for (int32_t j = 0; j < n; j++)
std::cout << M[i][j] << " ";
std::cout << std::endl;
}
std::cout << std::endl;
}
int32_t** add(int32_t** M1, int32_t** M2, int32_t n) {
int32_t** temp = initializeMatrix(n);
for (int32_t i = 0; i < n; i++)
for (int32_t j = 0; j < n; j++)
temp[i][j] = M1[i][j] + M2[i][j];
return temp;
}
int32_t** subtract(int32_t** M1, int32_t** M2, int32_t n) {
int32_t** temp = initializeMatrix(n);
for (int32_t i = 0; i < n; i++)
for (int32_t j = 0; j < n; j++)
temp[i][j] = M1[i][j] - M2[i][j];
return temp;
}
int32_t** strassenMultiply(int32_t** A, int32_t** B, int32_t n) {
if (n == 1) {
int32_t** C = initializeMatrix(1);
C[0][0] = A[0][0] * B[0][0];
return C;
}
int32_t** C = initializeMatrix(n);
int32_t k = n / 2;
int32_t** A11 = initializeMatrix(k);
int32_t** A12 = initializeMatrix(k);
int32_t** A21 = initializeMatrix(k);
int32_t** A22 = initializeMatrix(k);
int32_t** B11 = initializeMatrix(k);
int32_t** B12 = initializeMatrix(k);
int32_t** B21 = initializeMatrix(k);
int32_t** B22 = initializeMatrix(k);
for (int32_t i = 0; i < k; i++)
for (int32_t j = 0; j < k; j++) {
A11[i][j] = A[i][j];
A12[i][j] = A[i][k + j];
A21[i][j] = A[k + i][j];
A22[i][j] = A[k + i][k + j];
B11[i][j] = B[i][j];
B12[i][j] = B[i][k + j];
B21[i][j] = B[k + i][j];
B22[i][j] = B[k + i][k + j];
}
int32_t** S1 = add(A21, A22, k);
int32_t** S2 = subtract(S1, A11, k);
int32_t** S3 = subtract(A11, A21, k);
int32_t** S4 = subtract(A12, S2, k);
int32_t** S5 = subtract(B12, B11, k);
int32_t** S6 = subtract(B22, S5, k);
int32_t** S7 = subtract(B22, B12, k);
int32_t** S8 = subtract(S6, B21, k);
int32_t** P1 = strassenMultiply(S2, S6, k);
int32_t** P2 = strassenMultiply(A11, B11, k);
int32_t** P3 = strassenMultiply(A12, B21, k);
int32_t** P4 = strassenMultiply(S3, S7, k);
int32_t** P5 = strassenMultiply(S1, S5, k);
int32_t** P6 = strassenMultiply(S4, B22, k);
int32_t** P7 = strassenMultiply(A22, S8, k);
int32_t** T1 = add(P1, P2, k);
int32_t** T2 = add(T1, P4, k);
int32_t** C11 = add(P2, P3, k);
int32_t** C12 = add(add(P5, P6, k), T1, k);
int32_t** C21 = subtract(T2, P7, k);
int32_t** C22 = add(T2, P5, k);
for (int32_t i = 0; i < k; i++)
for (int32_t j = 0; j < k; j++) {
C[i][j] = C11[i][j];
C[i][j + k] = C12[i][j];
C[k + i][j] = C21[i][j];
C[k + i][k + j] = C22[i][j];
}
for (int32_t i = 0; i < k; i++) {
delete[] A11[i];
delete[] A12[i];
delete[] A21[i];
delete[] A22[i];
delete[] B11[i];
delete[] B12[i];
delete[] B21[i];
delete[] B22[i];
delete[] S1[i];
delete[] S2[i];
delete[] S3[i];
delete[] S4[i];
delete[] S5[i];
delete[] S6[i];
delete[] S7[i];
delete[] S8[i];
delete[] P1[i];
delete[] P2[i];
delete[] P3[i];
delete[] P4[i];
delete[] P5[i];
delete[] P6[i];
delete[] P7[i];
delete[] T1[i];
delete[] T2[i];
delete[] C11[i];
delete[] C12[i];
delete[] C21[i];
delete[] C22[i];
}
delete[] A11;
delete[] A12;
delete[] A21;
delete[] A22;
delete[] B11;
delete[] B12;
delete[] B21;
delete[] B22;
delete[] S1;
delete[] S2;
delete[] S3;
delete[] S4;
delete[] S5;
delete[] S6;
delete[] S7;
delete[] S8;
delete[] P1;
delete[] P2;
delete[] P3;
delete[] P4;
delete[] P5;
delete[] P6;
delete[] P7;
delete[] T1;
delete[] T2;
delete[] C11;
delete[] C12;
delete[] C21;
delete[] C22;
return C;
}
int32_t main() {
std::cout << "Enter size of the matrix: ";
int32_t n;
std::cin >> n;
int32_t** A = initializeMatrix(n);
int32_t** B = initializeMatrix(n);
//input(A, n);
generate(A, n);
//std::cout << "Matrix A:" << std::endl;
//printMatrix(A, n);
//input(B, n);
generate(B, n);
//std::cout << "Matrix B:" << std::endl;
//printMatrix(B, n);
int32_t** C = initializeMatrix(n);
int32_t time = clock();
//C = strassenMultiply(A, B, n);
C = normalMultiply(A, B, n);
time = clock() - time;
std::cout << time;
//std::cout << "Multipliction result:" << std::endl;
//printMatrix(C, n);
for (int32_t i = 0; i < n; i++)
delete[] A[i];
delete[] A;
for (int32_t i = 0; i < n; i++)
delete[] B[i];
delete[] B;
for (int32_t i = 0; i < n; i++)
delete[] C[i];
delete[] C;
return 0;
}
Answer the question
In order to leave comments, you need to log in
For example, because the wild number of allocations / releases of memory. And it's never fast.
And also memory leaks: add(add(P5, P6, k), T1, k); - the memory allocated in add(P5, P6, k) is not freed.
Didn't find what you were looking for?
Ask your questionAsk a Question
731 491 924 answers to any question