跳转至

250428 D8 模拟赛 T2 题解

题意

有一棵 \(n\) 个点的无根树,点有点权,节点 \(i\) 的权值是 \(a_i\)​。定义一条边的权值 \(w\) 为:删去这条边得到的两棵树,点权和之差的绝对值。

树上有一系列关键边。树的代价定义为所有关键边权值的最大值。

\(q\) 次询问,每次询问给定 \(k\) 条关键边和 \(L,R\)。你需要进行若干次操作,每次操作可以选择一个点 \(x\),将 \(a_x\) 加一。你需要保证操作次数介于 \([L,R]\) 之间。问树的代价最小是多少。

\(n,q\le 3\times 10^5,\ L,R\le 10^9,\ \sum k\le n\)

题解

对于每组询问,不包含关键边的连通块可以缩为一个点,这样就形成了一棵只包含关键边的树。然而,询问的数量和 \(n\) 同阶。如果我们对每组询问都进行一次 dfs 对图进行缩点,时间复杂度将达到 \(O(n^2)\)

考虑树的代价有什么性质。我们发现,只有连向叶子节点的边才有可能产生贡献;否则一定可以通过调整法增加边的权值。现在,我们只需要处理 \(deg=1\) 的连通块即可。

对于每条关键边 \((u,v)\),钦定 \(u=fa[v]\)。如果 \(v\) 子树内没有其他关键边,则 \(v\) 子树就是一个叶子连通块,否则一定不是。我们可以使用数据结构维护 dfn 序,实现子树查询。

另外的,如果根节点 \(rt\) 所在的连通块也有 \(deg=1\),但无法被统计到。我们考虑哪些关键边会和 \(rt\) 相连。我们发现,当 \(rt\to u\) 的路径上没有其它关键边时,\(u\) 会和 \(rt\) 相连,否则一定会被隔断。我们记录根节点连通块的 \(deg\) 和权值和,如果确实是叶子,就加入根节点连通块。

接下来考虑如何计算叶节点的贡献。记所有节点的权值和为 \(sum\),某个叶节点的权值为 \(v\),那么它会产生 \(|sum-2v|\) 的贡献。进一步的,我们只需考虑 \(2v\le sum\) 的情况即可;否则一定可以找出另一个叶节点满足 \(2v\le sum\),并且 \(|sum-2v|\) 更大。

考虑 \(v\) 最小的那个叶子。如果我们只增加它的权值,由于 \(2v\) 的增速比 \(sum\) 快,因此一定更优。当 \(v\) 的最小值和次小值相等时,一定不能更优。但由于 \(L\) 的限制,代价可能会被强行增加。

记当前已经进行了 \(cur\) 次操作。我们有策略:每次选前 \(k\) 小值进行 \(+1\)。如果已经不优(\(k>1\)),并且操作次数已经超过了 \(L\),则停止操作,只选择前 \(L-cur\) 个进行操作(为了保证 \(sum\) 最小)。

然而,如果发现 \((L-cur+1)\bmod k=0\),并且 \(R\ge L+1\),此时我们再额外进行一次操作,可以将最小值增加 \(1\),答案减小 \(2\),但 \(sum\) 只增加了 \(1\)。因此这种情况取 \(L+1\) 会更优,注意特判。

AC 代码
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
const int N = 3e5 + 10;
const ll INF = 0x3f3f3f3f3f;

struct Edge2 {
    int u, v;
} edg[N];

struct Edge {
    int v, next;
} pool[2 * N];
int ne, head[N];

void addEdge(int u, int v) {
    pool[++ne] = {v, head[u]};
    head[u] = ne;
}

int n, q;
int k;
ll L, R;

ll a[N], s[N];
int sz[N], fa[N], dfn[N], dt;
int p[N];

void dfs(int u, int a0) {
    dfn[u] = ++dt;
    fa[u] = a0;
    s[u] = a[u];
    sz[u] = 1;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(v == a0) continue;
        dfs(v, u);
        s[u] += s[v];
        sz[u] += sz[v];
    }
}

namespace BIT1 {
    int sum[N];
    inline int lowbit(int x) { return x & -x; }
    inline void add(int p, int v) {
        for(int i = p + 1; i <= n + 3; i += lowbit(i)) sum[i] += v;
    }
    inline int query(int p) {
        int res = 0;
        for(int i = p + 1; i > 0; i -= lowbit(i)) res += sum[i];
        return res;
    }
    inline int query(int l, int r) {
        return query(r) - query(l - 1);
    }
}

namespace BIT2 {
    int sum[N];
    inline int lowbit(int x) { return x & -x; }
    inline void add(int p, int v) {
        for(int i = p + 1; i <= n + 3; i += lowbit(i)) sum[i] += v;
    }
    inline void add(int l, int r, int v) {
        add(l, v);
        add(r + 1, -v);
    }
    inline int query(int p) {
        int res = 0;
        for(int i = p + 1; i > 0; i -= lowbit(i)) res += sum[i];
        return res;
    }
}

ll val[N], nn;

ll work() {
    sort(val + 1, val + 1 + nn);
    val[nn + 1] = INF;
    ll cur = 0, ans = s[1] - 2 * val[1];
    for(int i = 1; i <= nn; i++) {
        if(i == 1) {
            ll tmp = min(R, val[2] - val[1]);
            cur += tmp;
            ans -= tmp;
            if(cur == R) break;
            continue;
        }
        if(cur + i * (val[i + 1] - val[i]) < L) {
            ll tmp = val[i + 1] - val[i];
            ans += tmp * i - tmp * 2;
            cur += tmp * i;
        } else {
            if(L != R && (L - cur + 1) % i == 0) ++L;
            ans += (L - cur) - (L - cur) / i * 2;
            break;
        }
    }
    return ans;
}

// #define FIO

int main() {

    #ifdef FIO
    freopen("balance.in", "r", stdin);
    freopen("balance.out", "w", stdout);
    #endif

    cin >> n >> q;
    for(int i = 1; i <= n; i++) cin >> a[i];
    for(int i = 1; i <= n - 1; i++) {
        int u, v;
        cin >> u >> v;
        addEdge(u, v);
        addEdge(v, u);
        edg[i] = {u, v};
    }

    dfs(1, 0);

    for(int i = 1; i <= n - 1; i++) {
        int &u = edg[i].u, &v = edg[i].v;
        if(u != fa[v]) swap(u, v);
    }

    while(q--) {
        nn = 0;
        cin >> k >> L >> R;
        for(int i = 1; i <= k; i++) cin >> p[i];
        for(int i = 1; i <= k; i++) {
            int u = edg[p[i]].u, v = edg[p[i]].v;
            BIT1::add(dfn[u], 1);
            BIT2::add(dfn[v], dfn[v] + sz[v] - 1, 1);
        }
        ll rtw = s[1];
        int rtd = 0;
        for(int i = 1; i <= k; i++) {
            int u = edg[p[i]].u, v = edg[p[i]].v;
            if(!BIT1::query(dfn[v], dfn[v] + sz[v] - 1)) val[++nn] = s[v];
            if(!BIT2::query(dfn[u])) rtw -= s[v], ++rtd;
        }
        for(int i = 1; i <= k; i++) {
            int u = edg[p[i]].u, v = edg[p[i]].v;
            BIT1::add(dfn[u], -1);
            BIT2::add(dfn[v], dfn[v] + sz[v] - 1, -1);
        }
        if(rtd == 1) val[++nn] = rtw;
        cout << work() << '\n';
    }

    return 0;
}