跳转至

树上背包

有些时候,我们需要把背包容量(\(m\))分配给树上的节点。并且由于转移的特殊性,我们必须在树上完成 dp(不能摊到序列上)。

树上背包的核心思想是:依次考虑节点的每一棵子树,记录这个前缀的 dp 答案;然后尝试加入一棵新的子树,将这棵子树和原先的前缀合并。

普通背包

直接在树上进行 dp。某些上下界优化可以显著降低时间复杂度。

P1272 重建道路

给定一棵有根树,切掉树上的一些边,分离一些子树,使剩下的主干上节点数量为 \(P\),求最少需要切断多少条边。

考虑树上 DP。设 \(f_{u,i}\) 表示在以 \(u\) 为根的子树的主干上保留 \(i\) 个节点,所需要的最小代价。

我们从左到右考虑每一棵子树。对于 \(v\) 子树,我们需要把它和前面的所有子树合并起来。初始时 \(f_{u}\) 表示不考虑 \(u\) 的所有子树(即把它们都砍断)的答案。由此得到 \(f\) 的初值应为

\[ f_{u,i} = \begin{cases} adj[u].size(),&i=1\\ +\infty,&i\ne 1 \end{cases} \]

考虑加入一棵子树 \(v\)

\[ f'_{u,i} = \min_{k\le i}(f_{u,i}, f_{u,i-k} + f_{v,k} - 1) \]

注意,每次转移时都会考虑到一棵新的子树,此时就需要把预先砍掉的边加回来,所以需要 \(-1\)

模板
#include<iostream>
#include<cstring>
#include<vector>
using namespace std;
const int N = 155;
const int INF = (int)0x3f3f3f3f3f3f3f3f;

int n, p, ans = INF;
int d[N], f[N][N], size[N];
vector<int> adj[N];

void dfs(int u) {
    f[u][1] = adj[u].size();
    for(auto v : adj[u]) {
        dfs(v);
        for(int i = p; i >= 2; i--) {
            for(int k = 1; k < i; k++) {
                f[u][i] = min(f[u][i], f[u][i - k] + f[v][k] - 1);
            }
        }
    }
    if(u == 1) ans = min(ans, f[u][p]);
    else ans = min(ans, f[u][p] + 1);
}

int main(){

    memset(f, 0x3f, sizeof(f));

    cin >> n >> p;
    for(int i = 1; i <= n - 1; i++) {
        int u, v;
        cin >> u >> v;
        adj[u].push_back(v);
    }

    dfs(1);

    cout << ans << endl;

    return 0;
}

P3177 [HAOI2015] 树上染色

给定一棵 \(n\) 个节点的树,边有边权。给定一个常数 \(m\in [0,n]\),你需要在这棵树中选择 \(m\) 个点染成黑色,其余点为白色。每个颜色相同的点对 \((u,v)\) 都会产生 \(dis(u,v)\) 的价值,问最大价值和。

\(n\le 2000\)

我们可以将 \(u\rightarrow v\) 的边的贡献交由 \(v\) 统计。我们先将所有子树的信息合并:

\[ f'_{u,j}=\max_{j_1=0}^{j}\{f_{u,j_1}+f_{v,j-j_1}\} \]

然后统计 \(fa[u]\rightarrow u\) 的边产生的贡献:

\[ f'_{u,j}=f_{u,j}+w(fa[u],u) \Big((m-j)j+\big((n-m)-(sz[u]-j)\big)(sz[u]-j)\Big) \]

朴素的转移为 \(O(n^3)\)。我们通过一些上下界优化,可以将时间复杂度优化至 \(O(nm)\)

具体的,对于 \(f_{u,j}\),我们应该保证 \(j\le sz[u]\)。因此在转移过程中,我们需要保证 \(j_1\) 不超过当前 \(u\) 已经考虑的子树大小,并且 \(j-j_1\) 不超过 \(v\) 的子树大小 \(sz[v]\)

分析为 \(O(n^2)\)

合并两棵子树消耗 \(sz[a]sz[b]\) 的时间,可以看作是将两棵子树内的点两两配对。显然,任意两个点只会在 \(lca\) 处被配对一次。时间复杂度不超过 \(O(n^2)\)

同时,\(j,\ j_1\)\(j-j_1\) 都不能超过 \(m\)。对于子树大小超过 \(m\) 的情况,由于 \(j\) 不能枚举到超过 \(m\) 的位置,因此这种情况等价于 \(sz=m\) 的情况。现在我们只需考虑 \(sz\le m\) 的情况。

分析为 \(O(nm)\)

树上背包

为了方便讨论,不妨将所有树都转换为二叉树的情形,二叉树上的每一个节点都代表一次合并。这个过程会将节点数量增大常数倍。

树上背包2

考虑一次合并,消耗的时间复杂度显然是 \(sz[a]sz[b]\)。不妨设 \(sz[a]\ge sz[b]\)

  • 对于 \(sz[a]=m\) 的情况,合并之后的子树大小 \(sz'=m=sz[a]\),等价于我们消耗 \(m\times sz[b]\) 的时间删掉了 \(sz[b]\) 个点;这里均摊 \(O(nm)\)
  • 同时,如果 \(sz[b]<m\),容易发现 \(b\) 内部的时间消耗为 \(O(sz[b]^2)\le O(m\times sz[b])\),可以作为常数直接消去;
  • 对于 \(sz[a]<m\)\(sz[a]+sz[b]\ge m\) 的情况\(a,b\) 子树内部的复杂度不超过 \(O(m^2)\),本次合并的时间复杂度也不超过 \(O(m^2)\);由于这种情况之间没有包含关系,因此最多出现 \(\frac{n}{m}\) 次,总共消耗 \(O(nm)\) 的时间复杂度。
  • 对于 \(sz[a]+sz[b]<m\) 的情况,不断向上跳祖先,找到最靠上的满足 \(sz<m\) 的祖先 \(p\)\(p\) 子树内部的时间消耗为 \(O(sz[p]^2)\),这部分时间消耗在 \(p\) 参与的下一次合并中被作为常数去掉了;

因此,如下形式的树上背包,时间复杂度不超过 \(O(nm)\)

\[ f_{u,j}=\max_{j_1+j_2=j,\ j_1\le sz[u],\ j_2\le sz[v]}\{f_{u,j_1}+f_{v,j_2}\},\quad j\le m \]

其中 \(sz[u]\) 表示当前 \(u\) 已经考虑的子树大小。

推广

当节点消耗的容量不为 \(1\) 时,一棵子树消耗的容量不超过子树内节点的代价之和,此时我们只需要将每个点拆成 \(w[u]\)(它的代价)个点即可。这样分析出来时间复杂度为 \(O\Bigl(m\sum w[u]\Big)\)

代码
#include<iostream>
#define int long long
using namespace std;
const int N = 2010;

struct Edge {
    int v, w, next;
} pool[2 * N];
int ne, head[N];

void addEdge(int u, int v, int w) {
    pool[++ne] = {v, w, head[u]};
    head[u] = ne;
}

int n, k;
int sz[N];
int f[N][N];

void dfs(int u, int fa, int fw) {
    sz[u] = 1;
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v, w = pool[e].w;
        if(v == fa) continue;
        dfs(v, u, w);
        for(int i = min(sz[u], k); i >= 0; i--) {
            for(int j = min(sz[v], k - i); j >= 0; j--) {
                f[u][i + j] = max(f[u][i + j], f[u][i] + f[v][j]);
            }
        }
        sz[u] += sz[v];
    }
    for(int i = 0; i <= sz[u] && i <= k; i++) f[u][i] += (i * (k - i) + (sz[u] - i) * (n - k - sz[u] + i)) * fw;
}

signed main() {

    cin >> n >> k;
    for(int i = 1; i <= n - 1; i++) {
        int u, v, w;
        cin >> u >> v >> w;
        addEdge(u, v, w);
        addEdge(v, u, w);
    }

    dfs(1, 0, 0);

    cout << f[1][k] << endl;

    return 0;
}

qoj5357 芒果冰加了空气

给定一棵树,计数它的点分树,不要求分治中心选择在重心。

\(n\le 5000\)

对于一棵点分树,我们切断原树上的任意一条边,都可以得到一种唯一的两个连通块的点分树方案。具体的,设删掉的边为 \((u,v)\),我们先删去点分树上所有跨过 \((u,v)\) 的边,然后在不改变祖先关系的前提下向两个连通块的点分树中补充一些边,使它们分别连通成一棵树。

考虑在两个连通块之间连接一条边 \((u,v)\),两棵点分树如何合并。根据上面分裂的过程,我们不难发现,只需要将 \(u,v\) 在点分树上的返根链归并起来,就可以得到合并后的结果。这个过程可以类比 FHQ 的 splitmerge 的过程。

显然,合并的过程产生的方案数只与 \(u,v\) 在点分树中的深度有关。因此我们记 \(f_{u,j}\) 表示 \(u\) 节点在 \(u\) 子树的点分树中深度为 \(j\),点分树的方案数。考虑当前点分树中 \(u\) 的返根链中哪些节点是 \(v\) 子树提供的,容易写出转移式:

\[ f'_{u,j}=\sum_{j_1=0}^{j-1}{\Big(f_{u,j-j_1}\big(\sum_{k\ge j_1}f_{v,j_1}\big)\binom{j-1}{j_1}\Big)} \]

这个转移可以做到树上背包 \(O(n^2)\) 的时间复杂度。

代码
#include<iostream>
#define ll long long
using namespace std;
const int N = 5010;
const int MOD = 1e9 + 7;

struct Edge {
    int v, next;
} pool[2 * N];
int ne, head[N];

void addEdge(int u, int v) {
    pool[++ne] = {v, head[u]};
    head[u] = ne;
}

int n;
int sz[N];
int f[N][N], c[N][N];

void dfs(int u, int fa) {
    f[u][1] = 1;
    sz[u] = 1;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(v == fa) continue;
        dfs(v, u);
        sz[u] += sz[v];
        for(int j = sz[u]; j >= 1; j--) {
            int res = 0;
            for(int j1 = min(j, sz[v]); j1 >= max(0, j - (sz[u] - sz[v])); j1--) {
                (res += (ll)f[u][j - j1] * f[v][j1] % MOD * c[j - 1][j1] % MOD) %= MOD;
            }
            f[u][j] = res;
        }
    }
    for(int i = n - 1; i >= 0; i--) (f[u][i] += f[u][i + 1]) %= MOD;
}

int main() {

    cin >> n;
    for(int i = 1; i <= n - 1; i++) {
        int u, v;
        cin >> u >> v;
        addEdge(u, v);
        addEdge(v, u);
    }

    c[0][0] = 1;
    for(int i = 1; i <= n; i++) {
        c[i][0] = 1;
        for(int j = 1; j <= i; j++) {
            c[i][j] = (c[i - 1][j - 1] + c[i - 1][j]) % MOD;
        }
    }

    dfs(1, 0);

    cout << f[1][1] << endl;

    return 0;
}

依赖型背包

若选择一个节点则必须选择它的父亲,这种背包被称为依赖型背包。我们可以将它拍到 dfn 序列上获得更优的复杂度。

250430 D10 模拟赛 T1

给定一棵以 \(1\) 为根的有根树,每个节点对应一个物品,有两个属性:价值 \(a_i\) 和代价 \(b_i\)。一个物品依赖于它的父节点(根节点除外)。你需要选择一些物品,总代价不能超过 \(m\),求最大价值和。

\(n,m\le 2000,\ 1\le |a_i|,b_i\le 10^4\)

普通树上背包只能做到 \(O(nmV)\) 的复杂度。我们可以在 dfn 序列上考虑这个问题。设 \(f_{i,j}\) 表示考虑 dfn 序列 \([i,n]\) 的节点,背包容量为 \(j\) 的最大价值。考虑转移,若选择 dfn\(i\) 的节点,那么从 \(f_{i+1,j-b[i]}\) 转移即可;若不选,那么 \(i\) 子树内的节点一个都不能选,直接跳过即可,从 \(f_{i+sz[i],j}\) 转移。

时间复杂度 \(O(nm)\)

代码
#include<iostream>
#include<cstring>
#include<set>
#include<cassert>
using namespace std;
const int N = 2010;

struct Edge {
    int v, next;
} pool[2 * N];
int ne, head[N];

void addEdge(int u, int v) {
    pool[++ne] = {v, head[u]};
    head[u] = ne;
}

int n, m;

int a1[N], b1[N], a[N], b[N];
int sz[N], dfn[N], dt;

void dfs(int u, int fa) {
    dfn[u] = ++dt;
    sz[dt] = 1;
    a[dt] = a1[u];
    b[dt] = b1[u];
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(v == fa) continue;
        dfs(v, u);
        sz[dfn[u]] += sz[dfn[v]];
    }
}

int f[N][N];

// #define FIO

int main() {

    #ifdef FIO
    freopen("b.in", "r", stdin);
    freopen("b.out", "w", stdout);
    #endif

    cin >> n >> m;
    for(int i = 1; i <= n; i++) cin >> a1[i] >> b1[i];
    for(int i = 1; i <= n - 1; i++) {
        int u, v;
        cin >> u >> v;
        addEdge(u, v);
        addEdge(v, u);
    }

    dfs(1, 0);

    for(int i = n; i >= 1; i--) {
        for(int j = 0; j <= m; j++) {
            f[i][j] = f[i + sz[i]][j];
            if(j >= b[i]) f[i][j] = max(f[i][j], f[i + 1][j - b[i]] + a[i]);
        }
    }

    cout << f[1][m] << '\n';

    return 0;
}

P6326 Shopping

给定一棵树,第 \(i\) 个节点有一种价格 \(c_i\),价值 \(w_i\),库存 \(d_i\) 的物品,你有 \(m\) 元钱。有一个额外的规定:你购买了至少一件物品的商店必须组成一个连通块。问最多获得多少价值。

\(n\le 500,\ m\le 4000,\ w_i\le 4000,\ c_i\le m,\ d_i\le 2000\)

多测,\(T\le 5\)

普通树上背包容易解决连通块的限制,但是时间复杂度达到 \(O(nm^2)\),不能接受。

对于树上依赖性背包,我们需要对每棵子树都独立跑一遍,复杂度达到 \(O(n^2m)\)(单调队列优化可以不带 \(\log\)),也不能接受。

对于连通块考虑点分治。钦定选择分治中心即可,时间复杂度 \(O(nm\log n)\),可以通过。

代码
#include<iostream>
using namespace std;
const int N = 510;
const int M = 4010;

struct Edge {
    int v, next;
} pool[2 * N];
int ne, head[N];

void addEdge(int u, int v) {
    pool[++ne] = {v, head[u]};
    head[u] = ne;
}

int n, m;
int ans;
int w[N], c[M], d[N];

int vis[N];
int sz[N], mxp[N] = {N};

void get_sz(int u, int fa) {
    sz[u] = 1;
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v;
        if(v == fa || vis[v]) continue;
        get_sz(v, u);
        sz[u] += sz[v];
    }
}

void get_rt(int u, int fa, int tot, int &rt) {
    mxp[u] = 0;
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v;
        if(v == fa || vis[v]) continue;
        get_rt(v, u, tot, rt);
        mxp[u] = max(mxp[u], sz[v]);
    }
    mxp[u] = max(mxp[u], tot - sz[u]);
    if(mxp[u] < mxp[rt]) rt = u;
}

int dfn[N], id[N], dt;

void dfs(int u, int fa) {
    dfn[u] = ++dt;
    id[dt] = u;
    sz[u] = 1;
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v;
        if(v == fa || vis[v]) continue;
        dfs(v, u);
        sz[u] += sz[v];
    }
}

int f[N][M];
int que[N], hd, tl;

void calc(int u) {
    dt = 0;
    dfs(u, 0);
    for(int j = 0; j <= m; j++) f[dt + 1][j] = 0;
    for(int ii = dt; ii >= 1; ii--) {
        int i = id[ii];
        for(int j = 0; j <= m; j++) f[ii][j] = f[ii + sz[i]][j];
        for(int r = 0; r < c[i]; r++) {
            hd = 1, tl = 0;
            que[++tl] = r;
            for(int j = r + c[i]; j <= m; j += c[i]) {
                while(hd <= tl && que[hd] < j - d[i] * c[i]) ++hd;
                f[ii][j] = max(f[ii][j], f[ii + 1][que[hd]] + (j - que[hd]) / c[i] * w[i]);
                while(hd <= tl && f[ii + 1][que[tl]] - que[tl] / c[i] * w[i] <= f[ii + 1][j] - j / c[i] * w[i]) --tl;
                que[++tl] = j;
            }
        }
    }
    ans = max(ans, f[1][m]);
}

void solve(int u) {
    int rt = 0;
    get_sz(u, 0);
    get_rt(u, 0, sz[u], rt);
    u = rt;
    vis[u] = 1;
    calc(u);
    for(int e = head[u]; e; e = pool[e].next) {
        int v = pool[e].v;
        if(vis[v]) continue;
        solve(v);
    }
}

void clear() {
    ne = 0;
    ans = 0;
    for(int i = 1; i <= n; i++) head[i] = 0;
    for(int i = 1; i <= n; i++) vis[i] = 0;
}

int main() {

    ios::sync_with_stdio(0);
    cin.tie(0);

    int T;
    cin >> T;
    while(T--) {

        cin >> n >> m;
        clear();
        for(int i = 1; i <= n; i++) cin >> w[i];
        for(int i = 1; i <= n; i++) cin >> c[i];
        for(int i = 1; i <= n; i++) cin >> d[i];

        for(int i = 1; i <= n - 1; i++) {
            int u, v;
            cin >> u >> v;
            addEdge(u, v);
            addEdge(v, u);
        }

        solve(1);

        cout << ans << '\n';

    }

    return 0;
}