我理解的数据结构(八)—— 线段树(SegmentTree)

一、什么是线段树

1.最经典的线段树问题:区间染色
有一面墙,长度为n,每次选择一段墙进行染色,m次操作后,我们可以看见多少种颜色?m次操作后,我们可以在[i, j]区间内看见多少种颜色?
图片描述

数据结构 染色操作 查询操作
数组 O(n) O(n)

2.其他应用场景
2017年注册用户中消费最高的用户?消费最少的用户?学习时间最长的用户?

3.复杂度比较

数据结构 更新 查询
数组 O(n) O(n)
线段树 O(logn) O(logn)

4.线段树原理图
以求和为例:
图片描述

二、线段树基础

1.基础

  • 线段树不是完全二叉树
  • 线段树是平衡二叉树(整棵树最大深度和最小深度最大差为1)
  • 线段树依然可以用数组表示

    • 把不存在的节点看成null,线段树即是完全二叉树

2.数组存储线段树的空间需求

0层 1层 ... h-1层
1 2 ... 2^(h-1)
  • 对满二叉树:

    • h层,一共有2^h-1个节点(大约2^h)
    • 最后一层(h-1)层,有2^(h-1)个节点
    • 最后一层的节点数大致等于前面所有层节点之和
  • 如果需要存储n个元素

    • n=2^k,只需要2n的空间
    • n=2^k+1,需要4n的空间

三、线段树基础代码

public class SegmentTree<E> {

    private E[] data;
    private E[] tree;

    public SegmentTree(E[] arr) {
        data = (E[])new Object[arr.length];

        for (int i = 0; i < arr.length; i++) {
            data[i] = arr[i];
        }
        tree = (E[])new Object[4 * arr.length];
    }

    // 获取元素个数
    public int getSize() {
        return data.length;
    }

    // 获取某个索引上的值
    private E get(int index) {
        if (index < 0 || index >= data.length) {
            throw new IllegalArgumentException("index is illegal");
        }

        return data[index];
    }

    // 返回完全二叉树的数组表示中,一个索引所表示的元素的左子树的索引
    private int leftChild(int index) {
        return 2 * index + 1;
    }

    // 返回完全二叉树的数组表示中,一个索引所表示的元素的右子树的索引
    private int rightChild(int index) {
        return 2 * index + 2;
    }
}

四、创建线段树代码

1.定义融合器接口

public interface Merge<E> {
    // 区间的元素如何定义由用户决定
    E merge(E a, E b);
}

2.创建线段树代码


private Merge<E> merge;

public SegmentTree(E[] arr, Merge<E> merge) {

    // 线段树的融合器,用于定义线段树的区间元素到底如何存储
    this.merge = merge;

    data = (E[])new Object[arr.length];

    for (int i = 0; i < arr.length; i++) {
        data[i] = arr[i];
    }
    tree = (E[])new Object[4 * arr.length];
    buildSegmentTree(0, 0, data.length - 1);
}

// 递归:在treeIndex的位置创建表示区间[l,r]的线段树
private void buildSegmentTree(int treeIndex, int l, int r) {
    if (l == r) {
        tree[treeIndex] = data[l];
        return;
    }

    int leftTreeIndex = leftChild(treeIndex);
    int rightTreeIndex = rightChild(treeIndex);
    int min = (l + r) / 2;

    buildSegmentTree(leftTreeIndex, l, min);
    buildSegmentTree(rightTreeIndex, min + 1, r);
    tree[treeIndex] = merge.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}

@Override
public String toString() {

    StringBuilder res = new StringBuilder();

    res.append('[');
    for (int i = 0; i < tree.length; i++) {
        if (tree[i] == null) {
            res.append("null");
        } else {
            res.append(tree[i]);
        }

        if (i != tree.length - 1) {
            res.append(", ");
        }

    }
    res.append(']');

    return res.toString();
}

五、线段树的查询

图片描述

// 线段树的查询操作,区间[queryL, queryR]
public E query(int queryL, int queryR) {

   if (queryL < 0 || queryL >= data.length || queryR < 0 || queryR >= data.length) {
       throw new IllegalArgumentException("queryL or queryR is illegal");
   }

   return query(0, 0, data.length - 1, queryL, queryR);
}

// 递归,以treeIndex为根节点,区间为[l, r],查询区间为[queryL, queryR]
private E query(int treeIndex, int l, int r, int queryL, int queryR) {

   if (l == queryL && r == queryR) {
       return tree[treeIndex];
   }

   int leftChildIndex = leftChild(treeIndex);
   int rightChildIndex= rightChild(treeIndex);
   int mid = (l + r) / 2;

   if (queryL >= mid + 1) {
       return query(rightChildIndex, mid + 1, r, queryL, queryR);
   } else if (queryR <= mid) {
       return query(leftChildIndex, l, mid, queryL, queryR);
   }

   E left = query(leftChildIndex, l, mid, queryL, mid);
   E right = query(rightChildIndex, mid + 1, r, mid + 1, queryR);
   return merge.merge(left, right);
}

六、LeetCode上303号问题

题目:303. 区域和检索 - 数组不可变
描述:给定一个整数数组 nums,求出数组从索引 i 到 j (i ≤ j) 范围内元素的总和,包含 i, j 两点。
示例:

给定 nums = [-2, 0, 3, -5, 2, -1],求和函数为 sumRange()

sumRange(0, 2) -> 1
sumRange(2, 5) -> -1
sumRange(0, 5) -> -3

解题代码:

// 注意,如果要在leetcode上提交解答,必须把Merge接口和SegmentTree类的代码一并提交,这里并没有在写NumArray类中
public class NumArray {

    private SegmentTree<Integer> segTree;

    public NumArray(int[] nums) {

        if (nums.length > 0) {

            Integer[] data = new Integer[nums.length];
            for (int i = 0; i < nums.length; i++) {
                data[i] = nums[i];
            }

            segTree = new SegmentTree<>(data, (a, b) -> a + b);
        }
    }

    public int sumRange(int i, int j) {
        if (segTree == null) {
                throw new IllegalArgumentException("segment tree is null");
        }
        return segTree.query(i, j);
    }

}

七、线段树的更新

public void set(int index, E e) {

    if (index < 0 || index >= data.length) {
        throw new IllegalArgumentException("index is illegal");
    }

    set(0, 0, data.length - 1, index, e);
}

private void set(int treeIndex, int l, int r, int index, E e) {
    if (l == r) {
        tree[treeIndex] = e;
        return;
    }

    int leftChildIndex = leftChild(treeIndex);
    int rightChildIndex= rightChild(treeIndex);
    int mid = (l + r) / 2;

    if (index >= mid + 1) {
        set(rightChildIndex, mid + 1, r, index, e);
    } else if (index <= mid) {
        set(leftChildIndex, l, mid, index, e);
    }

    tree[treeIndex] = merge.merge(tree[leftChildIndex], tree[rightChildIndex]);
}

八、LeetCode上307号问题

题目:307. 区域和检索 - 数组可修改
描述:给定一个整数数组 nums,求出数组从索引 i 到 j (i ≤ j) 范围内元素的总和,包含 i, j 两点。
示例:

Given nums = [1, 3, 5]

sumRange(0, 2) -> 9
update(1, 2)
sumRange(0, 2) -> 8

解题代码:

class NumArray {

    // 注意,如果要在leetcode上提交解答,必须把Merge接口和SegmentTree类的代码一并提交,这里并没有在写NumArray类中    
    private SegmentTree<Integer> segTree;
    
    public NumArray(int[] nums) {
        if (nums.length > 0) {

            Integer[] data = new Integer[nums.length];
            for (int i = 0; i < nums.length; i++) {
                data[i] = nums[i];
            }

            segTree = new SegmentTree<>(data, (a, b) -> a + b);
        }
    }

    public void update(int i, int val) {
        if (segTree == null) {
            throw new IllegalArgumentException("segment tree is null");
        }
        segTree.set(i, val);
    }

    public int sumRange(int i, int j) {
        if (segTree == null) {
            throw new IllegalArgumentException("segment tree is null");
        }
        return segTree.query(i, j);
    }
}

罗纳尔多Coder
300 声望33 粉丝

everyday hardwork ? 1.1^365 : 0.9^365