题目大意
给定 p , q , α , β p,q,\alpha,\betap,q,α,β ,以及 f i , 0 = a i , f 0 , i = b i f_{i,0}=a_i,f_{0,i}=b_ifi,0=ai,f0,i=bi ,设 γ = p α + q β \gamma=p\alpha+q\betaγ=pα+qβ 。
f i , j = p f i − 1 , j + q f i , j − 1 + γ f_{i,j}=pf_{i-1,j}+qf_{i,j-1}+\gammafi,j=pfi−1,j+qfi,j−1+γ, 已知 h hh,求
∑ i = 0 n ∑ j = 0 n f i , j × h i ( n + 1 ) + j \sum_{i=0}^n\sum_{j=0}^nf_{i,j}\times h^{i(n+1)+j}i=0∑nj=0∑nfi,j×hi(n+1)+j
模数为998244353
题解
考虑将 f ff 分开三部分处理。
求 γ \gammaγ 的贡献
容易发现一个点 ( i , j ) (i,j)(i,j) 给 ( x , y ) (x,y)(x,y) 的贡献其实就是 p pp 的向下走的步数次方乘 q qq 的向右走的步数次方再乘上路径数。即
∑ x = 1 n ∑ y = 1 n ∑ i = 0 x − 1 ∑ j = 0 y − 1 γ ( i + j i ) p i q j h x ( n + 1 ) + y \sum_{x=1}^n\sum_{y=1}^n\sum_{i=0}^{x-1}\sum_{j=0}^{y-1}\gamma\begin{pmatrix} i+j\\i \end{pmatrix}p^iq^jh^{x(n+1)+y}x=1∑ny=1∑ni=0∑x−1j=0∑y−1γ(i+ji)piqjhx(n+1)+y
为了方便处理,故 i , j i,ji,j 枚举的是 x − i , y − j x-i,y-jx−i,y−j 。同时我们设 H = h n + 1 H=h^{n+1}H=hn+1 。接下来有个常用的小技巧,就是改变一下枚举顺序。同时,我们拆开组合数,即
γ ∑ i = 0 n − 1 ∑ j = 0 n − 1 ( i + j ) ! [ p i ∑ x = i + 1 n H x i ! ] [ q j ∑ y = j + 1 n h y j ! ] \gamma\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}(i+j)!\biggl[\frac{p^i\sum_{x=i+1}^nH^x}{i!}\biggl]\biggl[\frac{q^j\sum_{y=j+1}^nh^y}{j!}\biggl]γi=0∑n−1j=0∑n−1(i+j)![i!pi∑x=i+1nHx][j!qj∑y=j+1nhy]
用等比数列求和后,我们可以发现最后两项分别只含 i , j i,ji,j ,于是可以用NTT优化卷积实现。
具体来说就是枚举 i + j i+ji+j 求出 F FF 后,将 F i F_iFi 乘上 i ! ⋅ γ i!\cdot \gammai!⋅γ。
求 a aa 的贡献
与上式同理,即
∑ k = 1 n a k ∑ x = k n ∑ y = 1 n ( x − k + y − 1 x − k ) p x − k q y H x h y \sum_{k=1}^na_k\sum_{x=k}^n\sum_{y=1}^n\begin{pmatrix}x-k+y-1\\x-k\end{pmatrix} p^{x-k}q^yH^xh^yk=1∑nakx=k∑ny=1∑n(x−k+y−1x−k)px−kqyHxhy
y − 1 y-1y−1 的原因是第一步一定往右走。接着我们考虑枚举 T = x − k T=x-kT=x−k,即
∑ T = 0 n − 1 ∑ y = 1 n ( T + y − 1 T ) p T q y ∑ x = T + 1 n H x a x − T \sum_{T=0}^{n-1}\sum_{y=1}^n\begin{pmatrix}T+y-1\\T\end{pmatrix}p^Tq^y\sum_{x=T+1}^nH^xa_{x-T}T=0∑n−1y=1∑n(T+y−1T)pTqyx=T+1∑nHxax−T
设 A T = ∑ x = T + 1 n H x a x − T A_T=\sum_{x=T+1}^nH^xa_{x-T}AT=∑x=T+1nHxax−T,可以用卷积预处理。
接着就跟上面一样,将只含 T TT 和只含 y yy 的放一起,就可以 开卷了 用NTT优化矩阵乘法。
求 b bb 的贡献
与求 a aa 的贡献同理,不再赘述。
注意
- 上面的贡献指的是对答案的贡献
- 不要忘记加上 f i , 0 , f 0 , i f_{i,0},f_{0,i}fi,0,f0,i 的贡献
Code
#include <cstdio>
#include <cstring>
#include <iostream>
#define ll long long
using namespace std;
const ll P = 998244353;
const int N = 1e6 + 5;
int n, m, rev[N];
ll h, H, p, q, alpha, beta, C, ans;
ll a[N], b[N], fa[N], fb[N], jc[N], F[N];
template <typename T> inline void read(T &ss) {
ss = 0;
char ch = getchar();
for (; !isdigit(ch); ch = getchar());
for (; isdigit(ch); ch = getchar()) ss = (ss << 1) + (ss << 3) + (ch ^ 48);
}
ll ksm(ll a, ll b) {
ll res = 1;
for (; b; b >>= 1, a = a * a % P)
if (b & 1) res = res * a % P;
return res;
}
ll calc(ll p, ll a, ll b) {
return (ksm(p, b + 1) - ksm(p, a) + P) % P * ksm(p - 1, P - 2) % P;
}
void NTT(ll *a, int type) {
for (int i = 0; i < m; ++i)
if (i < rev[i]) swap(a[i], a[rev[i]]);
for (int j = 1; j < m; j <<= 1) {
int tot = 0;
ll W = ksm(3, (P - 1) / (j << 1));
if (type == -1) W = ksm(W, P - 2);
for (int k = 0; k < m; k += (j << 1)) {
ll w = 1;
for (int i = 0; i < j; ++i) {
++tot;
ll ye = a[i + k], yo = a[i + j + k] * w % P;
a[i + k] = (ye + yo) % P;
a[i + j + k] = (ye - yo + P) % P;
w = w * W % P;
}
}
}
}
void work(ll a[], ll b[]) {
NTT(a, 1); NTT(b, 1);
for (int i = 0; i <= m; ++i)
F[i] = a[i] * b[i] % P;
NTT(F, -1);
for (int i = 0; i < m; ++i)
F[i] = F[i] * ksm(m, P - 2) % P;
}
int main() {
freopen("dream.in", "r", stdin);
freopen("dream.out", "w", stdout);
read(n), read(h), read(alpha), read(beta);
m = 2 * n;
for (int i = 1; i <= 30; ++i) {
if ((1 << i) > m) {
m = 1 << i;
break;
}
}
for (int i = 0; i < m; ++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? (m >> 1) : 0);
ll x, y;
read(x), read(y); p = x * ksm(y, P - 2) % P;
read(x), read(y); q = x * ksm(y, P - 2) % P;
for (int i = 1; i <= n; ++i) read(a[i]);
for (int i = 1; i <= n; ++i) read(b[i]);
C = (p * alpha % P + q * beta % P) % P;
H = ksm(h, n + 1);
ans = 0;
jc[0] = 1;
for (int i = 1; i <= 2 *n; ++i)
jc[i] = jc[i - 1] * (ll)i % P;
for (int i = 1; i <= n; ++i)
ans = (ans + a[i] * ksm(H, i) % P) % P;
for (int i = 1; i <= n; ++i)
ans = (ans + b[i] * ksm(h, i) % P) % P;
ll sum = 1;
for (int i = 0; i < n; ++i, sum = sum * p % P)
fa[i] = sum * ksm(jc[i], P - 2) % P * calc(H, i + 1, n) % P;
sum = 1;
for (int i = 0; i < n; ++i, sum = sum * q % P)
fb[i] = sum * ksm(jc[i], P - 2) % P * calc(h, i + 1, n);
work(fa, fb);
for (int i = 0; i < 2 * n; ++i)
ans = (ans + F[i] * jc[i] % P * C % P) % P;
memset(fa, 0, sizeof(fa)); memset(fb, 0, sizeof(fb));
for (int i = 0; i <= n; ++i) fa[i] = a[n - i];
fb[0] = 1;
for (int i = 1; i <= n; ++i) fb[i] = fb[i - 1] * H % P;
work(fa, fb);
memset(fa, 0, sizeof(fa)); memset(fb, 0, sizeof(fb));
sum = 1;
for (int i = 0; i < n; ++i, sum = sum * p % P)
fa[i] = F[i + n] * sum % P * ksm(jc[i], P - 2) % P;
ll sum2 = q; sum = h;
for (int i = 1; i <= n; ++i, sum = sum * h % P, sum2 = sum2 * q % P)
fb[i] = sum * sum2 % P * ksm(jc[i - 1], P - 2) % P;
work(fa, fb);
for (int i = 1; i < 2 * n; ++i)
ans = (ans + F[i] * jc[i - 1] % P) % P;
memset(fa, 0, sizeof(fa)); memset(fb, 0, sizeof(fb));
for (int i = 0; i <= n; ++i) fa[i] = b[n - i];
fb[0] = 1;
for (int i = 1; i <= n; ++i) fb[i] = fb[i - 1] * h % P;
work(fa, fb);
memset(fa, 0, sizeof(fa)); memset(fb, 0, sizeof(fb));
sum = 1;
for (int i = 0; i < n; ++i, sum = sum * q % P)
fa[i] = F[i + n] * sum % P * ksm(jc[i], P - 2) % P;
sum2 = p; sum = H;
for (int i = 1; i <= n; ++i, sum = sum * H % P, sum2 = sum2 * p % P)
fb[i] = sum * sum2 % P * ksm(jc[i - 1], P - 2) % P;
work(fa, fb);
for (int i = 1; i < 2 * n; ++i)
ans = (ans + F[i] * jc[i - 1] % P) % P;
printf("%lld", ans);
return 0;
}