树链剖分 学习笔记

这就是传说中的树链剖分?

  • 树链剖分可以将一棵树的任意一条路径划分成不超过$O(\log n)$条链,并且满足dfs序,反正就很好啦,维护线段树什么的。。。
  • 还可以$O(\log n)$求$lca$,常数小什么的。

树链剖分

定义

  • 定义重子节点表示其子节点中子树最大的子结点。如果有相同的,任意取。如果没有子节点,就没有。
  • 轻子节点就是剩余的其他子节点。
  • 这个节点到重子节点的边叫做重边,其他叫做轻边
  • 把若干条首尾相连的重边称为重链
  • 把单独的节点也当成重链,就把整棵树分成了若干条链。
    如图:树链剖分1

实现

  • 分两个dfs实现:
    1. 第一个dfs记录一下子树大小和深度。
    2. 记录重子节点,dfs序,当前节点的链顶。
  • 代码实现
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    void dfs1(int o, int fat) {
    son[o] = -1;
    siz[o] = 1;
    for (int j = h[o]; j; j = nxt[j])
    if (!dep[p[j]]) {
    dep[p[j]] = dep[o] + 1;
    fa[p[j]] = o;
    dfs1(p[j], o);
    siz[o] += siz[p[j]];
    if (son[o] == -1 || siz[p[j]] > siz[son[o]]) son[o] = p[j];
    }
    }
    void dfs2(int o, int t) {
    top[o] = t;
    cnt++;
    tid[o] = cnt;
    rnk[cnt] = o;
    if (son[o] == -1) return;
    dfs2(son[o], t); //优先对重儿子进行dfs,可以保证同一条重链上的点时间戳连续
    for (int j = h[o]; j; j = nxt[j])
    if (p[j] != son[o] && p[j] != fa[o]) dfs2(p[j], p[j]);
    }

模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#include<bits/stdc++.h>
using namespace std;

int n, m, r, p;
int v[200009];
struct Edge{
int to, nxt;
}e[400009];
int head[200009], tot;
int size[200009], dep[200009], fat[200009];
int dfn[200009], top[200009], id[200009];
int son[200009];

//----------------------------------------------

int tree[800009], tag[800009];

//----------------------------------------------

void add(int x, int y){
e[++tot].to = y;
e[tot].nxt = head[x];
head[x] = tot;
}

void dfs1(int x, int fa, int d){
fat[x] = fa;
dep[x] = d;
size[x] = 1;
int maxx = -1;
for(int i = head[x]; i; i = e[i].nxt){
int y = e[i].to;
if(y == fa) continue;
dfs1(y, x, d + 1);
size[x] += size[y];
if(size[y]>maxx) son[x] = y, maxx = size[y];
}
}

int cnt = 0;
void dfs2(int x, int topp){
dfn[x] = ++cnt;
id[cnt] = x;
top[x] = topp;
if(!son[x]) return;
dfs2(son[x], topp);
for(int i = head[x]; i; i = e[i].nxt){
int y = e[i].to;
if(y == son[x] || y == fat[x]) continue;
dfs2(y, y);
}
}

//-----------------------------------------

void build(int k, int l, int r){
if(l == r){
tree[k] = v[id[l]] % p;
return;
}
int mid = (l + r) / 2;
build(k * 2, l, mid);
build(k * 2 + 1, mid + 1, r);
tree[k] = (tree[k * 2] + tree[k * 2 + 1]) % p;
return;
}

void Add(int k, int l, int r, int w){
tag[k] += w;
tree[k] += w * (r - l + 1);
tree[k] %= p;
return;
}

void pushdown(int k, int l, int r){
int mid = (l + r) / 2;
Add(k * 2, l, mid, tag[k]);
Add(k * 2 + 1, mid + 1, r, tag[k]);
tag[k] = 0;
}

void modify(int k, int l, int r, int x, int y, int w){
if(l >= x && r <= y){
Add(k, l, r, w);
return;
}
pushdown(k, l, r);
int mid = (l + r) / 2;
if(mid >= x) modify(k * 2, l, mid, x, y, w);
if(mid < y) modify(k * 2 + 1, mid + 1, r, x, y, w);
tree[k] = (tree[k * 2] + tree[k * 2 + 1]) % p;
}

int query(int k, int l, int r, int x, int y){
if(l >= x && r <= y){
return tree[k] % p;
}
pushdown(k, l, r);
int mid = (l + r) / 2;
int res = 0;
if(mid >= x) res = (res + query(k * 2, l, mid, x, y)) % p;
if(mid < y) res = (res + query(k * 2 + 1, mid + 1, r, x, y)) % p;
return res;
}

//---------------------------------------------

void modify_tree(int x, int y, int w){
w %= p;
while(top[x] != top[y]){
if(dep[top[x]] < dep[top[y]]) swap(x, y);
modify(1, 1, n, dfn[top[x]], dfn[x], w);
x = fat[top[x]];
}
if(dep[x] > dep[y]) swap(x, y);
modify(1, 1, n, dfn[x], dfn[y], w);
return;
}


int query_tree(int x, int y){
int res = 0;
while(top[x] != top[y]){
if(dep[top[x]] < dep[top[y]]) swap(x, y);
res += query(1, 1, n, dfn[top[x]], dfn[x]);
res %= p;
x = fat[top[x]];
}
if(dep[x] > dep[y]) swap(x, y);
res = (res + query(1, 1, n, dfn[x], dfn[y])) % p;
return res;
}

int main(){
scanf("%d%d%d%d", &n, &m, &r, &p);
for(int i = 1; i <= n; i++){
scanf("%d", &v[i]);
}
for(int i = 1; i <= n - 1; i++){
int x, y;
scanf("%d%d", &x, &y);
add(x, y), add(y, x);
}
dfs1(r, -1, 1);
dfs2(r, r);
build(1, 1, n);
for(int i = 1; i <= m; i++){
int op, x, y, w;
scanf("%d", &op);
if(op == 1){
scanf("%d%d%d", &x, &y, &w);
modify_tree(x, y, w);
}
if(op == 2){
scanf("%d%d", &x, &y);
printf("%d\n", query_tree(x, y));
}
if(op == 3){
scanf("%d%d", &x, &w);
modify(1, 1, n, dfn[x], dfn[x] + size[x] - 1, w);
}
if(op == 4){
scanf("%d", &x);
printf("%d\n", query(1, 1, n, dfn[x], dfn[x] + size[x] - 1));
}
}
return 0;
}

评论

Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×