LeetCode 480 滑动窗口中位数

本题的绝大多数难点都集中在如何设计一个合理、易用的数据结构,而非设计复杂精巧的算法,这在 OJ 习题中非常珍贵,同时又有用到了“延迟操作”的设计技巧,非常值得反复钻研。

原题链接如下: LeetCode 480 滑动窗口中位数

假设在一个整数数组 nums 上使用固定长度为 k 的滑动窗口扫描,每次窗口移动都会遇到新的数组子序列,而每个这样的子序列都有各自的中位数,要求输出所有这些中位数。

示例输入为:

1
2
3
Input:
nums = [1,3,-1,-3,5,3,6,7]
k = 3

可以假设 k 始终有效,即:k 始终小于等于输入的非空数组的元素个数。

对应输出为:

1
2
Output:
[1,-1,-1,3,5,6]

与真实值误差在 10 ^ -5 以内的答案将被视作正确答案。

LeetCode 239 滑动窗口最大值 不同,中位数需要找到子序列排序结果的中间位置,因此需要保留所有的窗口数据,需要通过改变数据的组织形式以避免重复计算。我们可以引入两个堆:

  • 一个是大顶堆:保存窗口元素中较小的一半元素
  • 一个是小顶堆:保存窗口元素中较大的一半元素

这样,直接获取两个堆顶的数据,就可以很轻松的求出中位数。如果窗口尺寸是奇数,那么只需要人为指定一个堆的个数多一个,中位数就是那个元素数量更多的堆的堆顶。

但是,必须要注意的是这样的数据结构查询操作简单,但是增添和删除元素都会较为复杂,原因如下:

  • 增添和删除元素都会破坏两个堆的元素个数的约定
  • 堆的随机删除操作时间复杂度并非 $O(1)$

我们可以引入“延迟删除”的内部操作,让我们借助这一操作能够保证便利数组时,只要操作数量够多,删除操作的平均时间复杂度为 $O(1)$。 为了实现这一目标,我们需要创立一个哈希表,统计加入两个堆的元素(因为元素可能重复,故不可使用集合)。一旦我们需要从任何一个堆中删除指定的元素,我们边令它在哈希表内的计数器增加一。 然后,我们可以设计一个“修剪”操作,把堆的顶端的需要删除的元素删去:

  1. 检查堆顶部是否是需要延迟删除的元素
  2. 弹出堆顶需要延迟删除的元素,重复第 1 步

这样,我们在进行删除操作时,如果一个堆的堆顶就是需要删除的元素,那么就在这个堆触发“修剪”操作。

为了保证两个堆的数量约定,我们每次增、删元素时后都要堆两个堆的元素数量做再平衡操作。 首先,因为我们需要保证延迟删除的成立,所以每个堆中可能存储多个需要延迟删除的元素,那么无法直接使用堆的标准 size 属性获知正确的元素数量。所以我们需要自己额外准备两个内部变量存储两个堆的有效元素数量,另外在增、删操作时注意手动维护两个变量符合实际情况。 然后,每次需要再平衡时,必须检查两个堆的元素数量,保证把一旦一个堆元素超标则移动到另一个堆中。

注意:

  • 超标的堆需要删除元素,同样要考虑是否进行“修剪”操作
  • 保证转移元素的同时,堆的数量变量也要符合实际情况

这样,每次我们进行增、删操作都要内部调用这个“再平衡”操作。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import java.util.PriorityQueue

class Solution {
    fun medianSlidingWindow(nums: IntArray, k: Int): DoubleArray {
        val re = DoubleArray(nums.size - k + 1)
        val cache = MedianList(k)
        for (i in 0 until k) {
            cache.add(nums[i])
        }
        re[0] = cache.getMedian()
        for (r in k until nums.size) {
            val l = r - k
            cache.remove(nums[l])
            cache.add(nums[r])
            re[l + 1] = cache.getMedian()
        }
        return re
    }

}

class MedianList(private val n: Int) {
    private val low = PriorityQueue<Int>(reverseOrder())
    private val high = PriorityQueue<Int>()
    private var lowCnt = 0
    private var highCnt = 0
    private val isOdd = n % 2 == 1
    private val delayed = HashMap<Int, Int>()

    fun add(data: Int) {
        if (low.isEmpty() || data <= low.peek()) {
            low.offer(data)
            lowCnt++
        } else {
            high.offer(data)
            highCnt++
        }
        reBalance()
    }

    fun remove(data: Int) {
        delayed[data] = (delayed[data] ?: 0) + 1
        if (data <= low.peek()) {
            --lowCnt
            if (low.peek() == data) prune(low)
        } else {
            --highCnt
            if (high.peek() == data) prune(high)
        }
        reBalance()
    }

    private fun reBalance() {
        if (lowCnt > highCnt + 1) {
            high.offer(low.poll())
            lowCnt--
            highCnt++
            prune(low)
        } else if (lowCnt < highCnt) {
            low.offer(high.poll())
            lowCnt++
            highCnt--
            prune(high)
        }
    }

    private fun prune(heap: PriorityQueue<Int>) {
        while (heap.isNotEmpty() && delayed.containsKey(heap.peek())) {
            val k = heap.poll()
            val newVal = delayed[k]!! - 1
            if (newVal == 0) delayed.remove(k)
            else delayed[k] = newVal
        }
    }

    fun getMedian() = if (isOdd) low.peek().toDouble() else (low.peek().toDouble() + high.peek()) / 2
}