AtCoder Beginner Contest 276 G - Count Sequences 差分

题意:

问你有多少序列 A = ( a 1 , a 2 , . . . , a n ) A=(a_1,a_2,...,a_n)A=(a1,a2,...,an)
满足:

  • a [ i ] % 3 ≠ a [ i − 1 ] % 3 a[i]\%3 \neq a[i-1]\%3a[i]%3=a[i1]%3
  • a [ i − 1 ] ≤ a [ i ] ≤ M a[i-1] \leq a[i] \le Ma[i1]a[i]M

前言:

乍一看,好像要dp,可能可以用数据结构瞎搞,但是1 e 7 1e71e7的数据,只能想其他方法

思路:

step0 转换为差分数组

B = ( b 1 , b 2 , . . . , b n ) B = (b_1,b_2,...,b_n)B=(b1,b2,...,bn)
b 1 = a 1 b_1 = a_1b1=a1
b i = a [ i ] − a [ i − 1 ] ( i ≥ 2 ) b_i = a[i] - a[i-1](i \ge 2)bi=a[i]a[i1](i2)

  • b i ≥ 0 b_i \ge 0bi0
  • b [ i ] % 3 b[i]\%3b[i]%3, 有b 2 , b 3 , . . . , b n ∈ { 1 , 2 } b2,b3,...,b_n \in \{1,2\}b2,b3,...,bn{1,2}
  • ∑ b i ≤ M \sum b_i \le MbiM

那么问题就转换为了 B BB有多少种方案。

step1

b [ i ] % 3 b[i]\%3b[i]%3, 有b 2 , b 3 , . . . , b n ∈ { 1 , 2 } b2,b3,...,b_n \in \{1,2\}b2,b3,...,bn{1,2}
b i = { 1 ∣ 2 } + k i ∗ 3 ( i ≥ 2 ) b_i = \{1|2\} + k_i*3 (i \ge 2)bi={1∣2}+ki3(i2)
b 1 = { 0 ∣ 1 ∣ 2 } + k 1 ∗ 3 b_1 = \{0|1|2\} + k_1*3b1={0∣1∣2}+k13
那么我们先计算出,每个位置的可能结果,之后再乘上对于3的分配

step2

用dp去分配1,2

  • 通过枚举 b 1 % 3 b_1\%3b1%3和余数
long dp1[] = new long[Math.max(n*3,m+1)];
        for(int o = 0; o < 3;o ++ ) for(int p = 0; p < 3; p ++ ){
            for(int i = 0; i <= n - 1; i ++ ){
                int nxt = i + o + p + (n-1);
                dp1[nxt] += C(i,n-1);
                dp1[nxt] %= mod;
            }
        }

step3

计算 k i k_iki:的分配方法
可以转换为经典的分小球问题

x个小球,放入n个盒子(盒子中允许有空球)
ans = C x + n − 1 n − 1 C_{x+n-1}^{n-1}Cx+n1n1
小球数为0... x 0...x0...x
ans = C n − 1 n − 1 C_{n-1}^{n-1}Cn1n1 + C n n − 1 C_{n}^{n-1}Cnn1 + C n + 1 n − 1 C_{n+1}^{n-1}Cn+1n1+…+C x + n − 1 n − 1 C_{x+n-1}^{n-1}Cx+n1n1
ans = C n + x n = C n + x x C_{n+x}^{n}= C_{n+x}^{x}Cn+xn=Cn+xx

for(int i = 0; i*3 <= M; i ++) {
	dp2[i*3] = C(i,n+i);
}

What‘s more

组合数学常用公式总结-更新中
这位大佬的MINT,的一些相关的函数

AC:

package com.hgs.atcoder.abc.contest276.g;

/**
 * @author youtsuha
 * @version 1.0
 * Create by 2022/11/5 23:51
 */
import java.util.*;
import java.io.*;
public class Main {
    static FastScanner cin;
    static PrintWriter cout;
    static long mod = 998244353;
    static long finv[];
    static long fact[];
    static long qpow(long a, long k, long p){
        long res = 1;
        while(k > 0){
            if(k%2==1) res = res*a%p;
            a = a*a%p;
            k >>= 1;
        }
        return res;
    }
    private static void init()throws IOException {
        cin = new FastScanner(System.in);
        cout = new PrintWriter(System.out);
        int mx = (int) (3e7+10);
        fact = new long[mx];
        finv = new long[mx];
        fact[0] = 1;
        for(int i = 1; i < mx; i ++ ) fact[i] = fact[i-1]*i%mod;
        finv[mx-1] = qpow(fact[mx-1], mod-2,mod);
        for(int i = mx - 2; i >= 0; i -- ) finv[i] = finv[i+1]*(i+1)%mod;
    }
    static long C(int a, int b){
        return fact[b]*finv[a]%mod*finv[b-a]%mod;
    }
    private static void close(){
        cout.close();
    }
    private static void sol()throws IOException {
        int n = cin.nextInt(), m = cin.nextInt();
        //dp1
        long dp1[] = new long[m+1];
        for(int o = 0; o < 3;o ++ ) for(int p = 0; p < 3; p ++ ){
            for(int i = 0; i <= n - 1; i ++ ){
                int nxt = i + o + p + (n-1);
                if(nxt > m) continue;
                dp1[nxt] += C(i,n-1);
                dp1[nxt] %= mod;
            }
        }
        //dp2
        long dp2[] = new long[m+1];
        for(int i = 0; i*3<= m; i ++ ) {
            dp2[i*3] = C(i,n+i);
        }
        //cal
        long ans = 0;
        for(int i = 0; i <= m; i ++ ) {
            ans += dp1[i]*dp2[m-i]%mod;
            ans %= mod;
        }
        cout.println(ans);
    }
    public static void main(String[] args) throws IOException {
        init();
        sol();
        close();
    }
}
class FastScanner {
    BufferedReader br;
    StringTokenizer st = new StringTokenizer("");

    public FastScanner(InputStream s) {
        br = new BufferedReader(new InputStreamReader(s));
    }

    public FastScanner(String s) throws FileNotFoundException {
        br = new BufferedReader(new FileReader(new File(s)));
    }

    public String next() throws IOException {
        while (!st.hasMoreTokens()){
            try {
                st = new StringTokenizer(br.readLine());
            } catch (IOException e) { e.printStackTrace(); }
        }
        return st.nextToken();
    }

    public int nextInt() throws IOException {
        return Integer.parseInt(next());
    }

    public long nextLong() throws IOException {
        return Long.parseLong(next());
    }

    public double nextDouble() throws IOException {
        return Double.parseDouble(next());
    }
}


版权声明:本文为qq_45377553原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。