跳转至

换根DP

有些问题需要对树上的每个节点求出答案。如果对每个节点都应用一遍 \(dfs\),时间复杂度将会达到 \(O(n^2)\)。这是因为每次 dfs 都会和上次得出的答案有一些重复的计算。换根 DP 就能充分运用这些重复的信息,将时间复杂度降低到 \(O(n)\)

换根 DP 的具体流程如下:

  • 随便选择一个节点作为初始根节点 \(root\),跑一遍 dfs;
  • 再从 \(root\) 跑一遍换根转移的 dfs,该转移能通过父亲的答案以较低时间复杂度转移到儿子。

例题

P3647 [APIO2014] 连珠线

题目大意

有一种游戏:

初始图上有一个节点,你需要加入 \(n-1\) 个点,加点有两种方式:

  • 从一个已有的节点 \(u\) 向新点 \(w\) 连一条红边;
  • 插入在两个由红边相连的节点 \(u,v\) 之间,即删去 \(u,v\) 之间红线,分别用蓝线连接 \(u,w\)\(w,v\)

给定一个该游戏的终止局面,但不给定边的颜色,问合法的染色方案中蓝边边权和最大是多少。

不保证 \(1\) 号节点是游戏的初始节点。

\(n\le 2\times 10^5\)

蓝边显然可以两两配对,形成一些长度为 \(2\) 条边的链。

我们考虑 DP,但有一些情况无法考虑,比如:

alt text

就是不合法的,但是 DP 不易去除这种情况。这是因为那些不是返祖链的蓝链要求父节点(链的 LCA)加入的时间最晚,导致状态数很多,转移复杂。

在思考的过程中,容易发现游戏的初始节点比较特殊,我们考虑以这个节点为根节点时蓝链的情况。容易发现此时每条蓝链都一定是返祖链,证明比较显然。同时,如果所有蓝链都是返祖链且不交,则一定是一种合法的局面。这极大的简化了 DP,因为我们可以只考虑蓝链是返祖链的情况。

假设我们已知起始节点(所有蓝链都是返祖链),考虑 DP,设 \(f_{u,0/1}\) 表示节点 \(u\) 是否是返祖链的中点,子树内产生的最大得分。记

\[ \begin{align*} v&\in to[u]\\ g_{v,0}&=\max\big(f_{v,0},f_{v,1}+w(u,v)\big)\\ g_{v,1}&=f_{v,0}+w(u,v) \end{align*} \]

容易写出转移方程:

\[ \begin{cases} f_{u,0}=\sum\limits_{v\in to[u]}{g_{v,0}}\\ f_{u,1}=f_{u,0}+\max\{g_{v,1}-g_{v,0}\} \end{cases} \]

到这里已经可以 \(O(n^2)\) 做了。但是过不了 \(2\times 10^5\)。这是因为我们不知道初始节点是哪个节点,只能分别尝试每一个节点作为根节点时的答案。这正是换根 DP 能够优化的一类问题。

考虑换根,子节点求和很好处理,而最大值考虑维护值最大、次大的两棵子树,分讨即可。

代码
#include<iostream>
using namespace std;
const int N = 2E5 + 10;
const int INF = (int)0x3f3f3f3f3f3f3f3f;

struct Edge {
    int v, w, next;
} pool[2 * N];
int nn, head[N];

void addEdge(int u, int v, int w) {
    pool[++nn] = {v, w, head[u]};
    head[u] = nn;
}

int n, ans;
int f[N][2], g[N][2], mx[N][2];

void dfs1(int u, int fa) {
    g[0][1] = -INF;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v, w = pool[i].w;
        if(v == fa) continue;
        dfs1(v, u);
        g[v][0] = max(f[v][0], f[v][1] + w);
        g[v][1] = f[v][0] + w - g[v][0];
        f[u][0] += g[v][0];
        // 维护最大和次大值
        if(g[v][1] >= g[mx[u][0]][1]) {
            mx[u][1] = mx[u][0];
            mx[u][0] = v;
        } else if(g[v][1] > g[mx[u][1]][1]) {
            mx[u][1] = v;
        }
    }
    f[u][1] = f[u][0] + g[mx[u][0]][1];
}

// 代码中 g[u][1] 实际表示 g[u][1] - g[u][0]
void dfs2(int u, int fa) {
    ans = max(ans, f[u][0]);
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v, w = pool[i].w;
        if(v == fa) continue;
        int f2[2][2], g2[2][2];
        f2[0][0] = f[u][0], f2[0][1] = f[u][1];
        f2[1][0] = f[v][0], f2[1][1] = f[v][1];
        g2[0][0] = g[u][0], g2[0][1] = g[u][1];
        g2[1][0] = g[v][0], g2[1][1] = g[v][1];
        // 更新 dp 数组
        f[u][0] -= g[v][0];
        f[u][1] -= g[v][0];
        if(v == mx[u][0]) f[u][1] += g[mx[u][1]][1] - g[v][1];
        g[u][0] = max(f[u][0], f[u][1] + w);
        g[u][1] = f[u][0] + w - g[u][0];
        if(g[u][1] >= g[mx[v][0]][1]) {
            mx[v][1] = mx[v][0];
            mx[v][0] = u;
        } else if(g[u][1] > g[mx[v][1]][1]) {
            mx[v][1] = u;
        }
        f[v][0] += g[u][0];
        f[v][1] = f[v][0] + g[mx[v][0]][1];
        dfs2(v, u);
        // 记得回溯
        f[u][0] = f2[0][0], f[u][1] = f2[0][1];
        f[v][0] = f2[1][0], f[v][1] = f2[1][1];
        g[u][0] = g2[0][0], g[u][1] = g2[0][1];
        g[v][0] = g2[1][0], g[v][1] = g2[1][1];
    }
}

int main() {

    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    cin >> n;
    for(int i = 1; i <= n - 1; i++) {
        int u, v, w;
        cin >> u >> v >> w;
        addEdge(u, v, w);
        addEdge(v, u, w);
    }

    dfs1(1, 0);
    dfs2(1, 0);

    cout << ans << endl;

    return 0;
}

P3047 [USACO12FEB] Nearby Cows G

题意

给定一棵树,点有点权。对每个点求出距离它不超过 \(k\) 的点的点权之和。

\(f_{u,i}\) 表示 \(u\) 子树内距离 \(u\) 节点不超过 \(i\) 的节点点权和,根节点的答案容易求出。考虑换根,由于加法满足差分性,去除子树的贡献是容易的,无需前缀和/后缀和解决。

代码
#include<iostream>
using namespace std;
const int N = 1E5 + 10;
const int K = 22;

struct Edge {
    int v, next;
} pool[2 * N];
int ne, head[N];

void addEdge(int u, int v) {
    pool[++ne] = {v, head[u]};
    head[u] = ne;
}

int n, k;
int w[N], ans[N];

int f[N][K];

void dfs1(int u, int fa) {
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(v == fa) continue;
        dfs1(v, u);
        for(int i = 1; i <= k; i++) {
            f[u][i] += f[v][i - 1];
        }
    }
    for(int i = 0; i <= k; i++) {
        f[u][i] += w[u];
    }
}

void dfs2(int u, int fa) {
    ans[u] = f[u][k];
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(v == fa) continue;
        for(int i = k; i >= 1; i--) {
            f[v][i] += f[u][i - 1] - (i >= 2 ? f[v][i - 2] : 0);
        }
        dfs2(v, u);
    }
}

int main() {

    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    cin >> n >> k;
    for(int i = 1; i <= n - 1; i++) {
        int u, v;
        cin >> u >> v;
        addEdge(u, v);
        addEdge(v, u);
    }
    for(int i = 1; i <= n; i++) cin >> w[i];

    dfs1(1, 0);
    dfs2(1, 0);

    for(int i = 1; i <= n; i++) {
        cout << ans[i] << '\n';
    }

    return 0;
}

P3478 [POI2008] STA-Station

题意

给定一个 n 个点的树,请求出一个结点,使得以这个结点为根时,所有结点的深度之和最大。

一个结点的深度之定义为该节点到根的简单路径上边的数量。

假如已知根节点,写出求深度之和的 dp 转移:

\[ f_{u}=\sum_{v\in to[u]}{f_{v}+sz[v]} \]

简单的换根可以不依赖上面的转移方程。直接写出换根方程:

\[ f_{v}=f_{u}-sz[v]+(n-sz[v]) \]
代码
#include<iostream>
#define int long long
using namespace std;
const int N = 1E6 + 10;

struct Edge {
    int v, next;
} pool[2 * N];
int nn, head[N];

void addEdge(int u, int v) {
    pool[++nn] = {v, head[u]};
    head[u] = nn;
}

int n, ans;
int sz[N], f[N];

void dfs1(int u, int fa) {
    sz[u] = 1;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(v == fa) continue;
        dfs1(v, u);
        sz[u] += sz[v];
        f[u] += f[v] + sz[v];
    }
}

void dfs2(int u, int fa) {
    if(f[u] > f[ans]) ans = u;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(v == fa) continue;
        f[v] = f[u] - sz[v] + (n - sz[v]);
        dfs2(v, u);
    }
}

signed main() {

    cin >> n;
    for(int i = 1; i <= n - 1; i++) {
        int u, v;
        cin >> u >> v;
        addEdge(u, v);
        addEdge(v, u);
    }

    dfs1(1, 0);
    dfs2(1, 0);

    cout << ans << endl;

    return 0;
}