【题解&总结】gmoj 5740. 幻想世界

题目大意

给定 p , q , α , β p,q,\alpha,\betap,q,α,β ,以及 f i , 0 = a i , f 0 , i = b i f_{i,0}=a_i,f_{0,i}=b_ifi,0=ai,f0,i=bi ,设 γ = p α + q β \gamma=p\alpha+q\betaγ=pα+qβ
f i , j = p f i − 1 , j + q f i , j − 1 + γ f_{i,j}=pf_{i-1,j}+qf_{i,j-1}+\gammafi,j=pfi1,j+qfi,j1+γ, 已知 h hh,求
∑ i = 0 n ∑ j = 0 n f i , j × h i ( n + 1 ) + j \sum_{i=0}^n\sum_{j=0}^nf_{i,j}\times h^{i(n+1)+j}i=0nj=0nfi,j×hi(n+1)+j
模数为998244353

题解

考虑将 f ff 分开三部分处理。

γ \gammaγ 的贡献

容易发现一个点 ( i , j ) (i,j)(i,j)( x , y ) (x,y)(x,y) 的贡献其实就是 p pp 的向下走的步数次方乘 q qq 的向右走的步数次方再乘上路径数。即
∑ x = 1 n ∑ y = 1 n ∑ i = 0 x − 1 ∑ j = 0 y − 1 γ ( i + j i ) p i q j h x ( n + 1 ) + y \sum_{x=1}^n\sum_{y=1}^n\sum_{i=0}^{x-1}\sum_{j=0}^{y-1}\gamma\begin{pmatrix} i+j\\i \end{pmatrix}p^iq^jh^{x(n+1)+y}x=1ny=1ni=0x1j=0y1γ(i+ji)piqjhx(n+1)+y
为了方便处理,故 i , j i,ji,j 枚举的是 x − i , y − j x-i,y-jxi,yj 。同时我们设 H = h n + 1 H=h^{n+1}H=hn+1 。接下来有个常用的小技巧,就是改变一下枚举顺序。同时,我们拆开组合数,即
γ ∑ i = 0 n − 1 ∑ j = 0 n − 1 ( i + j ) ! [ p i ∑ x = i + 1 n H x i ! ] [ q j ∑ y = j + 1 n h y j ! ] \gamma\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}(i+j)!\biggl[\frac{p^i\sum_{x=i+1}^nH^x}{i!}\biggl]\biggl[\frac{q^j\sum_{y=j+1}^nh^y}{j!}\biggl]γi=0n1j=0n1(i+j)![i!pix=i+1nHx][j!qjy=j+1nhy]
用等比数列求和后,我们可以发现最后两项分别只含 i , j i,ji,j ,于是可以用NTT优化卷积实现。
具体来说就是枚举 i + j i+ji+j 求出 F FF 后,将 F i F_iFi 乘上 i ! ⋅ γ i!\cdot \gammai!γ

a aa 的贡献

与上式同理,即
∑ k = 1 n a k ∑ x = k n ∑ y = 1 n ( x − k + y − 1 x − k ) p x − k q y H x h y \sum_{k=1}^na_k\sum_{x=k}^n\sum_{y=1}^n\begin{pmatrix}x-k+y-1\\x-k\end{pmatrix} p^{x-k}q^yH^xh^yk=1nakx=kny=1n(xk+y1xk)pxkqyHxhy
y − 1 y-1y1 的原因是第一步一定往右走。接着我们考虑枚举 T = x − k T=x-kT=xk,即
∑ T = 0 n − 1 ∑ y = 1 n ( T + y − 1 T ) p T q y ∑ x = T + 1 n H x a x − T \sum_{T=0}^{n-1}\sum_{y=1}^n\begin{pmatrix}T+y-1\\T\end{pmatrix}p^Tq^y\sum_{x=T+1}^nH^xa_{x-T}T=0n1y=1n(T+y1T)pTqyx=T+1nHxaxT
A T = ∑ x = T + 1 n H x a x − T A_T=\sum_{x=T+1}^nH^xa_{x-T}AT=x=T+1nHxaxT,可以用卷积预处理。
接着就跟上面一样,将只含 T TT 和只含 y yy 的放一起,就可以 开卷了 用NTT优化矩阵乘法。

b bb 的贡献

与求 a aa 的贡献同理,不再赘述。

注意

  1. 上面的贡献指的是对答案的贡献
  2. 不要忘记加上 f i , 0 , f 0 , i f_{i,0},f_{0,i}fi,0,f0,i 的贡献

Code

#include <cstdio>
#include <cstring>
#include <iostream>
#define ll long long
using namespace std;
const ll P = 998244353;
const int N = 1e6 + 5;
int n, m, rev[N];
ll h, H, p, q, alpha, beta, C, ans;
ll a[N], b[N], fa[N], fb[N], jc[N], F[N];
template <typename T> inline void read(T &ss) {
	ss = 0;
	char ch = getchar();
	for (; !isdigit(ch); ch = getchar());
	for (; isdigit(ch); ch = getchar()) ss = (ss << 1) + (ss << 3) + (ch  ^ 48);
}
ll ksm(ll a, ll b) {
	ll res = 1;
	for (; b; b >>= 1, a = a * a % P)
		if (b & 1) res = res * a % P;
	return res;
}
ll calc(ll p, ll a, ll b) {
	return (ksm(p, b + 1) - ksm(p, a) + P) % P * ksm(p - 1, P - 2) % P;
}
void NTT(ll *a, int type) {
	for (int i = 0; i < m; ++i)
		if (i < rev[i]) swap(a[i], a[rev[i]]);
	for (int j = 1; j < m; j <<= 1) {
		int tot = 0;
		ll W = ksm(3, (P - 1) / (j << 1));
		if (type == -1) W = ksm(W, P - 2);
		for (int k = 0; k < m; k += (j << 1)) {
			ll w = 1;
			for (int i = 0; i < j; ++i) {
				++tot;
				ll ye = a[i + k], yo = a[i + j + k] * w % P;
				a[i + k] = (ye + yo) % P;
				a[i + j + k] = (ye - yo + P) % P;
				w = w * W % P;
			}
		}
	}
}
void work(ll a[], ll b[]) {
	NTT(a, 1); NTT(b, 1);
	for (int i = 0; i <= m; ++i)
		F[i] = a[i] * b[i] % P;
	NTT(F, -1);
	for (int i = 0; i < m; ++i)
		F[i] = F[i] * ksm(m, P - 2) % P;
}
int main() {
	freopen("dream.in", "r", stdin);
	freopen("dream.out", "w", stdout);
	read(n), read(h), read(alpha), read(beta);
	m = 2 * n;
	for (int i = 1; i <= 30; ++i) {
		if ((1 << i) > m) {
			m = 1 << i;
			break;
		}
	}
	for (int i = 0; i < m; ++i)
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? (m >> 1) : 0);
	ll x, y;
	read(x), read(y); p = x * ksm(y, P - 2) % P;
	read(x), read(y); q = x * ksm(y, P - 2) % P;
	for (int i = 1; i <= n; ++i) read(a[i]);
	for (int i = 1; i <= n; ++i) read(b[i]);
	C = (p * alpha % P + q * beta % P) % P;
	H = ksm(h, n + 1);
	ans = 0;
	jc[0] = 1;
	for (int i = 1; i <= 2 *n; ++i)
		jc[i] = jc[i - 1] * (ll)i % P;
		
	for (int i = 1; i <= n; ++i)
		ans = (ans + a[i] * ksm(H, i) % P) % P;
	for (int i = 1; i <= n; ++i)
		ans = (ans + b[i] * ksm(h, i) % P) % P;
	
	ll sum = 1;
	for (int i = 0; i < n; ++i, sum = sum * p % P)
		fa[i] = sum * ksm(jc[i], P - 2) % P * calc(H, i + 1, n) % P;
	sum = 1;
	for (int i = 0; i < n; ++i, sum = sum * q % P)
		fb[i] = sum * ksm(jc[i], P - 2) % P * calc(h, i + 1, n);
	work(fa, fb);
	for (int i = 0; i < 2 * n; ++i)
		ans = (ans + F[i] * jc[i] % P * C % P) % P;
	memset(fa, 0, sizeof(fa)); memset(fb, 0, sizeof(fb));
	for (int i = 0; i <= n; ++i) fa[i] = a[n - i];
	fb[0] = 1;
	for (int i = 1; i <= n; ++i) fb[i] = fb[i - 1] * H % P;
	work(fa, fb);
	memset(fa, 0, sizeof(fa)); memset(fb, 0, sizeof(fb));
	sum = 1;
	for (int i = 0; i < n; ++i, sum = sum * p % P)
		fa[i] = F[i + n] * sum % P * ksm(jc[i], P - 2) % P;
	ll sum2 = q; sum = h;
	for (int i = 1; i <= n; ++i, sum = sum * h % P, sum2 = sum2 * q % P)
		fb[i] = sum * sum2 % P * ksm(jc[i - 1], P - 2) % P;
	work(fa, fb);
	for (int i = 1; i < 2 * n; ++i)
		ans = (ans + F[i] * jc[i - 1] % P) % P;
	
	memset(fa, 0, sizeof(fa)); memset(fb, 0, sizeof(fb));
	for (int i = 0; i <= n; ++i) fa[i] = b[n - i];
	fb[0] = 1;
	for (int i = 1; i <= n; ++i) fb[i] = fb[i - 1] * h % P;
	work(fa, fb);
	memset(fa, 0, sizeof(fa)); memset(fb, 0, sizeof(fb));
	sum = 1;
	for (int i = 0; i < n; ++i, sum = sum * q % P)
		fa[i] = F[i + n] * sum % P * ksm(jc[i], P - 2) % P;
	sum2 = p; sum = H;
	for (int i = 1; i <= n; ++i, sum = sum * H % P, sum2 = sum2 * p % P)
		fb[i] = sum * sum2 % P * ksm(jc[i - 1], P - 2) % P;
	work(fa, fb);
	for (int i = 1; i < 2 * n; ++i)
		ans = (ans + F[i] * jc[i - 1] % P) % P;
	printf("%lld", ans);
	return 0;
}

版权声明:本文为LeeCongWei原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。