跳转至

Lagrange 插值(拉格朗日插值)

考虑一个 \(n-1\) 次的多项式函数 \(f(x)=\sum_{i=0}^{n-1}{a_ix^i}\) 和它的 \(n\) 个点值 \((x_1,y_1),(x_2,y_2)\cdots (x_n,y_n)\)

定理 \(1\)

\(x_i\) 两两不同的 \(n\) 个点值唯一确定一个 \(n-1\) 次多项式函数。

证明

设多项式函数的系数从低到高依次为 \(a_0,a_1\cdots a_{n-1}\);我们可以通过 \(n\) 个点值列出一个线性方程组;写成矩阵的形式如下:

\[ \begin{bmatrix} 1& {x_0}& {x_0}^2& & {x_0}^{n-1}\\ 1& {x_1}& {x_1}^2& \cdots& {x_1}^{n-1}\\ 1& {x_2}& {x_2}^2& & {x_2}^{n-1}\\ & \vdots & & \ddots&\vdots\\ 1& {x_{n-1}}& {x_{n-1}}^2& & {x_{n-1}}^{n-1}\\ \end{bmatrix} \begin{bmatrix} a_0\\ a_1\\ a_2\\ \vdots\\ a_{n-1}\\ \end{bmatrix}= \begin{bmatrix} y_0\\ y_1\\ y_2\\ \vdots\\ y_{n-1}\\ \end{bmatrix} \]

线性方程组有唯一解的充要条件是系数矩阵满秩。注意到系数矩阵是一个范德蒙德矩阵,它的行列式不为 \(0\) 当且仅当 \(x_i\) 两两不同。得证。

给出 \(n\) 个点值(\(x\) 互不相同),我们可以确定一个 \(n-1\) 次多项式 \(f(x)\);再给定一个横坐标 \(a\)\(a\ne x_i\)),我们可以计算出这个多项式在 \(a\) 处的取值。这个过程叫做拉格朗日插值。

接下来我们讲述如何在低于高消 \(O(n^3)\) 的时间内完成插值。考虑多项式

\[ f(x)=\sum_{i=0}^{n-1}{y_i\prod_{j\ne i}{\frac{x-x_j}{x_i-x_j}}} \]

它后面的连乘会在 \(x=x_i\) 时取为 \(1\),其余情况取 \(0\)。不难发现这个多项式会在每个 \(x_i\) 处取到 \(y_i\),并且它的次数是 \(n-1\),因此它就是 \(n\) 个点唯一确定出来的那一个多项式。这个形式称为拉格朗日插值多项式。

在给定 \(a\) 的情况下,这个多项式显然可以 \(O(n^2)\) 计算。

P4781 【模板】拉格朗日插值

题意

给定 \(n\) 个点值,请你求出插值多项式在 \(k\) 处的点值模 \(998244353\) 的结果。

\(n\le 2000, x_i,y_i,k< 998244353\)

代码
for(int i = 1; i <= n; i++) {
    int s1 = 1, s2 = 1;
    for(int j = 1; j <= n; j++) {
        if(i == j) continue;
        s1 = s1 * (k - x[j] + MOD) % MOD;
        s2 = s2 * (x[i] - x[j] + MOD) % MOD;
    }
    int s = s1 * inv(s2) % MOD;
    (ans += s * y[i] % MOD) %= MOD;
}

CF622F The Sum of the k-th Powers

题意

给定 \(n,k\),求

\[ \sum_{i=1}^{n}i^k \]

\(k\le 10^6,\ n\le 10^9\)

定义

对于一个无穷数列 \(\{a_n\}\),如果它的插值多项式次数是有限的,则称它是多项式数列。

多项式数列的封闭性

两个多项式数列进行对应位置相加,对应位置相减,对应位置相乘,得到的数列仍然是多项式数列。

一个多项式数列进行前缀和、差分,得到的新数列仍然是多项式数列。

证明

加法、减法、点值乘法的封闭性不难证明,我们现在证明差分和前缀和的封闭性。

差分(后向)

显然,我们只需证明 \((n+1)^k-n^k\) 是关于 \(n\)\(k-1\) 次多项式。

\[ (n+1)^k-n^k=\sum_{i=0}^{k-1}\binom{k}{i}n^k \]

前缀和

显然,我们只需要证明 \(S_k(n)=\sum_{i=1}^ni^k\) 是关于 \(n\)\(k+1\) 次多项式 \(P_{k+1}(n)\)。考虑数学归纳法,\(k=0\) 时显然成立。现在我们假设对于所有 \(k<k_0\),都有 \(S_k(n)=P_{k+1}(n)\) 成立,我们尝试证明 \(k=k_0\) 时成立。

考虑一个恒等式:

\[ \begin{align*} \sum_{i=1}^n(i+1)^{k+1}-i^{k+1}&=\sum_{i=1}^n\sum_{j=0}^{k}i^j\binom{k+1}{j}\\ &=\sum_{j=0}^k\binom{k+1}{j}\sum_{i=1}^n i^j\\ &=\sum_{j=0}^k\binom{k+1}{j}S_j(n)\\ &=(k+1)S_k(n)+\sum_{j=0}^{k-1}\binom{k+1}{j}S_j(n) \end{align*} \]

注意到等式左边等于

\[ (n+1)^{k+1}-1 \]

因此

\[ \begin{align*} (k+1)S_k(n)&=(n+1)^{k+1}-1-\sum_{j=0}^{k-1}\binom{k+1}{j}S_j(n)\\ S_k(n)&=\frac{1}{{k+1}}\Bigg[(n+1)^{k+1}-1-\sum_{j=0}^{k-1}\binom{k+1}{j}S_j(n)\Bigg] \end{align*} \]

根据归纳假设容易得到,\(\sum_{j=0}^{k-1}\binom{k+1}{j}S_j(n)\) 是关于 \(n\)\(k\) 次多项式 \(P_k(n)\),因此等式右边是关于 \(n\)\(k+1\) 次多项式。

我们先求出 \(S_k(n)\) 的前 \(k+1\) 项,然后跑 Lagrange 插值即可。时间复杂度 \(O(k\log k)\)

代码
#include<iostream>
#include<cassert>
#define int long long
using namespace std;
const int N = 1e6 + 10;
const int MOD = 1e9 + 7;

int n, k;
int s[N];

inline int qpow(int a, int b) {
    int res = 1;
    while(b) {
        if(b & 1) res = res * a % MOD;
        a = a * a % MOD;
        b >>= 1;
    }
    return res;
}

int lagrange(int a[], int n, int t) {
    if(t <= n) return a[t];
    static int fact[N], ifact[N], inv[N], s[N], is;
    fact[0] = ifact[0] = ifact[1] = 1;
    for(int i = 1; i <= n; i++) fact[i] = fact[i - 1] * i % MOD;
    for(int i = 2; i <= n; i++) ifact[i] = MOD - ifact[MOD % i] * (MOD / i) % MOD;
    for(int i = 1; i <= n; i++) ifact[i] = ifact[i - 1] * ifact[i] % MOD;
    s[0] = 1;
    for(int i = 1; i <= n; i++) s[i] = s[i - 1] * (t - i) % MOD;
    is = qpow(s[n], MOD - 2);
    for(int i = n; i >= 1; i--) inv[i] = is * s[i - 1] % MOD, is = is * (t - i) % MOD;
    int ans = 0;
    for(int i = 1; i <= n; i++) {
        assert(inv[i] * (t - i) % MOD == 1);
        if((n - i) & 1) ans = (ans + MOD - a[i] % MOD * ifact[i - 1] % MOD * ifact[n - i] % MOD * inv[i] % MOD) % MOD;
        else ans = (ans + a[i] % MOD * ifact[i - 1] % MOD * ifact[n - i] % MOD * inv[i] % MOD) % MOD;
    }
    return ans * s[n] % MOD;
}

signed main() {

    cin >> n >> k;
    for(int i = 1; i <= k + 2; i++) s[i] = (s[i - 1] + qpow(i, k)) % MOD;

    cout << lagrange(s, k + 2, n) << '\n';

    return 0;
}

P4463 [集训队互测 2012] calc

题意

称一个长度为 \(n\) 的序列 \(a\) 是合法的,当且仅当所有 \(a_i\in [1,k]\) 且互不相同。设 \(b\) 取遍所有合法的序列,求

\[ \sum_{b}\prod_{i=1}^n{b_i} \]

对质数 \(p\) 取模的结果。

\(n\le 500,\ k\le 10^9\)

显然,我们可以先钦定 \(a\) 升序,最后将答案乘以 \(n!\) 即可。考虑朴素 dp:设 \(f_{i,j}\) 表示前 \(i\) 个数字,第 \(i\) 个数字为 \(j\) 的总贡献。容易写出转移:

\[ f_{i,j}=\sum_{p=1}^{j-1}f_{i-1,p}\times j \]

时间复杂度为 \(O(nV)\),考虑将 \(V\) 优化掉。

注意到转移式本质上是对上一行的 \(f\) 数组进行了前缀和,然后又乘以了一个 \(j\)。我们可以将每一行的 \(f\) 看成一个关于 \(j\) 的函数;显然,\(f_{1,j}=j\) 是一个关于 \(j\) 的多项式;再经过前缀和,并乘以 \(j\) 之后,仍然还是一个多项式,并且次数增加了 \(2\)

因此,\(f_{n,j}\) 是一个关于 \(j\) 的多项式,次数不超过 \(2n\)。我们先暴力 dp 算出一行 \(f_{n}\) 总共 \(2n\) 个数,然后使用拉格朗日插值得到 \(f_n\)\(j=k\) 时的取值。

时间复杂度 \(O(n^2)\)

P3643 [APIO2016] 划艇

题意

给定 \(n\) 以及 \(n\)\((l_i,r_i)\)。称一个序列二元组 \((a,b)\) 是合法的当且仅当

  • \(|a|=|b|\)
  • \(\forall i\in \big[1,|b|-1\big],\ a_i<a_{i+1}\)
  • \(\forall i\in \big[1,|b|\big],\ a_i\in \big[l[b_i],r[b_i]\big]\)

请你计数合法序列对的数量。

\(n\le 500,\ l_i,r_i\le 10^9\)

\(f_{i,j}\) 表示考虑前 \(i\) 个位置,最后一个数字是 \(j\) 的方案数。容易写出转移:

\[ f_{i,j}=f_{i-1,j}+\sum_{k<j}{f_{i-1,k}}\big[j\in [l_i,r_i]\big] \]

如果没有区间的限制,那么 dp 数组显然是一个多项式。然而,加上这个限制之后,有很多位置变成了 \(0\),整体不能保证多项式的性质。

我们先将 \(a_i\)\(b_i\) 离散化,将值域分段,考察每段内部是不是多项式。对于 \(j\in [l_i,r_i]\),转移变为:

\[ f_{i,j}=f_{i-1,j}+\sum_{l\le k<j}f_{i-1,k}+\sum_{k<l}{f_{i-1,k}} \]

其中,第三项对于当前段来说是一个常量。因此,每一段内部始终是一个多项式。如果某一段的长度比较长,则我们维护该段前缀和数组的前 \(m\) 项,使用拉插即可在 \(O(n)\) 的时间内得到该段的区间和。累加向后转移即可。

时间复杂度 \(O(n^3)\),注意常数优化。

代码
#include<iostream>
#include<algorithm>
#include<cassert>
#define re register
#define int long long
using namespace std;
const int N = 1040;
const int MOD = 1e9 + 7;

inline int mod(int x) { return (x % MOD + MOD) % MOD; }

inline int qpow(re int a, re int b) {
    re int res = 1;
    while(b) {
        if(b & 1) res = res * a % MOD;
        a = a * a % MOD;
        b >>= 1;
    }
    return res;
}

#define inv(x) qpow(x, MOD - 2)

inline void get_inv(int n, int ia[], int a[]) {
    static int s[N] = {1}, isn;
    for(int i = 1; i <= n; i++) s[i] = s[i - 1] * a[i] % MOD;
    isn = inv(s[n]);
    for(int i = n; i >= 1; i--) { ia[i] = isn * s[i - 1] % MOD; isn = isn * a[i] % MOD; }
}

int F[N], iF[N];

inline int lagrange(int n, int y[], int x) {
    int s = 1, ans = 0;
    static int sp[N] = {1}, ia[N], isn;
    for(int i = 1; i <= n; i++) sp[i] = sp[i - 1] * (x - i + MOD) % MOD;
    isn = inv(sp[n]);
    for(int i = n; i >= 1; i--) { ia[i] = isn * sp[i - 1] % MOD; isn = isn * (x - i) % MOD; }
    s = sp[n];
    for(re int i = n - 1; i >= 1; i -= 2) {
        (ans += MOD - y[i] * ia[i] % MOD * iF[n - i] % MOD * iF[i - 1] % MOD) %= MOD;
    }
    for(re int i = n; i >= 1; i -= 2) {
        (ans += y[i] * ia[i] % MOD * iF[n - i] % MOD * iF[i - 1] % MOD) %= MOD;
    }
    return ans * s % MOD;
}

int n, m;
int l[N], r[N];
int num[N], nn;

void lisanhua() {
    for(int i = 1; i <= n; i++) num[++nn] = l[i];
    for(int i = 1; i <= n; i++) num[++nn] = r[i];
    sort(num + 1, num + 1 + nn);
    nn = unique(num + 1, num + 1 + nn) - (num + 1);
    for(int i = 1; i <= n; i++) l[i] = lower_bound(num + 1, num + 1 + nn, l[i]) - num;
    for(int i = 1; i <= n; i++) r[i] = lower_bound(num + 1, num + 1 + nn, r[i]) - num;
}

int f[N][N];

int calc(int j) {
    if(num[j + 1] - num[j] <= m) return f[j][num[j + 1] - num[j]];
    return lagrange(m, f[j], num[j + 1] - num[j]);
}

signed main() {

    cin >> n;
    for(int i = 1; i <= n; i++) cin >> l[i] >> r[i];
    for(int i = 1; i <= n; i++) ++r[i];

    lisanhua();

    m = n + 5;

    F[0] = iF[0] = 1;
    for(int i = 1; i <= m; i++) F[i] = F[i - 1] * i % MOD;
    get_inv(m, iF, F);

    for(int k = 1; k <= min(num[1], m); k++) {
        f[0][k] = 1;
    }

    for(int i = 1; i <= n; i++) {
        int sum = 0;
        for(int j = 0; j < l[i]; j++) {
            sum = (sum + calc(j)) % MOD;
        }
        for(int j = l[i]; j < r[i]; j++) {
            int tmp = calc(j);
            for(int k = 1; k <= min(num[j + 1] - num[j], m); k++) {
                f[j][k] = (f[j][k] + f[j][k - 1] + sum) % MOD;
            }
            sum = (sum + tmp) % MOD;
        }
    }

    int ans = 0;
    for(int j = 0; j < nn; j++) {
        (ans += calc(j)) %= MOD;
    }

    cout << (ans + MOD - 1) % MOD << endl;

    return 0;
}