Reference: LeetCode EPI 14.3
Difficulty: Medium

Problem

Given a binary search tree, write a function kthSmallest to find the kth smallest element in it.

Note: You may assume $k$ is always valid, 1 ≤ k ≤ BST’s total elements.

Example:

1
2
3
4
5
6
7
Input: root = [3,1,4,null,2], k = 1
3
/ \
1 4
\
2
Output: 1

Follow up: What if the BST is modified (insert/delete operations) often and you need to find the kth smallest frequently? How would you optimize the kthSmallest routine?

Evercode: I think we can keep both the kth smallest element and (k-1)th smallest element. If we insert or delete an element larger than the kth smallest element, the result remains unaffected. If something smaller than is inserted, compare it with the (k-1)th smallest element. The larger one becomes the new kth smallest element and adjust (k-1)th element accordingly.

Or maybe you can add another attribute like size to the tree nodes telling how many nodes is there in the subtree it’s rooted at?

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

private int size(TreeNode x) {
if (x == null) return 0;
else return x.size;
}

// Assume the TreeNode has a size field
public TreeNode kthSmallest(TreeNode root, int k) {
if (root == null) return null;

int t = size(root.left);

if (t < k) { // go right
return kthSmallest(root.right, k - t - 1);
} else if (t + 1 == k) {
return root;
} else { // t > k
return kthSmallest(root.left, k);
}
}

Analysis

Methods:

  1. Recursion
    • Use a field variable to keep track of the count and result.
    • Use an array to keep track of the result.
    • Time: $O(H + k)$
    • Space: $O(H + k)$
  2. Iteration
    • With the help of stack, we can convert the recursion version into iteration.
    • We can use the stack to store everything we need.
    • Time: $O(H + k)$
    • Space: $O(H + k)$

Code

Recursion

Note:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
private int count;
private int result;

public int kthSmallest(TreeNode root, int k) {
count = 0;
inorder(root, k);
return result;
}

private void inorder(TreeNode x, int k) {
if (count == k) return;
if (x == null) return;

inorder(x.left, k);
if (count < k) {
result = x.val;
++count;
inorder(x.right, k);
}
}

An alternative way is to use a list to store the element, just like the solution in EPI.

Iteration

Note:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
public int kthSmallest(TreeNode root, int k) {
int count = 0;
int result = -99999;
Stack<TreeNode> stack = new Stack<>();
TreeNode p = root;
while (p != null || stack.size() > 0) {
while (p != null) {
stack.push(p);
p = p.left;
}
p = stack.pop();
if (count == k) break;
result = p.val;
++count;

p = p.right; // may be null
}
return result;
}

Improvement:

  • Use stack to keep track of the result.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
public int kthSmallest(TreeNode root, int k) {
Stack<TreeNode> stack = new Stack<>();
TreeNode p = root;
while (p != null || stack.size() > 0) {
while (p != null) {
stack.push(p);
p = p.left;
}
p = stack.pop(); // visit
--k;
if (k == 0) return p.val;
p = p.right; // may be null
}
return -1;
}