跳转至

P5311 [Ynoi2011] 成都七中 题解

题意

给你一棵 \(n\) 个节点的树,每个节点有一种颜色,有 \(m\) 次查询操作。

查询操作给定参数 \(L,R,x\),需输出:

将树中编号在 \([L,R]\) 内的所有节点保留,\(x\) 所在连通块中颜色种类数。

每次查询操作独立。

题解

我们发现一个点 \(u\) 能对 \(x\) 处的查询产生贡献,当且仅当 \(u\rightarrow x\) 的路径上节点编号的最小值 \(mn\) 满足 \(L\le mn\),最大值 \(mx\) 满足 \(mx\le R\)因此我们可以从连通块的任意一个节点出发,统计连通块的大小。

我们简化问题,先考虑 \(x\) 固定的情况:将 \(x\) 旋转到根,一遍 dfs 求出 \(x\) 到所有节点的路径的 \(mx\)\(mn\)。然后直接二维数颜色即可。

接下来考虑查询位置 \(x\) 不固定的情况。如果使用换根则无法维护路径最值。这里我们注意到,如果一个询问 \(x\) 和当前的根节点 \(rt\) “连通”(只保留询问的 \([L,R]\) 的节点情况下,两点处于同一连通块),则该询问等价于在当前的根节点上询问

否则这个询问所在的子树就和根节点以及其他子树分隔开了,成为一个较小的子问题,可以递归处理。这让我们联想到点分治。通过每次找到重心,可以将遍历子树的总时间复杂度降低到 \(O(n\log n)\)。算上二维数点的 \(\log\),总时间复杂度为 \(O(n\log^2 n)\)

AC 代码
#include<iostream>
#include<vector>
#include<cstring>
#include<algorithm>
#define cint const int &
using namespace std;
const int N = 1E5 + 10;
const int INF = (int)0x3f3f3f3f3f3f3f3f;

struct Edge {
    int v, next;
} pool[2 * N];
int ne, head[N];

void addEdge(int u, int v) {
    pool[++ne] = {v, head[u]};
    head[u] = ne;
}

struct Query {
    int l, r, id;
};

struct Op {
    int tp, k1, k2, id, col;
    inline bool operator<(const Op &other) const {
        if(k1 != other.k1) return k1 > other.k1;
        return tp > other.tp;
    }
};

int n, m;
int c[N], ans[N];
vector<Query> q[N];

int vis[N], mn[N], mx[N], sz[N];

namespace bs {
    int mxs[N] = {INF};
    void get_sz(cint u, cint f) {
        sz[u] = 1;
        for(int i = head[u]; i; i = pool[i].next) {
            cint v = pool[i].v;
            if(v == f || vis[v]) continue;
            get_sz(v, u);
            sz[u] += sz[v];
        }
    }
    void get_rt(cint u, cint f, cint tot, int &rt) {
        mxs[u] = 0;
        for(int i = head[u]; i; i = pool[i].next) {
            cint v = pool[i].v;
            if(v == f || vis[v]) continue;
            get_rt(v, u, tot, rt);
            mxs[u] = max(mxs[u], sz[v]);
        }
        mxs[u] = max(mxs[u], tot - sz[u]);
        if(mxs[u] < mxs[rt]) rt = u;
    }
    void add_op(cint u, cint f, vector<Op> &op) {
        mn[u] = min(u, mn[f]);
        mx[u] = max(u, mx[f]);
        op.push_back({1, mn[u], mx[u], 0, c[u]});
        for(Query &qry : q[u]) {
            if(qry.id && qry.l <= mn[u] && mx[u] <= qry.r) {
                op.push_back({0, qry.l, qry.r, qry.id, 0});
                qry.id = 0;
            }
        }
        for(int i = head[u]; i; i = pool[i].next) {
            cint v = pool[i].v;
            if(v == f || vis[v]) continue;
            add_op(v, u, op);
        }
    }
}

namespace BIT {
    int sum[N];
    inline int lowbit(cint x) { return x & -x; }
    inline void modify(cint p, cint v) {
        for(register int i = p; i <= n; i += lowbit(i)) sum[i] += v;
    }
    inline int query(cint p) {
        int res = 0;
        for(register int i = p; i > 0; i -= lowbit(i)) res += sum[i];
        return res;
    }
    inline void clear(cint p) {
        for(register int i = p; i <= n; i += lowbit(i)) sum[i] = 0;
    }
}

namespace cntPt {

    vector<Op> op;
    int opCnt;

    int lst[N]; // col=i 的节点中 mx 的最小值

    void calc(int u) {
        mn[0] = INF;
        mx[0] = 0;
        op.clear();
        bs::add_op(u, 0, op);
        sort(op.begin(), op.end());
        for(Op &o : op) {
            if(o.tp == 1) {
                if(o.k2 < lst[o.col]) {
                    if(lst[o.col] <= n) BIT::modify(lst[o.col], -1);
                    BIT::modify(o.k2, 1);
                    lst[o.col] = o.k2;
                }
            } else {
                ans[o.id] = BIT::query(o.k2);
            }
        }
        for(Op &o : op) {
            if(o.tp == 1) {
                lst[o.col] = INF;
                BIT::clear(o.k2);
            }
        }
    }

}

void init() {
    memset(cntPt::lst, 0x3f, sizeof(cntPt::lst));
}

void solve(int u) {
    bs::get_sz(u, 0);
    cntPt::calc(u);
    vis[u] = 1;
    for(int i = head[u]; i; i = pool[i].next) {
        int v = pool[i].v;
        if(vis[v]) continue;
        int rt = 0;
        bs::get_rt(v, u, sz[v], rt);
        solve(rt);
    }
}

int main() {

    cin >> n >> m;
    for(int i = 1; i <= n; i++) {
        cin >> c[i];
    }
    for(int i = 1; i <= n - 1; i++) {
        int u, v;
        cin >> u >> v;
        addEdge(u, v);
        addEdge(v, u);
    }
    for(int i = 1; i <= m; i++) {
        int l, r, x;
        cin >> l >> r >> x;
        q[x].push_back({l, r, i});
    }

    int rt = 0;

    init();
    bs::get_sz(1, 0);
    bs::get_rt(1, 0, sz[1], rt);

    solve(rt);

    for(int i = 1; i <= m; i++) {
        cout << ans[i] << '\n';
    }

    return 0;
}