Find the Kth Smallest Sum of a Matrix With Sorted Rows
You have a matrix where every row is sorted in ascending order. You need to pick exactly one element from each row, sum them up, and find the kth smallest such sum. This is LeetCode 1439: Find the Kth Smallest Sum of a Matrix With Sorted Rows, a hard problem that combines heap-based enumeration with a clever pruning strategy.
The problem
Given an m x n matrix mat where each row is sorted in non-decreasing order, and an integer k, find the kth smallest sum you can form by picking one element from each row.
Example: mat = [[1, 3, 11], [2, 4, 6]], k = 5.
The possible sums (sorted) are: 3, 5, 5, 7, 7, 9, 13, 15, 17. The 5th smallest is 7.
The approach
The brute force way is to generate every combination of one element per row, compute each sum, sort them, and return the kth. With m rows and n columns, that is n^m combinations. For m = 40 and n = 40, that is astronomically large.
The key insight is that the rows are sorted. The smallest possible sum always uses the first element of every row. From there, you can incrementally generate the next smallest sums by advancing one row's index at a time, using a min-heap to always pick the globally smallest candidate next.
This is the same pattern used in Merge K Sorted Lists, where a heap tracks the smallest element across multiple sources. Here, each "source" is a different way to advance one row's index.
Think of it like a multi-dimensional sorted search. The heap lets you explore sums in order from smallest to largest without generating all n^m combinations. You stop as soon as you have popped k values.
How the heap works
Start by pushing the sum formed by taking the first element of every row, along with the index tuple [0, 0, ..., 0]. Then repeat k times: pop the smallest sum, and for each row, try advancing that row's index by one (creating a new candidate sum). Use a visited set to avoid pushing duplicate index tuples.
After k pops, the last popped value is your answer.
For a matrix with only two rows, this simplifies to the same approach used in "Find K Pairs with Smallest Sums." For more rows, you can also apply pairwise merging: merge the first two rows into a virtual sorted list of sums (keeping only the top k), then merge that result with the third row, and so on. Both strategies work. The pairwise approach is often simpler to implement.
Step-by-step walkthrough
Step 1: Seed the heap with the smallest possible sum.
Pick the first element from each row: 1 + 2 = 3. Push (sum=3, indices=[0,0]) into the min-heap.
Step 2: Pop sum=3 (result #1). Expand neighbors.
Pop 3. Advance each row index by one: push (3+2=5, [1,0]) and (1+4=5, [0,1]). Both sums are 5.
Step 3: Pop sum=5 (result #2). Expand from [0,1].
Pop 5 from [0,1]. Push (3+4=7, [1,1]) and (1+6=7, [0,2]). Heap now has [5, 7, 7].
Step 4: Pop sum=5 (result #3). Expand from [1,0].
Pop the other 5 from [1,0]. Push (11+2=13, [2,0]) and (3+4=7, [1,1]) if not seen. Heap contains sums 7, 7, 7, 13.
Step 5: Pop sum=7 (result #4). Getting close to k=5.
Pop 7. Expand its neighbors and push new sums. We need one more pop to reach k=5.
Step 6: Pop sum=7 (result #5). This is our answer!
The 5th value popped from the heap is 7. That is the kth smallest sum. Return 7.
The code
import heapq
def kth_smallest(mat, k):
m, n = len(mat), len(mat[0])
indices = tuple([0] * m)
initial_sum = sum(mat[i][0] for i in range(m))
heap = [(initial_sum, indices)]
visited = {indices}
for _ in range(k):
current_sum, current_indices = heapq.heappop(heap)
for row in range(m):
col = current_indices[row]
if col + 1 < n:
new_indices = list(current_indices)
new_indices[row] = col + 1
new_indices = tuple(new_indices)
if new_indices not in visited:
visited.add(new_indices)
new_sum = current_sum - mat[row][col] + mat[row][col + 1]
heapq.heappush(heap, (new_sum, new_indices))
return current_sum
The function starts with the all-zeros index tuple and its corresponding sum. On each iteration, it pops the smallest sum from the heap, then tries advancing each row's column index by one. It computes the new sum incrementally by subtracting the old column value and adding the new one. The visited set prevents the same index tuple from being pushed multiple times.
After exactly k pops, current_sum holds the kth smallest sum.
Pairwise alternative. You can also solve this by merging two rows at a time. Treat the first row as the initial list of sums, then merge it with the second row to get the k smallest pairwise sums, then merge that result with the third row, and repeat:
import heapq
def kth_smallest(mat, k):
prev = mat[0]
for row in range(1, len(mat)):
heap = []
visited = set()
for j in range(min(k, len(prev))):
heapq.heappush(heap, (prev[j] + mat[row][0], j, 0))
visited.add((j, 0))
merged = []
while heap and len(merged) < k:
s, i, j = heapq.heappop(heap)
merged.append(s)
if j + 1 < len(mat[row]) and (i, j + 1) not in visited:
visited.add((i, j + 1))
heapq.heappush(heap, (prev[i] + mat[row][j + 1], i, j + 1))
prev = merged
return prev[k - 1]
Both approaches have similar complexity, but the pairwise version is more natural when you are already comfortable with the "k smallest pairs" pattern.
Complexity analysis
| Approach | Time | Space |
|---|---|---|
| Brute force | O(n^m * m) | O(n^m) |
| Heap-based (direct) | O(k * m * log(k * m)) | O(k * m) |
| Pairwise merge | O(m * k * log k) | O(k) |
The brute force is clearly impractical for large inputs. Both heap-based approaches keep the work proportional to k rather than the total number of combinations, which is the whole point.
The building blocks
Min-heap for ordered enumeration
The core idea here is the same one behind Kth Largest Element and Merge K Sorted Lists: use a min-heap to efficiently track the smallest candidate across multiple sources. Instead of sorting everything upfront, you lazily generate candidates in order. This pattern appears whenever you need the kth smallest or largest from a structured search space.
Pairwise row merging
When you have m rows, you do not need to reason about all m dimensions at once. Merge two rows at a time, keeping only the top k sums from each merge. This reduces the m-dimensional problem to m-1 two-dimensional merges. It is the same divide-and-conquer idea behind merge sort, applied to sum enumeration.
Edge cases
- Single row. If
m = 1, the answer is simplymat[0][k - 1]. No heap needed. - Single column. If every row has exactly one element, there is only one possible sum. Return it for any k (k is guaranteed to be valid).
- k = 1. The smallest sum is always the sum of the first elements of every row. This is the starting point of both approaches.
- Large k near the total number of combinations. The heap-based approach still works, but performance degrades as k approaches
n^m. In practice, the constraints guaranteekis at most 200.
From understanding to recall
The logic makes sense once you see it: seed the heap with the smallest sum, pop and expand, repeat k times. But in an interview, the tricky parts are the index management and the visited set. Which row do you advance? How do you compute the new sum without recalculating from scratch? How do you avoid pushing the same state twice?
These details are exactly what slip under pressure. Spaced repetition locks them in. You write the heap loop from scratch, get the index tuple manipulation right, and remember the incremental sum update (current_sum - mat[row][col] + mat[row][col + 1]). A few reps at increasing intervals and the implementation becomes automatic.
Related posts
- Kth Largest Element - The foundational heap-based selection problem, using a min-heap of size k
- Merge K Sorted Lists - Same min-heap pattern applied to merging multiple sorted sources into one
- Kth Smallest Element in a Sorted Matrix - A related matrix problem that uses binary search on the value space instead of a heap