【算法】寻找两个正序数组的中位数

本文参考于:寻找两个有序数组的中位数

题目:

给定两个大小分别为 m 和 n 的正序(从小到大)数组 nums1 和 nums2。请你找出并返回这两个正序数组的中位数 。

LeetCode 4.

最简单的方法就是先将两个数组归并,然后根据m+n的奇偶性取新数组的中位数即可,时间复杂度和空间复杂度均为 O(m+n) 。也可以只归并到数组中间,然后取此时的一个或两个元素计算中位数即可。

归并法(这里是先全部归并)代码如下:

public class Solution {
    //寻找中位数
    public double FindMedianSortedArrays(int[] nums1, int[] nums2) {

        int len = nums1.Length + nums2.Length;
        int[] nums = new int[len];
        int i = 0, j = 0, k = 0;
        while(k < nums.Length)
        {
            if(i >= nums1.Length)
            {
                nums[k] = nums2[j++];
            }
            else if(j >= nums2.Length)
            {
                nums[k] = nums1[i++];
            }
            else if(nums1[i] <= nums2[j])
            {
                nums[k] = nums1[i++];
            }
            else
            {
                nums[k] = nums2[j++];
            }
            k++;
        }
        return len % 2 == 0 ? (nums[len / 2] + nums[len / 2 - 1]) / 2.0
        : nums[len / 2];
    }
}

在此方法基础上,可以通过双指针的方法降低空间复杂度,即分别用两个指针在这两个数组上进行滑动,来模拟归并后的数组的遍历,只需每次让指向较小元素的指针移动即可,每次指针移动前指向的元素的顺序即两个数组归并后从小到大的顺序。双指针方法的空间复杂度为 O(1)

如果想要降低时间复杂度,就必须得避免对两个数组的全部遍历,实际上如果我们能不全部遍历就解决 寻找两个正序数组的第k小的数(1\le k \le m+n)这个问题的话,那么求中位数就十分简单了。考虑用二分法取第k小的数:

假设两个数组分别为A和B,假如 k = 1,那么第k小的数就是

    \[ Min(A[0], B[0]) \]

k \ge 2 ,假如 A[ \frac{k}{2}-1 ] \le B[ \frac{k}{2}-1 ] ,由于A和B是升序排列的,我们有\forall 0 \le i \le \frac{k}{2} - 2

    \[ A[i] \le A[ \frac{k}{2}-1 ] \quad AND \]

    \[ B[i] \le B[ \frac{k}{2}-1 ] \]

即最多只有 (\frac{k}{2} - 1) + (\frac{k}{2} - 1) = k - 2  个在A或B内的元素小于 A[ \frac{k}{2}-1 ] 。那么 A[ \frac{k}{2}-1 ] 和它之前的所有A的元素都不可能是第k小的数,那么我们就可以排除掉A中 \frac{k}{2} 个元素。同理,若是 B[ \frac{k}{2}-1 ] \le A[ \frac{k}{2}-1 ] ,我们可以排除掉B中 \frac{k}{2} 个元素。

每次排除后,我们可以调整A或B的起始位置(两个数组的起始位置通过两个变量记录,初始值为0),以跳过被排除的元素,并将k减去排除的元素数量,那么问题就和上面一样了,只是起始位置和k变了。

如果A[ \frac{k}{2}-1 ]B[ \frac{k}{2}-1 ] 越界了,那么我们可以选取数组的最后一个元素,但是要注意的是这次排除的数量就不是 \frac{k}{2} 个了,需要另外计算。

最终我们有以下两种情况:

  • k为1,上面已经说明。
  • 一个数组“为空”(起始位置即数组末尾),那么只需取另一个数组的第k小的元素即可。

无论是哪种情况,第k小的数已经找到。

解决了这个问题后,取中位数就简单了:

  • 若m+n为偶数,那么中位数为第 \frac{m+n}{2}+1 小的数和第 \frac{m+n}{2} 小的数的平均值
  • 若m+n为奇数,那么中位数为第 \frac{m+n}{2} + 1 小的数(整除)

由于k的初始值为 \frac{m+n}{2}+1\frac{m+n}{2} ,每次排除都能减少一半范围,所以时间复杂度为O(log(m+n)),空间复杂度为 O(1)

C#代码如下:

public class Solution {
    // 寻找第k小的数
    public static int Find(int[] nums1, int[] nums2, int k)
    {
        //起始位置
        int index1 = 0, index2 = 0;
        while(true)
        {
            //A数组为空
            if(index1 == nums1.Length)
            {
                return nums2[index2 + k - 1];
            }
            //B数组为空
            if(index2 == nums2.Length)
            {
                return nums1[index1 + k - 1];
            }
            //k为1
            if(k == 1)
            {
                return Math.Min(nums1[index1], nums2[index2]);
            }
            //取元素的位置,如果没越界就是k/2 - 1,否则为数组最后一个元素
            int next1 = Math.Min(index1 + k / 2 - 1, nums1.Length - 1);
            int next2 = Math.Min(index2 + k / 2 - 1, nums2.Length - 1);
            //判断元素大小
            if(nums1[next1] < nums2[next2])
            {
                //改变起始位置,跳过排除元素,并减少k的值
                k -= next1- index1 + 1;
                index1 = next1 + 1;
            }
            else
            {
                k -= next2 - index2 + 1;
                index2 = next2 + 1;
            }
        }
    }

    //寻找中位数
    public double FindMedianSortedArrays(int[] nums1, int[] nums2) {
        int len = nums1.Length + nums2.Length;
        if(len % 2 == 0)
        {
            return (Find(nums1, nums2, len / 2) + 
            Find(nums1, nums2, len / 2 + 1)) / 2.0;
        }
        else
        {
            return Find(nums1, nums2, len / 2 + 1);
        }
    }
}

对于这个问题,有没有更快的办法呢?答案是有的,可以见文章开头引用的LeetCode官方题解的方法二,该方法的时间复杂度为log(min(m, n))

C#代码如下:

public double FindMedianSortedArrays(int[] nums1, int[] nums2) {
    //划分法 log(min(m, n))
    //确保 m <= n
    if(nums1.Length > nums2.Length)
        return FindMedianSortedArrays(nums2, nums1);

    int m = nums1.Length;
    int n = nums2.Length;
    int left = 0 , right = m;
    int lmax = 0, rmin = 0;
    while(left <= right)
    {
        int i = (left + right) / 2;
        int j = (m + n + 1) / 2 - i;
        int aim1 = i == 0 ? int.MinValue : nums1[i - 1];
        int bim1 = j == 0 ? int.MinValue : nums2[j - 1];
        int ai = i == m ? int.MaxValue : nums1[i];
        int bi = j == n ? int.MaxValue : nums2[j];
        if(aim1 <= bi)
        {
            lmax = Math.Max(aim1, bim1);
            rmin = Math.Min(ai, bi);
            left = i + 1;
        }
        else
        {
            right = i - 1;
        }
    }
    return (m + n) % 2 == 0 ? (lmax + rmin) / 2.0 : lmax;
}