问题分析

树状数组和线段树是两种非常常用的数据结构,那么他们分别是用于什么问题呢?这里借用一位力扣大佬的总结:

  1. 数组不变,求区间和:「前缀和」、「树状数组」、「线段树」
  2. 多次修改某个数(单点),求区间和:「树状数组」、「线段树」
  3. 多次修改某个区间,输出最终结果:「差分」
  4. 多次修改某个区间,求区间和:「线段树」、「树状数组」(看修改区间范围大小)
  5. 多次将某个区间变成同一个数,求区间和:「线段树」、「树状数组」(看修改区间范围大小)

这样看来,「线段树」能解决的问题是最多的,那我们是不是无论什么情况都写「线段树」呢?

答案并不是,而且恰好相反,只有在我们遇到的问题,不得不写「线段树」的时候,我们才考虑线段树。

因为「线段树」代码很长,而且常数很大,实际表现不算很好。我们只有在不得不用的时候才考虑「线段树」。

总结一下,我们应该按这样的优先级进行考虑:

  1. 简单求区间和,用「前缀和」
  2. 只要求最终结果,用「差分」
  3. 多次将某个区间变成同一个数,用「线段树」
  4. 其他情况,用「树状数组」

(作者:AC_OIer,链接:https://leetcode-cn.com/problems/range-sum-query-mutable/solution/guan-yu-ge-lei-qu-jian-he-wen-ti-ru-he-x-41hv/,来源:力扣(LeetCode))

树状数组

Go语言模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
type NumArray struct {
tr, nums []int
len int
}

func Constructor(nums []int) NumArray {
N := NumArray{
len: len(nums),
tr: make([]int, N.len+1),
nums: nums,
}
for i := 0; i < N.len; i++ {
add(i+1, N.len, nums[i], N.tr)
}
return N
}

func lowbit(x int) (res int) {
return x & (-x)
}

func add(x, n, v int, tr []int) {
for i := x; i <= n; i += lowbit(i) {
tr[i] += v
}
}

func query(x int, tr []int) (res int) {
for i := x; i > 0; i -= lowbit(i) {
res += tr[i]
}
return
}

线段树

Go语言模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
type Node struct {
l, r, v, add int
}

func Constructor(left, right int) Node {
Nd := Node{
l: left,
r: right,
}
return Nd
}

const N = 20009 //这里改成具体问题里的数据长度
var tr []Node = make([]Node, N*4, N*4)

func pushup(i int) {
tr[i].v = tr[i << 1].v + tr[i << 1 | 1].v
}

func pushdown(i int) {
diff := tr[i].add
tr[i << 1].v += diff
tr[i << 1].add += diff
tr[i << 1 | 1].v += diff
tr[i << 1 | 1].add += diff
tr[i].add = 0
}

func build(i, l, r int) {
tr[i] = Constructor(l, r)
if l != r {
mid := (l + r) >> 1
build(i << 1, l, mid)
build(i << 1 | 1, mid + 1, r)
}
}

func update(i, l, r, v int) {
if l <= tr[i].l && r >= tr[i].r {
tr[i].v += v
tr[i].add += v
} else {
pushdown(i)
mid := (tr[i].l + tr[i].r) >> 1
if l <= mid {
update(i << 1, l, r, v)
}
if r > mid {
update(i << 1 | 1, l, r, v)
}
pushup(i)
}
}

func query(i, l, r int) (res int) {
if l <= tr[i].l && r >= tr[i].r {
return tr[i].v
} else {
pushdown(i)
mid := (tr[i].l + tr[i].r) >> 1
if l <= mid {
res += query(i << 1, l, r)
}
if r > mid {
res += query(i << 1 | 1, l, r)
}
return
}
}