Given the root of a binary search tree and an integer k, return the kth smallest value (1-indexed) among all node values in the BST. It is guaranteed that k is between 1 and the number of nodes in the tree. The tree can have between 1 and 10,000 nodes.
Example
5
/ \
3 6
/ \
2 4
/
1
k = 3
Output: 3
The values in sorted order are [1, 2, 3, 4, 5, 6]. The 3rd smallest is 3.
An inorder traversal of a BST visits nodes in ascending sorted order. Perform an iterative inorder traversal using a stack, and count each node you visit. When you've visited k nodes, you've found the kth smallest element.
Why this works
Inorder traversal of a BST always produces values in sorted ascending order (left, root, right). So the 1st node visited is the smallest, the 2nd is the next smallest, and so on. By counting visits and stopping at k, you find the kth smallest without sorting anything.
Step by step
- Push left — starting from the root, keep going left and pushing onto the stack until you hit null. This reaches the smallest unvisited node.
- Pop and visit — pop the top of the stack; this is the next node in sorted order. Decrement k.
- Check k — if k reaches 0, this node is the kth smallest, return its value.
- Go right — move to the right child and repeat (go left again from there).
Time: O(h + k)
Space: O(h)
class TreeNode {
int val;
TreeNode left, right;
TreeNode(int val) { this.val = val; }
}
class Solution {
public int kthSmallest(TreeNode root, int k) {
Deque<TreeNode> stack = new ArrayDeque<>();
TreeNode current = root;
while (current != null || !stack.isEmpty()) {
while (current != null) { // go as far left as possible
stack.push(current);
current = current.left;
}
current = stack.pop(); // visit the next smallest node
k--;
if (k == 0) {
return current.val; // found the kth smallest
}
current = current.right; // explore right subtree next
}
return -1;
}
}
struct TreeNode {
int val;
TreeNode *left, *right;
TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
};
class Solution {
public:
int kthSmallest(TreeNode* root, int k) {
stack<TreeNode*> st;
TreeNode* current = root;
while (current || !st.empty()) {
while (current) { // go as far left as possible
st.push(current);
current = current->left;
}
current = st.top(); // visit the next smallest node
st.pop();
k--;
if (k == 0) {
return current->val; // found the kth smallest
}
current = current->right; // explore right subtree next
}
return -1;
}
};
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def kthSmallest(root: TreeNode | None, k: int) -> int:
stack = []
current = root
while current or stack:
while current: # go as far left as possible (toward smallest values)
stack.append(current)
current = current.left
current = stack.pop() # visit the next smallest node
k -= 1
if k == 0:
return current.val # found the kth smallest
current = current.right # explore right subtree next
return -1