线段树合并
有时,我们需要给树上每一个节点维护一个一维 dp
数组。若其转移过程满足以下条件,可以使用线段树合并优化:
- 每个节点的 \(dp\) 值都是由其子节点转移而来;
- \(dp_u\) 的一个下标位置由 \(dp_v\) 中相同位置下标转移而来;
- (可选)\(dp_u\) 的一个下标位置 \(i\) 额外需要 \(dp_u\) 或 \(dp_v\) 在 \([1,i]\) 区间前缀和来转移;
题目大意
村落里的一共有 \(n\) 座房屋,并形成一个树状结构。救济粮分 \(m\) 次发放,每次选择两个房屋 \((x, y)\),然后对于 \(x\) 到 \(y\) 的路径上(含 \(x\) 和 \(y\))每座房子里发放一袋 \(z\) 类型的救济粮。
我们想知道,当所有的救济粮发放完毕后,每座房子里存放的最多的是哪种救济粮。
我们发现如果暴力枚举路径上的每个节点,时间复杂度就已经炸了。因此我们首先考虑使用数据结构维护路径信息,或者使用树上差分。因为每个节点还可能有多个元素,所以直接使用数据结构维护路径,时间和空间都会爆炸。
考虑树上差分。对于一条路径 \((u,v)\),我们修改 \(x\),\(y\),\(lca(x,y)\),\(fa[lca(x,y)]\),并在第二次 dfs
的时候做子树求和。记
- \(g_{u,j}\) 表示节点 \(u\) 中 \(j\) 类型救济粮的差分数组;
- \(f_{u,j}\) 表示发放完所有粮食之后,节点 \(u\) 中 \(j\) 类型救济粮的总数量;
由此得到:
\[
f_{u,j}=g_{u,j}+\sum_{v\in to[u]}f_{v,j}
\]
这种和子树数组对应位置直接相加的式子可以使用线段树合并优化。
具体的,对于每条操作路径,我们向 \(x\) 和 \(y\) 节点对应的线段树(动态开点的权值线段树)的 \(w\) 位置 \(+1\),\(lca(x,y)\) 对应的位置 \(-1\)。
接着,在 dfs
中将子节点 \(v\) 的线段树合并到 \(u\) 的线段树上。把所有子节点的线段树都合并之后,当前线段树就是 \(u\) 的 \(f_{u}\) 数组。直接查询最大值即可。
时间复杂度分析
当两棵线段树有一个重合的节点时,才会继续向下递归,产生 \(O(1)\) 的时间复杂度。同时两个重合的节点被合并成了一个,这等价于删除了一个节点。由此得到:每产生 \(O(1)\) 的时间复杂度,都会删除 \(1\) 个节点。
因为初始情况下一共有 \(O(n\log V)\) 个节点,并且合并过程中不会产生新的节点,因此删除节点数一定小于等于 \(n\log V\)。由上面的结论得到,时间复杂度的上界是 \(O(n\log V)\)。
代码
| #include<iostream>
#define int long long
using namespace std;
const int N = 1E5 + 10;
const int Z = 1E5;
const int LOGN = 20;
struct myPair {
int pos, val;
};
inline const myPair &max(const myPair &a, const myPair &b) {
if(a.val >= b.val) return a;
return b;
}
struct Edge {
int v;
int next;
} pool[2 * N];
int ne, head[N];
void addEdge(int u, int v) {
pool[++ne] = {v, head[u]};
head[u] = ne;
}
int n, m;
int anc[N][LOGN], dep[N];
void dfs0(int u, int fa) {
dep[u] = dep[fa] + 1;
anc[u][0] = fa;
for(int i = 1; i < LOGN; i++) {
anc[u][i] = anc[ anc[u][i - 1] ][i - 1];
}
for(int i = head[u]; i; i = pool[i].next) {
int v = pool[i].v;
if(v == fa) continue;
dfs0(v, u);
}
}
int getlca(int x, int y) {
if(dep[x] < dep[y]) {
swap(x, y);
}
for(int i = LOGN - 1; i >= 0; i--) {
if(dep[anc[x][i]] >= dep[y]) x = anc[x][i];
}
if(x == y) return x;
for(int i = LOGN - 1; i >= 0; i--) {
if(anc[x][i] != anc[y][i]) {
x = anc[x][i];
y = anc[y][i];
}
}
return anc[x][0];
}
int ans[N];
int lc[4 * 20 * N], rc[4 * 20 * N], nn;
myPair mx[4 * 20 * N];
int rt[N];
void push_up(int p) {
mx[p] = max(mx[lc[p]], mx[rc[p]]);
}
void insert(int &p, int l, int r, int q, int v) {
if(p == 0) p = ++nn;
if(l == r) {
mx[p].val += v;
mx[p].pos = l;
return;
}
int mid = (l + r) >> 1;
if(mid >= q) insert(lc[p], l, mid, q, v);
else insert(rc[p], mid + 1, r, q, v);
push_up(p);
}
int merge(int p1, int p2, int l, int r) {
if(p1 == 0 || p2 == 0) return p1 | p2;
if(l == r) {
mx[p1].val += mx[p2].val;
return p1;
}
int mid = (l + r) >> 1;
lc[p1] = merge(lc[p1], lc[p2], l, mid);
rc[p1] = merge(rc[p1], rc[p2], mid + 1, r);
push_up(p1);
return p1;
}
void dfs(int u, int f) {
for(int i = head[u]; i; i = pool[i].next) {
int v = pool[i].v;
if(v == f) continue;
dfs(v, u);
rt[u] = merge(rt[u], rt[v], 1, Z);
}
if(mx[rt[u]].val != 0) ans[u] = mx[rt[u]].pos;
}
signed main() {
cin >> n >> m;
for(int i = 1; i <= n - 1; i++) {
int u, v;
cin >> u >> v;
addEdge(u, v);
addEdge(v, u);
}
dfs0(1, 0);
for(int i = 1; i <= m; i++) {
int x, y, z;
cin >> x >> y >> z;
int lca = getlca(x, y);
insert(rt[x], 1, Z, z, 1);
insert(rt[y], 1, Z, z, 1);
insert(rt[lca], 1, Z, z, -1);
insert(rt[anc[lca][0]], 1, Z, z, -1);
}
dfs(1, 0);
for(int i = 1; i <= n; i++) {
cout << ans[i] << '\n';
}
return 0;
}
|
子树深度查询
题目大意
给定一棵有根树,根节点为 \(1\) 号节点,每个点有点权。有 \(q\) 次询问,每次询问给定 \(u,l,r\),询问在 \(u\) 子树内所有满足 \(dep[v]\in [l,r]\) 的节点 \(v\) 的点权的最大值。
二维数点做不了,因为两个维度都有上下界的限制,且答案不可差分。
这里我们使用线段树合并解决。使用线段树维护深度,可以 \(O(\log n)\) 求出区间点权最大值。因为深度是不变的,满足使用线段树合并的要求,直接套板子即可。
题目大意
给定一棵 \(n\)(\(n\le 5\times 10^5\))个节点的有根树和 \(m\)(\(m \le 5\times 10^5\))条返祖链。现在要求用黑白两种颜色给所有边染色,使得每条返祖链都有至少一条边是白色的。问染色方案数对 \(998244353\) 取模的结果。
记返祖链的底端为其起点,顶端为其末端。
容易发现,从同一个节点出发的多条返祖链,只有末端最深的一个是有效的。我们给每个节点维护一个数组 \(w[u]\) 表示从该起点出发的所有返祖链中,末端的最大深度。若没有从该节点出发的返祖链,则 \(w[u]=0\)。
因为我们可以在根节点上方挂一条虚边,钦定这条虚边始终为白色,再从所有节点都向这个虚边上方连一条返祖链,就可以将有、无剩余返祖链的情况简单的统一。
| for(int i = 1; i <= m; i++) {
int u, v;
cin >> u >> v;
w[v] = max(w[v], dep[u]);
}
|
考虑一种朴素的 dp
,设 \(dp_{u,j}\) 表示节点 \(u\) 的子树中,整个都在子树内的限制(返祖链)都被满足,而没有被满足的所有返祖链的末端节点中最深的一个深度为 \(j\),方案数是多少。(若所有都被满足,则 \(j=0\))
容易写出状态转移方程:
\[
dp'_{u,j}=dp_{u,j}\times (\sum_{k=0}^{j}{dp_{v,k}}+\sum_{k=0}^{dep_u}{dp_{v,k}})+(\sum_{k=0}^{j-1}{dp_{u,k}})\times dp_{v,j}
\]
其中,第 \(1\)、\(2\) 项表示新增的子树 \(v\) 没有改变最深的末端深度;第 \(3\) 项表示 \(v\) 产生了一条末端更深的未处理的返祖链。
这种式子可以直接使用线段树合并优化。
- 对于 \(\sum\limits_{k=0}^{dep_u}{dp_{v,k}}\) 的常数,可以在调用
merge()
之前对 \(v\) 的线段树调用 query()
得到,记为 \(c\);
- 对于 \(dp_{v,k}\) 和 \(dp_{u,k}\) 的前缀和,因为线段树合并的递归顺序是一种前序遍历的
dfs
序,可以在 dfs
的时候动态维护两个变量 \(sum_1\) 和 \(sum_2\),分别记录两个前缀和。
| /*
* p1 线段树对应 dp[u] 数组
* p2 线段树对应 dp[v] 数组
*/
int merge(int p1, int p2, int l, int r) {
if(p1 == 0 && p2 == 0) return 0;
/*
* p2 == 0 说明 dp[v][j] 为 0,则转移方程的第三项为 0。
* 且 sum1 没有变化。
* 我们直接在原先 p1(dp[u][j])的基础上乘以 (sum1 + c) 即可(调用 move_tag 就是这个效果)。
*/
if(p2 == 0) {
sum2 = sum2 + sum[p1];
move_tag(p1, sum1 + c);
return p1;
}
/*
* p1 == 0 说明 dp[u][j] 为 0,则方程第 1、2 项为 0。
* 且 sum2 没有变化。
* 可以直接将 sum2 乘在 p2 树上面,并将其 copy 到 p1 上返回。
*/
if(p1 == 0) {
sum1 = sum1 + sum[p2];
move_tag(p2, sum2);
return p2;
}
/*
* 单节点处,直接根据公式处理即可。
* 注意,sum2 的求和上界是 j-1,并且求和的元素是 原先 的 dp[u][1](区分更新后的 dp[u][1])
* 因此要先开一个临时变量记录 sum2 更新后的结果,再去合并 p1 和 p2。
*/
if(l == r) {
int tmp;
sum1 = sum1 + sum[p2];
tmp = sum2 + sum[p1];
sum[p1] = sum[p1] * (sum1 + c) + sum[p2] * sum2;
sum2 = tmp;
return p1;
}
/*
* 线段树合并板子
*/
push_down(p1);
push_down(p2);
int mid = (l + r) >> 1;
lc[p1] = merge(lc[p1], lc[p2], l, mid);
rc[p1] = merge(rc[p1], rc[p2], mid + 1, r);
push_up(p1);
return p1;
}
|
调用处:
| void dfs(int u, int f) {
for(int i = head[u]; i; i = pool[i].next) {
int v = pool[i].v;
if(v == f) continue;
dfs(v, u);
sum1 = sum2 = 0;
c = query(rt[v], 0, n, 0, dep[u]);
rt[u] = merge(rt[u], rt[v], 0, n);
}
}
|
完整代码
| #include<iostream>
#define int long long
using namespace std;
const int N = 5E5 + 10;
const int MOD = 998244353;
struct Edge {
int v;
int next;
} pool[2 * N];
int ne, head[N];
void addEdge(int u, int v) {
pool[++ne] = {v, head[u]};
head[u] = ne;
}
int n, m;
int dep[N], w[N];
void dfs1(int u, int f) {
dep[u] = dep[f] + 1;
for(int i = head[u]; i; i = pool[i].next) {
int v = pool[i].v;
if(v == f) continue;
dfs1(v, u);
}
}
int nn;
int rt[N];
int sum[32 * N], lc[32 * N], rc[32 * N], mul[32 * N];
void push_up(int p) {
sum[p] = ((lc[p] ? sum[lc[p]] : 0) + (rc[p] ? sum[rc[p]] : 0)) % MOD;
}
void move_tag(int p, int tg) {
if(p == 0) return;
sum[p] = sum[p] * tg % MOD;
mul[p] = mul[p] * tg % MOD;
}
void push_down(int p) {
move_tag(lc[p], mul[p]);
move_tag(rc[p], mul[p]);
mul[p] = 1;
}
void insert(int &p, int l, int r, int q, int v) {
if(p == 0) {
p = ++nn;
mul[p] = 1;
}
if(l == r) {
sum[p] = (sum[p] + v) % MOD;
return;
}
push_down(p);
int mid = (l + r) >> 1;
if(q <= mid) insert(lc[p], l, mid, q, v);
else insert(rc[p], mid + 1, r, q, v);
push_up(p);
}
int query(int p, int l, int r, int ql, int qr) {
if(p == 0) return 0;
if(ql <= l && r <= qr) {
return sum[p];
}
push_down(p);
int mid = (l + r) >> 1, res = 0;
if(mid >= ql) res = query(lc[p], l, mid, ql, qr);
if(mid < qr) res = (res + query(rc[p], mid + 1, r, ql, qr)) % MOD;
return res;
}
int c, sum1, sum2;
int merge(int p1, int p2, int l, int r) {
if(p1 == 0 && p2 == 0) return 0;
if(p2 == 0) {
sum2 = (sum2 + sum[p1]) % MOD;
move_tag(p1, (sum1 + c) % MOD);
return p1;
}
if(p1 == 0) {
sum1 = (sum1 + sum[p2]) % MOD;
move_tag(p2, sum2);
return p2;
}
if(l == r) {
int tmp;
sum1 = (sum1 + sum[p2]) % MOD;
tmp = (sum2 + sum[p1]) % MOD;
sum[p1] = (sum[p1] * ((sum1 + c) % MOD) % MOD + sum[p2] * sum2 % MOD) % MOD;
sum2 = tmp;
return p1;
}
push_down(p1);
push_down(p2);
int mid = (l + r) >> 1;
lc[p1] = merge(lc[p1], lc[p2], l, mid);
rc[p1] = merge(rc[p1], rc[p2], mid + 1, r);
push_up(p1);
return p1;
}
void dfs(int u, int f) {
for(int i = head[u]; i; i = pool[i].next) {
int v = pool[i].v;
if(v == f) continue;
dfs(v, u);
sum1 = sum2 = 0;
c = query(rt[v], 0, n, 0, dep[u]);
rt[u] = merge(rt[u], rt[v], 0, n);
}
}
signed main() {
cin >> n;
for(int i = 1; i <= n - 1; i++) {
int u, v;
cin >> u >> v;
addEdge(u, v);
addEdge(v, u);
}
dfs1(1, 0);
cin >> m;
for(int i = 1; i <= m; i++) {
int u, v;
cin >> u >> v;
w[v] = max(w[v], dep[u]);
}
for(int i = 1; i <= n; i++) {
insert(rt[i], 0, n, w[i], 1);
}
dfs(1, 0);
cout << query(rt[1], 0, n, 0, 0) << endl;
return 0;
}
|
做法高度相似于上一题,也需要前缀和处理,以及线段树乘法 tag
。