LeetCode 4 寻找两个正序数组的中位数

即便是简单的算法,如果不能窥见本质,那么在加上一定额外变型之后都会变得非常困难。本题如果不能理解二分法的抽象本质,就难以处理题目的输入数据,而理解之后就能快速写出简洁的代码。

原题链接如下: LeetCode 4 寻找两个正序数组的中位数

本题的题干非常简单,只是把常规求中位数的输入数据变为 2 个有序数组。而且要注意,数组的长度不定,两个数组元素之间也没有特殊关系。假设两个数组的长度分别为 m 和 n,则算法的时间复杂度为 O(log(m + n))。

示例输入为:

1
2
3
4
5
6
7
Input 1:
nums1 = [1, 3]
nums2 = [2]

Input 2:
nums1 = [1, 2]
nums2 = [3, 4]

对应输出为:

1
2
3
4
5
Output 1:
2.0

Output 2:
2.5

为了保证一定的精度,本题的结果使用 Double 类型储存。

中位数本身就按时这答案算法可能涉及二分法,另外题目明确规定了算法的时间复杂度是 O(log(m + n)),这就说明必须对所有数据使用二分法而非其他解法。当然,即使同样是二分法也会因为思考角度的不同而构建出完全不一样的解法。

实际上求中位数就是求第 k 个数的特别版,而之所以一定要把特例问题转化为一般问题不是为了把问题复杂化——在二分数据的过程中必定会不断丢弃一半的数据,这样新的数据中的中位数就不是第 k 个。为了在新数据中也能使用同样的代码求解,就必须进行一般化处理。

假设两个有序数组分别是 A 和 B。要找到第 k 个元素,我们可以比较 A[k/2-1]B[k/2−1],其中 / 表示整数除法,可能的比较结果如下:

  • A[k/2−1] < B[k/2−1]:A 中 A[k/2-1] 前面的数字一定比它小,而 B 中最多也只可能是 B[k/2−1] 前面的数比它小,即一共最多有 k - 2 个数比它小,那么 A[k/2-1] 以及它前面的数字就都不可能是答案,可以舍去
  • A[k/2−1] > B[k/2−1]:同上可排除 B[k/2−1] 和它前面的数字
  • A[k/2−1] = B[k/2−1]:舍去那一部分都可以,任选一个即可

然后,我们在对一个数组进行舍去操作后,原题的第 k 个数就不是新的数据里面的第 k 个数,必须减去已经舍去的数据的个数。

当然,数组的长度和 k 的值都在不断变化,所以必须注意各种可能发生的边界条件:

  • A[k/2−1]B[k/2−1] 越界:那么,就只能选择 A 或 B 的最后一个元素
  • 一个数组为空:直接再另一个数组中选择第 k 个元素即可
  • k 为 1:直接比较 A 和 B 的第一个元素,返回较小的那个即可

用一个例子说明上述算法。假设两个有序数组如下:

1
2
A: 1 3 4 9
B: 1 2 3 4 5 6 7 8 9

两个有序数组的长度分别是 4 和 9,长度之和是 13,中位数是两个有序数组中的第 7 个元素,因此 k = 7。

比较两个有序数组中下标为 k/2-1=2 的数,即 A[2]B[2],如下面所示:

1
2
3
4
A: 1 3 4 9
B: 1 2 3 4 5 6 7 8 9

由于 A[2] > B[2],因此排除 B[0]B[2],即数组 B 的下一次选取操作从 3 开始,同时更新 k 的值:k = k - k/2 = 4。

下一步寻找,比较两个有序数组中下标为 k/2 - 1 = 1 的数,即 A[1]B[4],如下面所示:

1
2
3
4
A: 1 3 4 9
B: [1 2 3] 4 5 6 7 8 9

其中方括号部分表示已经被排除的数,由于 A[1] < B[4],因此排除 A[0]A[1],即 A 的下一次选取操作从 2 开始,同时更新 k 的值:k = k − k/2 = 2。

下一步寻找,比较两个有序数组中下标为 k/2 − 1 = 0 的数,即比较 A[2]B[3],如下面所示,其中方括号部分表示已经被排除的数。

1
2
3
4
A: [1 3] 4 9
B: [1 2 3] 4 5 6 7 8 9

由于 A[2] = B[3],可以选择排除 A 中的元素,因此排除 A[2],即 A 的下一次选取操作从 3 开始,同时更新 k 的值: k = k - k/2 = 1。

由于 k 的值变成 1,因此比较两个有序数组中的未排除下标范围内的第一个数,其中较小的数即为第 k 个数,由于 A[3] > B[3],因此第 k 个数是 B[3] 即 4。

1
2
3
4
A: [1 3 4] 9
B: [1 2 3] 4 5 6 7 8 9

本算法每一轮循环可以将查找范围 k 减少一半,因此时间复杂度是 O(log(m+n))。

从另一个角度看,中位数把一组数据划分为两部分,一部分的所有数据都必定不大于另一部分,所以本题可以使用划分法解决。

一个无序数组中求第 k 个数的“快排划分法”,同样会使用了类似的二分法。

首先,我们不妨在任意位置 i 将数组 A 划分成两个部分,在任意位置 j 将 B 划分成两个部分,划分结果如下:

1
2
3
4
5
6
7
数组 A:
           left_A            |          right_A
    A[0], A[1], ..., A[i-1]  |  A[i], A[i+1], ..., A[m-1]

数组 B:
           left_B            |          right_B
    B[0], B[1], ..., B[j-1]  |  B[j], B[j+1], ..., B[n-1]

这时进行数据组合,假设将 left_A 和 left_B 放入一个集合,并将 right_A 和 right_B 放入另一个集合,结果如下所示:

1
2
3
          left_part          |         right_part
    A[0], A[1], ..., A[i-1]  |  A[i], A[i+1], ..., A[m-1]
    B[0], B[1], ..., B[j-1]  |  B[j], B[j+1], ..., B[n-1]

如果我们能够保证 left_part 的元素都小于 right_part 的元素,并且两部分元素个数相等,那么显然就找到了中位数。而其充要条件就是下面所有条件必须全部满足:

  • A[i-1] <= B[j]
  • B[i-1] <= A[j]
  • i + j = (m + n + 1) / 2

此处我们让 m + n 为奇数时 left_part 的元素多一个,反之亦可。

其实就是把问题转化为:在 A 中找到一个下标 i 使得 A[i-1] <= B[j] 成立,其中 j = (m + n + 1) / 2 - i,而找下标的过程即可使用二分法。一旦下标找到,那么只要根据总长度的奇偶选取不同的元素即可得到问题答案。

本解法可以在长度较短的数组上进行二分法,保证算法时间复杂度是 O(log(min(m, n))),优于第一种解法。

解法一(迭代法):

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class Solution {
    fun findMedianSortedArrays(nums1: IntArray, nums2: IntArray): Double {
        val len = nums1.size + nums2.size
        val left = (len + 1) / 2
        val right = (len + 2) / 2
        return if (len % 2 != 0)
            getKth(nums1, nums2, left).toDouble()
        else
            (getKth(nums1, nums2, left) +
                    getKth(nums1, nums2, right)) * 0.5
    }

    private fun getKth(nums1: IntArray, nums2: IntArray, k: Int): Int {
        var left1 = 0
        var left2 = 0
        val right1 = nums1.size
        val right2 = nums2.size
        var tmpK = k
        while (left1 < right1 || left2 < right2) {
            if (left1 == right1) return nums2[left2 + tmpK - 1]
            if (left2 == right2) return nums1[left1 + tmpK - 1]
            if (tmpK == 1) return minOf(nums1[left1], nums2[left2])
            val mid1 = minOf(left1 + tmpK / 2, right1) - 1
            val mid2 = minOf(left2 + tmpK / 2, right2) - 1
            if (nums1[mid1] > nums2[mid2]) {
                tmpK -= mid2 - left2 + 1
                left2 = mid2 + 1
            } else {
                tmpK -= mid1 - left1 + 1
                left1 = mid1 + 1
            }
        }
        return 0
    }
}

解法一(递归法):

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Solution {
    fun findMedianSortedArrays(nums1: IntArray, nums2: IntArray): Double {
        val len = nums1.size + nums2.size
        val left = (len + 1) / 2
        val right = (len + 2) / 2
        return if (len % 2 != 0)
            getKth(nums1, 0, nums1.size, nums2, 0, nums2.size, left).toDouble()
        else
            (getKth(nums1, 0, nums1.size, nums2, 0, nums2.size, left) +
                    getKth(nums1, 0, nums1.size, nums2, 0, nums2.size, right)) * 0.5
    }

    private fun getKth(nums1: IntArray, start1: Int, end1: Int, nums2: IntArray, start2: Int, end2: Int, k: Int): Int {
        val len1 = end1 - start1
        val len2 = end2 - start2
        if (len1 > len2) return getKth(nums2, start2, end2, nums1, start1, end1, k)
        if (len1 == 0) return nums2[start2 + k - 1]
        if (k == 1) return minOf(nums1[start1], nums2[start2])
        val mid1 = start1 + minOf(len1, k / 2) - 1
        val mid2 = start2 + minOf(len2, k / 2) - 1
        return if (nums1[mid1] > nums2[mid2])
            getKth(nums1, start1, end1, nums2, mid2 + 1, end2, k - minOf(len2, k / 2))
        else
            getKth(nums1, mid1 + 1, end1, nums2, start2, end2, k - minOf(len1, k / 2))
    }
}

注意:递归法第一次递归只是为了保证两个输入数组的长度关系一定,最多执行一次,而之后的两次递归都是尾递归,编译器可能进行优化,所以理论上空间复杂度会是 O(1)。

解法二:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Solution {
    fun findMedianSortedArrays(nums1: IntArray, nums2: IntArray): Double {
        if (nums1.size > nums2.size) return findMedianSortedArrays(nums2, nums1)

        val m = nums1.size
        val n = nums2.size
        val half = (m + n + 1) / 2
        var left = 0
        var right = m
        while (left < right) {
            val i = (left + right + 1) / 2
            val j = half - i
            if (nums1[i - 1] > nums2[j]) right = i - 1
            else left = i
        }
        val i = left
        val j = half - left
        val leftMax1 = if (i == 0) Int.MIN_VALUE else nums1[i - 1]
        val leftMax2 = if (j == 0) Int.MIN_VALUE else nums2[j - 1]
        val rightMin1 = if (i == m) Int.MAX_VALUE else nums1[i]
        val rightMin2 = if (j == n) Int.MAX_VALUE else nums2[j]

        return if ((n + m) % 2 == 1) maxOf(leftMax1, leftMax2).toDouble()
        else (maxOf(leftMax1, leftMax2) + minOf(rightMin1, rightMin2)).toDouble() / 2
    }
}