跳转至

exCRT / 扩展中国剩余定理

对于同余方程:

\[ \begin{cases} x\equiv c_1 \pmod{m_1}\\ x\equiv c_2 \pmod{m_2}\\ \cdots\\ x\equiv c_n \pmod{m_n} \end{cases} \]

\(m_1, m_2, m_3\ldots m_n\) 互质,可以用普通CRT做(虽然我不会)。但若不互质,可以用exCRT求解。

考虑其中两个方程:

\[ \begin{cases} x\equiv c_1 \text{ } \pmod {m_1}\\ x\equiv c_2 \text{ } \pmod {m_2} \end{cases} \]

它等价于:

\[ \begin{cases} x = c_1 + m_1 * k_1\\ x = c_2 + m_2 * k_2 \end{cases} \]

其中 \(k1,k2 \in \Z\) . 两式相减:

\[ m_1 * k_1 - m_2 * k_2 = c_2 - c_1 \]

这符合二元一次不定方程的形式。同时因为 \(k_1,k_2\) 是整数,可以用扩展欧几里得求出一组 \(k_1,k_2\) ,这样就能求出一组特解 \(x\) 了。

这里要注意:由于 \(-m2\) 是负数,\(k1\) 要对 \(\frac{k_2}{gcd(k_1, k_2)}\) 取模并化为正数。 至于模数为什么是这个,很好证明:\(k_1\) 减去 \(\frac{k_2}{gcd(k_1, k_2)}\)\(k_2\) 加上 \(\frac{k_1}{gcd(k_1, k_2)}\),左边 \(\frac{k_1 \times k_2}{gcd(k_1, k_2)}\) 一加一减就抵消了。

得出特解 \(x_0\) 之后,可以证明:两个方程等价于一个新方程:

\[ x \equiv x_0 \pmod {\text{lcm}(m_1, m_2)} \]

再经过 \(n-2\) 次合并(总共 \(n-1\) 次),最后得到的 \(x_0\) 就是原方程组的解。

例题

P4777 【模板】扩展中国剩余定理(EXCRT)

  • 这里要写“快速乘”,一边乘一边取模,否则会爆 long long 。
代码
#include<iostream>
#include<algorithm>
#define int long long
using namespace std;
const int N = 1E5 + 10;

int n;
int a1, a2, b1, b2;

int exgcd(int a, int b, int &x, int &y){
    if(b == 0){
        x = 1;
        y = 0;
        return a;
    }
    int x1, y1;
    int d = exgcd(b, a % b, x1, y1);
    x = y1;
    y = x1 - a / b * y1;
    return d;
}

int lcm(int x, int y){
    return x / __gcd(x, y) * y;
}

int mul(int a, int b, int mod){
    int f = 1;
    if(b < 0){
        f = -1;
        b = -b;
    }
    int res = 0;
    while(b){
        if(b & 1) res = (res + a) % mod;
        b >>= 1;
        a = (a + a) % mod;
    }
    return res * f;
}

signed main(){

    cin >> n;
    cin >> a1 >> b1;
    for(int i = 2; i <= n; i++){
        cin >> a2 >> b2;
        int k1, k2, c, d;
        c = b2 - b1;
        d = exgcd(a1, a2, k1, k2);
        if(c % d != 0){
            cout << -1 << endl;
            return 0;
        }
        k1 %= a2 / d;
        k1 = mul(k1, c / d, a2 / d);
        if(k1 < 0) k1 += a2 / d;
        int mod = lcm(a1, a2);
        int tmp = mul(k1, a1, mod);
        b1 = tmp + b1;
        b1 %= mod;
        a1 = mod;
    }
    cout << b1 << endl;

    return 0;
}