虚树
有时,题目会在初始时给定一棵树,然后多次询问,每次询问给定 \(k\) 个关键点 \(p_1,p_2\cdots p_k\),要求回答相关信息,同时保证 \(\sum_{k}\)。我们显然不能 \(O(n)\) 的回答每次询问。因此,我们可以使用虚树来将单次询问的时间降低到 \(O(k)\)。
虚树通过只保留树上的关键点以及维持树形态的必要节点(总共 \(O(k)\) 个)来将实际的计算量减小到只和 \(k\) 有关。
建立
我们先给出结论:当且仅当一个节点存在三棵不同的包含关键点的子树,它应当被保留在虚树中。这样,虚树中的一条边就代表了原树中的一条链和挂在链上的所有子树;同时要注意虚树中的一个点还包含了被省略掉的子树。
为了更方便建立虚树,我们放宽上面的条件,加入所有 \(\operatorname{lca}(p_i,p_j)\)。不难发现放宽条件后最多比放宽前多一个节点。
到这里,我们已经能写出一种方法建出一种虚树:
- 将所有关键点按
dfn
序排序;
- 单调栈清空,初始加入节点 \(1\)(如果没有);
- 依次将关键点加入单调栈,设当前节点为 \(u\),如果栈顶节点是 \(u\) 的祖先,那么直接将 \(u\) 入栈;
- 否则不断弹出单调栈中的元素,直到满足栈顶第二个元素是 \(u\) 的祖先;
- 将栈顶和 \(u\) 的 \(lca\) 记录下来,连边 $lca\to $ 栈顶,弹出栈顶,压入 \(lca\);
- 压入 \(u\);
- 结束后,弹出栈内的剩余元素;
- 每次正常弹栈之前从栈顶第二个元素连一条边到栈顶;
然后在虚树上求解答案即可。
性质
- 虚树的节点数量和关键点数量同阶;
- 虚树上所有节点的相对祖先关系和原树相同;
- 虚树上任意两个节点在原树上的 \(\operatorname{lca}\) 还在虚树上;
题意
给定一棵 \(n\) 个点的树,有若干次询问,每次询问给定 \(k\) 个关键点 \(p_1,p_2,\cdots p_k\),表示询问 \((p_i,p_j)\) 二元组(\(i\ne j\))中 \(p_i,p_j\) 距离的总和,距离的最小值和最大值。
\(n\le 10^6,\ \sum k\le 2n\)
先建出虚树,然后在虚树上跑一遍 dfs
,记录子树内离当前点最近、最远的关键点,和关键点的总数,然后在 lca
处统计贡献即可。
参考代码
本代码采用模拟调用栈的方法,不建议使用这种写法。
| #include<bits/stl_algobase.h>
#include<ctype.h>
#include<cstdio>
#include<cassert>
#include<algorithm>
#define ll long long
using namespace std;
const int N = 1e6 + 10;
const int LOGN = 21;
const ll INF = 0x3f3f3f3f3f3f3f3f;
struct istream {
char ch;
template<typename _Tp>
inline istream &operator>>(_Tp &x) {
while(!isdigit(ch = getchar_unlocked()));
x = ch - '0';
while(isdigit(ch = getchar_unlocked())) x = x * 10 + ch - '0';
return *this;
}
} cin;
struct ostream {
char buf[60], top;
inline ostream() { top = 0; }
inline ostream &operator<<(char c) {
putchar_unlocked(c);
return *this;
}
template<typename _Tp>
inline ostream &operator<<(_Tp x) {
do buf[++top] = x % 10, x /= 10; while(x);
while(top) putchar_unlocked(buf[top--] + '0');
return *this;
}
} cout;
struct Edge {
int v, next;
} pool[2 * N];
int ne, head[N];
void addEdge(int u, int v) {
pool[++ne] = {v, head[u]};
head[u] = ne;
}
int n, q, k;
int dfn[N], id[N], dt;
int dep[N], mn[LOGN][N], lg[N];
void dfs(int u, int fa) {
dfn[u] = ++dt;
id[dt] = u;
dep[u] = dep[fa] + 1;
mn[0][dfn[u]] = fa;
for(int e = head[u]; e; e = pool[e].next) {
int v = pool[e].v;
if(v == fa) continue;
dfs(v, u);
}
}
int getLCA(int x, int y) {
if(x == y) return x;
x = dfn[x], y = dfn[y];
if(x > y) swap(x, y);
++x;
int d = lg[y - x + 1];
int t1 = mn[d][x], t2 = mn[d][y - (1 << d) + 1];
return dep[t1] < dep[t2] ? t1 : t2;
}
int getDis(int x, int y) {
return dep[x] + dep[y] - 2 * dep[getLCA(x, y)];
}
int p[2 * N], cnt;
int sta[N], top;
bool cmp_dfn(int a, int b) {
return dfn[a] < dfn[b];
}
int imp[N];
ll ans1, ans2, ans3;
ll sz[N], mxd[N], mnd[N];
void pop_stack() {
int u = sta[top - 1], v = sta[top];
if(u) {
int w = getDis(u, v);
ans1 += (ll)w * sz[v] * (k - sz[v]);
sz[u] += sz[v];
ans2 = min(ans2, mnd[u] + w + mnd[v]);
ans3 = max(ans3, mxd[u] + w + mxd[v]);
mnd[u] = min(mnd[u], w + mnd[v]);
mxd[u] = max(mxd[u], w + mxd[v]);
}
--top;
}
void push_stack(int u) {
sz[u] = imp[u];
mxd[u] = imp[u] ? 0 : -INF;
mnd[u] = imp[u] ? 0 : INF;
sta[++top] = u;
}
int main() {
for(int i = 2; i < N; i++) lg[i] = lg[i >> 1] + 1;
cin >> n;
for(int i = 1; i <= n - 1; i++) {
int u, v;
cin >> u >> v;
addEdge(u, v);
addEdge(v, u);
}
dfs(1, 0);
for(int k = 1; k < LOGN; k++) {
for(int i = 1; i + (1 << k) - 1 <= n; i++) {
int t1 = mn[k - 1][i], t2 = mn[k - 1][i + (1 << (k - 1))];
mn[k][i] = dep[t1] < dep[t2] ? t1 : t2;
}
}
cin >> q;
while(q--) {
ans1 = 0, ans2 = INF, ans3 = 0, top = 0;
cin >> k;
for(int i = 1; i <= k; i++) cin >> p[i];
for(int i = 1; i <= k; i++) imp[p[i]] = 1;
sort(p + 1, p + 1 + k, cmp_dfn);
cnt = k;
for(int i = 1; i < k; i++) p[++cnt] = getLCA(p[i], p[i + 1]);
sort(p + 1, p + 1 + cnt, cmp_dfn);
cnt = unique(p + 1, p + 1 + cnt) - (p + 1);
push_stack(p[1]);
for(int i = 2; i <= cnt; i++) {
int lca = getLCA(sta[top], p[i]);
while(lca != sta[top]) {
pop_stack();
}
push_stack(p[i]);
}
while(top) pop_stack();
for(int i = 1; i <= cnt; i++) imp[p[i]] = 0;
cout << ans1 << ' ' << ans2 << ' ' << ans3 << '\n';
}
return 0;
}
|
题意
给定一棵 \(n\) 个点的树,有若干次询问,每次询问给定 \(k\) 个关键点 \(p_1,p_2,\cdots p_k\);对于树上的每个点,它都会被离它最近的关键点控制,若距离相同则被标号小者控制;问每个关键点控制多少个节点。
\(n\le 3\times 10^5,\ \sum k\le 3\times 10^5\)
先建出虚树 \(T'\)(边有边权),然后求出虚树上每个节点会被哪个关键点控制(记为 \(s_i\)),以及控制的距离是多少。
考察每一条边 \((u,v)\)(不妨设 \(u=fa[v]\)),\(u,v\) 被同一关键点控制的情况是平凡的;若不同,我们用倍增求出 \((u,v)\) 在原树上对应的链的分界点 \(t\)(即 \(t\) 是链上最深的被 \(s_v\) 控制的节点),记节点 \(s\) 是在链上 \(u\) 下面的一个节点,那么 \(s_u\) 会控制这条链上 \(sz[s]-sz[t]\) 个节点,\(s_v\) 会控制链上 \(sz[t]-sz[v]\) 个节点,累加到 \(s_u,s_v\) 的贡献上即可。
考察每一个节点 \(u\),除去链上的贡献,它还会对 \(s_u\) 产生 \(sz[u]-\sum_{v\in son[u]\wedge v\in T'}{sz[v]}\) 的贡献,将它们也累加到 ans
数组即可。
代码
| #include<iostream>
#include<vector>
#include<algorithm>
using namespace std;
const int N = 3e5 + 10;
const int LOGN = 20;
const int INF = 0x3f3f3f3f;
struct Edge {
int v, next;
} pool[2 * N];
int ne, head[N];
void addEdge(int u, int v) {
pool[++ne] = {v, head[u]};
head[u] = ne;
}
int n, T;
int k;
int p[N];
int anc[LOGN][N];
int mnl[LOGN][N], lg[N];
int dep[N], sz[N];
int dfn[N], dt;
void dfs1(int u, int fa) {
dfn[u] = ++dt;
dep[u] = dep[fa] + 1;
anc[0][u] = fa;
sz[u] = 1;
mnl[0][dfn[u]] = fa;
for(int i = 1; i < LOGN; i++) anc[i][u] = anc[i - 1][ anc[i - 1][u] ];
for(int e = head[u]; e; e = pool[e].next) {
int v = pool[e].v;
if(v == fa) continue;
dfs1(v, u);
sz[u] += sz[v];
}
}
void init_st() {
for(int i = 2; i <= n; i++) lg[i] = lg[i >> 1] + 1;
for(int k = 1; k < LOGN; k++) {
for(int i = 1; i + (1 << k) - 1 <= n; i++) {
int t1 = mnl[k - 1][i], t2 = mnl[k - 1][i + (1 << (k - 1))];
mnl[k][i] = dep[t1] < dep[t2] ? t1 : t2;
}
}
}
int getLCA(int x, int y) {
if(x == y) return x;
x = dfn[x], y = dfn[y];
if(x > y) { swap(x, y); } ++x;
int d = lg[y - x + 1];
return dep[mnl[d][x]] < dep[mnl[d][y - (1 << d) + 1]] ? mnl[d][x] : mnl[d][y - (1 << d) + 1];
}
int get_son(int x, int y) {
int x1 = y;
for(int i = LOGN - 1; i >= 0; i--) {
if(dfn[anc[i][x1]] > dfn[x]) x1 = anc[i][x1];
}
return x1;
}
inline bool cmp_dfn(int a, int b) { return dfn[a] < dfn[b]; }
int iskey[N];
int sta[N], top;
int ans[N];
vector<int> qr; // 询问的点
vector<int> kp; // 所有虚树点
namespace VT {
struct Edge {
int v, son, next;
} pool[2 * N];
int ne, head[N];
void addEdge(int u, int v) {
pool[++ne] = {v, get_son(u, v), head[u]};
head[u] = ne;
}
struct myPair {
int p, d;
inline myPair operator+(int w) const {
return {p, d + w};
}
inline bool operator<(const myPair &b) const {
if(d != b.d) return d < b.d;
return p < b.p;
}
} mn[N];
int val[N];
void dfs1(int u) {
val[u] = sz[u];
if(iskey[u]) {
mn[u] = {u, 0};
} else mn[u] = {0, INF};
for(int e = head[u]; e; e = pool[e].next) {
int v = pool[e].v, w = dep[v] - dep[u];
val[u] -= sz[pool[e].son];
dfs1(v);
if(mn[v] + w < mn[u]) mn[u] = mn[v] + w;
}
}
void calc(int x, int y, int s) {
int p = y;
for(int i = LOGN - 1; i >= 0; i--) {
int p1 = anc[i][p];
if(dfn[p1] <= dfn[x]) continue;
int w1 = dep[p1] - dep[x];
int w2 = dep[y] - dep[p1];
if(mn[y] + w2 < mn[x] + w1) p = p1;
}
ans[mn[y].p] += sz[p] - sz[y];
ans[mn[x].p] += sz[s] - sz[p];
}
void dfs2(int u) {
ans[mn[u].p] += val[u];
for(int e = head[u]; e; e = pool[e].next) {
int v = pool[e].v, w = dep[v] - dep[u];
if(mn[u] + w < mn[v]) mn[v] = mn[u] + w;
calc(u, v, pool[e].son);
dfs2(v);
}
}
void work() {
dfs1(1);
dfs2(1);
}
void clear(int u) {
for(int e = head[u]; e; e = pool[e].next) {
int v = pool[e].v;
clear(v);
}
head[u] = 0;
}
// 清边
void clear() {
clear(1);
ne = 0;
}
}
int main() {
cin >> n;
for(int i = 1; i <= n - 1; i++) {
int u, v;
cin >> u >> v;
addEdge(u, v);
addEdge(v, u);
}
qr.reserve(n);
kp.reserve(n);
dfs1(1, 0);
init_st();
cin >> T;
while(T--) {
cin >> k;
qr.clear(); kp.clear(); top = 0;
for(int i = 1; i <= k; i++) {
cin >> p[i];
ans[p[i]] = 0;
qr.push_back(p[i]);
kp.push_back(p[i]);
}
sort(p + 1, p + 1 + k, cmp_dfn);
sta[++top] = 1;
for(int i = 1 + (p[1] == 1); i <= k; i++) {
int lca = getLCA(p[i], sta[top]);
if(lca == sta[top]) {
sta[++top] = p[i];
continue;
}
while(dfn[sta[top - 1]] > dfn[lca]) {
VT::addEdge(sta[top - 1], sta[top]);
--top;
}
if(sta[top - 1] == lca) {
VT::addEdge(sta[top - 1], sta[top]);
--top;
} else {
VT::addEdge(lca, sta[top]);
--top;
sta[++top] = lca;
kp.push_back(lca);
}
sta[++top] = p[i];
}
while(top >= 2) {
VT::addEdge(sta[top - 1], sta[top]);
--top;
}
for(int i = 1; i <= k; i++) iskey[p[i]] = 1;
VT::work();
for(int i : qr) cout << ans[i] << ' ';
cout << '\n';
for(int i = 1; i <= k; i++) iskey[p[i]] = 0;
VT::clear();
}
return 0;
}
|