题目链接:https://codeforces.com/problemset/problem/772/C
13.1 题意
给定一个整数 m ( 1 ≤ m ≤ 2 ⋅ 1 0 5 ) m(1 \le m \le 2 \cdot 10^5)m(1≤m≤2⋅105),和 n ( 0 ≤ n < m ) n(0 \le n < m)n(0≤n<m) 个整数。
现在需要构造出一种满足以下要求的序列:
- 每个元素都在 [ 0 , m − 1 ] [0,m-1][0,m−1] 之间取值;
- 每一种前缀积对 m mm 取余之后都不同;
- 每一种前缀积都不可以在 n nn 个整数中出现过;
- 序列的长度最长。
13.2 解题过程
首先我们可以构造一个合法的前缀积序列 p r e f i x i prefix_iprefixi,之后通过扩展欧几里得求出答案序列 a aa。
求解答案序列时,我们需要反复求解下面这个同余方程:
p r e f i x i − 1 ⋅ a i ≡ p r e f i x i m o d m prefix_{i-1} \cdot a_i \equiv prefix_i \mod mprefixi−1⋅ai≡prefiximodm
这个同余方程有解的充要条件为 gcd ( p r e f i x i − 1 , m ) ∣ p r e f i x i \gcd(prefix_{i-1}, m)|prefix_igcd(prefixi−1,m)∣prefixi。
所以问题转化为求最长的一串序列 p r e f i x prefixprefix,使得相邻的 p r e f i x prefixprefix 值均可满足整除关系。
显然,gcd ( x , m ) ∣ m \gcd(x, m)|mgcd(x,m)∣m。
因此我们可以枚举 m mm 的因子 i ii 和 j jj,假设 i < j i < ji<j,则如果 i ∣ j i | ji∣j,建立有向边 < i , j > <i,j><i,j>,表明 j jj 可以接续在 g c d ( x , m ) = i gcd(x,m)=igcd(x,m)=i 之后。
任何数均可以和 0 00 建立有向边。
很显然,我们得到了一张有向无环图。
当 gcd ( x 1 , m ) = gcd ( x 2 , m ) = k \gcd(x_1,m)=\gcd(x_2,m)=kgcd(x1,m)=gcd(x2,m)=k 时,若 x 1 ⋅ y ≡ z m o d m x_1 \cdot y \equiv z \mod mx1⋅y≡zmodm 有解,则 x 2 ⋅ y ≡ z m o d m x_2 \cdot y \equiv z \mod mx2⋅y≡zmodm 一定有解。因此对于同一个 k kk 下的任意两个 x xx,代入到方程
p r e f i x i − 1 ⋅ a i ≡ p r e f i x i m o d m prefix_{i-1} \cdot a_i \equiv prefix_i \mod mprefixi−1⋅ai≡prefiximodm
中,一定有解。这表明我们可以将 gcd ( x , m ) \gcd(x,m)gcd(x,m) 相同的点缩成一个点来看待,这些点如果被选中,就可以全部加入到序列中。
设满足 gcd ( x , m ) = k \gcd(x,m)=kgcd(x,m)=k 的 x xx 的个数为 c n t k cnt_kcntk,d p [ u ] dp[u]dp[u] 表示以 u uu 为 p r e f i x prefixprefix 的最后一个元素时,在其之前最多有多少个数字。在计算 c n t k cnt_kcntk 时,被禁止的节点直接跳过,不参与到计算中。
则对于有向边 < u , v > <u,v><u,v>,有转移 d p [ v ] = max ( d p [ v ] , d p [ u ] + c n t u ) dp[v] = \max(dp[v], dp[u] + cnt_u)dp[v]=max(dp[v],dp[u]+cntu),通过拓扑排序进行 dp 即可(类似于求最长路),在 dp 时需要记录每个点的前驱,即从哪个点转移过来的。
之后从 0 00 向前回溯前驱,加入到 p r e f i x prefixprefix 序列中,并进行逆转。
最后根据扩展欧几里得算法求解同余方程,即可得到原序列 a aa。
注意要特判 0 00 是否被禁止,若没有被禁止,则必须将其加入到 a aa 序列的结尾。
时间复杂度:O ( m log m ) O(m \log m)O(mlogm)。
13.3 错误点
- 没有特判 0 00 是否被禁止,而是直接将 0 00 加入到答案序列的末尾。
- 答案会爆
int,因此必须使用long long来存储。
13.4 代码
int cnt, head[maxn], col[maxn], indeg[maxn], topo[maxn], w[maxn];
int tot, from[maxn];
bool ban[maxn];
vector<int> ve[maxn], factor;
struct edge
{
int v, nxt;
} Edge[maxn];
void init()
{
for(int i = 0; i <= 200000; i++) head[i] = -1;
cnt = 0;
}
void addedge(int u, int v)
{
Edge[cnt].v = v;
Edge[cnt].nxt = head[u];
head[u] = cnt++;
}
ll gcd(ll a, ll b)
{
return b == 0 ? a : gcd(b, a % b);
}
ll ex_gcd(ll a, ll b, ll &x, ll &y)
{
if (b == 0)
{
x = 1;
y = 0;
return a;
}
else
{
ll r = ex_gcd(b, a % b, y, x);
y -= x * (a / b);
return r;
}
}
ll mod(ll number, ll k)
{
if(number < 0)
return (-((-number) % k) + k) % k;
else
return number % k;
}
void toposort()
{
queue<int> que;
for(int i = 0; i <= 200000; i++)
{
if(!indeg[i]) que.push(i);
}
while(!que.empty())
{
int now = que.front();
que.pop();
topo[++tot] = now;
for(int i = head[now]; i != -1; i = Edge[i].nxt)
{
int v = Edge[i].v;
--indeg[v];
if(!indeg[v]) que.push(v);
}
}
}
ll ans[maxn], lst[maxn], dp[maxn];
int main()
{
int n, m, num;
scanf("%d%d", &n, &m);
init();
for(int i = 1; i <= n; i++)
{
scanf("%d", &num);
ban[num] = true;
}
for(int i = 0; i <= m - 1; i++)
{
if(ban[i]) continue;
col[i] = gcd(i, m);
ve[col[i]].pb(i);
}
for(int i = 1; i * i <= m; i++)
{
if(m % i) continue;
factor.pb(i);
if(i * i != m) factor.pb(m / i);
}
sort(factor.begin(), factor.end());
factor.pb(0);
for(int i = 0; i < factor.size(); i++)
{
for(int j = i + 1; j < factor.size(); j++)
{
int a = factor[i], b = factor[j];
if(b % a == 0)
{
addedge(a, b);
indeg[b]++;
}
}
}
toposort();
for(int i = 0; i <= m - 1; i++) w[i] = ve[i].size();
w[0] = ve[m].size();
for(int i = 1; i <= tot; i++)
{
int u = topo[i];
for(int j = head[u]; j != -1; j = Edge[j].nxt)
{
int v = Edge[j].v;
if(dp[u] + w[u] > dp[v])
{
dp[v] = dp[u] + w[u];
from[v] = u;
}
}
}
from[1] = -1;
int pos = 0;
int numm = 0;
if(from[0])
{
while(pos != -1)
{
for(int i = 0; i < ve[pos].size(); i++)
{
lst[++numm] = ve[pos][i];
}
pos = from[pos];
}
}
reverse(lst + 1, lst + numm + 1);
for(int i = 1; i < numm; i++)
{
ll x = lst[i], y = lst[i + 1];
ll l, k;
ll d = gcd(x, 1LL * m);
ll k1 = y / d;
ll k2 = m / d;
ex_gcd(x, 1LL * m, l, k);
l = mod(l * k1, k2);
ans[i + 1] = l;
}
ans[1] = lst[1];
printf("%d\n", numm + (!ban[0]));
for(int i = 1; i <= numm; i++)
{
printf("%lld ", ans[i]);
}
if(!ban[0]) printf("0");
return 0;
}