poj3468 A Simple Problem with Integers

本题是一道线段树的经典例题,主要考察线段树的延迟操作,难度不高,但解题思路要清晰,数据结构写法要熟练

题目描述

假设有 n 个数排成一列,且位置序号从 1 开始递增。而且,会对这个排列进行两种操作:

  • 查询某个位置区间内的数字之和
  • 对某个位置区间内的所有数字全都加上给定的数值

要注意的是,数字个数可能很多,且操作也会有很多次。要求输出每次查询操作的结果。

输入有 3 部分组成,第一部分输入两个整数 n 和 q,代表数字个数和操作次数,用空格分开。紧接着,输入 n 个整数为原始数据,用空格作为各自之间的间隔。 最后,不断一行一行地输入操作指令,操作指令的格式为“指令类型+操作对象”:

  • 查询类:Q a b 表示查询区间 [a, b] 内的数字之和
  • 加法类:C a b c 表示对区间 [a, b] 内的所有数字都加上数值 c

示例输入为:

1
2
3
4
5
6
7
10 5
1 2 3 4 5 6 7 8 9 10
Q 4 4
Q 1 10
Q 2 4
C 3 6 3
Q 2 4

对每个查询命令,输出查询的结果,注意结果可能会超出 32 位二进制整数。

对应的示例输出为:

1
2
3
4
4
55
9
15

知识点解析

线段树是处理分段型数据查询的利器,尤其当数据量非常大,并且作用在数据上的操作次数十分频繁时,使用线段树可以极大地缩短运算时间。 数组下标天然就能当作线段树分段依据,即线段数区间 [a, b](a <= b) 表示数组下标 a 到 b 之间的元素。这样,就能把一些需要应用线段树解构,但是数据本身性质不符合分段表示的数据集,转化为合法的线段树形式。 另外,应注意树状数组中的元素内容和树的节点序号没有特别的关系,节点序号仅仅是为了方便遍历时使用。 比如:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
int[] arr = {...}; // 原始数据

class Node{ // 树的节点
    public int left, right;
}

Node[] tree = new Node[arr.length << 2]; // 存储树

int leftChild(int i) { // 左子节点序号
        return i << 1;
}

rightChild(int i) { // 右子节点序号
        return i << 1 | 1;
}

最后,如果数据需要多次对不同区间数据进行修改,为了程序运行地更高效,可以在节点内容中加入延迟标记,无需每次操作都修改区间内所有元素。 比如,对应需要进行加法操作的线段数可以设置延迟标记如下:

1
2
3
4
class Node {
    int left, right;
    int sum, add;
}

解题思路

本题基本上是考察线段树的直接操作,不需要进行任何转换操作。 本题注意设立延迟标记时,要考虑到结果可能超过 int 的取值范围,所以延迟标记的类型应该为 long

代码展示

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;

import static java.lang.System.in;

/**
 *
 */
public class Main {

    public static void main(String[] args) throws IOException {
        BufferedReader reader = new BufferedReader(new InputStreamReader(in));

        String tmp = reader.readLine();
        String[] t = tmp.split(" ");
        int n = Integer.parseInt(t[0]);
        int q = Integer.parseInt(t[1]);
        SegTree segTree = new SegTree();
        segTree.build(1, n, 1);

        int[] nums = new int[n];
        tmp = reader.readLine();
        t = tmp.split(" ");
        for (int i = 0; i < n; i++) {
            nums[i] = Integer.parseInt(t[i]);
            segTree.update(i + 1, i + 1, nums[i], 1);
        }

        for (int i = 0; i < q; i++) {
            tmp = reader.readLine();
            t = tmp.split(" ");

            switch (t[0].charAt(0)) {
                case 'Q':
                    int x = Integer.parseInt(t[1]);
                    int y = Integer.parseInt(t[2]);
                    System.out.println(segTree.query(x, y, 1));
                    break;
                case 'C':
                    int a = Integer.parseInt(t[1]);
                    int b = Integer.parseInt(t[2]);
                    int c = Integer.parseInt(t[3]);
                    segTree.update(a, b, c, 1);
                    break;
                default:
                    break;
            }
        }
    }
}

class Node {

    public int left, right;
    public long sum;
    public long add;
}

class SegTree {

    public static final int MAX = 100010;
    private Node[] buffer = new Node[MAX << 2];

    private static int mid(int l, int r) {
        return (l + r) >> 1;
    }

    private static int leftChild(int i) {
        return i << 1;
    }

    private static int rightChild(int i) {
        return i << 1 | 1;
    }

    public void build(int l, int r, int i) {
        if (buffer[i] == null) buffer[i] = new Node();
        buffer[i].left = l;
        buffer[i].right = r;
        buffer[i].add = 0;

        if (l == r) return;
        int mid = mid(l, r);
        build(l, mid, leftChild(i));
        build(mid + 1, r, rightChild(i));
        buffer[i].sum = buffer[leftChild(i)].sum + buffer[rightChild(i)].sum;
    }

    public void update(int l, int r, long add, int i) {
        if (buffer[i].right < l || buffer[i].left > r) return;

        if (buffer[i].left >= l && buffer[i].right <= r) {
            buffer[i].add += add;
            buffer[i].sum += (buffer[i].right - buffer[i].left + 1) * add;
            return;
        }

        int leftChild = leftChild(i);
        int rightChild = rightChild(i);

        pushDown(i);

        update(l, r, add, leftChild);
        update(l, r, add, rightChild);
        buffer[i].sum = buffer[leftChild].sum + buffer[rightChild].sum;
    }

    public long query(int l, int r, int i) {
        if (buffer[i].right < l || buffer[i].left > r) return 0;

        if (buffer[i].left >= l && buffer[i].right <= r) {
            return buffer[i].sum;
        }

        int leftChild = leftChild(i);
        int rightChild = rightChild(i);

        pushDown(i);

        return query(l, r, leftChild) + query(l, r, rightChild);
    }

    private void pushDown(int i) {
        int leftChild = leftChild(i);
        int rightChild = rightChild(i);

        if (leftChild == rightChild) {
            buffer[i].add = 0;
            return;
        }

        if (buffer[i].add != 0) {
            buffer[leftChild].sum += (buffer[leftChild].right - buffer[leftChild].left + 1)
                                     * buffer[i].add;
            buffer[leftChild].add += buffer[i].add;

            buffer[rightChild].sum += (buffer[rightChild].right - buffer[rightChild].left + 1)
                                      * buffer[i].add;
            buffer[rightChild].add += buffer[i].add;

            buffer[i].add = 0;
        }
    }
}

参考

一步一步理解线段树