月度归档:2018年02月

k次幂前缀和的线性做法

给定正整数 \(x, k\) 和质数 \(p\),求
\[ \sum_{i=1}^{x} i^k \bmod p \]

这是一道很经典的题目,在杜瑜皓的《多项式及求和》中有详细讨论。比较暴力的做法是对于每个 \(i^k\) 都用快速幂算出其值,然后相加,时间复杂度为 \(O(x \log k) \),当 \(x\) 很大时,就超时了。然而利用拉格朗日插值法,这题最快可以做到 \(O(k)\) 的复杂度,其做法主要由以下几部分构成。

线性预处理逆元

求 1..n 之间每个数的逆元,如果都用费马小定理或者扩展欧几里得算,那么复杂度将会达到 \(O(n \log p)\) 或 \(O(n \log n)\)。利用一些递推式,可以线性地求出 1..n 中每个数的逆元,从而复杂度可以减少一个 log。常用的一个递推公式是

\[ i^{-1} = -\left\lfloor \frac{p}{i} \right\rfloor \cdot (p \bmod i)^{-1} \bmod p \]

这一公式的正确性也很好证明,等式右端可以写成

\[ -\frac{\left\lfloor \frac{p}{i} \right\rfloor}{p-\left\lfloor \frac{p}{i} \right\rfloor i} = \frac{1}{i} \]

利用此公式,可以在 \(O(n)\) 时间内求出 1..n 中每一个数的逆元。

线性筛求积性函数的值

注意到 \(i^k\) 是一个积性函数,这样,利用线性筛,我们只需要用快速幂计算素数 \(i\) 的值,其他点的函数值就可以直接得到。根据素数定理,前 \(n\) 个数中素数的个数约为 \(O(\frac{n}{\log n})\),从而利用快速幂计算出素数处的值,可以做到 \(O(\frac{n}{\log n}\log{k}\) ,再加上筛法本身 \(O(n)\) 的复杂度,我们可以在 \(O(n + \frac{n}{\log n}\log{k})\) 的时间复杂度内计算出 \(i^k\) 的前 \(n\) 项值。

拉格朗日插值法

注意到答案一定是关于 \(x\) 的 \(k+1\) 次多项式。而根据线性代数的知识,我们只要知道这个多项式的任意 \(k+2\) 个点处的值,我们就能确定这个多项式。拉格朗日插值公式就给出了这个多项式

\[ f(x) = \sum_{i=0}^{k+1} y_i l_i(x) \]

其中

\[ l_i(x) = \prod_{0 \leq m \leq k+1, m \neq i} \frac{x – x_m}{x_j – x_m} \]

具体到这道题上,插值公式就是

\[ f(n) = \sum_{i = 0}^{k+1} (-1)^{k+i+1} f(i) \frac{n(n-1) \cdots (n-i+1)(n-i-1) \cdots (n-d)}{(d-i)!i!} \]

其中 \(f(i)\) 就是 \(i^k\) 的前缀和,上式右端求和的每一项,分子都是 n 到 n-d 的前缀积乘后缀积,分母可以用递推线性求出,这样就可以在 \(O(k)\) 时间内解决整个问题。

#include <bits/stdc++.h>
using namespace std;

#define rep(i, n) for (int i=0; i<(n); i++)
#define Rep(i, n) for (int i=1; i<=(n); i++)
#define range(x) (x).begin(), (x).end()
typedef long long LL;
typedef unsigned long long ULL;

#define pow owahgrauhgso

ULL n, k, m;

ULL powmod(ULL b, ULL e) {
  ULL r = 1;
  while (e) {
    if (e & 1) r = r * b % m;
    b = b * b % m;
    e >>= 1;
  }
  return r;
}

const int MAXN = 1000006;
ULL inv[MAXN];

void init_inv() {
  inv[1] = 1;
  for (int i = 2; i <= k+1; i++) { 
    inv[i] = (m - m / i) * inv[m % i] % m; 
    assert(inv[i] * i % m == 1);
  }
}

ULL pow[MAXN];
ULL prime[MAXN], cnt;

void sieve() {
  pow[1] = 1;
  for (int i = 2; i <= k+1; i++) {
    if (!pow[i]) {
      pow[i] = powmod(i, k);
      prime[cnt++] = i;
      for (int j = 0; j < cnt && i*prime[j] <= k+1; j++) {
        pow[i * prime[j]] = pow[i] * pow[prime[j]] % m;
        if (i % prime[j] == 0) break;
      }
    }
  }
}

ULL sum[MAXN];
ULL ans[MAXN];

auto addmod = [](ULL a, ULL b) -> ULL {return (a+b)%m;};

ULL lagrange() {
  ULL p;
  p = 1;
  for (int i=0; i<=k+1; i++) {
    if (i) p = p * inv[i] % m;
    ans[i] = (k+1-i)&1 ? m-sum[i] : sum[i];
    ans[i] = ans[i] * p % m;
    p = p * (m + n - i) % m;
  }
  p = 1;
  for (int i=k+1; i>=0; i--) {
    if (k+1-i) p = p * inv[k+1-i] % m;
    ans[i] = ans[i] * p % m;
    p = p * (m + n - i) % m;
  }
  return accumulate(ans, ans+k+2, 0, addmod); 
}

 

int main() { cin >> n >> k >> m; init_inv(); sieve(); partial_sum(pow, pow+k+2, sum, addmod); cout << lagrange() << endl; return 0; }