跳转至

树上背包

对于树上的背包问题,需要把背包容量以一定方案分配给子树,暴力阶乘的时间复杂度是不可取的。树上分组背包可以将求出单个节点最优解的时间复杂度降低到 \(O(n^2)\)。对于某些上下界优化,可以把整棵树的均摊时间复杂度优化至 \(O(n^2)\)

例题

P1272 重建道路

题目大意

给定一棵有根树,切掉树上的一些边,分离一些子树,使剩下的主干上节点数量为 \(P\),求最少需要切断多少条边。

考虑树上 DP。设 \(f_{u,i}\) 表示在以 \(u\) 为根的子树的主干上保留 \(i\) 个节点,所需要的最小代价。

我们从左到右考虑每一棵子树。对于 \(v\) 子树,我们需要把它和前面的所有子树合并起来。初始时 \(f_{u}\) 表示不考虑 \(u\) 的所有子树(即把它们都砍断)的答案。由此得到 \(f\) 的初值应为

\[ f_{u,i} = \begin{cases} adj[u].size(),&i=1\\ +\infty,&i\ne 1 \end{cases} \]

考虑加入一棵子树 \(v\)

\[ f'_{u,i} = \min_{k\le i}(f_{u,i}, f_{u,i-k} + f_{v,k} - 1) \]

注意,每次转移时都会考虑到一棵新的子树,此时就需要把预先砍掉的边加回来,所以需要 \(-1\)

模板代码
#include<iostream>
#include<cstring>
#include<vector>
using namespace std;
const int N = 155;
const int INF = (int)0x3f3f3f3f3f3f3f3f;

int n, p, ans = INF;
int d[N], f[N][N], size[N];
vector<int> adj[N];

void dfs(int u) {
    f[u][1] = adj[u].size();
    for(auto v : adj[u]) {
        dfs(v);
        for(int i = p; i >= 2; i--) {
            for(int k = 1; k < i; k++) {
                f[u][i] = min(f[u][i], f[u][i - k] + f[v][k] - 1);
            }
        }
    }
    if(u == 1) ans = min(ans, f[u][p]);
    else ans = min(ans, f[u][p] + 1);
}

int main(){

    memset(f, 0x3f, sizeof(f));

    cin >> n >> p;
    for(int i = 1; i <= n - 1; i++) {
        int u, v;
        cin >> u >> v;
        adj[u].push_back(v);
    }

    dfs(1);

    cout << ans << endl;

    return 0;
}

P3177 [HAOI2015] 树上染色

题目大意

有一棵点数为 \(n\) 的树,树边有边权。给你一个在 \(0 \sim n\) 之内的正整数 \(k_1\) ,你要在这棵树中选择 \(k_1\) 个点,将其染成黑色,并将其他的 \(k_2=n-k_1\) 个点染成白色。将所有点染色后,你会获得黑点两两之间的距离加上白点两两之间的距离的和的收益。问收益最大值是多少。

\(n\le 2000\)

我们可以将 \(u\rightarrow v\) 的边的贡献交由 \(v\) 统计。我们写出转移方程:

\[ f'_{u,i}=\max_{k=0}^{k_1}\{f_{u,k}+f_{v,i-k_1}\} \]

处理完所有子树后,统计 \(fa[u]\rightarrow u\) 的边产生的贡献:

\[ f'_{u,i}=f_{u,i}+w(fa[u],u)\times \Big((k_1-i)i+\big(k_2-(sz[u]-i)\big)(sz[u]-i)\Big) \]

此种转移的时间复杂度为 \(O(n^3)\)。通过对背包过程中 \(i,k\) 的上下界剪枝,可以优化至 \(O(n^2)\)

\[ f'_{u,i}=\max_{k\le sz[u],\ i-k\le sz[v]}\{f_{u,k}+f_{v,i-k}\} \]

其中 \(sz[u]\) 为动态统计的子树大小,不包含 \(v\) 及后面还未统计的子树。

时间复杂度分析

注意到枚举的 \(k\le sz[u],\ i-k\le sz[v]\) 可以看作是枚举 \(sz[u]\)\(sz[v]\) 中的每一个点,将它们“两两合并”。因此任选两个点只会在其 lca 处合并,产生 \(O(1)\) 的时间复杂度。点对的数量是 \(O(n^2)\) 的,因此总时间复杂度为 \(O(n^2)\)

代码
#include<iostream>
#include<cstring>
#include<algorithm>
#define int long long
using namespace std;
const int N = 2010;

struct Edge {
    int v, w;
    int next;
} pool[2 * N];
int ne, head[N];

void addEdge(int u, int v, int w) {
    pool[++ne] = {v, w, head[u]};
    head[u] = ne;
}

int n, k1, k2;
int sz[N];
int f[N][N];

void dfs(int u, int fa, int fw) {
    sz[u] = 1;
    f[u][0] = 0;
    f[u][1] = 0;
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v, w = pool[e].w;
        if(v == fa) continue;
        dfs(v, u, w);
        for(int i = min(k1, sz[u] + sz[v]); i >= 0; i--) {
            for(int j = max(0ll, i - sz[u]); j <= min(min(i, k1), sz[v]); j++) {
                f[u][i] = max(f[u][i], f[u][i - j] + f[v][j]);
            }
        }
        sz[u] += sz[v];
    }
    for(int i = 0; i <= min(k1, sz[u]); i++) {
        f[u][i] += fw * ((k1 - i) * i + (k2 - (sz[u] - i)) * (sz[u] - i));
    }
}

signed main() {

    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    memset(f, -0x3f, sizeof(f));

    cin >> n >> k1;
    k2 = n - k1;
    for(int i = 1; i <= n - 1; i++) {
        int u, v, w;
        cin >> u >> v >> w;
        addEdge(u, v, w);
        addEdge(v, u, w);
    }

    dfs(1, 0, 0);

    cout << f[1][k1] << endl;

    return 0;
}