BEN2のブログ

たまに書いています。

ABC 156 D - Bouquet

ABC 156 D - Bouquet に関連する、逆元、二項係数についてのメモを残します。

参考

mod の世界における "逆数"

はじめに、mod とは何なのか、については、

この動画で解説されています。

mod の世界で割り算することを考えたい。
実数の世界では、 a \div b を計算するとき、\displaystyle a \times \frac{1}{b} という掛け算に変えることができる。
mod の世界での割り算も、mod の世界における "逆数" を用意して掛け算に変えたい。

1 次方程式  ax \equiv b \, ({\rm{mod}} \, p) において、
 aa^{-1} \equiv 1 \, ({\rm{mod}} \, p) を満たす  a^{-1} (これを 逆元 という) が存在するならば、
 a^{-1} \cdot ax \equiv a^{-1} \cdot b \, ({\rm{mod}} \, p) から  x \equiv a^{-1}b \, ({\rm{mod}} \, p) となる。
このように、 a で割ることは、 a^{-1} を掛けることに変えることができる。
では、 a^{-1} はどう求めるのか。

ここで、フェルマーの小定理を利用する。

フェルマーの小定理 :
素数  p、任意の整数  a に対して、 a^{p} \equiv a \, ({\rm{mod}} \, p) が成り立つ。
とくに、 a p で割り切れないとき、 a^{p-1} \equiv 1 \, ({\rm{mod}} \, p) が成り立つ。

これより  a \cdot a^{p-2} \equiv 1 \, ({\rm{mod}} \, p) となり、 a \, ({\rm{mod}} \, p) の逆元が  a^{p-2} \, ({\rm{mod}} \, p) であることがわかる。
つまり  {\rm{mod}} \, p の世界において、 a で割ることは、 a^{p-2} を掛けることに変えることができる。
整数  a素数  p で割り切れない、という制約を忘れず。

 {}_n {\rm{C}} _k \, ({\rm{mod}} \, p)

逆元を利用して  {}_n {\rm{C}} _k \, ({\rm{mod}} \, p) を計算する実装を考える。

\displaystyle {}_n {\rm{C}} _k = \frac{n !}{k ! \, (n - k)!} = n ! \cdot (k !)^{-1} \cdot ((n - k)!)^{-1}

であるから、前処理として階乗、階乗の逆元をメモしておけば、 {}_n {\rm{C}} _k \, ({\rm{mod}} \, p) O(1) で計算できる。
では、 n = 1, \, 2, \, 3, \, \cdots の順にメモしていく。
階乗のメモの説明は省略。
階乗の逆元は、 (n !)^{-1} = ((n - 1)!)^{-1} \cdot n^{-1} の関係から順に求まるが、 n の逆元の計算が必要である。
 {\rm{mod}} \, p における  n の逆元  n^{p-2} は、繰り返し自乗法により  O(\log p) で計算できるので、都度計算することにする。
ちなみに、 n の逆元を拡張 Euclid の互除法により  O(1) で計算できるらしく、

この記事で解説されています。

以下、ABC 034 C - 経路 の提出コード。
問題概要 :  {}_{W+H-2} {\rm{C}} _{W-1} \, ({\rm{mod}} \, 10^{9} + 7) を出力  ( 2 \leq W, H \leq 10^{5} )

#include <bits/stdc++.h>
using namespace std;

// 階乗、階乗の逆元のメモ
vector<long long> fac, fac_inv;

// a^b mod p を返す (繰り返し自乗法)
long long modpow(long long a, long long b, long long p) {
  if (b == 0) return 1;
  long long res = modpow(a, b / 2, p);
  if (b % 2 == 0) res = (res * res) % p;
  else res = (((res * res) % p) * a) % p;
  return res;
}

// mod p における a の逆元 a^{p-2} mod p を返す
long long inverse(long long a, long long p) {
  return modpow(a, p-2, p);
}

// nCk (mod p) の前処理 O(n log p)
void comInit(int n, int p) {
  fac.resize(n+1);
  fac_inv.resize(n+1);
  fac.at(0) = fac.at(1) = 1;
  fac_inv.at(0) = fac_inv.at(1) = 1;
  for (int i = 2; i <= n; i++) {
    fac.at(i) = fac.at(i-1) * i % p;
    fac_inv.at(i) = fac_inv.at(i-1) * inverse(i, p) % p;
  }
}

// nCk mod p を返す O(1)
long long com(int n, int k, int p) {
  if (n < k || n < 0 || k < 0) return 0;
  return fac.at(n) * (fac_inv.at(k) * fac_inv.at(n-k) % p) % p;
}

int main() {
  const int MOD = 1000000007;
  int W, H; cin >> W >> H;
  comInit(W+H, MOD);  // 前処理
  cout << com(W+H-2, W-1, MOD) << endl;  // W+H-2_C_W-1 (mod MOD)
}

 n \leq 10^{9}, \, k \leq 2 \times 10^{5}

出力すべき値は、 2^{n} - 1 - {}_n {\rm{C}} _a - {}_n {\rm{C}} _b \, ({\rm{mod}} \, 10^{9} + 7)

先ほどのコードは前処理の計算量が  O( n \log p) であるため、この問題のように  n が最大で  10^{9} となる場合には使えない。
しかし、 n は固定であり、 k は最大でも  2 \times 10^{5} と小さいため、

\displaystyle {}_n {\rm{C}} _k = \frac{n}{1} \cdot \frac{n-1}{2} \cdot \frac{n-2}{3} \cdots \frac{n-k+1}{k}

とみることで、

\displaystyle {}_n {\rm{C}} _2 = {}_n {\rm{C}} _1 \cdot \frac{n-1}{2}, \, {}_n {\rm{C}} _3 = {}_n {\rm{C}} _2 \cdot \frac{n-2}{3}, \, \cdots

のようにして、 {}_n {\rm{C}} _{1}, \, {}_n {\rm{C}} _{2}, \cdots , \, {}_n {\rm{C}} _{k} の順にメモしていけばよい。

#include <bits/stdc++.h>
using namespace std;

// com[k] := nCk の値をメモ
vector<long long> com;

// a^b mod p を返す
long long modpow(long long a, long long b, long long p) {
  if (b == 0) return 1;
  long long res = modpow(a, b / 2, p);
  if (b % 2 == 0) res = (res * res) % p;
  else res = (((res * res) % p) * a) % p;
  return res;
}

// mod p における a の逆元 a^{p-2} mod p を返す
long long inverse(long long a, long long p) {
  return modpow(a, p-2, p);
}

// nCk (mod p) の前処理 O(k log p)
void comInit(long long n, long long k, long long p) {
  com.resize(k+1);
  com.at(0) = 1;
  long long tmp = 1;
  for (int i = 1; i <= k; i++) {
    tmp = ((tmp * (n-i+1) % p) * inverse(i, p)) % p;
    com.at(i) = tmp;
  }
}

// val を p で割った余り (>= 0) を返す
long long mod(long long val, long long p) {
  long long res = val % p;
  if (res < 0) res += p;
  return res;
}

int main() {
  const int MOD = 1000000007;
  int n, a, b; cin >> n >> a >> b;
  comInit(n, 200000, MOD);  // 前処理
  long long ans = modpow(2, n, MOD);
  cout << mod(ans - 1 - com.at(a) - com.at(b), MOD) << endl;
}