Hanwan Space

树链剖分入门


树链剖分就是把一棵树拆成几条链来处理,便于线段树进行区间操作。

引入

我们知道区间操作通常用线段树。但是如果要对一棵树上的路径或子树进行操作呢?这就需要把树拆分成链来处理,然后用线段树来维护这些链。

通常说的树链剖分实际上是指重链剖分,就是以树的重链为基础来对树进行拆分,拆分成几个重链。 实际上就是对树进行标号,使得那些路径上的点标号连续。这样就便于线段树进行区间操作了。

首先需要明白的几个姿势

  • 重(zhòng)儿子:一个非叶子节点的儿子中,子节点最多的那个儿子。
  • 重边(不是chóng边):一个非叶子节点和它重儿子之间的连边。
  • 重链:相邻重(zhòng)边连接起来形成的一条链叫重链。
  • 轻儿子、轻边:就是剩下的儿子和边。

需要注意的是,每一条重链的起点其实是轻儿子或根节点。 对于是轻儿子同时也是叶子节点的节点来说,它自身就是一条长度为 1 的重链。

树链剖分

有了上面的前置姿势,下面就要正式介绍树链剖分啦。

对节点进行标号

需要两个 DFS。

第一次 DFS 统计深度和子树大小,第二次 DFS 对树进行正式的标号。

需要拿一些数组存下这些信息,我们这样规定:

  • siz 记录子树大小
  • fa 记录每个节点的父亲节点
  • dep 记录每个节点的深度

以上就是第一遍 DFS 所需记录的信息。

所以第一遍 DFS 很简单,只需要这样写。

cpp
void dfs1(int u) {
	siz[u] = 1, dep[u] = dep[fa[u]] + 1; // 记录子树大小、深度
	for (int i = head[u]; ~i; i = edge[i].next) {
		int v = edge[i].to;
		if (v != fa[u]) {
			fa[v] = u; // 记录父亲节点
			dfs1(v);
			siz[u] += siz[v]; // 统计子树大小
		}
	}
}

重头戏在第二遍 DFS 上。

我们还需要额外的数组来记录信息。

  • pos 记录经过第二遍 DFS 后节点的标号
  • top 记录每个节点所在重链上的顶端节点

需要注意的是根节点的 top 值是它本身,记得初始化。

所以第二遍 DFS 这样写

cpp
void dfs2(int u) {
	int maxn = 0, nxt = -1; // nxt 记录重儿子
	pos[u] = ++num; // 记录节点的新标号
	for (int i = head[u]; ~i; i = edge[i].next) {
		int v = edge[i].to;
		if (v != fa[u] && siz[v] > maxn) // 找出一个节点子节点中子树最大的节点,说明它是重儿子
			maxn = siz[v], nxt = v;
	}
	if (nxt == -1) return; // 如果它没有重儿子就退出(说明已经到了叶子节点)
	top[nxt] = top[u]; // 同一条重链的顶端节点相同
	dfs2(nxt); // 递归优先处理重儿子
	for (int i = head[u]; ~i; i = edge[i].next) {
		int v = edge[i].to;
		if (v != fa[u] && v != nxt) {
			top[v] = v; // 对于非重儿子再进行处理,新重链的起点即它本身
			dfs2(v); // 找出新的重链
		}
	}
}

这样进行两遍 DFS 中,我们已经完成了链剖分的全部操作,成功地把树拆分成了几条链。

显然同一条重链上的节点的标号是连续的,这样便于我们进行区间操作。

路径操作

对于同一条链上的两点,它们的标号在线段树上是连续的,我们可以很轻易地进行区间操作来实现路径操作。

但更多的情况是不在同一条链上,所以需要一些神奇的操作。

就那路径上的节点求和为例吧。

我们先取出深度较深的那个点,然后求出这个点到它所在重链的顶端节点的和,然后再把这个点更新为它所在重链顶端节点的父亲节点,这样就又调到了顶端节点上面的那条重链上。再次取出深度较深的那个点,如此往复,直到两点在同一条重链上,这样直接进行区间操作即可。

写成伪(Python)代码就是

python
def sum(u, v):
    ans = 0
    while top[u] != top[v]:
        if dep[u] > dep[v]:
            u, v = v, u    # 交换变量,C++ 是 std::swap(u, v);
        ans += query_sum(pos[top[v]], pos[v]) # 这是线段树的区间求和操作
        v = fa[top[v]]
    if pos[u] > pos[v]:
        u, v = v, u
    ans += query_sum(pos[u], pos[v])
    return ans

区间更新也是类似的操作。把求和替换成线段树更新即可。

子树操作

对于每一个子树来说,它们的标号也是连续的(因为是 DFS)。

直接进行区间操作即可。对于以 u 为根的子树,它在区间上的右端点标号为 pos[u] + siz[u] - 1

求 LCA(最近公共祖先)

树链剖分是可以在线求 LCA 的,而且实测比倍增跑得快。

只要两个点在同一条重链上,那么它们的 LCA 一定是深度小的那个节点。

那么如果它们不在一条重链上呢?参考上面路径操作的方法。往上跳直到跳到同一条重链上为止。

伪(Python)代码

python
def sum(u, v):
    while top[u] != top[v]:
        if dep[u] > dep[v]:
            u, v = v, u
        v = fa[top[v]]
    if dep[u] > dep[v]:
        return v
    return u

时间复杂度

(由于本人太蒟蒻了不会证,所以以下内容是网上抄的)

性质一

如果边 (u,v)(u,v) 为轻边,那么 size(v)size(u)/2size(v)\leq size(u)/2

性质二

树中任意两个节点之间的路径中轻边的条数不会超过 log2n\log _{2}n ,重路径的数目不会超过 log2n\log _{2}n

根据以上两点性质以及线段树查询和修改的复杂度 O(log2n)O(\log_2n),可以得知总复杂度为 O(log22n)O(\log_2^2n)

例题

模板题

传送门

操作:路径更新与求和,子树更新与求和。

完整代码

cpp
#include <bits/stdc++.h>
 
using namespace std;
 
const int maxN = 1e5 + 3;
const int INF = 0x3f3f3f3f;
 
struct Edge {
	int next, to;
} edge[maxN << 1];
 
struct node {
	int sum, lazy;
} tree[maxN << 2];
 
int dep[maxN], siz[maxN], fa[maxN];
int pos[maxN], top[maxN], head[maxN], w[maxN], val[maxN];
int cnt, num, mod, n, m, root;
 
void add(int from, int to) {
	edge[++cnt] = (Edge) {head[from], to};
	head[from] = cnt;
}
 
void dfs1(int u) {
	siz[u] = 1, dep[u] = dep[fa[u]] + 1;
	for (int i = head[u]; ~i; i = edge[i].next) {
		int v = edge[i].to;
		if (v != fa[u]) {
			fa[v] = u;
			dfs1(v);
			siz[u] += siz[v];
		}
	}
}
 
void dfs2(int u) {
	int maxn = 0, nxt = -1; // nxt 记录重儿子
	pos[u] = ++num;
	for (int i = head[u]; ~i; i = edge[i].next) {
		int v = edge[i].to;
		if (v != fa[u] && siz[v] > maxn)
			maxn = siz[v], nxt = v;
	}
	if (nxt == -1) return;
	top[nxt] = top[u];
	dfs2(nxt); // 递归重儿子
	for (int i = head[u]; ~i; i = edge[i].next) {
		int v = edge[i].to;
		if (v != fa[u] && v != nxt) {
			top[v] = v;
			dfs2(v);
		}
	}
}
 
struct SegmentTree {
	#define ls (o << 1)
	#define rs (o << 1 | 1)
	#define mid ((l + r) >> 1)
	void push_up(int o) {
		tree[o].sum = (tree[ls].sum + tree[rs].sum) % mod;
	}
	void build(int o, int l, int r) {
		if (l == r) {
			tree[o].sum = w[l];
			return;
		}
		build(ls, l, mid);
		build(rs, mid + 1, r);
		push_up(o);
	}
	void push_down(int o, int l, int r) {
		tree[ls].sum = (tree[ls].sum + tree[o].lazy * (mid - l + 1)) % mod;
		tree[rs].sum = (tree[rs].sum + tree[o].lazy * (r - mid)) % mod;
		tree[ls].lazy = (tree[ls].lazy + tree[o].lazy) % mod;
		tree[rs].lazy = (tree[rs].lazy + tree[o].lazy) % mod;
		tree[o].lazy = 0;
	}
	void update(int o, int l, int r, int sl, int sr, int k) {
		if (sl > r || sr < l) return;
		if (sl <= l && sr >= r) {
			tree[o].sum = (tree[o].sum + k * (r - l + 1)) % mod;
			tree[o].lazy = (tree[o].lazy + k) % mod;
			return;
		}
		push_down(o, l, r);
		if (sl <= mid) update(ls, l, mid, sl, sr, k);
		if (sr > mid) update(rs, mid + 1, r, sl, sr, k);
		push_up(o);
	}
	int query(int o, int l, int r, int sl, int sr) {
		if (sl <= l && sr >= r) return tree[o].sum % mod;
		push_down(o, l, r);
		int ans = 0;
		if (sl <= mid) ans += query(ls, l, mid, sl, sr);
		if (sr > mid) ans += query(rs, mid + 1, r, sl, sr);
		return ans % mod;
	}
} T;
 
void update_path(int u, int v, int k) {
	while (top[u] != top[v]) {
		if (dep[top[u]] > dep[top[v]]) swap(u, v);
		T.update(1, 1, num, pos[top[v]], pos[v], k);
		v = fa[top[v]];
	}
	if (pos[u] > pos[v]) swap(u, v);
	T.update(1, 1, num, pos[u], pos[v], k);
}
 
int query_path(int u, int v) {
	int ans = 0;
	while (top[u] != top[v]) {
		if (dep[top[u]] > dep[top[v]]) swap(u, v);
		ans = (ans + T.query(1, 1, num, pos[top[v]], pos[v])) % mod;
		v = fa[top[v]];
	}
	if (pos[u] > pos[v]) swap(u, v);
	ans = (ans + T.query(1, 1, num, pos[u], pos[v])) % mod;
	return ans;
}
 
void update_tree(int u, int k) {
	T.update(1, 1, num, pos[u], pos[u] + siz[u] - 1, k);
}
 
int query_tree(int u) {
	return T.query(1, 1, num, pos[u], pos[u] + siz[u] - 1) % mod;
}
 
int main() {
	memset(head, -1, sizeof(head));
	scanf("%d%d%d%d", &n, &m, &root, &mod);
	for (int i = 1; i <= n; i++) scanf("%d", &val[i]);
	for (int i = 1, a, b; i < n; i++) {
		scanf("%d%d", &a, &b);
		add(a, b);
		add(b, a);
	}
	dfs1(root); top[root] = root; dfs2(root);
	for (int i = 1; i <= n; i++) w[pos[i]] = val[i];
	T.build(1, 1, num);
	for (int i = 1; i <= m; i++) {
		int x, y, k, opt;
		scanf("%d%d", &opt, &x);
		if (opt == 1) {
			scanf("%d%d", &y, &k);
			update_path(x, y, k);
		} else if (opt == 2) {
			scanf("%d", &y);
			printf("%d\n", query_path(x, y));
		} else if (opt == 3) {
			scanf("%d", &k);
			update_tree(x, k);
		} else printf("%d\n", query_tree(x));
	}
	return 0;
}

[ZJOI2008]树的统计

传送门

操作:单点修改,查询路径最大值与和。

完整代码

cpp
#include <bits/stdc++.h>
 
using namespace std;
 
const int maxN = 1e5 + 10;
const int INF = 0x7f7f7f7f;
 
struct Edge {
	int next, to;
} edge[maxN << 1];
 
struct node {
	int max, sum;
} st[maxN << 2];
 
int head[maxN];
int pos[maxN], top[maxN], siz[maxN], dep[maxN], fa[maxN];
int w[maxN];
int n, q, cnt, num = 0;
string opt;
 
void add(int from, int to) {
	edge[++cnt] = (Edge) {head[from], to};
	head[from] = cnt;
}
 
void dfs1(int u) {
	siz[u] = 1, dep[u] = dep[fa[u]] + 1;
	for (int i = head[u]; ~i; i = edge[i].next) {
		int v = edge[i].to;
		if (v != fa[u]) {
			fa[v] = u;
			dfs1(v);
			siz[u] += siz[v];
		}
	}
}
 
void dfs2(int u) {
	int maxn = 0, nxt = -1;
	pos[u] = ++num;
	for (int i = head[u]; ~i; i = edge[i].next) {
		int v = edge[i].to;
		if (v != fa[u] && siz[v] > maxn)
			maxn = siz[v], nxt = v;
	}
	if (nxt == -1) return;
	top[nxt] = top[u];
	dfs2(nxt);
	for (int i = head[u]; ~i; i = edge[i].next) {
		int v = edge[i].to;
		if (v != fa[u] && v != nxt) {
			top[v] = v;
			dfs2(v);
		}
	}
}
 
struct segmentTree {
	#define ls (o << 1)
	#define rs (o << 1 | 1)
	#define mid ((l + r) >> 1)
	void build(int o, int l, int r) {
		if (l == r) {
			st[o].sum = st[o].max = w[l];
			return;
		}
		build(ls, l, mid);
		build(rs, mid + 1, r);
		st[o].max = max(st[ls].max, st[rs].max);
		st[o].sum = st[ls].sum + st[rs].sum;
	}
	void update(int o, int l, int r, int x, int k) {
		if (l == r) {
			st[o].sum = st[o].max = k;
			return;
		}
		if (x <= mid) update(ls, l, mid, x, k);
		if (x > mid) update(rs, mid + 1, r, x, k);
		st[o].max = max(st[ls].max, st[rs].max);
		st[o].sum = st[ls].sum + st[rs].sum;
	}
	int query_sum(int o, int l, int r, int sl, int sr) {
		if (sr < l || r < sl) return 0;
		if (sl <= l && r <= sr) return st[o].sum;
		int sum = 0;
		if (sl <= mid) sum += query_sum(ls, l, mid, sl, sr);
		if (sr > mid) sum += query_sum(rs, mid + 1, r, sl, sr);
		return sum;
	}
	int query_max(int o, int l, int r, int sl, int sr) {
		if (sr < l || r < sl) return 0;
		if (sl <= l && r <= sr) return st[o].max;
		int maxn = -INF;
		if (sl <= mid) maxn = max(maxn, query_max(ls, l, mid, sl, sr));
		if (sr > mid) maxn = max(maxn, query_max(rs, mid + 1, r, sl, sr));
		return maxn;
	}
} T;
 
int find_sum(int a, int b) {
	int ans = 0;
	while (top[a] != top[b]) {
		if (dep[top[a]] > dep[top[b]]) swap(a, b);
		ans += T.query_sum(1, 1, num, pos[top[b]], pos[b]);
		b = fa[top[b]];
	}
	if (pos[a] > pos[b]) swap(a, b);
	ans += T.query_sum(1, 1, num, pos[a], pos[b]);
	return ans;
}
 
int find_max(int a, int b) {
	int ans = -INF;
	while (top[a] != top[b]) {
		if (dep[top[a]] > dep[top[b]]) swap(a, b);
		ans = max(ans, T.query_max(1, 1, num, pos[top[b]], pos[b]));
		b = fa[top[b]];
	}
	if (pos[a] > pos[b]) swap(a, b);
	ans = max(ans, T.query_max(1, 1, num, pos[a], pos[b]));
	return ans;
}
 
int main() {
	memset(head, -1, sizeof(head));
	scanf("%d", &n);
	for (int i = 1, a, b; i < n; i++) {
		scanf("%d%d", &a, &b);
		add(a, b);
		add(b, a);
	}
	dfs1(1); top[1] = 1; dfs2(1);
	for (int i = 1; i <= n; i++) scanf("%d", &w[pos[i]]);
	T.build(1, 1, num);
	scanf("%d", &q);
	for (int i = 1, a, b; i <= q; i++) {
		cin >> opt;
		scanf("%d%d", &a, &b);
		if (opt == "CHANGE") T.update(1, 1, num, pos[a], b);
		else if (opt == "QMAX")
			printf("%d\n", find_max(a, b));
		else if (opt == "QSUM")
			printf("%d\n", find_sum(a, b));
	}
	return 0;
}

  • 树链剖分
© hawa130转载请注明出处
License: CC BY-NC-SA 4.0

上一篇

差分约束简单入门

下一篇

线性基 — 异或和处理利器