DoorKickers的博客

博客

[AH2017/HNOI2017] 礼物

2020-04-19 12:22:55 By DoorKickers

首先, 两个手环增加非负整数亮度, 等价于其中一个增加一个整数亮度.

答案应该是 $$ \sum_{i = 1}^{n}(x_i - y_i + c)^2 $$ 展开 $$ \sum_{i = 1}^{n}(x_i^2 + y_i^2 - 2x_iy_i + 2x_ic - 2y_ic + c^2) \\= \sum_{i = 1}^{n}(x_i^2 + y_i^2 + c^2) + 2c\sum_{i = 1}^{n}{(x_i - y_i)} - 2\sum_{i = 1}^{n}x_iy_i $$ 然后发现我们只需要算最后一项. $$ \sum_{i = 1}^{n}x_iy_i $$ 设$x_0 = y_0 = 0$

考虑将$x$翻转, 然后左移一位 原式等于 $$ \sum_{i = 0}^{n}{x_{n - i}y_i} $$ 是个卷积.

考虑将x 破环成链(主要是卷积的性质)

假设

1 2 3 1 2 3 0 :x

0 4 5 6 0 0 0 :y

设结果为$z(x)$ $$ z_i = \sum_{k = 0}^{i}x_{i - k}y_k $$ 发现卷起来以后n ~ 2n - 1取个max就是答案. 代码:

#include <bits/stdc++.h>

using namespace std;

#define int long long

const int g = 3;
const int N = 1e6 + 10;
const int mod = 1004535809;

int omega[N], omegaInv[N], a[N], b[N], rev[N];

inline int fp(int a, int n, int mod)
{
    int res = 1;
    while (n)
    {
        if (n & 1) res = (res * a) % mod;
        n >>= 1;
        a = (a * a) % mod;
    }
    return res;
}

inline int inv(int x)
{
    return fp(x, mod - 2, mod);
}

inline void init(const int n)
{
    int k = 0;
    while ((1 << k) < n) k++;
    int x = fp(g, (mod - 1) / n, mod);
    omega[0] = omegaInv[0] = 1;
    for (int i = 1; i < n; i++)
    {
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << k - 1);
        omega[i] = (omega[i - 1] * x) % mod;
        omegaInv[i] = inv(omega[i]);
    }
}

inline void ntt(int *a, const int n, int *omega)
{
    for (int i = 0; i < n; i++)
    {
        if (i > rev[i]) continue;
        swap(a[i], a[rev[i]]);
    }
    for (int l = 2; l <= n; l *= 2)
    {
        int m = l / 2;
        for (int *p = a; p != a + n; p += l)
        {
            for (int i = 0; i < m; i++)
            {
                int t = (omega[n / l * i] * p[m + i]) % mod;
                p[m + i] = (p[i] - t + mod) % mod;
                p[i] = (p[i] + t) % mod;
            }
        }
    }
}

inline void dft(int *a, const int n)
{
    ntt(a, n, omega);
}

inline void idft(int *a, const int n)
{
    ntt(a, n, omegaInv);
    int x = inv(n);
    for (int i = 0; i < n; i++)
        a[i] = (a[i] * x) % mod;
}

inline void multiply(int *a1, const int n1, int *a2, const int n2)
{
    int n = 1;
    while (n < n1 + n2 + 2) n *= 2;
    init(n);
    dft(a1, n);
    dft(a2, n);
    for (int i = 0; i < n; i++)
        a1[i] = (a1[i] * a2[i]) % mod;
    idft(a1, n);
    idft(a2, n);
}

inline void squared(int *a, int n1)
{
    int n = 1;
    while (n < n1 * 2 + 2) n *= 2;
    init(n);
    dft(a, n);
    for (int i = 0; i < n; i++)
        a[i] = (a[i] * a[i]) % mod;
    idft(a, n);
}

inline void poly_print(int *a, int n)
{
    for (int i = 0; i <= n; i++)
        cout << a[i] << ' ';
    cout << endl;
}

inline void test()
{
    double m;
    cin >> m;
    int t = m;
    cout << "fk " << t << endl;
    int n = floor(m);
    cout << n << endl;
    // n = ceil(m);
    cout << n << endl;
    exit(0);
}


signed main()
{
    // freopen("input.txt", "r", stdin);
    // test();
    int n, m;
    cin >> n >> m;
    int X = 0;
    int Y = 0;
    int XX = 0;
    int YY = 0;
    for (int i = 0; i < n; i++)
    {
        cin >> a[i];
        X += a[i];
        XX += fp(a[i], 2, mod);
        a[i + n] = a[i];
    }
    reverse(a, a + n * 2);
    // for (int i = 0; i < n * 2; i++)
    // {
    //     cout << a[i] << ' ';
    // }
    // cout << endl;
    for (int i = 1; i <= n; i++)
    {
        cin >> b[i];
        Y += b[i];
        YY += fp(b[i], 2, mod);
    }
    int ans1 = ((Y - X) / n) + 1;
    // cout << "fk " << (double)((Y - X) / n);
    int ans2 = (Y - X) / n;
    int ans3 = (Y - X) / n - 1;
    // // cout << endl;
    // cout << "ans1 " << ans1 << endl;
    // cout << "ans2 " << ans2 << endl;
    // cout << "ans3 " << ans3 << endl;
    // cout << "x " << (Y - X) << endl;
    // cout << "n " << n << endl;
    // cout << fixed << setprecision(50) << (double)((Y - X) / n) << endl;
    // double t1 = Y - X;
    // double t2 = n;
    // cout << t1 / t2 << endl;
    // cout << "double " << (double)((Y - X) / n) << endl;
    // cout << (int)(ceil((double)((Y - X) / n))) << endl;
    int res1 = XX + YY;
    int res2 = min(ans1 * 2 * (X - Y) + n * ans1 * ans1, ans2 * 2 * (X - Y) + n * ans2 * ans2);
    res2 = min(res2, ans3 * 2 * (X - Y) + n * ans3 * ans3);
    int len = n * 2 - 1;
    multiply(a, len, b, len);
    int res3 = 0;
    for (int i = n; i < 2 * n; i++)
        res3 = max(res3, a[i]);
    cout << res1 + res2 - 2 * res3 << endl;
    return 0;
}

评论

暂无评论

发表评论

可以用@mike来提到mike这个用户,mike会被高亮显示。如果你真的想打“@”这个字符,请用“@@”。