CS/알고리즘

[알고리즘] Divide and Conquer - Strassen’s Matrix Multiplication

윤곰이 2024. 4. 23. 17:01
학교에서 들은 알고리즘 분석 강의 내용을 복습하면서 작성한 글입니다.
March 25, 2024 12:17 AM

 

기존의 행렬 곱셈 방식

for (i = 1; i <= n; i++)
	for (j = 1; j <= n; j++) {
		C[i][j] = 0;
		for (k = 1; k <= n; k++)
			C[i][j] = C[i][j] + A[i][k] * B[k][j];}

중첩 세번 → 시간 복잡도: θ(n^3)

Strassen

  • 곱셈 7번 한 번에 하고 → 이후 곱셈 없이 더하기만!! (곱셈 줄이기..!)
  • 행렬의 크기가 커질수록 이득

 

Algorithm

  • 두 개의 n X n 행렬, n은 2의 거듭제곱 형태
void strassen(int n, nxn_matrix A, nxn_matrix B, nxn_matrix &C){
	if (n <= thresshold) //n이 임계값보다 작거나 같으면 일반적 행렬곱셈 알고리즘 사용
		compute C = A x B using the standard algorithm
	else{
		partition A into four submatrices A11, A12, A21, A22;
		partition B into four submatrices B11, B12, B21, B22;
		compute C = A x B using Strassen's method;
	}
}

 

Time Complexity

  • 곱셈 (n=2^k → k=log_2(n))
    • T(1)=1 → 입력행렬 크기 1: 곱셈 연산 한번만
    • T(n) = 7(n/2) → 각 부분 문제를 해결할 때 7번의 재귀 호출을 수행하며, 각 호출마다 입력 크기가 n/2로 줄어든다k=log_2(n)이므로
    • T(n) = 7^k T(1) = 7^log_2(n) * 1 = n^log_2(7) = n^2.81
    • T(n) = 7T(n/2) = 7^2 T(n/2^2) = 7^3 T(n/2^3) = ... = 7^k T(n/2^k)
  • 덧셈까지 고려
    • T(1)=0
    • T(n) = 7T(n/2) + 18(n/2)^2
    • T(n) = 6n^log_2(7) - 6n^2 ~~ 6n^2.81 - 6n^2