【题解】3745. Problem A

3745. Problem A

题目大意

给定一棵树,每条边初始颜色为0。有三个操作:

  1. 将某条路径上所有边反色;
  2. 将有且只有一个点在某条路径上的边反色;
  3. 求一条路径上的黑边个数。

总结

容易想到重链剖分。

我们用线段树维护 s u m 0 sum_0sum0 表示区间内白边个数, s u m 1 sum_1sum1 同理,v a l valval 表示连向该点的所有轻边是否反转。

先看操作1:在链中跳的时候显然是把每个点的 s u m 0 , s u m 1 sum_0,sum_1sum0,sum1 翻转;在链之间跳时给跳之前链头标记取反,表示该链头到其父亲的边取反。

再看操作2:在链中跳的时候就是取反 v a l valval ;在链之间跳时(包括未开始跳时)将其连向重儿子的重边取反;在最后两个点都在同一条重链上时,设上面的为 x xx ,若 x xx 不为链顶,则需要将 x xx 连向父亲的边取反。

操作3就较为简单:在链中跳时查询区间和,对于链与链之间的边,考虑将两个端点求出,将其 v a l valval 值与下面的点的标记异或,若为1,则为黑。

(好久没有打过这么长的代码了)

Code

#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 1e5 + 5;
struct node {
	int to, nxt;
}e[N << 1];
int n, cnt, tot;
int bz[N], size[N], fa[N], son[N], top[N], head[N], dfn[N], dep[N];
int sum[N << 2][2], val[N << 2], tag[N << 2];
template <typename T>inline void read(T &ss) {
	ss = 0;
	char ch = getchar();
	for (; !isdigit(ch); ch = getchar());
	for (; isdigit(ch); ch = getchar()) ss = (ss << 1) + (ss << 3) + (ch ^ 48);
	return;
}
void pushdown(int x) {
	int p = x << 1, q = x << 1 | 1;
	if (tag[x]) {
		swap(sum[p][0], sum[p][1]);
		swap(sum[q][0], sum[q][1]);
		tag[p] ^= 1, tag[q] ^= 1;
		tag[x] = 0;
	}
	if (val[x])
		val[x] = 0, val[p] ^= 1, val[q] ^= 1;
	return;
}
void update(int x) {
	sum[x][0] = sum[x << 1][0] + sum[x << 1 | 1][0];
	sum[x][1] = sum[x << 1][1] + sum[x << 1 | 1][1];
	return;
}
void change(int x, int l, int r, int p) {
	++sum[x][0];
	if (l == r) return;
	int mid = (l + r) >> 1;
	if (p <= mid) change(x << 1, l, mid, p);
	else change(x << 1 | 1, mid + 1, r, p);
	return;
}
void change2(int x, int l, int r, int st, int en, int id) {
	if (st > en || !st || !en) return;
	if (l == st && r == en) {
		if (id == 0) tag[x] ^= 1, swap(sum[x][0], sum[x][1]);
		else val[x] ^= 1;
		return;
	}
	pushdown(x);
	int mid = (l + r) >> 1;
	if (en <= mid) change2(x << 1, l, mid, st, en, id);
	else if (st > mid) change2(x << 1 | 1, mid + 1, r, st, en, id);
	else change2(x << 1, l, mid, st, mid, id), change2(x << 1 | 1, mid + 1, r, mid + 1, en, id);
	update(x);
	return;
}
int query(int x, int l, int r, int st, int en) {
	if (st > en || !st || !en) return 0;
	if (l == st && r == en) return sum[x][1];
	pushdown(x);
	int mid = (l + r) >> 1;
	if (en <= mid) return query(x << 1, l, mid, st, en);
	else if (st > mid) return query(x << 1 | 1, mid + 1, r, st, en);
	else return query(x << 1, l, mid, st, mid) + query(x << 1 | 1, mid + 1, r, mid + 1, en);
}
int query2(int x, int l, int r, int p) {
	if (l == r) return val[x];
	pushdown(x);
	int mid = (l + r) >> 1;
	if (p <= mid) return query2(x << 1, l, mid, p);
	else return query2(x << 1 | 1, mid + 1, r, p);
}
void add(int u, int v) {
	e[++cnt].to = v;
	e[cnt].nxt = head[u];
	head[u] = cnt;
	return;
}
void dfs(int u) {
	size[u] = 1;
	for (int i = head[u]; i; i = e[i].nxt) {
		int v = e[i].to;
		if (v == fa[u]) continue;
		fa[v] = u; dep[v] = dep[u] + 1;
		dfs(v);
		size[u] += size[v];
		if (size[son[u]] < size[v]) son[u] = v;
	}
	return;
}
void dfs2(int u, int t) {
	dfn[u] = ++tot;
	top[u] = t;
	if (u != t) change(1, 1, n, tot);
	if (!son[u]) return;
	dfs2(son[u], t);
	for (int i = head[u]; i; i = e[i].nxt) {
		int v = e[i].to;
		if (v == fa[u] || v == son[u]) continue;
		dfs2(v, v);
	}
	return;
}
void jump1(int x, int y) {
	while (top[x] != top[y]) {
		if (dep[top[x]] > dep[top[y]]) swap(x, y);
		change2(1, 1, n, dfn[top[y]] + 1, dfn[y], 0);
		bz[top[y]] ^= 1;
		y = fa[top[y]];
	}
	if (dep[x] > dep[y]) swap(x, y);
	change2(1, 1, n, dfn[x] + 1, dfn[y], 0);
	return;
}
void jump2(int x, int y) {
	while (top[x] != top[y]) {
		if (dep[top[x]] > dep[top[y]]) swap(x, y);
		change2(1, 1, n, dfn[top[y]], dfn[y], 1);
		change2(1, 1, n, dfn[son[y]], dfn[son[y]], 0);
		y = fa[top[y]];
	}
	if (dep[x] > dep[y]) swap(x, y);
	change2(1, 1, n, dfn[x], dfn[y], 1);
	change2(1, 1, n, dfn[son[y]], dfn[son[y]], 0);
	if (x != top[x]) change2(1, 1, n, dfn[x], dfn[x], 0);
	return;
}
int jump3(int x, int y) {
	int res = 0;
	while (top[x] != top[y]) {
		if (dep[top[x]] > dep[top[y]]) swap(x, y);
		res += query(1, 1, n, dfn[top[y]] + 1, dfn[y]);
		res += query2(1, 1, n, dfn[top[y]]) ^ query2(1, 1, n, dfn[fa[top[y]]]) ^ bz[top[y]]; 
		y = fa[top[y]];
	}
	if (dep[x] > dep[y]) swap(x, y);
	res += query(1, 1, n, dfn[x] + 1, dfn[y]);
	return res;
} 
int main() {
	read(n);
	for (int i = 1; i < n; ++i) {
		int x, y;
		read(x), read(y);
		add(x, y), add(y, x);
	}
	dep[1] = 1;
	dfs(1);
	dfs2(1, 1); 
	int Q;
	read(Q);
	while (Q--) {
		int x, y, z;
		read(z), read(x), read(y);
		if (z == 1) jump1(x, y);
		if (z == 2) jump2(x, y);
		if (z == 3) printf("%d\n", jump3(x, y));
	}
	return 0;
} 

版权声明:本文为LeeCongWei原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。