Maximum Product of Splitted Binary Tree: Subtree Sum DFS
You are given a binary tree where each node holds a positive integer. Remove one edge to split the tree into two subtrees. The product of the sums of those two subtrees should be as large as possible. Return the maximum product modulo 10^9 + 7. This is LeetCode 1339: Maximum Product of Splitted Binary Tree, and it boils down to one clean observation: if you know the total sum and every subtree sum, you can evaluate every possible cut in constant time.
Why this problem matters
At first glance, it looks like you need to try removing every edge, recompute both sums from scratch, and track the best product. That would be O(n^2). But the relationship between the two parts is simple. If the total sum of the tree is S and you cut an edge producing a subtree with sum T, the other part has sum S - T. The product is T * (S - T). You never need to recompute the other side.
This insight turns the problem into a subtree-sum collection task. Once you have all the subtree sums, you just iterate through them and find the one that maximizes T * (S - T). That is a single pass over an array.
The pattern here, computing subtree sums via post-order DFS, is one of the most reusable building blocks in tree problems. You will see it again in problems like Binary Tree Maximum Path Sum, Sum of Distances in Tree, and Diameter of Binary Tree. Getting comfortable with this pattern pays off quickly.
The approach
The solution uses two passes over the tree:
Pass 1: Compute the total sum. Walk the entire tree and add up all node values. You can do this with a simple DFS or even a BFS. Store the result as total.
Pass 2: Collect subtree sums. Run a post-order DFS. At each node, compute the sum of its subtree (node value plus left subtree sum plus right subtree sum). Every subtree sum you compute corresponds to one possible edge cut. If you cut the edge above that subtree, the two parts have sums subtree_sum and total - subtree_sum.
After both passes, iterate through all collected subtree sums and find the maximum value of subtree_sum * (total - subtree_sum). Return that maximum modulo 10^9 + 7.
def max_product(root):
MOD = 10**9 + 7
subtree_sums = []
def dfs(node):
if node is None:
return 0
s = node.val + dfs(node.left) + dfs(node.right)
subtree_sums.append(s)
return s
total = dfs(root)
best = 0
for s in subtree_sums:
best = max(best, s * (total - s))
return best % MOD
You only need one DFS, not two. The first call to dfs(root) already returns the total sum (it is the subtree sum of the root). There is no need for a separate pass to compute the total. The trick is collecting all subtree sums during that single DFS, then doing the product comparison afterward.
Walking through the tree
Step 1: Compute the total sum of the tree.
First pass: walk the entire tree summing all node values. Total = 1 + 2 + 3 + 4 + 5 + 6 = 21.
Step 2: DFS computes subtree sums bottom-up.
Second pass: post-order DFS returns the subtree sum at each node. Node 2's subtree: 4 + 5 + 2 = 11. Node 3's subtree: 6 + 3 = 9.
Step 3: Try cutting the edge to node 2. Product = 11 x 10.
Cut the edge to node 2. Two parts: subtree_sum = 11, total - subtree_sum = 21 - 11 = 10. Product = 110. Best so far = 110.
Step 4: Try all other cuts and track the maximum.
Every edge is a candidate cut. Node 2: 11x10=110. Node 3: 9x12=108. Node 4: 4x17=68. Node 5: 5x16=80. Node 6: 6x15=90. Max = 110.
Step 5: Return 110 mod (10^9 + 7) = 110.
The maximum product is 110. Since the product can be very large in general, return it modulo 10^9 + 7. For this example, 110 % (10^9+7) = 110.
The key takeaway from the walkthrough is that every node's subtree sum represents one possible cut. You are not searching for the right edge to cut. You are computing all subtree sums and then picking the best one. The DFS does the heavy lifting, and the final loop is just a scan through a list.
Notice how the best cut (edge to node 2, product 110) wins over the cut to node 3 (product 108). This happens because 11 and 10 are closer to each other than 9 and 12. In general, the product T * (S - T) is maximized when T is as close to S / 2 as possible. You do not need to use this fact explicitly in the algorithm, but it gives you good intuition for why certain cuts win.
Here is the final clean solution:
def max_product(root):
MOD = 10**9 + 7
subtree_sums = []
def dfs(node):
if node is None:
return 0
s = node.val + dfs(node.left) + dfs(node.right)
subtree_sums.append(s)
return s
total = dfs(root)
best = 0
for s in subtree_sums:
best = max(best, s * (total - s))
return best % MOD
Complexity analysis
| Approach | Time | Space |
|---|---|---|
| Two-pass DFS (compute total, collect subtree sums) | O(n) | O(n) |
| Single-pass DFS (collect sums, then scan) | O(n) | O(n) |
Time is O(n). You visit every node exactly once during the DFS. The final loop over subtree sums is also O(n). Total work is O(n).
Space is O(n). You store one subtree sum per node, so the list has n entries. The recursion stack adds O(h) where h is the height of the tree, but O(n) dominates. In the worst case (skewed tree), both the list and the stack are O(n).
Edge cases to watch for
- Apply mod only at the end. The problem asks for the result modulo
10^9 + 7. Do not apply the mod during the product comparison. If you mod intermediate products, you might get a smaller number that incorrectly appears to be the max. Compute the true maximum product using regular integers, then mod once at the very end when returning. - Single child. A node with only one child still produces a valid subtree sum. The missing child contributes 0. The DFS handles this naturally because
dfs(None)returns 0. - Two-node tree. The only possible cut splits the tree into two single nodes. The product is simply
root.val * child.val. The algorithm handles this correctly. - All nodes have value 1. The total sum is n and the best cut splits the tree as evenly as possible. The algorithm still works since it tries every subtree sum.
- Large values. Node values can be up to 10^4 and the tree can have up to 50,000 nodes. The total sum can be up to 5 * 10^8, and the product can be up to roughly 6.25 * 10^16. This fits in a 64-bit integer (Python handles big integers natively), so overflow is not an issue in Python. In languages like Java or C++, you need
longfor the product computation.
The building blocks
This problem rests on two core patterns:
1. Post-order DFS to compute subtree sums. You process children before the current node, combine their results, and return the subtree sum upward. This is the same traversal order used in Diameter of Binary Tree and Binary Tree Maximum Path Sum. The skeleton is always the same: recurse left, recurse right, combine, return.
2. Decompose a global problem into local subtree decisions. Every edge in the tree is a candidate cut. Instead of evaluating each cut independently (which would be expensive), you collect all the information you need in a single traversal, then make the global decision afterward. This "collect then decide" pattern shows up whenever you need to compare all subtrees against each other or against the whole tree.
Together, these building blocks let you solve the problem in linear time. The post-order DFS collects subtree sums. The loop over those sums picks the maximum product. Two simple steps, one clean solution.
From understanding to recall
You just traced how a single post-order DFS collects every subtree sum, and how the formula T * (S - T) lets you evaluate any edge cut without recomputing. The mod-at-the-end detail and the "closer to half means bigger product" intuition are both fresh in your mind.
But two weeks from now, will you remember to collect subtree sums in a list instead of trying to evaluate cuts inline during the DFS? Will you remember to mod only at the return statement, not during the max comparison? These small details are easy to mix up under pressure.
Spaced repetition fixes this. You practice reconstructing the solution from scratch at increasing intervals. After a few repetitions, the pattern becomes automatic. You see "split tree, maximize product" and your hands type the post-order DFS without hesitation.
Related posts
- Binary Tree Maximum Path Sum - Another tree problem requiring DFS to track sums across subtrees
- Sum of Distances in Tree - A harder tree problem using similar subtree-sum techniques
- Diameter of Binary Tree - DFS-based tree traversal tracking subtree properties