CS/알고리즘
[알고리즘] Divide and Conquer - Strassen’s Matrix Multiplication
윤곰이
2024. 4. 23. 17:01
학교에서 들은 알고리즘 분석 강의 내용을 복습하면서 작성한 글입니다.
March 25, 2024 12:17 AM
기존의 행렬 곱셈 방식
for (i = 1; i <= n; i++)
for (j = 1; j <= n; j++) {
C[i][j] = 0;
for (k = 1; k <= n; k++)
C[i][j] = C[i][j] + A[i][k] * B[k][j];}

중첩 세번 → 시간 복잡도: θ(n^3)

Strassen
- 곱셈 7번 한 번에 하고 → 이후 곱셈 없이 더하기만!! (곱셈 줄이기..!)
- 행렬의 크기가 커질수록 이득

Algorithm
- 두 개의 n X n 행렬, n은 2의 거듭제곱 형태
void strassen(int n, nxn_matrix A, nxn_matrix B, nxn_matrix &C){
if (n <= thresshold) //n이 임계값보다 작거나 같으면 일반적 행렬곱셈 알고리즘 사용
compute C = A x B using the standard algorithm
else{
partition A into four submatrices A11, A12, A21, A22;
partition B into four submatrices B11, B12, B21, B22;
compute C = A x B using Strassen's method;
}
}
Time Complexity
- 곱셈 (n=2^k → k=log_2(n))
- T(1)=1 → 입력행렬 크기 1: 곱셈 연산 한번만
- T(n) = 7(n/2) → 각 부분 문제를 해결할 때 7번의 재귀 호출을 수행하며, 각 호출마다 입력 크기가 n/2로 줄어든다k=log_2(n)이므로
- T(n) = 7^k T(1) = 7^log_2(n) * 1 = n^log_2(7) = n^2.81
- T(n) = 7T(n/2) = 7^2 T(n/2^2) = 7^3 T(n/2^3) = ... = 7^k T(n/2^k)
- 덧셈까지 고려
- T(1)=0
- T(n) = 7T(n/2) + 18(n/2)^2
- T(n) = 6n^log_2(7) - 6n^2 ~~ 6n^2.81 - 6n^2