矩阵快速幂与矩阵加速

前两天被人大附巨佬喷了。。。于是决定今天学一下矩阵加速

矩阵快速幂

  • 矩阵乘法的规则就不说了,直接说快速幂吧。
  • 矩阵快速幂把原快速幂的代码里乘号的部分改成矩阵乘法就好了,可以写个函数或者重载运算符。
  • 然后就出现了一堆没开longlong的错误。。。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;

const ll mod = 1e9 + 7;

struct node{
vector<vector<ll> > array;
void init(int n){
array.resize(n + 1);
for(int i = 1; i <= n; i++) array[i].resize(n + 1);
}
void inite(int n, bool k){
for(int i = 1; i <= n; i++){
for(int j = 1; j <= n; j++){
array[i][j] = 0;
}
}
if(k) for(int i = 1; i <= n; i++) array[i][i] = 1;
}
};

struct Solution{
int n;
ll k;
node a, c, e;

node mul(node x, node y){
c.inite(n, 0);
for(int i = 1; i <= n; i++){
for(int j = 1; j <= n; j++){
for(int k = 1; k <= n; k++){
c.array[i][j] = (c.array[i][j] % mod + x.array[i][k] * y.array[k][j] % mod) % mod;
}
}
}
return c;
}

node power(node x, ll k){
node ans = e;
while(k){
if(k % 2 == 1){
ans = mul(ans, x);
}
x = mul(x, x);
k /= 2;
}
return ans;
}

void Solve(){
scanf("%d%lld", &n, &k);
a.init(n);
c.init(n);
e.init(n);
e.inite(n, 1);
for(int i = 1; i <= n; i++){
for(int j = 1; j <= n; j++){
scanf("%lld", &a.array[i][j]);
}
}
node ans = power(a, k);
for(int i = 1; i <= n; i++){
for(int j = 1; j <= n; j++){
printf("%lld ", ans.array[i][j] % mod);
}
printf("\n");
}
}
};

int main(){
Solution().Solve();
return 0;
}

矩阵加速

  • 矩阵加速,用来加速数列的递推,比如斐波那契数列的递推。

  • 能优化到$O(m^3log_2n)$

  • 分为两步:

    1. 构造初始矩阵
    2. 矩阵快速幂实现加速

感觉不难,就是构造初始矩阵比较懵。

  • 举个栗子:斐波那契数列,设定目标矩阵为$\left[\begin{array}{ccc}fib(n) \\fib(n - 1)\end{array}\right]$
    然后希望$\left[\begin{array}{ccc}fib(n-1) \\fib(n-2)\end{array}\right]$乘一个矩阵变成目标矩阵,就推出了这个矩阵$\left[\begin{array}{ccc}1\quad 1\\1\quad 0\end{array}\right]$。

  • 其他的也一样,可以参考这个

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;

const ll mod = 1e9 + 7;

struct node{
ll e[4][4];
void init(){
e[1][1] = 1, e[1][2] = 1, e[1][3] = 0;
e[2][1] = 0, e[2][2] = 0, e[2][3] = 1;
e[3][1] = 1, e[3][2] = 0, e[3][3] = 0;
}
void init0(bool flag){
for(int i = 1; i <= 3; i++){
for(int j = 1; j <= 3; j++){
e[i][j] = 0;
}
}
if(flag){
for(int i = 1; i <= 3; i++){
e[i][i] = 1;
}
}
}
};

struct Solution{
int t;
ll n;

node mul(node x, node y){
node c;
c.init0(0);
for(int i = 1; i <= 3; i++){
for(int j = 1; j <= 3; j++){
for(int k = 1; k <= 3; k++){
c.e[i][j] += (x.e[i][k] * y.e[k][j]) % mod;
c.e[i][j] %= mod;
}
}
}
return c;
}

node power(ll k){
node ans, x;
x.init();
ans.init0(1);
while(k){
if(k & 1){
ans = mul(ans, x);
}
x = mul(x, x);
k >>= 1;
}
return ans;
}

void Solve(){
scanf("%d", &t);
while(t--){
scanf("%lld", &n);
node ans = power(n - 1);
printf("%lld\n", ans.e[1][1]);
}
}
};

int main(){
Solution().Solve();
return 0;
}

评论

Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×