树状数组

简介

树状数组是一种支持 单点修改 和 区间查询 的,代码量小的数据结构。

普通树状数组维护的信息及运算要满足 结合律 且 可差分,如加法(和)、乘法(积)、异或等。

事实上,树状数组能解决的问题是线段树能解决的问题的子集:树状数组能做的,线段树一定能做;线段树能做的,树状数组不一定可以。然而,树状数组的代码要远比线段树短,时间效率常数也更小。

有时,在差分数组和辅助数组的帮助下,树状数组还可解决更强的 区间加单点值 和 区间加区间和 问题。

树状数组能快速求解信息的原因:我们总能将一段前缀 $[1,n]$ 拆成 不多于 $log\ n$ 段区间,使得这 $log\ n$ 段区间的信息是已知的。

于是,我们只需合并这 $log\ n$ 段区间的信息,就可以得到答案。相比于原来直接合并 $n$ 个信息,效率有了很大的提高。

不难发现信息必须满足结合律,否则就不能像上面这样合并了。

1

管辖区间

树状数组中,规定 $c[x]$ 管辖的区间长度为 $2^k$ ,其中:

  • 设二进制最低位为第 $0$ 位,则 $k$ 恰好为二进制表示中,最低位的 $1$ 所在的二进制位数。
  • $2^k$( $c[x]$ 的管辖区间长度)恰好为 $x$ 二进制表示中,最低位的 $1$ 以及后面所有 $0$ 组成的数。

举个例子,$c_{88}$ 管辖的是哪个区间?
因为 $88_{(10)} = 01011000_{(2)}$,其二进制最低位的 $1$ 以及后面的 $0$ 组成的二进制是 $1000$ ,即 $8$ ,所以 $c_{88}$ 管辖 $8$ 个 $a$ 数组中的元素。因此,$c_{88}$ 代表 $a[81…88]$ 的区间信息。

我们记:$x$ 二进制最低位 $1$ 以及后面的 $0$ 组成的数为 $lowbit(a)$,那么 $c[x]$ 管辖的区间就是 $[x-lowbit(x)+ 1,x]$。

这里注意:lowbit 指的不是最低位 $1$ 所在的位数 $k$ ,而是这个 $1$ 和后面所有 $0$ 组成的 $2^k$。

实现代码:

1
2
3
4
int lowbit(int x)
{
return x & (-x);
}

区间查询

$c$ 数组是用来储存原始数组 $a$ 某段区间的和的,也就是说,这些区间的信息是已知的,我们的目标就是把查询前缀拆成这些小区间。

举例:计算 $a[4…7]$ 的和。

我们还是从 $c_7$ 开始跳,跳到 $c_6$ 再跳到 $c_4$ 。此时我们发现它管理了 $a[1…4]$ 的和,但是我们不想
要 $a[1…3]$ 这一部分,怎么办呢?很简单,减去 $a[1…3]$ 的和就行了。

那不妨考虑最开始,就将査询 $a[4…7]$ 的和转化为査询 $a[1…7]$ 的和,以及査询 $a[1…3]$ 的和,最终将两个结果作差。

2

其实任何一个区间査询都可以这么做:査询 $a[l…r]$ 的和,就是 $a[1…r]$ 的和减去 $a[1..l-1]$ 的和,从而把区间问题转化为前缀问题,更方便处理。
我们可以写出査询 $a[1…x]$ 的过程:

  • 从 $c[x]$ 开始往前跳,有 $c[x]$ 管辖 $a[x-lowbit(x)+ 1…x]$。
  • 令$x←x-lowbit(x)$,如果 $x=0$ 说明已经跳到尽头了,终止循环,否则回到第一步
  • 将跳到的c合并。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
int getsum(int x) // 整体前缀和
{
ll tot = 0;
while (x > 0)
{
tot += tr[x];
x -= lowbit(x);
}
return tot;
}

int query(int l, int r) // 区间查询
{
return getsum(r) - getsum(l - 1);
}

单点修改

设 $n$ 表示 $a$ 的大小,不难写出单点修改 $a[x]$ 的过程:

  • 初始令 $x’=x$ 。
  • 修改 $c[x’]$。
  • 令 $x’ ← x’+ lowbit(x’)$,如果 $x’>n$ 说明已经跳到尽头了,终止循环,否则回到第二步。

区间信息和单点修改的种类,共同决定 $c[x’]$ 的修改方式。下面给几个例子:

  • 若 $c[x’]$ 维护区间和,修改种类是将 $a[x]$ 加上 $p$,则修改方式则是将所有 $c[x’]$ 也加上 $p$。
  • 若 $c[x’]$ 维护区间积,修改种类是将 $a[x]$ 乘上 $p$,则修改方式则是将所有 $c[x’]$ 也乘上 $p$。

实现代码:

1
2
3
4
5
6
7
8
void update(int x, int k) // 单点修改
{
while (x <= n)
{
tr[x] += k;
x += lowbit(x);
}
}

建树

也就是根据最开始给出的序列,将树状数组建出来(c全部预处理好)

一般可以直接转化为 $n$ 次单点修改,时间复杂度 $O(n\ log\ n)$。

$O(n)$建树

以维护区间和为例。

方法一:

每一个节点的值是由所有与自己直接相连的儿子的值求和得到的。因此可以倒着考虑贡献,即每次确定完儿子的值后,用自己的值更新自己的直接父亲。

1
2
3
4
5
6
7
void init() {
for (int i = 1; i <= n; ++i) {
t[i] += a[i];
int j = i + lowbit(i);
if (j <= n) t[j] += t[i];
}
}

方法二:

前面讲到 $c_i$ 表示的区间是 $[i-lowbit(i)+1,i]$,那么我们可以先预处理一个 $sum$ 前缀和数组,再计算 $c$ 数组。

1
2
3
4
5
void init() {
for (int i = 1; i <= n; ++i) {
t[i] = sum[i] - sum[i - lowbit(i)];
}
}

模板例题

单点修改与区间查询

【模板】树状数组 1 - 洛谷

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#include <bits/stdc++.h>
using namespace std;
#define endl '\n'
typedef long long ll;

ll n, m;
ll tr[500005] = {0};

void debug()
{
return;
}

int lowbit(int x)
{
return x & (-x);
}

void update(int x, int k) // 单点修改
{
while (x <= n)
{
tr[x] += k;
x += lowbit(x);
}
}

int getsum(int x) // 整体前缀和
{
ll tot = 0;
while (x > 0)
{
tot += tr[x];
x -= lowbit(x);
}
return tot;
}

int query(int l, int r) // 区间查询
{
return getsum(r) - getsum(l - 1);
}

void solve()
{
cin >> n >> m;
int temp;

// 建树状数组
for (int i = 1; i <= n; i++)
{
cin >> temp;
update(i, temp);
}

// 接下来m次操作
int op;
int x, y;
for (int i = 1; i <= m; i++)
{
cin >> op >> x >> y;
if (op == 1)
{
update(x, y);
}
else
{
cout << query(x, y) << endl;
}
}
return;
}

int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);

int t;
t = 1;
// cin >> t;
while (t--)
{
solve();
}
return 0;
}

区间修改与单点查询

【模板】树状数组 2 - 洛谷

我们用树状数组保存差分即可实现区间修改。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#include <bits/stdc++.h>
using namespace std;
#define endl '\n'
typedef long long ll;

ll n;
ll a[500005] = {0};
ll tr[500005] = {0};

void debug()
{
for (int i = 1; i <= n; i++)
{
cout << tr[i] << " ";
}
cout << endl;
return;
}

int lowbit(int x)
{
return x & (-x);
}

void update(int x, int k) // 树状数组保存差分信息
{
while (x <= n)
{
tr[x] += k;
x += lowbit(x);
}
}

int getsum(int x) // 整体前缀和,差分和
{
ll tot = 0;
while (x > 0)
{
tot += tr[x];
x -= lowbit(x);
}
return tot;
}

void solve()
{
ll m;
cin >> n >> m;
ll temp;
for (int i = 1; i <= n; i++)
{
cin >> a[i];
}
int op;
ll x, y;
ll k;
for (int i = 1; i <= m; i++)
{
cin >> op;
if (op == 1)
{
cin >> x >> y >> k;
update(x, k);
update(y + 1, -k);
}
else
{
cin >> x;
cout << getsum(x) + a[x] << endl;
}
}
return;
}

int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);

int t;
t = 1;
// cin >> t;
while (t--)
{
solve();
}
return 0;
}

另一个写法是树状数组存相邻个两数之间的差值,在此不做演示。

二维树状数组

[JSOI2009] 计数问题 - 洛谷

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#include <bits/stdc++.h>
using namespace std;
#define endl '\n'
typedef long long ll;
int n, m;
ll a[305][305] = {0};
ll tr[305][305][105] = {0};

void debug()
{
return;
}

int lowbit(int x)
{
return x & (-x);
}

void update(int x, int y, int c, int k)
{
for (int i = x; i <= n; i += lowbit(i))
{
for (int j = y; j <= m; j += lowbit(j))
{
tr[i][j][c] += k;
}
}
}

int getsum(int x, int y, int c)
{
int tot = 0;
for (int i = x; i > 0; i -= lowbit(i))
{
for (int j = y; j > 0; j -= lowbit(j))
{
tot += tr[i][j][c];
}
}
return tot;
}

int query(int x0, int y0, int x2, int y2, int c)
{
return getsum(x2, y2, c) - getsum(x0 - 1, y2, c) - getsum(x2, y0 - 1, c) + getsum(x0 - 1, y0 - 1, c);
}

void solve()
{
cin >> n >> m;

// 建树状数组
for (int i = 1; i <= n; i++)
{
for (int j = 1; j <= m; j++)
{
cin >> a[i][j];
update(i, j, a[i][j], 1);
}
}

int q;
cin >> q;
int x, y, c;
int op;
for (int i = 1; i <= q; i++)
{
cin >> op;
if (op == 1)
{
cin >> x >> y >> c;
update(x, y, a[x][y], -1);
a[x][y] = c;
update(x, y, c, 1);
}
else
{
int x1, y1, x2, y2;
cin >> x1 >> x2 >> y1 >> y2 >> c;
cout << query(x1, y1, x2, y2, c) << endl;
}
}
return;
}

int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);

int t;
t = 1;
// cin >> t;
while (t--)
{
solve();
}
return 0;
}

二维树状数组差分

上帝造题的七分钟 - 洛谷

由二维差分的知识可知,我们需要维护四个数组。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#include <bits/stdc++.h>
using namespace std;
#define endl '\n'
typedef long long ll;
int n, m;
void debug()
{
return;
}
int lowbit(int x)
{
return x & (-x);
}
struct P
{
int tr[2050][2050] = {0};
void update(int x, int y, int k)
{
while (x <= n)
{
int a = y;
while (a <= m)
{
tr[x][a] += k;
a += lowbit(a);
}
x += lowbit(x);
}
}

int getsum(int x, int y)
{
int ans = 0;
while (x >= 1)
{
int a = y;
while (a >= 1)
{
ans += tr[x][a];
a -= lowbit(a);
}
x -= lowbit(x);
}
return ans;
}
} my1, my2, my3, my4; // 分别维护tr[i][j],tr[i][j]*i,tr[i][j]*j,tr[i][j]*i*j;

void updateall(int x, int y, int k)
{
my1.update(x, y, k);
my2.update(x, y, k * x);
my3.update(x, y, k * y);
my4.update(x, y, k * x * y);
}

int getsumall(int x, int y)
{
int ans = 0;
ans += my1.getsum(x, y) * (x * y + x + y + 1);
ans -= my2.getsum(x, y) * (y + 1);
ans -= my3.getsum(x, y) * (x + 1);
ans += my4.getsum(x, y);
return ans;
}

void solve()
{
char op;
cin >> op >> n >> m;
int a, b, c, d, k;
while (cin >> op)
{
if (op == 'L')
{
cin >> a >> b >> c >> d >> k;
updateall(a, b, k);
updateall(a, d + 1, -k);
updateall(c + 1, b, -k);
updateall(c + 1, d + 1, k);
}
else if (op == 'k')
{
cin >> a >> b >> c >> d;
cout << getsumall(c, d) - getsumall(a - 1, d) - getsumall(c, b - 1) + getsumall(a - 1, b - 1) << endl;
}
}
return;
}

int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);

int t;
t = 1;
// cin >> t;
while (t--)
{
solve();
}
return 0;
}

树状数组求逆序对

逆序对 - 洛谷

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#include <bits/stdc++.h>
using namespace std;
#define endl '\n'
typedef long long ll;

int n;
int tr[500005] = {0};
struct P
{
int num;
int id;
} a[500005];

bool cmp(P a1, P a2)
{
if (a1.num == a2.num)
{
return a1.id > a2.id;
}
return a1.num > a2.num;
}

int lowbit(int x)
{
return x & (-x);
}

void update(int x, int k)
{
while (x <= n)
{
tr[x] += k;
x += lowbit(x);
}
}

int getsum(int x)
{
int tot = 0;
while (x > 0)
{
tot += tr[x];
x -= lowbit(x);
}
return tot;
}

void solve()
{
cin >> n;
ll ans = 0;
for (int i = 1; i <= n; i++)
{
cin >> a[i].num; // 当前数字
a[i].id = i; // 数字对应的下标
}
// 实现降序排列,数值大的在前面
sort(a + 1, a + n + 1, cmp);
for (int i = 1; i <= n; i++)
{
update(a[i].id, 1);
// 对于这个数而言,比他大的数都已经在数组内,只需查询下表比他小的即可
ans += getsum(a[i].id - 1);
}
cout << ans;
return;
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);

int t;
t = 1;
// cin >> t;
while (t--)
{
solve();
}
return 0;
}

树状数组
https://serendipity565.github.io/posts/977cdf95ad38/
作者
Serendipity
发布于
2024年7月24日
许可协议
BY-SERENDIPITY565