Skip to main content

All Nodes Distance K in Binary Tree

Question

Given a binary tree, find the distance of all nodes from an input node at a distance of K levels.

Example 1
None

Solution

all//All Nodes Distance K in Binary Tree.py
import collections

# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def distanceK(self, root, target, K):
# Step 1: Construct a graph
graph = collections.defaultdict(list)
def build_graph(node):
if node.left:
graph[node].append(node.left)
graph[node.left].append(node)
build_graph(node.left)
if node.right:
graph[node].append(node.right)
graph[node.right].append(node)
build_graph(node.right)
build_graph(root)

# Step 2: BFS from target node
queue = collections.deque([(target, 0)])
visited = {target}
ans = []
while queue:
if queue[0][1] == K:
ans.append(queue[0][0].val)
node, dist = queue.popleft()
for nei in graph[node]:
if nei not in visited:
visited.add(nei)
queue.append((nei, dist + 1))
return ans