题目大意
我们讲一个悲伤的故事。 从前有一个贫穷的樵夫在河边砍柴。 这时候河里出现了一个水神,夺过了他的斧头,说: “这把斧头,是不是你的?” 樵夫一看:“是啊是啊!” 水神把斧头扔在一边,又拿起一个东西问: “这把斧头,是不是你的?” 樵夫看不清楚,但又怕真的是自己的斧头,只好又答:“是啊是啊!” 水神又把手上的东西扔在一边,拿起第三个东西问: “这把斧头,是不是你的?” 樵夫还是看不清楚,但是他觉得再这样下去他就没法砍柴了。 于是他又一次答:“是啊是啊!真的是!” 水神看着他,哈哈大笑道: “你看看你现在的样子,真是丑陋!” 之后就消失了。 樵夫觉得很坑爹,他今天不仅没有砍到柴,还丢了一把斧头给那个水神。 于是他准备回家换一把斧头。 回家之后他才发现真正坑爹的事情才刚开始。 水神拿着的的确是他的斧头。 但是不一定是他拿出去的那把,还有可能是水神不知道怎么偷偷从他家里拿走的。 换句话说,水神可能拿走了他的一把,两把或者三把斧头。 樵夫觉得今天真是倒霉透了,但不管怎么样日子还得过。 他想统计他的损失。 樵夫的每一把斧头都有一个价值,不同斧头的价值不同。总损失就是丢掉的斧头价值和。 他想对于每个可能的总损失,计算有几种可能的方案。 注意:如果水神拿走了两把斧头a和b,(a,b)和(b,a)视为一种方案。拿走三把斧头时,(a,b,c),(b,c,a),(c,a,b),(c,b,a),(b,a,c),(a,c,b)视为一种方案。
随笔瞎写
考虑一个生成函数 $$ A = a_0 + a_1x + a_2x^2 + a_3x^3 + \cdots + a_nx^n $$ 把系数看成方案数, 把指数看成价值
两个多项式的卷积为 $$ C = A \times B = \sum_{i, j} a_i \cdot b_j $$ 看一个简单的例子 $$ A = 2x + x^2 \\B = 3x + 2x^2 \\C = 6x^2 + 4x^3 + 3x^3 + 2x^4 = 6x^2 + 7x^3+ 2x^4 $$ 显然, 卷积后的多项式可以代表从$A, B$ 里各选出一个物品, 然后总价值达到某一个数的方案数
这个题中, 我们要从$A$里选出两个物品, 就可以自己对自己进行卷积 $$ A = 2x + x^2 \\A \times A = 4x^2 + 4x^3 + x^4 $$ 但是存在一个物品被选中两次的情况
我们设$B$为每个物品取两次的情况 $$ B = 2x^2 + x^4 $$ 显然, 最终答案为$A \times A - B $ .
同理我们考虑从$A$中选3个物品, $$ A \times A \times A = $$ 发现可能会出现一个物品被选中多次的情况
枚举一下哪个物品被多选了, 设为$x$
发现可能会出现$(x, x, x), (x, x, y), (y, x, x), (x, y, x)$ 这几种情况
我们设$D = A \times B$
发现可以得到$(x, x, x), (y, x, x)$ 这两种情况
我们可以把结果乘三, 但发现$(x, x, x) $减多了, 那就再加回来
设$C$为选三个物品的情况, 再加回来两次就好了
最终答案是三种情况加起来, 又因为不同顺序算一种, 所以要除一下 $A + \frac{A \times A - B}{2!} + \frac{A \times A \times A - A \times B \cdot 3 + 2 \cdot C }{3!}$
答案就算出来了
用FFT或者NTT加速
复杂度是$O($不会分析但好像能过把$)$
SB代码
/*
Date&Time: 2:37pm 3/16/2020]
*/
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 3e6 + 10;
const int mod = 998244353;
const int g = 3;
int omega[N], omegaInv[N], rev[N], ta[N], tta[N], tab[N], a[N], b[N], c[N];
int n;
inline int fp(int a, int n, int mod)
{
int res = 1;
while (n)
{
if (n & 1) res = (res * a) % mod;
a = (a * a) % mod;
n >>= 1;
}
return res % mod;
}
inline int inv(int x)
{
return fp(x, mod - 2, mod);
}
inline void init(const int n)
{
int k = 0;
while ((1 << k) < n) k++;
int x = fp(g, (mod - 1) / n, mod);
omega[0] = omegaInv[0] = 1;
for (int i = 1; i < n; i++)
{
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (k - 1));
omega[i] = (omega[i - 1] * x) % mod;
omegaInv[i] = inv(omega[i]);
}
}
inline void ntt(int *a, const int n, int *omega)
{
for (int i = 0; i < n; i++)
{
if (i > rev[i]) continue;
swap(a[i], a[rev[i]]);
}
for (int l = 2; l <= n; l *= 2)
{
int m = l / 2;
for (int *p = a; p != a + n; p += l)
{
for (int i = 0; i < m; i++)
{
int t = (omega[n / l * i] * p[m + i]) % mod;
p[m + i] = (p[i] - t + mod) % mod;
p[i] = (p[i] + t) % mod;
}
}
}
}
inline void dft(int *a, const int n)
{
ntt(a, n, omega);
}
inline void idft(int *a, const int n)
{
ntt(a, n, omegaInv);
int x = inv(n);
for (int i = 0; i < n; i++)
a[i] = (a[i] * x) % mod;
}
inline void poly_multiply(int *a1, int n1, int *a2, int n2)
{
int n = 1;
while (n < n1 + n2 + 2) n *= 2;
init(n);
dft(a1, n), dft(a2, n);
for (int i = 0; i < n; i++)
a1[i] = (a1[i] * a2[i]) % mod;
idft(a1, n);
idft(a2, n);
}
inline void poly_squred(int *a, int n)
{
int n1 = 1;
while (n1 < n * 2 + 2) n1 *= 2;
init(n1);
dft(a, n1);
for (int i = 0; i < n1; i++)
a[i] = (a[i] * a[i]) % mod;
idft(a, n1);
}
inline void poly_plus(int *a1, int n1 , int *a2, int n2, bool fl)
{
int n = max(n1, n2);
for (int i = 0; i <= n; i++)
a1[i] = fl ? (a1[i] + a2[i]) % mod : (a1[i] - a2[i] + mod) % mod;
}
inline void poly_devide(int *a, int n, int v)
{
int x = inv(v);
for (int i = 0; i <= n; i++)
a[i] = (a[i] * x) % mod;
}
inline void poly_times(int *a, int n, int v)
{
for (int i = 0; i <= n; i++)
a[i] = (a[i] * v) % mod;
}
int ans[N];
inline void print_poly(int* a, int n)
{
for (int i = 0; i <= n; i++)
{
cout << a[i] << ' ';
}
cout << endl;
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);
cin >> n;
int siz = 0;
for (int i = 1; i <= n; i++)
{
int x;
cin >> x;
siz = max(siz, x * 3);
a[x]++, b[x * 2]++, c[x * 3]++;
ta[x]++, tta[x]++, tab[x]++;
}
poly_squred(ta, siz);
poly_multiply(tta, siz, ta, siz);
poly_multiply(tab, siz, b, siz);
poly_times(tab, siz, 3);
poly_times(c, siz, 2);
poly_plus(ans, siz, a, siz, 1);
poly_plus(ta, siz, b, siz, 0);
poly_devide(ta, siz, 2);
poly_plus(ans, siz, ta, siz, 1);
poly_plus(tta, siz, tab, siz, 0);
poly_plus(tta, siz, c, siz, 1);
poly_devide(tta, siz, 6);
poly_plus(ans, siz, tta, siz, 1);
for (int i = 0; i <= siz; i++)
{
if (ans[i])
cout << i << ' ' << ans[i] << endl;
}
return 0;
}
/*
AC
*/