快速幂

记录一下几个快速幂算法模板

  • 快速幂
  • 模意义下的快速幂
  • 矩阵快速幂

经典快速幂

原理

对于任意一个整数,我们都可以将其表示为一系列 2 的幂的和即:

其中 对应位上的bit, 的二进制长度, 下文中同理
于是对于 ,可以将其写成:

也就是我们可以通过计算其中的每一项并求积来计算

又由

所以为了计算 每一项我们可以:

  • 首先判断此处的 是不为
    • ,那么我们无需计算这一项
    • ,那么我们将当前的结果乘上
  • 为了不重复计算,我们可以由前一次的 的平方,计算得下一次所需要的

实现

迭代版本
1
2
3
4
5
6
7
8
9
ll qpow(ll a, ll n){
ll res = 1;
while(n){
if (n & 1) res = res * a; // n & 1 其实是取n的二进制最后一位,与 n % 2类似
a *= a; // 平方 a
n >>= 1; // 右移一位即 n /= 2
}
return res;
}

模运算的法则:

模意义下取幂
1
2
3
4
5
6
7
8
9
10
ll qpow(ll a, ll n, ll mod){
ll res = 1;
a %= mod; // 先mod一次防止第一次 a * a 的时候溢出
while(n){
if (n & 1) res = (res * a) % mod;
a = (a * a) % mod;
n >>= 1;
}
return res;
}

复杂度分析

循环执行次数也即是 的二进制位数,所以时间复杂度为

矩阵快速幂

原理

原理类似,算法类似,矩阵的幂也可以拆开算,时间复杂度也是压成
(这里认为计算矩阵乘法的时间复杂度是的)

实现

先给出matrix类的实现

class matrix
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class matrix {
public:
pair<ll, ll> shape;
vector<vector<ll>> values;

matrix() = default;

matrix(const ll rowsize, const ll colsize)
: shape(rowsize, colsize), values(rowsize, vector<ll>(colsize, 0)) {}

matrix& operator+=(matrix& mat) {
if (!(mat.shape.first == shape.first && mat.shape.second == shape.second)) {
throw invalid_argument("Wrong mat shape for add operation.");
exit(-1);
}
for (ll i = 0; i < mat.shape.first; ++i) {
for (ll j = 0; j < mat.shape.second; ++j) {
values[i][j] += mat.values[i][j];
}
}
return *this;
}

matrix& operator+(matrix& mat) {
matrix* res = new matrix(mat.shape.first, mat.shape.second);
if (!(mat.shape.first == shape.first && mat.shape.second == shape.second)) {
throw invalid_argument("Wrong mat shape for add operation.");
exit(-1);
}
for (ll i = 0; i < mat.shape.first; ++i) {
for (ll j = 0; j < mat.shape.second; ++j) {
res->values[i][j] = mat.values[i][j] + values[i][j];
}
}
return *res;
}

matrix& operator*(matrix& mat) {
matrix* res = new matrix(mat.shape.first, mat.shape.second);
if (!(mat.shape.second == shape.first)) {
throw invalid_argument("Wrong mat shape for multiply operation.");
exit(-1);
}
for (ll i = 0; i < shape.first; ++i) {
for (ll k = 0; k < mat.shape.second; ++k) {
for (ll j = 0; j < mat.shape.first; ++j) {
res->values[i][k] += values[i][j] * mat.values[j][k];
}
}
}
return *res;
}

void clear() {
for (ll i = 0; i < shape.first; ++i) {
for (ll j = 0; j < shape.second; ++j) {
values[i][j] = 0;
}
}
}

void unitify() {
if (shape.first != shape.second) {
throw invalid_argument("Not a n by n matrix.");
}
clear();
for (ll i = 0; i < shape.first; ++i) {
values[i][i] = 1;
}
}
};

因为有操作符重载,实现看着其实没区别

qpow
1
2
3
4
5
6
7
8
9
10
matrix qpow(matrix a, ll n) {
matrix ans = a;
ans.unitify();
while (n) {
if (n & 1) ans = ans * a;
a = a * a;
n >>= 1;
}
return ans;
}

贴一个测试程序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#include <bits/stdc++.h>
using ll = long long;
using namespace std;
class matrix {
public:
pair<ll, ll> shape;
vector<vector<ll>> values;

matrix() = default;

matrix(const ll rowsize, const ll colsize)
: shape(rowsize, colsize), values(rowsize, vector<ll>(colsize, 0)) {}

matrix& operator+=(matrix& mat) {
if (!(mat.shape.first == shape.first && mat.shape.second == shape.second)) {
throw invalid_argument("Wrong mat shape for add operation.");
exit(-1);
}
for (ll i = 0; i < mat.shape.first; ++i) {
for (ll j = 0; j < mat.shape.second; ++j) {
values[i][j] += mat.values[i][j];
}
}
return *this;
}

matrix& operator+(matrix& mat) {
matrix* res = new matrix(mat.shape.first, mat.shape.second);
if (!(mat.shape.first == shape.first && mat.shape.second == shape.second)) {
throw invalid_argument("Wrong mat shape for add operation.");
exit(-1);
}
for (ll i = 0; i < mat.shape.first; ++i) {
for (ll j = 0; j < mat.shape.second; ++j) {
res->values[i][j] = mat.values[i][j] + values[i][j];
}
}
return *res;
}

matrix& operator*(matrix& mat) {
matrix* res = new matrix(mat.shape.first, mat.shape.second);
if (!(mat.shape.second == shape.first)) {
throw invalid_argument("Wrong mat shape for add operation.");
exit(-1);
}
for (ll i = 0; i < shape.first; ++i) {
for (ll k = 0; k < mat.shape.second; ++k) {
for (ll j = 0; j < mat.shape.first; ++j) {
res->values[i][k] += values[i][j] * mat.values[j][k];
}
}
}
return *res;
}

void clear() {
for (ll i = 0; i < shape.first; ++i) {
for (ll j = 0; j < shape.second; ++j) {
values[i][j] = 0;
}
}
}

void unitify() {
if (shape.first != shape.second) {
throw invalid_argument("Not a n by n matrix.");
}
clear();
for (ll i = 0; i < shape.first; ++i) {
values[i][i] = 1;
}
}
};

matrix qpow(matrix a, ll n) {
matrix ans = a;
ans.unitify();
while (n) {
if (n & 1) ans = ans * a;
a = a * a;
n >>= 1;
}
return ans;
}

int main() {
matrix a(3, 3);
for (ll i = 0; i < 3; i++) {
for (ll j = 0; j < 3; j++) {
a.values[i][j] = 1;
}
}
cout << "mat a" << endl;
for (ll i = 0; i < 3; i++) {
for (ll j = 0; j < 3; j++) {
cout << a.values[i][j] << " ";
}
cout << endl;
}
ll n;
cout << "n: ";
cin >> n;
matrix res = qpow(a, n);
for (ll i = 0; i < 3; i++) {
for (ll j = 0; j < 3; j++) {
cout << res.values[i][j] << " ";
}
cout << endl;
}
return 0;
}

斐波那契数列

这边贴一个矩阵快速幂求斐波那契的例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include <bits/stdc++.h>
using ll = long long;
using namespace std;
struct matrix {
ll a1, a2, b1, b2;
matrix(ll a1, ll a2, ll b1, ll b2) : a1(a1), a2(a2), b1(b1), b2(b2) {};
matrix operator*(const matrix& y) {
matrix ans((a1 * y.a1 + a2 * y.b1), (a1 * y.a2 + a2 * y.b2),
(b1 * y.a1 + b2 * y.b1), (b1 * y.a2 + b2 * y.b2));
return ans;
}
};
matrix qpow(matrix a, ll n) {
matrix ans(1, 0, 0, 1);
while (n) {
if (n & 1) ans = ans * a;
a = a * a;
n >>= 1;
}
return ans;
}
ll fib(int n) {
if (n == 1 || n == 2 || n < 1) {
return 1;
} else {
matrix M(0, 1, 1, 1);
matrix ans = qpow(M, n - 1);
return ans.a1 + ans.a2;
}
}
int main() {
ll n;
cin >> n;
cout << fib(n) << endl;
return 0;
}