Skip to content
← All posts

Count Subtrees With Max Distance Between Cities: Bitmask Enumeration

6 min read
leetcodeproblemhardtreesbit-manipulationgraph

LeetCode Count Subtrees With Max Distance Between Cities (problem 1617) gives you n cities numbered 1 to n and an array of n - 1 edges that form a tree. You need to return an array answer where answer[i] is the number of subtrees in which the maximum distance between any two cities equals i + 1. A subtree is any connected subset of cities that includes all edges between selected cities from the original tree.

subtree {1,2,4}: diameter = 21234dist = 2
A tree of 4 cities. The highlighted subtree contains cities 1, 2, and 4 with a diameter (max distance) of 2.

Why this problem matters

This problem blends three important algorithmic ideas: bitmask enumeration, tree connectivity checks, and tree diameter computation. It teaches you to recognize when a brute-force approach is actually the right one. The constraint n is at most 15, which is a classic signal that bitmask enumeration over all 2^n subsets is feasible. Problems like this train you to spot small input bounds and reach for bit manipulation instead of trying to find a polynomial-time shortcut that may not exist.

The key insight

Since n is at most 15, you can enumerate all 2^n subsets of cities using bitmasks. For each subset, you check whether it forms a valid subtree by counting the edges between selected cities. A valid subtree with k nodes must have exactly k - 1 edges. If it does, the subset is connected and forms a tree. Then you compute the diameter using double BFS: run BFS from any node in the subset to find the farthest node, then run BFS again from that farthest node to find the actual diameter. Increment the count for that diameter value in your result array.

The solution

from collections import deque

def count_subtrees_with_max_distance(n, edges):
    adj = [[] for _ in range(n + 1)]
    for u, v in edges:
        adj[u].append(v)
        adj[v].append(u)

    answer = [0] * (n - 1)

    for mask in range(1, 1 << n):
        cities = []
        for i in range(n):
            if mask & (1 << i):
                cities.append(i + 1)

        if len(cities) < 2:
            continue

        city_set = set(cities)
        edge_count = 0
        for c in cities:
            for neighbor in adj[c]:
                if neighbor in city_set and neighbor > c:
                    edge_count += 1

        if edge_count != len(cities) - 1:
            continue

        def bfs(start):
            dist = {start: 0}
            queue = deque([start])
            farthest = start
            max_dist = 0
            while queue:
                node = queue.popleft()
                for neighbor in adj[node]:
                    if neighbor in city_set and neighbor not in dist:
                        dist[neighbor] = dist[node] + 1
                        queue.append(neighbor)
                        if dist[neighbor] > max_dist:
                            max_dist = dist[neighbor]
                            farthest = neighbor
            return farthest, max_dist

        far_node, _ = bfs(cities[0])
        _, diameter = bfs(far_node)

        answer[diameter - 1] += 1

    return answer

The solution iterates over every non-empty subset of cities using a bitmask from 1 to 2^n - 1. For each subset, it extracts which cities are included, skips subsets with fewer than 2 cities (since a single city has no meaningful distance), and counts the edges in the original tree that connect two cities both present in the subset. If the edge count equals k - 1 for k selected cities, the subset forms a valid tree.

Once you know the subset is a valid subtree, you compute its diameter with double BFS. The first BFS from any city in the subset finds the farthest city. The second BFS from that farthest city measures the actual diameter. This two-pass BFS trick works because the farthest node from any starting node in a tree is always one endpoint of the diameter.

The validity check (edge count equals nodes minus 1) works because in any graph, having exactly k - 1 edges among k nodes and being connected are equivalent when the edges come from a tree. You do not need a separate connectivity check like BFS or union-find. Just count the edges.

Visual walkthrough

Step 1: Subset {1, 2} (bitmask 0011)

1234

mask = 0011

nodes = 2, edges = 1

connected: yes

diameter = 1

Cities 1 and 2 are connected by a direct edge. This is a valid subtree with 2 nodes and 1 edge. The diameter is 1.

Step 2: Subset {1, 2, 3} (bitmask 0111)

1234

mask = 0111

nodes = 3, edges = 2

connected: yes

diameter = 2

Cities 1, 2, and 3 form a valid subtree with 3 nodes and 2 edges. The farthest pair is 2 to 3 (distance 2 through city 1).

Step 3: Subset {2, 3} (bitmask 0110)

1234

mask = 0110

nodes = 2, edges = 0

connected: no

not a valid subtree

Cities 2 and 3 are not directly connected (they connect only through city 1). This subset has 2 nodes but 0 internal edges, so it does not form a valid subtree. Skip it.

Step 4: Subset {1, 2, 4} (bitmask 1011)

1234

mask = 1011

nodes = 3, edges = 2

connected: yes

diameter = 2

Cities 1, 2, and 4 form a valid subtree (1-2 and 2-4 are edges). BFS from city 1 reaches city 4 at distance 2. BFS from city 4 confirms the diameter is 2.

Each step above examines a different subset. Notice how subset fails the validity check because cities 2 and 3 have no direct edge between them in the original tree. The only path from 2 to 3 goes through city 1, which is not in the subset. Valid subtrees like and both have diameter 2, so they both contribute to answer[1].

Complexity analysis

ApproachTimeSpace
Bitmask enumeration + BFSO(2^n * n)O(n)

You enumerate 2^n subsets. For each subset, counting edges and running BFS both take O(n) time in the worst case (since each subset has at most n cities and the adjacency list scan is bounded by the number of edges, which is n - 1). The total time is O(2^n * n). With n at most 15, this gives roughly 15 * 32768 = about 500,000 operations, which is very fast.

The space complexity is O(n) for the BFS queue, the distance dictionary, and the city set for each subset. The adjacency list uses O(n) space as well.

The building blocks

1. Bitmask subset enumeration

for mask in range(1, 1 << n):
    cities = []
    for i in range(n):
        if mask & (1 << i):
            cities.append(i + 1)

Each integer from 1 to 2^n - 1 represents a subset. Bit i being set means city i + 1 is included. This is the standard way to enumerate all subsets when n is small enough (typically n is at most 20). You extract the selected elements by checking each bit position.

2. Tree diameter via double BFS

far_node, _ = bfs(cities[0])
_, diameter = bfs(far_node)

The double BFS technique finds the diameter of a tree in two passes. Start BFS from any node and find the farthest node f. Then start BFS from f and the distance to the farthest node from f is the diameter. This works because in a tree, the farthest node from any node is always an endpoint of some longest path, and the farthest node from that endpoint gives the full diameter.

Edge cases

  • n = 2. There is exactly one edge and one possible subtree (both cities). The diameter is 1, so the answer is [1].
  • Star graph. All cities connect to a single hub. Every pair of non-hub cities has distance 2 through the hub, and every subtree including the hub and at least two leaves has diameter 2.
  • Path graph. Cities form a straight line. The number of subtrees with diameter d depends on how many contiguous segments of length d + 1 exist, plus non-contiguous but connected subsets.
  • Single city subsets. These are skipped since diameter is undefined (or zero) for a single node, and the problem only asks for distances 1 through n - 1.
  • Full tree selected. The mask (1 << n) - 1 selects all cities. Its diameter is the diameter of the entire tree.

From understanding to recall

The two ideas worth committing to memory are the bitmask enumeration trigger and the double BFS diameter trick. When you see n is at most 15, that is your cue to think about enumerating all 2^n subsets. When you need the diameter of a tree (or subtree), remember the two-pass BFS pattern: pick any node, BFS to find the farthest, BFS again from there.

Practice the validity check separately. The fact that k - 1 edges among k nodes guarantees a connected tree (when all edges come from a tree) is a useful shortcut that avoids writing a full connectivity check. Drill this on a few small examples until it feels obvious.

Related posts

Count Subtrees With Max Distance Between Cities combines bitmask enumeration with tree diameter computation in a clean and self-contained way. The small input bound makes brute force not just acceptable but optimal. Once you internalize the subset enumeration pattern and the double BFS trick, problems like this become a matter of assembling known pieces. CodeBricks helps you drill these building blocks with spaced repetition so they stay sharp when you need them.