动态 DP
前置知识:线段树维护矩阵乘法,线段树优化 DP,重链剖分。
题意
给定一棵 \(n\) 个点的树,点有点权。有 \(m\) 次操作,每次操作给定 \(x,y\),表示修改点 \(x\) 的权值为 \(y\)。你需要在每次操作之后求出这棵树的最大权独立集的权值大小。
\(n,m\le 10^5\)
如果修改权值之后暴力 DP 求出答案,则时间复杂度为 \(O(nm)\),不能通过本题。
容易发现修改 \(x\) 节点的权值只会改变 \(x\) 到根节点路径上所有节点的 dp 值,其它位置则保持不变。然而返根链的长度之和仍然是 \(O(nm)\) 量级的,考虑优化。
我们知道,重链剖分通过将树剖分为若干重链,使得任意一个节点到根节点都只经过不超过 \(\log_2n\) 条不同的重链。由于一条重链上所有节点的 \(dfn\) 序是连续的,可以使用线段树维护每条重链的信息,实现 \(O(\log^2 n)\) 修改和查询路径信息。
这启示我们可以通过将树剖分为若干重链,以此减少操作返根链的时间复杂度。但是一个节点的 dp 值还受其轻儿子影响,应该如何使用线段树维护重链信息呢?
我们写出朴素的 dp 转移式:
\[
\begin{align*}
f_{u,0}&=\left(\sum_{v\in to[u]}{\max(f_{v,0},f_{v,1})}\right)\\
f_{u,1}&=\left(\sum_{v\in to[u]}{f_{v,0}}\right)+w[u]
\end{align*}
\]
为了方便书写,我们定义 \(0\) 号节点为一个虚点,是所有叶子节点的唯一“儿子”。其 dp 值 \(f_{u,0}=0\),\(f_{u,1}=-\infty\)。
记 \(v_0\) 表示 \(u\) 的重儿子;分离重儿子和轻儿子:
\[
\begin{align*}
f_{u,0}&=\left(\sum_{v\in to[u]/\{v_0\}}{\max(f_{v,0},f_{v,1})}\right)+\max(f_{v_0,0},f_{v_0,1})\\
f_{u,1}&=\left(\sum_{v\in to[u]/\{v_0\}}{f_{v,0}}\right)+w[u]+f_{v_0,0}
\end{align*}
\]
树上动态 DP 的核心思想就是把所有轻儿子的信息浓缩为一个矩阵,将这个矩阵与重儿子的 dp 向量相乘,就能得到当前节点的 dp 向量。根据转移方程,我们定义此处的矩阵乘法为 \(\left<\max,+\right>\) 的广义矩阵乘法。
记
\[
\begin{align*}
g_{u,0}&=\sum_{v\in to[u]/\{v_0\}}{f_{v,0}}\\
g_{u,1}&=\sum_{v\in to[u]/\{v_0\}}{\max(f_{v,0},f_{v,1})}
\end{align*}
\]
则转移式可以写成
\[
\begin{align*}
f_{u,0}&=g_{u,1}+\max(f_{v_0,0},f_{v_0,1})\\
f_{u,1}&=g_{u,0}+w[u]+f_{v_0,0}
\end{align*}
\]
写成矩阵乘法的形式:
\[
\left[
\begin{matrix}
f_{u,0}\\
f_{u,1}
\end{matrix}
\right]=
\left[
\begin{matrix}
g_{u,1}& g_{u,1}\\
g_{u,0}+w[u]& -\infty
\end{matrix}
\right]
\left[
\begin{matrix}
f_{v_0,0}\\
f_{v_0,1}
\end{matrix}
\right]
\]
由于 \(g\) 只和轻儿子有关,因此每个节点的转移矩阵也就和重儿子无关。使用线段树维护每一条重链上转移矩阵的乘积。对于节点 \(u\),记 \(\operatorname{bot}[u]\) 为其链底节点,我们用线段树查询 \([u,\operatorname{bot}[u]]\) 中节点的转移矩阵的乘积 \(op\),将 \(op\) 左乘 \(0\) 号节点的 dp 向量,即可得到 \(u\) 节点的 dp 向量。
修改时,先修改 \(x\) 节点的权值 \(w[x]\) 和其转移矩阵;然后从下往上考虑 \(x\) 到根节点经过的所有轻边 \(fa[x']\rightarrow x'\),使用线段树查询 \(x'\) 的 dp 向量,然后更新 \(fa[x']\) 的转移矩阵即可。
注意矩阵乘法的左右顺序
矩阵乘法没有交换律,使用线段树维护时一定要注意乘法顺序。可以根据
\[
(AB)^{\operatorname{T}}=B^{\operatorname{T}}A^{\operatorname{T}}
\]
转置矩阵并交换顺序。
代码
| #include<iostream>
using namespace std;
const int N = 1E5 + 10;
const int INF = (int)0x3f3f3f3f3f3f3f3f;
struct Edge {
int v, next;
} pool[2 * N];
int ne, head[N];
void addEdge(int u, int v) {
pool[++ne] = {v, head[u]};
head[u] = ne;
}
struct Matrix {
int a[2][2];
inline Matrix() { a[0][0] = a[0][1] = a[1][0] = a[1][1] = 0; }
inline int* operator[](int index) { return a[index]; }
inline const int* operator[](int index) const { return a[index]; }
inline Matrix operator*(const Matrix& b) const {
Matrix res;
res[0][0] = max(a[0][0] + b[0][0], a[0][1] + b[1][0]);
res[0][1] = max(a[0][0] + b[0][1], a[0][1] + b[1][1]);
res[1][0] = max(a[1][0] + b[0][0], a[1][1] + b[1][0]);
res[1][1] = max(a[1][0] + b[0][1], a[1][1] + b[1][1]);
return res;
}
};
Matrix a[N];
namespace SegT {
Matrix tr[4 * N];
inline int lc(int x) { return x << 1; }
inline int rc(int x) { return x << 1 | 1; }
inline void push_up(int p) {
tr[p] = tr[lc(p)] * tr[rc(p)];
}
void build(int p, int l, int r) {
if(l == r) {
tr[p] = a[l];
return;
}
int mid = (l + r) >> 1;
build(lc(p), l, mid);
build(rc(p), mid + 1, r);
push_up(p);
}
void update(int p, int l, int r, int q) {
if(l == r) {
tr[p] = a[l];
return;
}
int mid = (l + r) >> 1;
if(q <= mid) update(lc(p), l, mid, q);
else update(rc(p), mid + 1, r, q);
push_up(p);
}
Matrix query(int p, int l, int r, int ql, int qr) {
if(ql <= l && r <= qr) {
return tr[p];
}
int mid = (l + r) >> 1;
if(mid >= qr) return query(lc(p), l, mid, ql, qr);
if(mid < ql) return query(rc(p), mid + 1, r, ql, qr);
return query(lc(p), l, mid, ql, qr) * query(rc(p), mid + 1, r, ql, qr);
}
}
int n, m;
int w[N];
int sz[N], son[N], fa[N];
int dfn[N], id[N], top[N], bot[N], dt;
int f[N][2], g[N][2];
void get_sz(int u, int a0) {
sz[u] = 1;
fa[u] = a0;
for(int i = head[u]; i; i = pool[i].next) {
int v = pool[i].v;
if(v == a0) continue;
get_sz(v, u);
sz[u] += sz[v];
if(sz[v] > sz[son[u]]) son[u] = v;
}
}
void init(int u, int a0, int tp) {
dfn[u] = ++dt;
id[dt] = u;
top[u] = tp;
if(!son[u]) {
f[u][0] = 0;
f[u][1] = w[u];
Matrix& cur = a[dfn[u]];
cur[1][0] = w[u];
cur[1][1] = -INF;
bot[u] = u;
return;
}
init(son[u], u, tp);
bot[u] = bot[son[u]];
f[u][0] += max(f[son[u]][0], f[son[u]][1]);
f[u][1] += f[son[u]][0] + w[u];
for(int i = head[u]; i; i = pool[i].next) {
int v = pool[i].v;
if(v == a0 || v == son[u]) continue;
init(v, u, v);
g[u][0] += f[v][0];
g[u][1] += max(f[v][0], f[v][1]);
}
f[u][0] += g[u][1];
f[u][1] += g[u][0];
Matrix& cur = a[dfn[u]];
cur[0][0] = cur[0][1] = g[u][1];
cur[1][0] = g[u][0] + w[u];
cur[1][1] = -INF;
}
void modify(int x, int y) {
a[dfn[x]][1][0] += y - w[x];
w[x] = y;
SegT::update(1, 1, n, dfn[x]);
x = top[x];
while(x != 1) {
Matrix cur = SegT::query(1, 1, n, dfn[x], dfn[bot[x]]);
a[dfn[fa[x]]][0][0] -= max(f[x][0], f[x][1]);
a[dfn[fa[x]]][1][0] -= f[x][0];
f[x][0] = cur[0][0];
f[x][1] = cur[1][0];
a[dfn[fa[x]]][0][0] += max(f[x][0], f[x][1]);
a[dfn[fa[x]]][0][1] = a[dfn[fa[x]]][0][0];
a[dfn[fa[x]]][1][0] += f[x][0];
SegT::update(1, 1, n, dfn[fa[x]]);
x = top[fa[x]];
}
}
int query() {
Matrix rt = SegT::query(1, 1, n, dfn[1], dfn[bot[1]]);
return max(rt[0][0], rt[1][0]);
}
int main() {
cin >> n >> m;
for(int i = 1; i <= n; i++) cin >> w[i];
for(int i = 1; i <= n - 1; i++) {
int u, v;
cin >> u >> v;
addEdge(u, v);
addEdge(v, u);
}
get_sz(1, 0);
init(1, 0, 1);
SegT::build(1, 1, n);
while(m--) {
int x, y;
cin >> x >> y;
modify(x, y);
cout << query() << '\n';
}
return 0;
}
|
题意
给定一棵 \(n\) 个点的树,点有点权,有 \(m\) 次询问;每次询问给定 \((a,x,b,y)\),保证 \(x,y\in \{0,1\}\),表示钦定节点 \(a,b\) 选或不选,树的最小权点覆盖。
\(n,m\le 10^5\)
注意到,对于一次询问,我们只需要修改 \(a,b\) 的权值为 \(+\infty\) 或 \(-\infty\),然后查询全局最小点覆盖即可。
代码
| #include<cstdio>
#include<ctype.h>
#define ll long long
using namespace std;
const int N = 1e5 + 10;
const ll INF = 0x3f3f3f3f3f3f3f3f;
const ll INF2 = 0x3f3f3f3f3f3f3f;
template<typename _Tp>
inline const _Tp& min(const _Tp& a, const _Tp& b) {
if(a < b) return a;
return b;
}
template<typename _Tp>
inline const _Tp& max(const _Tp& a, const _Tp& b) {
if(a < b) return b;
return a;
}
struct my_istream {
template<typename _Tp>
my_istream& operator>>(_Tp& x) {
char ch;
while(!isdigit(ch = getchar_unlocked()));
x = ch - 48;
while(isdigit(ch = getchar_unlocked())) x = x * 10 + ch - 48;
return *this;
}
} cin;
struct my_ostream {
char buf[60], nb;
inline my_ostream() { nb = 0; }
my_ostream& operator<<(int x) {
while(x) buf[++nb] = x % 10, x /= 10;
while(nb) putchar(buf[nb--] + 48);
return *this;
}
my_ostream& operator<<(ll x) {
while(x) buf[++nb] = x % 10, x /= 10;
while(nb) putchar(buf[nb--] + 48);
return *this;
}
my_ostream& operator<<(const char* s) {
while(*s) putchar(*(s++));
return *this;
}
} cout;
struct Edge {
int v, next;
} pool[2 * N];
int ne, head[N];
void addEdge(int u, int v) {
pool[++ne] = {v, head[u]};
head[u] = ne;
}
struct Matrix {
ll a[2][2];
inline ll* operator[](int index) { return a[index]; }
inline const ll* operator[](int index) const { return a[index]; }
inline Matrix() { a[0][0] = a[0][1] = a[1][0] = a[1][1] = 0; }
inline Matrix operator*(const Matrix& b) const {
Matrix res;
res[0][0] = min(a[0][0] + b[0][0], a[0][1] + b[1][0]);
res[0][1] = min(a[0][0] + b[0][1], a[0][1] + b[1][1]);
res[1][0] = min(a[1][0] + b[0][0], a[1][1] + b[1][0]);
res[1][1] = min(a[1][0] + b[0][1], a[1][1] + b[1][1]);
return res;
}
};
int n, m;
ll w[N];
int sz[N], son[N], top[N], bot[N], fa[N];
int dfn[N], dt;
ll f[N][2], g[N][2];
Matrix a[N];
namespace SegT {
Matrix tr[4 * N];
inline int lc(int x) { return x << 1; }
inline int rc(int x) { return x << 1 | 1; }
inline void push_up(int p) {
tr[p] = tr[lc(p)] * tr[rc(p)];
}
void build(int p, int l, int r) {
if(l == r) {
tr[p] = a[l];
return;
}
int mid = (l + r) >> 1;
build(lc(p), l, mid);
build(rc(p), mid + 1, r);
push_up(p);
}
void update(int p, int l, int r, int q) {
if(l == r) {
tr[p] = a[l];
return;
}
int mid = (l + r) >> 1;
if(mid >= q) update(lc(p), l, mid, q);
else update(rc(p), mid + 1, r, q);
push_up(p);
}
void query(int p, int l, int r, int ql, int qr, Matrix& res) {
if(ql <= l && r <= qr) {
res = res * tr[p];
return;
}
int mid = (l + r) >> 1;
if(mid >= ql) query(lc(p), l, mid, ql, qr, res);
if(mid < qr) query(rc(p), mid + 1, r, ql, qr, res);
}
}
void init1(int u, int a0) {
sz[u] = 1;
fa[u] = a0;
for(int i = head[u]; i; i = pool[i].next) {
int v = pool[i].v;
if(v == a0) continue;
init1(v, u);
if(sz[v] > sz[son[u]]) son[u] = v;
sz[u] += sz[v];
}
}
void init2(int u, int a0, int tp) {
dfn[u] = ++dt;
top[u] = tp;
if(son[u]) init2(son[u], u, tp), bot[u] = bot[son[u]];
else bot[u] = u;
for(int i = head[u]; i; i = pool[i].next) {
int v = pool[i].v;
if(v == a0 || v == son[u]) continue;
init2(v, u, v);
g[u][0] += min(f[v][0], f[v][1]);
g[u][1] += f[v][1];
}
f[u][0] = g[u][1] + f[son[u]][1];
f[u][1] = g[u][0] + min(f[son[u]][0], f[son[u]][1]) + w[u];
Matrix &cur = a[dfn[u]];
cur[0][0] = INF;
cur[0][1] = g[u][1];
cur[1][0] = g[u][0] + w[u];
cur[1][1] = g[u][0] + w[u];
}
void update(int p) {
{
Matrix &cur = a[dfn[p]];
cur[1][0] = g[p][0] + w[p];
cur[1][1] = g[p][0] + w[p];
SegT::update(1, 1, n, dfn[p]);
}
p = top[p];
while(p != 1) {
int ff = fa[p];
Matrix cur;
cur[0][1] = cur[1][0] = INF;
SegT::query(1, 1, n, dfn[p], dfn[bot[p]], cur);
g[ff][0] -= min(f[p][0], f[p][1]);
g[ff][1] -= f[p][1];
f[p][0] = cur[0][1];
f[p][1] = cur[1][1];
g[ff][0] += min(f[p][0], f[p][1]);
g[ff][1] += f[p][1];
Matrix &fm = a[dfn[ff]];
fm[0][1] = g[ff][1];
fm[1][0] = g[ff][0] + w[ff];
fm[1][1] = g[ff][0] + w[ff];
SegT::update(1, 1, n, dfn[ff]);
p = top[fa[p]];
}
}
ll query() {
Matrix cur;
SegT::query(1, 1, n, dfn[1], dfn[bot[1]], cur);
return cur[1][1];
}
int main() {
cin >> n >> m;
while(getchar() != '\n');
for(int i = 1; i <= n; i++) cin >> w[i];
for(int i = 1; i <= n - 1; i++) {
int u, v;
cin >> u >> v;
addEdge(u, v);
addEdge(v, u);
}
init1(1, 0);
init2(1, 0, 1);
SegT::build(1, 1, n);
while(m--) {
int a, x, b, y;
ll ta, tb, c = 0;
cin >> a >> x >> b >> y;
ta = w[a], tb = w[b];
if(x == 0) w[a] = INF2;
else c += w[a] + INF2, w[a] = -INF2;
if(y == 0) w[b] = INF2;
else c += w[b] + INF2, w[b] = -INF2;
update(a);
update(b);
ll res = query() + c;
if(res >= INF2) {
cout << "-1\n";
} else cout << res << "\n";
w[a] = ta, w[b] = tb;
update(a);
update(b);
}
return 0;
}
/*
5 3
2 4 1 3 9
1 5
5 2
5 3
3 4
2 1 3 1
1 0 3 0
1 0 5 0
*/
|