跳转至

P8499 [NOI2022] 挑战 NPC II 题解

题意

给定两棵有根树 \(T_1\)\(T_2\),保证 \(|T_2|\le |T_1|\le |T_2|+5\)。问你能否通过删除 \(T_1\) 的一些节点,使得 \(T_1\)\(T_2\) 同构。删除节点不能改变树其他部分的连通性。两棵树同构当且仅当存在一组重标号方案,使得两棵树变得相等,并且根节点位置不变。

多测,\(\sum |T_1|\le 5\times 10^5\)

题解

注意到两棵树 \(T_1,T_2\) 同构等价于 \(rt_1,rt_2\) 的所有子节点存在一组匹配,使得匹配的子树都同构。同时注意到删除节点的数量不太多,最多只会改变 \(5\) 棵子树。

考虑一种暴力,状态 \((x,y)\) 表示能否使 \(T_1\) 中的 \(x\) 树和 \(T_2\) 中的 \(y\) 树同构。注意到如果 \(x\)\(y\) 初始不匹配的子树已经超过了 \(5\) 棵,那么一定不行。其余情况,我们向下递归尝试匹配,产生最多 \(5\times 5\) 个子问题,然后在本层跑最大匹配。

注意到暴力有很大的优化空间。例如,由于树的形态改变得不多,因此大部分子树都是匹配的,这部分我们可以用树哈希优化。同时,我们再加上一些必要的剪枝,例如保证 \(|T_1|\ge |T_2|\) 等等。

同时,随着不断的递归,\(|T_1|-|T_2|\) 可能已经变得更小。记 \(d=|T_1|-|T_2|\),那么不匹配子树的数量 \(t\) 最多是 \(d\) 棵,同时子问题的 \(d\) 更不可能超过 \(d-t+1\)

\(\Delta d\) 子问题数量
\(0\) \(\times 1\)
\(-1\) \(\times 4\)
\(-2\) \(\times 9\)
\(-3\) \(\times 16\)
\(-4\) \(\times 25\)

子问题数量最大翻 \(1024\) 倍,其实已经差不多能 AC 了。又注意到并不是每棵子树都能卡到 \(d-t+1\) 的极限,因此肯定跑不满(手玩尝试卡到 \(1024\) 确实做不到)。

#include<iostream>
#include<random>
#include<cassert>
#include<set>
#define ull unsigned long long
using namespace std;
const int N = 1e5 + 20;

struct myPair {
    int u;
    ull hsh;
    inline myPair(int _u, ull _hsh) : u(_u), hsh(_hsh) {}
    inline bool operator==(const myPair &b) const {
        return hsh == b.hsh;
    }
    inline bool operator<(const myPair &b) const {
        return hsh < b.hsh;
    }
};

struct myPair_Hash {
    inline ull operator()(const myPair &p) const { return p.hsh; }
};

struct Edge2 {
    int u, v;
};

struct Edge {
    int v, next;
} pool[2][2 * N];
int ne[2], head[2][N], cs[2][N];

void clear_Edge(int n) {
    ne[0] = ne[1] = 0;
    for(int i = 1; i <= n; i++) head[0][i] = head[1][i] = cs[0][i] = cs[1][i] = 0;
}

void addEdge(int id, int u, int v) {
    pool[id][++ne[id]] = {v, head[id][u]};
    head[id][u] = ne[id];
    ++cs[id][u];
}

int T, k;
int n, m;
int rt0, rt1;

ull seed = (mt19937_64){(random_device){}()}();
ull hsh[2][N], sz[2][N];

ull shift(ull x) {
    x ^= x << 2;
    x ^= x >> 3;
    x ^= x << 21;
    x ^= seed;
    return x;
}

void dfs(int u, int fa, int id) {
    hsh[id][u] = 1;
    sz[id][u] = 1;
    for(int e = head[id][u]; e; e = pool[id][e].next) {
        int v = pool[id][e].v;
        dfs(v, u, id);
        hsh[id][u] += shift(hsh[id][v]);
        sz[id][u] += sz[id][v];
    }
}

namespace match {
    Edge pool[30];
    int ne, head[7];
    int link[7], vis[7];
    void clear() {
        ne = 0;
        head[1] = head[2] = head[3] = head[4] = head[5] = 0;
        link[1] = link[2] = link[3] = link[4] = link[5] = 0;
        vis[1] = vis[2] = vis[3] = vis[4] = vis[5] = 0;
    }
    void addEdge(int u, int v) {
        pool[++ne] = {v, head[u]};
        head[u] = ne;
    }
    bool dfs(int u) {
        for(int e = head[u]; e; e = pool[e].next) {
            int v = pool[e].v;
            if(vis[v]) continue;
            vis[v] = 1;
            if(!link[v] || dfs(link[v])) {
                link[v] = u;
                return true;
            }
        }
        return false;
    }
    int calc() {
        int res = 0;
        for(int i = 1; i <= 5; i++) {
            vis[1] = vis[2] = vis[3] = vis[4] = vis[5] = 0;
            res += dfs(i);
        }
        return res;
    }
}

bool solve(int x, int y) {
    if(hsh[0][x] == hsh[1][y]) return true;
    int delta = sz[0][x] - sz[1][y];
    if(delta <= 0 || delta > k) return false;
    if(sz[1][y] == 0) return true;
    multiset<myPair> st1, st2;
    for(int e = head[0][x]; e; e = pool[0][e].next) {
        int v = pool[0][e].v;
        st1.insert({v, hsh[0][v]});
    }
    for(int e = head[1][y]; e; e = pool[1][e].next) {
        int v = pool[1][e].v;
        if(st1.count({0, hsh[1][v]})) st1.erase(st1.find({0, hsh[1][v]}));
        else st2.insert({v, hsh[1][v]});
    }
    if(st1.size() < st2.size()) return false;
    if(st1.size() > delta) return false;
    vector<Edge2> edg;
    int i = 1, j;
    for(auto it1 = st1.begin(); it1 != st1.end(); ++it1, i++) {
        j = 1;
        for(auto it2 = st2.begin(); it2 != st2.end(); ++it2, j++) {
            if(solve(it1->u, it2->u)) {
                edg.push_back({i, j});
            }
        }
    }
    match::clear();
    for(Edge2 &e : edg) {
        assert(1 <= e.u && e.u <= 5);
        match::addEdge(e.u, e.v);
    }
    if(match::calc() >= st2.size()) return true;
    return false;
}

int main() {

    cin >> T >> T >> k;
    while(T--) {
        cin >> n;
        clear_Edge(n);
        for(int i = 1; i <= n; i++) {
            int x;
            cin >> x;
            if(~x) addEdge(0, x, i);
            else rt0 = i;
        }
        cin >> m;
        for(int i = 1; i <= m; i++) {
            int x;
            cin >> x;
            if(~x) addEdge(1, x, i);
            else rt1 = i;
        }
        dfs(rt0, 0, 0);
        dfs(rt1, 0, 1);
        cout << (solve(rt0, rt1) ? "Yes\n" : "No\n");
    }

    return 0;
}