مدونة fast_photon

المدونات

一种不太快速计算多个少项式集合幂级数异或卷积的方法

2025-12-18 16:23:45 By fast_photon

首先求 ${\prod}_\oplus(a_ix^{\emptyset}+b_ix^{S_i})$,其中二项式个数为 $n$,$S_i\subseteq \{1,2,\cdots,m\}$。

考虑 FWT,要将 $\operatorname{FWT}(a_ix^{\emptyset}+b_ix^{S_i})$ 对位乘得到 $g(x)$ 之后 IFWT。

$\begin{aligned} & [x^T]g(x) \\ = & \prod [x^T]\operatorname{FWT}(a_i x^{\emptyset}+b_ix^{S_i}) \\ = & \prod (a_i+(-1)^{|S_i\cap T|}b_i) \\ = & \exp(\sum \ln (a_i+(-1)^{|S_i\cap T|}b_i)) \\ = & \exp(\sum [x^T] \operatorname{FWT}(\dfrac {\ln(a_i+b_i)+\ln(a_i-b_i)}{2} x^{\emptyset} + \dfrac {\ln(a_i+b_i)-\ln(a_i-b_i)}{2} x^{S_i})) \\ = & \exp([x^T]\operatorname{FWT}(\sum(\dfrac {\ln(a_i+b_i)+\ln(a_i-b_i)}{2} x^{\emptyset} + \dfrac {\ln(a_i+b_i)-\ln(a_i-b_i)}{2} x^{S_i})))\end{aligned}$

只要对数能求就能直接做。

如果有较多项呢?对于 $\sum a_{i,j}x^{S_{i,j}}$,对于每个 $i$,找一组 $F_2^m$ 中的基底,它们张成一个秩为 $2^{k_i}$ 的子空间。把 $S_{i,j}$ 变换到 $S'_{i,j}$,在子空间里做 FWT,取对数,做 IFWT,再把点值加到原空间的对应位置(具体来说,对于一个子空间的向量 $v$,取出所有 $v_j=1$ 的 $S_{i,j}$,将其 $\operatorname{xor}$ 起来),做 FWT,取指数,做 IFWT。

复杂度 $O(m2^m+\sum_i k_i2^{k_i}+L(\sum_i 2^{k_i}))$,其中 $L(x)$ 是计算 $x$ 个数的对数的复杂度,$k_i$ 是所有 $S_{i,j}$ 张成子空间的大小。

不能取对数的时候,考虑额外记录 $0$ 的指数。对于 $x\neq 0$,记 $\log x=(\ln x,0)$,记 $\ln 0=(0,1)$ 即可。做完原空间上的 FWT 后,若某个向量是 $(x,0)$,将其替换为数值 $\exp(x)$,否则替换为数值 $0$,再 IFWT 即可。

以下是示例代码。由于没有高消,其累加复杂度是 $O(2^{m_i})$,$m_i$ 是第 $i$ 个少项式的项数。可以在全局常量里调 $m$ 和模数,也可以改成输入模数。使用了较为朴素的方法求原根,在示例代码中这可能成为复杂度瓶颈。

#include<iostream>
#include<cstdio>
#include<fstream>
#include<queue>
#include<algorithm>
#include<cstring>
#include<random>
#include<map>
#include<cassert>
#include<stack>
#define fopen(x, y) freopen(x".in", "r", stdin); freopen(y".out", "w", stdout);
#define int long long
#define pii pair<int, int>
#define fi first
#define se second
#ifdef int
#define inf 0x3f3f3f3f3f3f3f3fll
#else
#define inf 0x3f3f3f3f
#endif

#define maxB 65536
#define maxm 1048576 

using namespace std;

//int m, M, mod, g, b, B, Mod;
const int m = 20, M = 1ll << m, mod = 998244353, g = 3, b = 16, B = 65536, Mod = 998244352ll * M;

inline int qpow(int x, int y) { int xum = 1; while(y) { if(y & 1) (xum *= x) %= mod; (x *= x) %= mod; y >>= 1; } return xum; }

int n, a, s[maxm], x[105], w[105], hfi[maxm], hse[maxm], ffi[maxm], fse[maxm];
int l3, l_1;

//void init_0();
void init_1();
int gxp(int x);
int gln(int x);

void fwt(int *f, int m, const int mod) {
    const int M = 1ll << m;
    for(int j = 0; j < m; j++) {
        int L = 1ll << (j + 1), bL = 1ll << j;
        for(int t = 0; t < M; t += L) {
            int *x = &f[t], *y = x + bL;
            for(int i = 0; i < bL; i++, x++, y++) {
                (*x += *y) %= mod;
                *y = (*x - (*y << 1)) % mod;
            }
        }
    }
}

void work() {
//    mod = 998244353; init_0();

    init_1();

    cin >> n;

    for(int i = 1; i <= n; i++) {
        int k;
        cin >> k;
        const int K = 1ll << k;
        for(int j = 0; j < K; j++) hfi[j] = hse[j] = 0;
        for(int j = 0; j < k; j++) {
            cin >> x[j] >> w[j];
            hfi[1ll << j] = w[j];
        }
        s[0] = 0;
        for(int j = 1; j < K; j++) {
            s[j] = s[j ^ (j & -j)] ^ x[__builtin_ctzll(j & -j)];
        }
        fwt(hfi, k, mod);
        for(int j = 0; j < K; j++) {
            (hfi[j] += mod) %= mod;
            if(hfi[j] == 0) hse[j] = 1ll << (m - k);
            else hfi[j] = (gln(hfi[j]) << (m - k)) % Mod;
        }
        fwt(hfi, k, Mod);
        fwt(hse, k, Mod);
        for(int j = 0; j < K; j++) {
            (ffi[s[j]] += hfi[j]) %= Mod;
            (fse[s[j]] += hse[j]) %= Mod;
        }
    }

    fwt(ffi, m, Mod);
    fwt(fse, m, Mod);

    int sum = 0, iv = qpow(M, mod - 2);
    for(int i = 0; i < M; i++) {
        if(fse[i]) {
            ffi[i] = 0;
            continue;
        }
        (ffi[i] >>= m) %= (mod - 1);
        if(ffi[i] < 0) ffi[i] += mod - 1;
        ffi[i] = gxp(ffi[i]);
    }
    fwt(ffi, m, mod);
    for(int i = 0; i < M; i++) {
        (ffi[i] *= iv) %= mod;
        cout << (ffi[i] + mod) % mod << ' ';
    }
    cout << '\n';
}

signed main() {
    ios::sync_with_stdio(false); cin.tie(0);
    int _ = 1;
    while(_--) work();
}

/*
void init_0() {
    Mod = (mod - 1) * maxn;
    vector<int> t;
    int s = mod - 1;
    B = 1, b = 0;
    while(B * B < mod) B <<= 1, b++;
    for(int i = 2; i * i <= s; i++) {
        if(s % i == 0) {
            while(s % i == 0) s /= i;
            t.push_back(i);
        }
    }
    for(int i = 2; i < mod; i++) {
        int flg = 1;
        for(int d : t) {
            if(qpow(i, (mod - 1) / d) == 1) {
                flg = 0;
                break;
            }
        }
        if(flg) {
            g = i;
            break;
        }
    }
}
//*/

int gb, gk[maxB], gbk[maxB];
pii gt[maxB + 1];
int find(int x) {
    pii *p = lower_bound(gt, gt + B, make_pair(x, 0ll));
    if(p->fi == x) return p->se;
    return -1;
}
int gln(int x) {
    int cnt = 0, p = find(x);
    while(p == -1) {
        (x *= gb) %= mod;
        cnt++;
        p = find(x);
    }
    return ((p - cnt * B) % (mod - 1) + (mod - 1)) % (mod - 1);
}
void init_1() {
    gb = 1;
    gk[0] = 1;
    for(int i = 1; i < B; i++) gk[i] = gk[i - 1] * g % mod;
    gb = gk[B - 1] * g % mod;
    gbk[0] = 1;
    for(int i = 1; i < B; i++) gbk[i] = gbk[i - 1] * gb % mod;

    for(int i = 0; i < B; i++) gt[i] = make_pair(gk[i], i);
    sort(gt, gt + B); 
    gt[B] = make_pair(mod, -1);
}
int gxp(int x) {
    return gk[x & (B - 1)] * gbk[x >> b] % mod;
}
fast_photon Avatar

fast_photon