Sort the Matrix Diagonally: Diagonal-by-Diagonal Sorting
You are given an m x n matrix of integers. Your job is to sort each diagonal of the matrix in non-decreasing order (from top-left to bottom-right) and return the resulting matrix. A diagonal is defined as a set of cells that share the same top-left to bottom-right direction.
This is LeetCode 1329: Sort the Matrix Diagonally.
Why this problem matters
Matrix diagonal problems test whether you can identify the hidden structure inside a 2D grid. At first glance, extracting every diagonal seems like it would require a mess of nested loops and special-case logic. But there is a clean mathematical property that makes the whole thing fall into place.
This problem also teaches you to think about grouping. Many matrix problems become simpler when you stop treating the grid as rows and columns and start grouping cells by some shared property. Here, that property is the diagonal index. The same grouping idea shows up in problems about anti-diagonals, zigzag traversals, and diagonal printing.
Finally, this is a good warmup for harder matrix transformation problems. Once you are comfortable extracting, modifying, and reinserting diagonal slices, you will find problems like rotating images or transposing matrices feel more approachable.
The key insight
Every cell on the same top-left-to-bottom-right diagonal shares the same value of row - col. If you are at position (0, 0), the key is 0 - 0 = 0. At (1, 1), it is 1 - 1 = 0. At (0, 1), it is 0 - 1 = -1. Cells with the same key sit on the same diagonal.
This gives you a direct grouping strategy:
- Walk through every cell in the matrix. For each cell, compute
row - coland append the cell's value to a list for that key. - Sort each list.
- Walk through the matrix again in the same order. For each cell, pop the next value from the sorted list for its key and place it back into the matrix.
The beauty of this approach is that you do not need to figure out where each diagonal starts or how long it is. The row - col key handles all of that automatically.
The solution
from collections import defaultdict
def diagonal_sort(mat: list[list[int]]) -> list[list[int]]:
m, n = len(mat), len(mat[0])
diags = defaultdict(list)
for r in range(m):
for c in range(n):
diags[r - c].append(mat[r][c])
for key in diags:
diags[key].sort()
idx = defaultdict(int)
for r in range(m):
for c in range(n):
key = r - c
mat[r][c] = diags[key][idx[key]]
idx[key] += 1
return mat
The first pass groups every cell's value by its diagonal key (r - c). The second pass sorts each group. The third pass writes the sorted values back into the matrix, using an index counter per diagonal to track which sorted value comes next.
Because we iterate in row-major order both times (top to bottom, left to right), the values are collected and placed back in the correct diagonal order. The top-left cell of each diagonal is visited first, so it receives the smallest sorted value.
You can simplify the write-back step by using collections.deque and calling popleft() instead of maintaining a separate index counter. Both approaches work, but the index counter avoids the overhead of creating deque objects.
Visual walkthrough
Step 1: Identify all diagonals using row - col
Each unique (row - col) value defines a diagonal. Here we have 6 diagonals: keys -3, -2, -1, 0, 1, 2.
Step 2: Sort each diagonal group independently
key 0: [3,2,1] becomes [1,2,3]. key -1: [3,1,2] becomes [1,2,3]. key 1: [2,1] becomes [1,2].
Step 3: Place sorted values back into original positions
Walk each diagonal again and fill in the sorted values from left to right, top to bottom.
Step 4: Verify each diagonal is sorted
Every diagonal now reads in non-decreasing order from top-left to bottom-right.
The walkthrough shows the full process on a 3x4 matrix. Each diagonal is identified by its row - col key, sorted independently, and then placed back. Notice how the rest of the matrix (cells not on the highlighted diagonal) stays untouched during each diagonal's sort.
Complexity analysis
| Approach | Time | Space |
|---|---|---|
| Diagonal sorting | O(m * n * log(min(m, n))) | O(m * n) |
Time: You visit every cell twice (once to collect, once to place back), which is O(m * n). Sorting each diagonal takes O(d * log(d)) where d is the diagonal length. The longest diagonal has min(m, n) elements, and there are m + n - 1 diagonals total. The sum of all diagonal lengths is m * n, so the total sorting work across all diagonals is O(m * n * log(min(m, n))).
Space: The dictionary holding all diagonal groups stores m * n values total. The index counter dictionary uses O(m + n) entries. Overall, the space usage is O(m * n).
The building blocks
1. Diagonal grouping by row - col
The core trick is recognizing that row - col is constant along any top-left-to-bottom-right diagonal. This same property is useful in many other matrix problems. For anti-diagonals (top-right to bottom-left), the constant is row + col instead. Knowing both of these grouping keys gives you quick access to any diagonal slice of a matrix.
2. Collect, transform, and scatter
The three-pass pattern (collect values into groups, transform the groups, scatter them back) is a general technique that works beyond matrices. You will see it in problems that ask you to rearrange elements within specific subsets of an array, sort rows or columns independently, or apply any per-group operation to a structured collection.
Edge cases
- Single row or single column: Every cell is on its own diagonal (length 1). Sorting single-element lists does nothing, so the matrix is returned unchanged. Your code handles this automatically.
- 1x1 matrix: A single cell is trivially sorted. No special handling needed.
- Already sorted diagonals: If every diagonal is already in non-decreasing order, the algorithm still runs but produces the same matrix. There is no early exit optimization worth adding for this case.
- All identical values: If every cell holds the same value, every diagonal sorts to the same sequence. The output equals the input.
- Large matrix with small value range: The algorithm's time complexity depends on the diagonal length, not the value range. Even if all values are between 1 and 100, the sort still takes O(d * log(d)) per diagonal.
From understanding to recall
You now understand why row - col groups cells by diagonal and how the collect-sort-scatter pattern solves this problem cleanly. But understanding is not the same as fluency. In an interview, you need to write this from memory without hesitating over which key to use or whether to iterate in row-major or column-major order.
Spaced repetition bridges that gap. You practice writing the diagonal grouping logic, the sort step, and the scatter step from scratch. You do it today, again in a few days, and again a week later. After a handful of reps, the row - col insight and the three-pass structure become automatic. You see "sort diagonals" and the solution appears on your screen without effort.
Related posts
- Spiral Matrix - Matrix traversal patterns
- Rotate Image - Matrix transformation
- Diagonal Traverse - Diagonal traversal
CodeBricks breaks Sort the Matrix Diagonally into its diagonal-grouping and collect-sort-scatter building blocks, then drills them independently with spaced repetition. You practice writing the row - col key, the grouping pass, and the scatter pass from memory until the pattern is automatic. When a diagonal problem shows up in your interview, you do not think about it. You just write it.