跳转至

状压 DP

枚举子集

有时,对于一个二进制状态 \(i\),我们只希望从它的子集转移,因为这样能有效的减少时间复杂度。

Y555 最优组队

给定 \(n\) 个人(\(n\le 16\)),你需要把他们分成若干组,满足所有组的权值和最大。题目给定所有可能出现的组对应的权值。

我们注意到 \(n\) 很小,可以指数复杂度通过,因此考虑状压。设二进制状态 \(i\),表示对应集合中的人都已组好队,最大的权值和。很容易写出状态转移方程:

\[ f_i=\max_{j\subset i}\{f_j+w_{i-j}\} \]

对于状态 \(i\),暴力枚举所有状态 \(j\),再判断是否满足 \(j\subset i\),决定是否转移。时间复杂度为 \(O(4^n)\),“不能”通过本题。

1
2
3
4
5
6
7
8
9
for(int i = 1; i <= (1 << n) - 1; i++) {
    for(int j = 0; j < i; j++) {
        if((i | j) == i){
            f[i] = max(f[i], f[j] + a[i ^ j]);
        }
    }
}

cout << f[(1 << n) - 1] << endl;

我也不知道这段暴力为什么就水过了,可能是带一个小常数

我们注意到,大量的时间被浪费在了枚举那些不是 \(i\) 的子集的 \(j\) 上面。在此我们提供一种可以只遍历到所有真子集的方法:

1
2
3
4
5
6
7
for(int i = 1; i <= (1 << n) - 1; i++) {
    int j = i;
    while(j) {
        j = (j - 1) & i;
        f[i] = max(f[i], f[j] + a[i ^ j]);
    }
}

通过此种优化,时间复杂度降为 \(O(3^n)\),可以在更短时间内通过本题。

时间复杂度证明

设一个状态有 \(i\) 个元素,其子集一共有 \(2^i\) 个。有 \(i\) 个元素的状态一共有 \(C_n^i\) 种方案,总转移时间复杂度为:

\[ \sum_{i=1}^n{C_n^i\times 2^i}=3^n \]

此种方法比暴力转移要更优。

图上状压

Y556 最短路径

给定一个 \(n\) 个点 \(m\) 条边的有向图,有 \(k\) 个标记点,要求从规定的起点按任意顺序经过所有标记点到达规定的终点,问最短的距离是多少。

\(n\le 5\times 10^4,\ m\le 10^5,\ k \le 10\)

我们可以在图上跑一遍 Dijkstra,同时通过状压记录经过了哪些标记点。时间复杂度 \(O((n\times 2^k + m\times 2 ^ k)\log (m\times 2^k))\),不能通过本题,考虑优化。

通过仔细分析该算法,我们发现各个标记点之间的最短路被重复计算了 \(2^k\) 次,这是完全不必要的,考虑优化这些重复计算。

我们可以先求出这 \(k\) 个点之间两两的最短路,再在这 \(k\) 个点的简化图上跑状压 Dijkstra。这样,时间复杂度就被优化为 \(O\big(k(n+m)\log m+k^2\log k^2\big)\),可以通过本题。

代码

合理利用 namespace 区分名称相似的函数可以避免很多不必要的麻烦。

#include<iostream>
#include<cstring>
#include<queue>
#define int long long

using namespace std;
const int N = 5E4 + 10;
const int M = 1E5 + 10;
const int K = 15;

struct Edge {
    int v, w;
    int next;
} pool[2 * M];
int nn, head[N];

void addEdge(int u, int v, int w) {
    pool[++nn] = {v, w, head[u]};
    head[u] = nn;
}

int n, m, k, s, t;
int tag[N];
int b[K], dis[K][K];

namespace Dij1 {

    struct Node {
        int u, d;
        inline bool operator<(const Node &other) const {
            return d > other.d;
        }
    };

    int vis[N], d[N];
    priority_queue<Node> que;

    void Dijkstra(int x) {

        while(!que.empty()) que.pop();
        memset(d, 0x3f, sizeof(d));
        memset(vis, 0, sizeof(vis));
        d[b[x]] = 0;
        que.push({b[x], 0});
        while(!que.empty()) {
            int u = que.top().u;
            que.pop();
            if(vis[u]) continue;
            vis[u] = 1;
            for(int i = head[u]; i; i = pool[i].next) {
                int v = pool[i].v, w = pool[i].w;
                if(d[u] + w < d[v]) {
                    d[v] = d[u] + w;
                    que.push({v, d[v]});
                }
            }
        }
        for(int i = 1; i <= k + 2; i++) {
            dis[x][i] = d[b[i]];
        }

    }

}

namespace Dij2 {

    struct Node {
        int u, sta;
        int d;
        inline bool operator<(const Node &other) const {
            return d > other.d;
        }
    };

    int vis[K][1 << 12], d[K][1 << 12];
    priority_queue<Node> que;

    void Dijkstra() {

        if(s == t && k == 0) {
            cout << 0 << endl;
            exit(0);
        }

        memset(d, 0x3f, sizeof(d));
        d[k + 1][tag[s] ? (1 << tag[s] - 1) : 0] = 0;
        que.push({k + 1, tag[s] ? (1 << tag[s] - 1) : 0, 0});
        while(!que.empty()) {
            int u = que.top().u;
            int sta = que.top().sta;
            que.pop();
            if(vis[u][sta]) continue;
            vis[u][sta] = 1;
            if(u == k + 2 && (sta & ((1 << k) - 1)) == (1 << k) - 1) {
                cout << d[u][sta] << endl;
                exit(0);
            }
            for(int v = 1; v <= k + 2; v++) {
                if(u == v) continue;
                int tmp = sta | (1 << v - 1);
                if(d[u][sta] + dis[u][v] < d[v][tmp]) {
                    d[v][tmp] = d[u][sta] + dis[u][v];
                    que.push({v, tmp, d[v][tmp]});
                }
            }
        }

    }

}

signed main() {

    cin >> n >> m >> k >> s >> t;
    for(int i = 1; i <= m; i++) {
        int u, v, w;
        cin >> u >> v >> w;
        addEdge(u, v, w);
    }
    for(int i = 1; i <= k; i++) {
        int x;
        cin >> x;
        tag[x] = i;
        b[i] = x;
    }

    b[k + 1] = s;
    b[k + 2] = t;
    for(int i = 1; i <= k + 2; i++) {
        Dij1::Dijkstra(i);
    }

    Dij2::Dijkstra();

    cout << -1 << endl;

    return 0;
}

连通性状压 & 最小表示法

P2109 [NOI2007] 生成树计数

给定排成一行的 \(n\) 个节点,相邻节点之间的距离都为 \(1\)。每个点和所有与其距离不超过 \(k\) 的节点连边。问这个图的生成树的数量对 \(65521\) 取模的结果。

\(n\le 10^{15},k\le 5\)

\(n\le 10^{15}\),但不易找到通项公式,因此考虑矩阵快速幂优化 dp。

考虑第 \(i\) 个节点,枚举其向 \([i-k,i-1]\) 中的哪些节点有连边。注意到 \(k\le 5\),因此当前节点最多只能与前面的 \(5\) 个节点连边,dp 也只需要记录最近 \(5\) 个节点的状态。然而生成树要求图上不能出现环,因此我们需要保证第 \(i\) 个节点连接的 \([i-k,i-1]\) 里的点原本是不连通的。这就需要我们记录最后 \(k\) 个点的连通性。

要想记录所有可能出现的连通性状态,可以先暴力搜索出所有合法的连边方案,再把连通性相同的状态合并到一起,重新编号(这叫做最小表示法)。经过实测,\(k=5\) 时只有 \(52\) 种连通性状态。

int id[55560], st[60], cnt[60], scnt;
int fa[K];

int find(int x) {
    if(fa[x] == x) return x;
    return fa[x] = find(fa[x]);
}

// 参数 s 表示边集 E 的二进制状压形式。
void check(int s) {
    for(int i = 1; i <= k; i++) fa[i] = i;
    for(int i = 1, c = 0; i <= k; i++) {
        for(int j = 1; j < i; j++, c++) {
            // 判环
            if(s & (1 << c)) {
                if(find(i) == find(j)) return;
                fa[max(find(i), find(j))] = min(find(i), find(j));
            }
        }
    }
    int sta = 0;
    for(int i = 1; i <= k; i++) {
        sta = sta * 10 + find(i);
    }
    if(!id[sta]) {
        id[sta] = ++scnt;
        st[scnt] = sta;
    }
    cnt[id[sta]]++;
}

int main() {
    ...
    for(int i = 0; i < (1 << (k * (k - 1) / 2)); i++) {
        check(i);
    }
    ...
}

这样,我们就得到了所有连通性状态的最小表示法出现次数。接下来,我们只需要枚举第 \(i\) 个节点向 \([i-k,i-1]\) 的连边,判断是否有环,再找出新的连通性状态的最小表示法,就可以知道 \([i-k-1,i-2]\) 的每一个连通性状态 \(s_1\)\([i-k,i-1]\) 的连通性状态 \(s_2\) 的贡献(转移系数)。

我们将转移系数写成一个矩阵,并将 \([i-k,i-1]\) 的每种连通性状态对应的方案数写在一个 \(1\times n\) 的向量里,将向量和矩阵相乘,即可得到后一个位置 \(i+1\) 的状态。

最后使用矩阵快速幂计算出 \(op^{n-k}\),再以 \(cnt\) 为初始状态,得到最终状态。输出 \(k\) 个节点全部连通的状态对应的方案数即可。

代码
#include<iostream>
#include<cstring>
#include<cmath>
#include<vector>
#define cint const int&
#define int long long
using namespace std;
const int K = 7;
const int MOD = 65521;
const int F[] = {-1, -1, 11, 111, 1111, 11111};

struct Matrix {
    int n;
    int a[60][60];
    inline Matrix(int _n) { n = _n; memset(a, 0, sizeof(a)); }
    inline int* operator [] (int index) { return a[index]; }
    inline const int* operator [] (int index) const { return a[index]; }
    inline Matrix operator * (const Matrix &b) const {
        Matrix res(n);
        for(int i = 1; i <= n; i++) {
            for(int j = 1; j <= n; j++) {
                for(int k = 1; k <= n; k++) {
                    res[i][j] = (res[i][j] + a[i][k] * b[k][j] % MOD) % MOD;
                }
            }
        }
        return res;
    }
};

inline vector<int> operator*(const vector<int> &a, const Matrix &b) {
    vector<int> res(b.n + 1);
    for(int i = 1; i <= b.n; i++) {
        for(int j = 1; j <= b.n; j++) {
            res[i] = (res[i] + a[j] * b[j][i]) % MOD;
        }
    }
    return res;
}

int n, k, lim;

int id[55560], st[60], cnt[60], scnt;
vector<int> beg, ed;

int p[K], fa[K];

int find(int x) {
    if(fa[x] == x) return x;
    return fa[x] = find(fa[x]);
}

void check(int s) {
    for(int i = 1; i <= k; i++) fa[i] = i;
    for(int i = 1, c = 0; i <= k; i++) {
        for(int j = 1; j < i; j++, c++) {
            if(s & (1 << c)) {
                if(find(i) == find(j)) return;
                fa[max(find(i), find(j))] = min(find(i), find(j));
            }
        }
    }
    int sta = 0;
    for(int i = 1; i <= k; i++) {
        sta = sta * 10 + find(i);
    }
    if(!id[sta]) {
        id[sta] = ++scnt;
        st[scnt] = sta;
    }
    cnt[id[sta]]++;
}

inline void add(Matrix &op, cint s1, cint s2) {
    op[id[s1]][id[s2]]++;
}

Matrix qpow(Matrix a, int k) {
    Matrix res(a.n);
    for(int i = 1; i <= a.n; i++) res[i][i] = 1;
    while(k) {
        if(k & 1) res = res * a;
        a = a * a;
        k >>= 1;
    }
    return res;
}

signed main() {

    cin >> k >> n;

    for(int i = 0; i < (1 << (k * (k - 1) / 2)); i++) {
        check(i);
    }

    beg.resize(scnt + 1);
    for(int i = 1; i <= scnt; i++) beg[i] = cnt[i];

    Matrix op(scnt);
    for(int s = 1; s <= scnt; s++) {
        int sta = st[s];
        for(int i = k; i >= 1; i--) p[i] = sta % 10, sta /= 10;
        for(int i = 0; i < (1 << k); i++) {
            for(int j = 1; j <= k; j++) fa[j] = p[j];
            fa[k + 1] = k + 1;
            bool flag = true;
            for(int j = 1; j <= k; j++) {
                if(i & (1 << j - 1)) {
                    if(find(k + 1) == find(j)) {
                        flag = false;
                        break;
                    }
                    fa[max(find(k + 1), find(j))] = min(find(k + 1), find(j));
                }
            }
            if(!flag) continue;
            int sta2 = 0, tag = 0;
            for(int i = 2; i <= k + 1; i++) if(find(i) == 1) { tag = i; break; }
            if(!tag) continue;
            for(int i = 2; i <= k + 1; i++) {
                sta2 = sta2 * 10 + (find(i) == 1 ? tag : find(i)) - 1;
            }
            add(op, st[s], sta2);
        }
    }

    op = qpow(op, n - k);
    ed = beg * op;

    cout << ed[id[F[k]]] << endl;

    return 0;
}

多行状态

P2704 [NOI2001] 炮兵阵地

YbtOj 链接

题面请参考洛谷链接。

我们发现第 \(i\) 行的状态受到第 \(i-1\) 行和 \(i-2\) 行的影响。状态应当同时包含 \(i\) 行和 \(i-1\) 行的信息。

我们设状态 \((i,j_1,j_2)\) 表示第 \(i\) 行的状态为 \(j_1\),第 \(i-1\) 行的状态为 \(j_2\),最多能摆放多少个大炮。

\[ f_{i,j_1,j_2}=\max_{j_3\cap j_1=\emptyset\text{ and }j_3\cap j_2=\emptyset}\{f_{i-1,j_2,j_3}+cnt(j_1)\} \]

再通过预处理每行的 valid 数组,找到最小表示法,能极大优化时间复杂度(因为每行的合法状态不超过 \(70\) 个)。

多维状压

有时,单个点的状态数可能大于 \(2\),此时无法用二进制简单的表示出此状态。为此,对于一个状态 \(i\),其第 \(j\) 个位置的状态可以表示为 \(\lfloor\frac{i}{k^j}\rfloor \bmod k\)。这样,对于 \(k\) 维的信息,我们也能使用状压处理了。只不过此种方法无法像二进制一样可以进行位运算。