来源:力扣(LeetCode)
链接:https://leetcode.cn/problems/median-of-two-sorted-arrays

Description

给定两个大小分别为 m 和 n 的正序(从小到大)数组 nums1 和 nums2。请你找出并返回这两个正序数组的 中位数 。

算法的时间复杂度应该为 O(log (m+n)) 。

Input

两个 vector<int>,表示两个正序的数组

Output

两个数组合并后的中位数(double 类型)

Sample Input

1
2
示例1: nums1=[1,3], nums2=[2]
示例2: nums1=[1,2], nums2=[3,4]

Sample Output

1
2
示例1: 2
示例2: 2.5

Hint

nums1.length == m
nums2.length == n
0 <= m <= 1000
0 <= n <= 1000
1 <= m + n <= 2000
$-10^6\leq nums1[i],\;nums2[i]\leq 10^6$

分析

做完这个题才知道自己OI水平下滑确实厉害,一开始居然没想到用二分。。。

首先说暴力做法(时间复杂度$O(m+n)$)

首先排除合并后排序然后直接输出中位数的原始做法,因为两个数组都是正序的,所以可以想到设两个指针分别指向两个数组的头部,随指针的移动筛选出较小的数字,两个指针一共移动 $\frac{m+n}{2}$ 次的时候就可以结束了,此时指针指向的数字就是中位数

然后是满足题目条件的二分做法

把原题推广至求合并后数组第 $k$ 小的数,那么:

由于两个数组(设为 $A$,$B$)均正序,在一次二分中,考虑比较 $A[\frac{k}{2}]$ 和$B[\frac{k}{2}]$ 的大小,若 $A[\frac{k}{2}]\geq B[\frac{k}{2}]$,则说明合并之后的第 $k$ 小不可能是 $B[\frac{k}{2}]$ 及其之前的数字中某一个($B[\frac{k}{2}]$ 最大也只能是第 $k-1$ 小),所以把这部分数字(设有 $d$ 个)剔除,剩下的数中寻找第 $k-d$ 小,如此重复直到某一个数组为空,或者 $k=1$ 即可直接判断中位数

求中位数,也就是 $k=\frac{m+n}{2}$,时间复杂度理论 $O(log_2(m+n))$

说实话数据量太小体现不出多大差异,要是像OI那种动辄上亿的量级就能看出来了

Codes

暴力解法(24ms)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class Solution {
public:
double findMedianSortedArrays(std::vector<int>& nums1,
std::vector<int>& nums2) {
int total = nums1.size() + nums2.size();
int result, result_2, mid;
mid = total & 1 ? (total + 1) >> 1 : total >> 1;
mid--;
if (nums1.empty())
return total & 1 ? (double)nums2[mid]
: (double)(nums2[mid] + nums2[mid + 1]) / 2.0;
if (nums2.empty())
return total & 1 ? (double)nums1[mid]
: (double)(nums1[mid] + nums1[mid + 1]) / 2.0;
std::vector<int>::iterator it1 = nums1.begin();
std::vector<int>::iterator it2 = nums2.begin();
for (int i = 0; i <= mid; i++) {
if (it1 == nums1.end())
result = *it2++;
else if (it2 == nums2.end())
result = *it1++;
else if (*it1 <= *it2)
result = *it1++;
else
result = *it2++;
}
if (total & 1)
return (double)result;
else {
if (it1 == nums1.end())
result_2 = *it2++;
else if (it2 == nums2.end())
result_2 = *it1++;
else if (*it1 <= *it2)
result_2 = *it1++;
else
result_2 = *it2++;
return (double)(result + result_2) / 2.0;
}
}
};

二分解法(16ms)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class Solution {
public:
int findKth(std::vector<int>& arr1, std::vector<int>& arr2, int st1, int st2,
int len1, int len2, int k) {
if (st1 > len1) return arr2[st2 + k - 2];
if (st2 > len2) return arr1[st1 + k - 2];
if (k == 1) return std::min(arr1[st1 - 1], arr2[st2 - 1]);
int next1 = std::min(len1, st1 + (k >> 1) - 1);
int next2 = std::min(len2, st2 + (k >> 1) - 1);
if (arr1[next1 - 1] >= arr2[next2 - 1])
return findKth(arr1, arr2, st1, next2 + 1, len1, len2,
k - (next2 - st2 + 1));
else
return findKth(arr1, arr2, next1 + 1, st2, len1, len2,
k - (next1 - st1 + 1));
}

double findMedianSortedArrays(std::vector<int>& nums1,
std::vector<int>& nums2) {
int len1 = nums1.size(), len2 = nums2.size();
int total = len1 + len2;
if (total & 1) {
int mid = (total + 1) >> 1;
if (nums1.empty()) return (double)nums2[mid - 1];
if (nums2.empty()) return (double)nums1[mid - 1];
return (double)findKth(nums1, nums2, 1, 1, len1, len2,
(total + 1) >> 1);
} else {
int mid = total >> 1;
if (nums1.empty()) return (double)(nums2[mid - 1] + nums2[mid]) / 2.0;
if (nums2.empty()) return (double)(nums1[mid - 1] + nums1[mid]) / 2.0;
return (double)(findKth(nums1, nums2, 1, 1, len1, len2, mid) +
findKth(nums1, nums2, 1, 1, len1, len2, mid + 1)) /
2.0;
}
}
};