Reference: LeetCode EPI 11.6
Difficulty: Medium

Problem

Given a n x n matrix where each of the rows and columns are sorted in ascending order, find the kth smallest element in the matrix.

Note that it is the kth smallest element in the sorted order, not the kth distinct element.

Note: You may assume $k$ is always valid, $1 ≤ k ≤ n^2$.

Example:

1
2
3
4
5
6
7
8
matrix = [
[ 1, 5, 9],
[10, 11, 13],
[12, 13, 15]
],
k = 8,

return 13.

Analysis

Brute-Force

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// Brute-Force
public int kthSmallest(int[][] matrix, int k) {
int n = matrix.length;
PriorityQueue<Integer> pq = new PriorityQueue<>((n1, n2) -> (n2 - n1));
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
pq.add(matrix[i][j]);
if (pq.size() > k) {
pq.poll();
}
}
}
return pq.poll();
}

Time: $O(n^2\log{k})$
Space: $O(k)$

Row Heap

  • Add the first row into the priority queue. Poll values from it for k times then the last one we poll will be the kth smallest.
  • Each time we poll a value from it, we should offer its next value if available.

Note: We can also use the column.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
// PQ (min-heap)
public int kthSmallest(int[][] matrix, int k) {
int kthNum = -1;

int n = matrix.length;
Comparator<int[]> comp = (n1, n2) -> (matrix[n1[0]][n1[1]] - matrix[n2[0]][n2[1]]);
PriorityQueue<int[]> minPQ = new PriorityQueue<>(comp);
// add first row
for (int col = 0; col < n; ++col) {
minPQ.add(new int[] { 0, col });
}

// add & poll PQ
while (k > 1) {
int[] index = minPQ.poll();
kthNum = matrix[index[0]][index[1]];
// add next
int nextRow = index[0] + 1;
if (nextRow < n) {
minPQ.add(new int[] { nextRow, index[1] });
}
k -= 1;
}

return kthNum;
// return matrix[minPQ.peek()[0]][minPQ.peek()[1]];
}

Time: $O(n\log{k})$
Space: $O(k)$

If we use bottom-up heapification, it should be $O(n + k\log{n})$

If $k < n$, we can have pruning:

Reference:

The key point for any binary search is to figure out search space. There are two kinds of search space: index and range. Usually, when the array is sorted in one direction, we can use index as search space; otherwise, we use values as our search space.

In this case, we cannot use index as our search space, because the matrix is sorted in two directions, we can not find a linear way to map the number and its index.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
public int kthSmallest(int[][] matrix, int k) {
int n = matrix.length;
int lo = matrix[0][0]; // min value
int hi = matrix[n - 1][n - 1]; // max value

while (lo <= hi) {
int mid = lo + (hi - lo) / 2;
int count = count(matrix, mid); // count values that are <= mid
if (count >= k) {
hi = mid - 1;
} else {
lo = mid + 1;
}
}
return lo; // e.g. k = 5, finally lo, hi both equal 11
}
/*
[ 1, 5, 9],
[10, 11, 13],
[12, 13, 15]
*/
// count the number of values that are <= mid
private int count(int[][] matrix, int mid) {
int n = matrix.length;
int count = 0;
int j = n - 1;
for (int i = 0; i < n; ++i) { // row - O(n)
// from right to left ==> j + 1 is the count number
// mid is getting bigger
// move n / 2 times
// move n / 4 times
// move n / 8 times
while (j >= 0 && matrix[i][j] > mid) { // col - O(logn)
--j;
}
count += (j + 1);
}
return count;
}

Time: $O(n\log{n}\log{N})$

  • $N$ is the search space that ranges from the smallest to the largest element.
  • Each count takes $O(n\log{n})$ in the best case when mid is getting smaller.
  • Each count takes $O(n^2)$ in the worst case when mid is getting larger.

Space: $O(1)$

Improvement:

We can actually improve count function.

Time: $O(n\log{N})$
Space: $O(1)$

Count by row starting from up-right:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// do a binary search
private int count(int[][] matrix, int mid) {
int n = matrix.length;
int r = 0, c = n - 1;
int count = 0;
while (r < n && c >= 0) {
if (mid >= matrix[r][c]) {
r += 1;
count += c + 1;
} else {
c -= 1;
}
}
return count;
}

Count by column starting from bottom-left:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
private int count(int[][] matrix, int mid) {
int n = matrix.length;
int r = n - 1, c = 0;
int count = 0;
while (r >= 0 && c < n) {
if (mid >= matrix[r][c]) {
c += 1;
count += r + 1;
} else {
r -= 1;
}
}
return count;
}