计算图(自动微分)- Computational Graph

刷 PTA 看到的

  • L3-023 计算图

原理

  1. 在计算图中,节点的返回值依次后向传递以计算结果,依次前向传递以计算其偏导数。
  2. 由链式求导法则,在同一条路上的偏导数相乘前向传递,而不同路上的偏导数则累加到目标节点上。

实现

写的时候想简单了,两个都是 bfs 逐层处理,但是实际上如果有环会出问题,实际实现中应该使用拓扑排序,根据入度和出度作为指标决定计算顺序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
#include <bits/stdc++.h>
using namespace std;

class Node {
public:
int type; // 节点类型
double value; // 节点计算值
vector<double> pfpx; // 关于各个 operand 的局部导数
vector<Node*> nextNodes; // 后继节点
vector<Node*> operands; // 前驱节点(操作数)
double d = 0; // 反向传播时存储的梯度

void compute() {
if (type == 0) {
return; // 输入节点直接返回
} else if (type == 1) { // 加法:x+y
value = operands[0]->value + operands[1]->value;
pfpx = {1, 1};
} else if (type == 2) { // 减法:x-y
value = operands[0]->value - operands[1]->value;
pfpx = {1, -1};
} else if (type == 3) { // 乘法:x*y
value = operands[0]->value * operands[1]->value;
pfpx = {operands[1]->value, operands[0]->value};
} else if (type == 4) { // 指数:e^x
value = exp(operands[0]->value);
pfpx = {value}; // 导数仍为 e^x
} else if (type == 5) { // 对数:ln(x)
value = log(operands[0]->value);
pfpx = {1 / operands[0]->value};
} else if (type == 6) { // 正弦:sin(x)
value = sin(operands[0]->value);
pfpx = {cos(operands[0]->value)};
}
}
};

class ComputationalGraph {
public:
vector<Node> Nodes;
Node* root; // 输出节点(无后继边)
int n = 0; // 已添加节点数

ComputationalGraph() = default;
ComputationalGraph(int sz) { Nodes.resize(sz); }

// 添加输入节点
void addNode(int type, double value) {
Nodes[n].value = value;
Nodes[n].type = type;
++n;
}

// 添加双目运算节点
void addNode(int type, int operand0, int operand1) {
Nodes[n].type = type;
Nodes[n].operands = {&Nodes[operand0], &Nodes[operand1]};
Nodes[operand0].nextNodes.push_back(&Nodes[n]);
Nodes[operand1].nextNodes.push_back(&Nodes[n]);
n++;
}

// 添加单目运算节点
void addNode(int type, int operand0) {
Nodes[n].type = type;
Nodes[n].operands = {&Nodes[operand0]};
Nodes[operand0].nextNodes.push_back(&Nodes[n]);
n++;
}

// 前向传播计算:使用拓扑排序保证每个节点在依赖都计算完后计算
double compute() {
int sz = Nodes.size();
vector<int> indegree(sz, 0);
// 计算各节点的入度
for (int i = 0; i < sz; i++) {
for (auto child : Nodes[i].nextNodes) {
int idx = child - &Nodes[0];
indegree[idx]++;
}
}
queue<Node*> q;
// 入度为 0 的节点先入队
for (int i = 0; i < sz; i++) {
if (indegree[i] == 0) q.push(&Nodes[i]);
}
while (!q.empty()) {
Node* cur = q.front();
q.pop();
cur->compute();
for (auto child : cur->nextNodes) {
int idx = child - &Nodes[0];
indegree[idx]--;
if (indegree[idx] == 0) q.push(child);
}
// 输出节点:没有后继边
if (cur->nextNodes.empty()) {
root = cur;
}
}
return root->value;
}

// 反向传播计算梯度:先获得拓扑序,再倒序传递梯度
vector<double> bp() {
int sz = Nodes.size();
// 计算拓扑序列
vector<int> indegree(sz, 0);
for (int i = 0; i < sz; i++) {
for (auto child : Nodes[i].nextNodes) {
int idx = child - &Nodes[0];
indegree[idx]++;
}
}
queue<Node*> q;
vector<Node*> topo;
for (int i = 0; i < sz; i++) {
if (indegree[i] == 0) q.push(&Nodes[i]);
}
while (!q.empty()) {
Node* cur = q.front();
q.pop();
topo.push_back(cur);
for (auto child : cur->nextNodes) {
int idx = child - &Nodes[0];
indegree[idx]--;
if (indegree[idx] == 0) q.push(child);
}
}
// 倒序:从输出节点开始传递梯度
reverse(topo.begin(), topo.end());
// 清除所有节点梯度
for (auto& node : Nodes) {
node.d = 0;
}
// 输出节点梯度为 1
root->d = 1;
// 反向传播:按照倒序遍历
for (auto node : topo) {
for (int i = 0; i < node->operands.size(); i++) {
node->operands[i]->d += node->d * node->pfpx[i];
}
}
vector<double> res;
// 只输出输入节点的梯度,顺序与输入时相同
for (int i = 0; i < sz; i++) {
if (Nodes[i].type == 0) res.push_back(Nodes[i].d);
}
return res;
}
};

void solve() {
int nodeCount;
cin >> nodeCount;
ComputationalGraph cg(nodeCount);
int type, op1, op2;
double v;
// 使用 for 循环,避免破坏节点数
for (int i = 0; i < nodeCount; i++) {
cin >> type;
switch (type) {
case 0:
cin >> v;
cg.addNode(type, v);
break;
case 1:
case 2:
case 3:
cin >> op1 >> op2;
cg.addNode(type, op1, op2);
break;
case 4:
case 5:
case 6:
cin >> op1;
cg.addNode(type, op1);
break;
}
}
cout << fixed << setprecision(3);
cout << cg.compute() << endl;
auto res = cg.bp();
for (int i = 0; i < res.size(); i++) {
if (i != 0) cout << " ";
cout << res[i];
}
}

signed main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int t = 1;
while (t--) solve();
return 0;
}

UPD

更加通用的实现,向量输出,自定义函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
#include <bits/stdc++.h>
using namespace std;
using ld = long double;
using ll = long long;
class ComputeNode {
public:
size_t inputLength = 0;
ld value = 0;
vector<ld> pfpx;
vector<ComputeNode*> nextNodes;
vector<ComputeNode*> operands;
ld (*fp)(vector<ld>) = nullptr;
vector<ld> (*dfdxp)(vector<ld>) = nullptr;
size_t id = 999;
ld grad = 0;
ComputeNode(size_t _inputLength, ld (*_fp)(vector<ld>),
vector<ld> (*_dfdxp)(vector<ld>), size_t _id)
: inputLength(_inputLength), fp(_fp), dfdxp(_dfdxp), id(_id) {}
ComputeNode(ld val, size_t _id) : value(val), id(_id) {};
ld compute();
};

ld ComputeNode::compute() {
if (inputLength == 0) return value;
vector<ld> inputs;
for (size_t i = 0; i < inputLength; i++) {
inputs.push_back(operands[i]->value);
}
value = fp(inputs);
pfpx = dfdxp(inputs);
return value;
}

class ComputationalGraph {
public:
vector<ComputeNode*> Nodes;
vector<ComputeNode*> outNodes;
vector<ComputeNode*> inNodes;
vector<ComputeNode*> topo_forward;
vector<ComputeNode*> topo_backward;
ComputationalGraph() = default;
ComputeNode* addNode(size_t inputLength, ld (*fp)(vector<ld>),
vector<ld> (*dfdxp)(vector<ld>)); // 计算节点
ComputeNode* addNode(ld value); // 输入节点
void bindNode(vector<ComputeNode*> ops, ComputeNode* ComputeNodePtr);
vector<ld> compute();
vector<ld> bp();
void clearGrad() {
for (auto nd : Nodes) nd->grad = 0;
}
};

ComputeNode* ComputationalGraph::addNode(size_t inputLength,
ld (*fp)(vector<ld>),
vector<ld> (*dfdxp)(vector<ld>)) {
ComputeNode* nd = new ComputeNode(inputLength, fp, dfdxp, Nodes.size());
Nodes.push_back(nd);
return nd;
}

ComputeNode* ComputationalGraph::addNode(ld value) {
ComputeNode* nd = new ComputeNode(value, Nodes.size());
Nodes.push_back(nd);
return nd;
}

void ComputationalGraph::bindNode(vector<ComputeNode*> ops,
ComputeNode* ComputeNodePtr) {
for (auto nd : ops) {
ComputeNodePtr->operands.push_back(nd);
nd->nextNodes.push_back(ComputeNodePtr);
}
}

vector<ld> ComputationalGraph::compute() {
if (!topo_forward.size()) {
map<ComputeNode*, ll> indegree;
bool flag = false;
for (auto nd : Nodes) {
for (auto nnd : nd->nextNodes) {
indegree[nnd]++;
}
}
if (outNodes.size() != 0) flag = true;
queue<ComputeNode*> q;
if (inNodes.size() == 0) {
for (auto nd : Nodes)
if (indegree[nd] == 0) inNodes.push_back(nd); // 入度为0的是输入节点
sort(inNodes.begin(), inNodes.end(),
[&](auto a, auto b) { return a->id < b->id; }); // 按节点id排序
}
for (auto nd : inNodes) q.push(nd);
while (!q.empty()) {
auto nd = q.front();
q.pop();
topo_forward.push_back(nd);
for (auto nnd : nd->nextNodes) {
--indegree[nnd]; // 依赖nd的节点的入度-1(已经计算完成)
if (indegree[nnd] == 0) q.push(nnd); // 节点依赖都已计算完成,入队
}
if (nd->nextNodes.size() == 0 && !flag) {
outNodes.push_back(nd); // 统计输出节点
}
}
if (!flag) {
sort(outNodes.begin(), outNodes.end(),
[&](auto a, auto b) { return a->id < b->id; }); // 按节点id排序
}
topo_backward = topo_forward;
reverse(topo_backward.begin(), topo_backward.end());
}
for (auto nd : topo_forward) {
nd->compute();
}
vector<ld> res;
for (auto nd : outNodes) {
res.push_back(nd->value);
}
return res;
}

vector<ld> ComputationalGraph::bp() {
clearGrad(); // 清除梯度
for (auto nd : outNodes) nd->grad = 1; // 输出节点导数为1
for (auto nd : topo_backward)
for (size_t i = 0; i < nd->inputLength; i++)
nd->operands[i]->grad +=
nd->grad *
nd->pfpx[i]; // 前向传播,前面节点的梯度累加当前这一路上的梯度乘积
vector<ld> res;
for (auto nd : inNodes) {
res.push_back(nd->grad); // 所有输入变量的偏导数
}
return res;
}

ld add(vector<ld> inputs) { return inputs[0] + inputs[1]; }
vector<ld> dadddx(vector<ld> inputs) { return {1, 1}; }

ld sub(vector<ld> inputs) { return inputs[0] - inputs[1]; }
vector<ld> dsubdx(vector<ld> inputs) { return {1, -1}; }

ld mul(vector<ld> inputs) { return inputs[0] * inputs[1]; }
vector<ld> dmuldx(vector<ld> inputs) { return {inputs[1], inputs[0]}; }

ld ex(vector<ld> inputs) { return exp(inputs[0]); }
vector<ld> dexdx(vector<ld> inputs) { return {exp(inputs[0])}; }

ld ln(vector<ld> inputs) { return log(inputs[0]); }
vector<ld> dlndx(vector<ld> inputs) { return {1.0 / inputs[0]}; }

ld sin(vector<ld> inputs) { return sin(inputs[0]); }
vector<ld> dsindx(vector<ld> inputs) { return {cos(inputs[0])}; }

void solve() {
int nodeCount;
cin >> nodeCount;
ComputationalGraph cg;
int type, op1, op2;
double v;
vector<tuple<int, int, int>> nodes;
map<int, ComputeNode*> mp;
// 使用 for 循环,避免破坏节点数
for (int i = 0; i < nodeCount; i++) {
cin >> type;
switch (type) {
case 0:
cin >> v;
mp[i] = cg.addNode(v);
nodes.push_back({i, -1, -1});
break;
case 1:
cin >> op1 >> op2;
mp[i] = cg.addNode(2, add, dadddx);
nodes.push_back({i, op1, op2});
break;
case 2:
cin >> op1 >> op2;
mp[i] = cg.addNode(2, sub, dsubdx);
nodes.push_back({i, op1, op2});
break;
case 3:
cin >> op1 >> op2;
mp[i] = cg.addNode(2, mul, dmuldx);
nodes.push_back({i, op1, op2});
break;
case 4:
cin >> op1;
mp[i] = cg.addNode(1, ex, dexdx);
nodes.push_back({i, op1, -1});
break;
case 5:
cin >> op1;
mp[i] = cg.addNode(1, ln, dlndx);
nodes.push_back({i, op1, -1});
break;
case 6:
cin >> op1;
mp[i] = cg.addNode(1, sin, dsindx);
nodes.push_back({i, op1, -1});
break;
}
}
for (auto [nd, op1, op2] : nodes) {
vector<ComputeNode*> ops;
if (op1 != -1) ops.push_back(mp[op1]);
if (op2 != -1) ops.push_back(mp[op2]);
cg.bindNode(ops, mp[nd]);
}
cout << fixed << setprecision(3);
cout << cg.compute()[0] << endl;
auto res = cg.bp();
for (int i = 0; i < (int)res.size(); i++) {
if (i != 0) cout << " ";
cout << res[i];
}
}

int main() {
solve();
return 0;
}