Even Idiots Can Make Game

백준 10830 G4 행렬 곱셉

Date/Lastmod
Section
PS
Title
백준 10830 G4 행렬 곱셈
Ps-Tags
분할 정복행렬
Solved
04-08/25

#1 접근

#2 완전 탐색으로 생각하기

작성한 알고리즘이 $O(B)$여도 $10^{12}$의 연산을 하게 되므로 제한 시간인 1초를 벗어나게 된다. 때문에 나이브하게 곱셈을 $B$번 반복하는 것으로 풀이해서는 안된다.

또한 행렬 곱셈에 $O(N^3)$이 소요된다. 행렬 곱셈 자체가 가장 부담되는 연산이므로, 이 곱셈 연산 수 자체를 줄이는 것으로 알고리즘 방향을 잡아야 한다.

#2 곱셈 짝짓기

곱셈을 뭉텅이로 두 개씩 짝짓는다면 어떨까? $A^9$를 예로 들어보자.

그림으로 표현하면 아래와 같다.

식으로 표현하면 아래와 같다.

$A^9 = A \cdot A^8 = A \cdot (A^4)^2 = A \cdot ((A^2)^2)^2 = A \cdot ((A\cdot A)^2)^2$

연산 횟수가 9회에서, $\log_{2}{(9-1)} + 1$회로 줄어들었다.

이를 일반화해보자. $A^n$을 구하는 함수 $f(A,n)$은 다음과 같이 정의할 수 있다.

$$ f(A, n) = \begin{cases} A & \text{if } n = 1 \\ f(A, \frac{n}{2})^2 & \text{if } n \text{ is even} \\ A \cdot f(A, n - 1) & \text{if } n \text{ is odd} \end{cases} $$

복잡도는 $O(N^3\log{B})$이다.

이를 그대로 한번 코드로 옮겨 보자.

우선 행렬 곱셈 함수인 mul은 아래와 같다.

using mat_t = vector<vector<long long>>;

mat_t mul(const mat_t& l, const ma_t& r) {
  mat_t ret(l.size(), vector<long long>(l.front().size(), 0));

  for (int i = 0; i < N; ++i) {
    for (int j = 0; j < N; ++j) {
      for (int k = 0; k < N; ++k) {
        ret[i][j] += l[i][k] * r[k][j];
      }
    }
  }
  
  return ret;
}

참고

보면 알겠지만 mulfor문이 3번 중첩된 흉측한 $N^3$의 함수이므로, 우리가 이런 짓을 하는 것은 모두 다 이 mul의 호출 횟수를 줄이기 위함이다.

행렬 지수승 최적화 재귀 함수는 아래와 같다.

mat_t f(const mat_t& A, long long n) {
  // 기저 사례
  if (n == 1) {
    return A;
  }

  // 지수가 짝수
  if (n % 2 == 0) {
    // 결과를 재사용하기 위해 변수에 저장
    mat_t half = f(A, n / 2);
    return mul(half, half);
  // 지수가 홀수
  } else {
    return mul(A, f(A, n - 1));
  }
}

참고

결과를 재사용하기 위해 half에 저장하여 mul(half, half)처럼 작성한 것을 주의한다. 만일 half를 사용하지 않고 mul(f(A, n / 2), f(A, n / 2))를 하면 그냥 곱셈을 계속해서 반복한 것과 다름없는 연산 횟수가 되어버린다.

예를 들어 $A^{11}$의 경우 다음과 같이 재귀가 호출된다.

#1 코드

코드는 재귀에서 반복을 사용하였고, 1000으로 나눈 나머지를 저장한다는 조건을 추가하였다.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include <iostream>
#include <vector>

using namespace std;
using ll = long long;
using mat_t = vector<vector<ll>>;

int N;
ll B;
mat_t mat;

/// @brief 행렬 l과 r의 곱연산 수행
mat_t mat_mul(const mat_t& l, const mat_t& r) {
  mat_t ret = mat_t(N, vector<ll>(N, 0));

  for (int i = 0; i < N; ++i) {
    for (int j = 0; j < N; ++j) {
      for (int k = 0; k < N; ++k) {
        ret[i][j] += l[i][k] * r[k][j];
      }
      ret[i][j] %= 1000;
    }
  }

  return ret;
}

/// @brief 행렬 제곱 연산 최적화
mat_t mat_pow(const mat_t& m, ll p) {
  // 반환값을 단위 행렬로 초기화 시켜둠
  mat_t ret = mat_t(N, vector<ll>(N, 0));
  for (int i = 0; i < N; ++i) {
    ret[i][i] = 1;
  }

  mat_t cpy = m;

  while (p > 0) {
    // 지수가 홀수일 때 
    // (마지막에 지수가 1이 되므로 한번은 무조건 거쳐감)
    if (p % 2 == 1) {
      ret = mat_mul(ret, cpy);
    }

    // 지수가 짝수일 때
    cpy = mat_mul(cpy, cpy);
    p /= 2;
  }
  return ret;
}

int main() {
  // 입력
  cin >> N >> B;
  mat = mat_t(N, vector<ll>(N));
  for (int i = 0; i < N; ++i) {
    for (int j = 0; j < N; ++j) {
      cin >> mat[i][j];
    }
  }

  // 처리
  mat_t ret = mat_pow(mat, B);

  // 출력
  for (int i = 0; i < N; ++i) {
    for (int j = 0; j < N; ++j) {
      cout << ret[i][j] << ' ';
    }
    cout << '\n';
  }
}
comments powered by Disqus