3745. Problem A
题目大意
给定一棵树,每条边初始颜色为0。有三个操作:
- 将某条路径上所有边反色;
- 将有且只有一个点在某条路径上的边反色;
- 求一条路径上的黑边个数。
总结
容易想到重链剖分。
我们用线段树维护 s u m 0 sum_0sum0 表示区间内白边个数, s u m 1 sum_1sum1 同理,v a l valval 表示连向该点的所有轻边是否反转。
先看操作1:在链中跳的时候显然是把每个点的 s u m 0 , s u m 1 sum_0,sum_1sum0,sum1 翻转;在链之间跳时给跳之前链头标记取反,表示该链头到其父亲的边取反。
再看操作2:在链中跳的时候就是取反 v a l valval ;在链之间跳时(包括未开始跳时)将其连向重儿子的重边取反;在最后两个点都在同一条重链上时,设上面的为 x xx ,若 x xx 不为链顶,则需要将 x xx 连向父亲的边取反。
操作3就较为简单:在链中跳时查询区间和,对于链与链之间的边,考虑将两个端点求出,将其 v a l valval 值与下面的点的标记异或,若为1,则为黑。
(好久没有打过这么长的代码了)
Code
#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 1e5 + 5;
struct node {
int to, nxt;
}e[N << 1];
int n, cnt, tot;
int bz[N], size[N], fa[N], son[N], top[N], head[N], dfn[N], dep[N];
int sum[N << 2][2], val[N << 2], tag[N << 2];
template <typename T>inline void read(T &ss) {
ss = 0;
char ch = getchar();
for (; !isdigit(ch); ch = getchar());
for (; isdigit(ch); ch = getchar()) ss = (ss << 1) + (ss << 3) + (ch ^ 48);
return;
}
void pushdown(int x) {
int p = x << 1, q = x << 1 | 1;
if (tag[x]) {
swap(sum[p][0], sum[p][1]);
swap(sum[q][0], sum[q][1]);
tag[p] ^= 1, tag[q] ^= 1;
tag[x] = 0;
}
if (val[x])
val[x] = 0, val[p] ^= 1, val[q] ^= 1;
return;
}
void update(int x) {
sum[x][0] = sum[x << 1][0] + sum[x << 1 | 1][0];
sum[x][1] = sum[x << 1][1] + sum[x << 1 | 1][1];
return;
}
void change(int x, int l, int r, int p) {
++sum[x][0];
if (l == r) return;
int mid = (l + r) >> 1;
if (p <= mid) change(x << 1, l, mid, p);
else change(x << 1 | 1, mid + 1, r, p);
return;
}
void change2(int x, int l, int r, int st, int en, int id) {
if (st > en || !st || !en) return;
if (l == st && r == en) {
if (id == 0) tag[x] ^= 1, swap(sum[x][0], sum[x][1]);
else val[x] ^= 1;
return;
}
pushdown(x);
int mid = (l + r) >> 1;
if (en <= mid) change2(x << 1, l, mid, st, en, id);
else if (st > mid) change2(x << 1 | 1, mid + 1, r, st, en, id);
else change2(x << 1, l, mid, st, mid, id), change2(x << 1 | 1, mid + 1, r, mid + 1, en, id);
update(x);
return;
}
int query(int x, int l, int r, int st, int en) {
if (st > en || !st || !en) return 0;
if (l == st && r == en) return sum[x][1];
pushdown(x);
int mid = (l + r) >> 1;
if (en <= mid) return query(x << 1, l, mid, st, en);
else if (st > mid) return query(x << 1 | 1, mid + 1, r, st, en);
else return query(x << 1, l, mid, st, mid) + query(x << 1 | 1, mid + 1, r, mid + 1, en);
}
int query2(int x, int l, int r, int p) {
if (l == r) return val[x];
pushdown(x);
int mid = (l + r) >> 1;
if (p <= mid) return query2(x << 1, l, mid, p);
else return query2(x << 1 | 1, mid + 1, r, p);
}
void add(int u, int v) {
e[++cnt].to = v;
e[cnt].nxt = head[u];
head[u] = cnt;
return;
}
void dfs(int u) {
size[u] = 1;
for (int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if (v == fa[u]) continue;
fa[v] = u; dep[v] = dep[u] + 1;
dfs(v);
size[u] += size[v];
if (size[son[u]] < size[v]) son[u] = v;
}
return;
}
void dfs2(int u, int t) {
dfn[u] = ++tot;
top[u] = t;
if (u != t) change(1, 1, n, tot);
if (!son[u]) return;
dfs2(son[u], t);
for (int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if (v == fa[u] || v == son[u]) continue;
dfs2(v, v);
}
return;
}
void jump1(int x, int y) {
while (top[x] != top[y]) {
if (dep[top[x]] > dep[top[y]]) swap(x, y);
change2(1, 1, n, dfn[top[y]] + 1, dfn[y], 0);
bz[top[y]] ^= 1;
y = fa[top[y]];
}
if (dep[x] > dep[y]) swap(x, y);
change2(1, 1, n, dfn[x] + 1, dfn[y], 0);
return;
}
void jump2(int x, int y) {
while (top[x] != top[y]) {
if (dep[top[x]] > dep[top[y]]) swap(x, y);
change2(1, 1, n, dfn[top[y]], dfn[y], 1);
change2(1, 1, n, dfn[son[y]], dfn[son[y]], 0);
y = fa[top[y]];
}
if (dep[x] > dep[y]) swap(x, y);
change2(1, 1, n, dfn[x], dfn[y], 1);
change2(1, 1, n, dfn[son[y]], dfn[son[y]], 0);
if (x != top[x]) change2(1, 1, n, dfn[x], dfn[x], 0);
return;
}
int jump3(int x, int y) {
int res = 0;
while (top[x] != top[y]) {
if (dep[top[x]] > dep[top[y]]) swap(x, y);
res += query(1, 1, n, dfn[top[y]] + 1, dfn[y]);
res += query2(1, 1, n, dfn[top[y]]) ^ query2(1, 1, n, dfn[fa[top[y]]]) ^ bz[top[y]];
y = fa[top[y]];
}
if (dep[x] > dep[y]) swap(x, y);
res += query(1, 1, n, dfn[x] + 1, dfn[y]);
return res;
}
int main() {
read(n);
for (int i = 1; i < n; ++i) {
int x, y;
read(x), read(y);
add(x, y), add(y, x);
}
dep[1] = 1;
dfs(1);
dfs2(1, 1);
int Q;
read(Q);
while (Q--) {
int x, y, z;
read(z), read(x), read(y);
if (z == 1) jump1(x, y);
if (z == 2) jump2(x, y);
if (z == 3) printf("%d\n", jump3(x, y));
}
return 0;
}
版权声明:本文为LeeCongWei原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。