366 Find Leaves of Binary Tree

Given a binary tree, collect a tree's nodes as if you were doing this: Collect and remove all leaves, repeat until the tree is empty.
Example:
Given binary tree
1
/ \
2 3
/ \
4 5
Returns [4, 5, 3], [2], [1].
Explanation:
Removing the leaves [4, 5, 3] would result in this tree:
1
/
2
Now removing the leaf [2] would result in this tree:
1
Now removing the leaf [1] would result in the empty tree:
[]
Returns [4, 5, 3], [2], [1].
The Idea: The height of any node is the longest path from the node to any leaf. Notice how we can use the height of the tree as a direct index to the solution. In other words, grouping and sorting the heights of the tree reveals the solution. Alternatively, we can map all common nodes to a particular height. Then we can iterate through all the heights in order to reveal the solution.
Complexity: O(n) time and space
def findLeaves(self, root):
"""
:type root: TreeNode
:rtype: List[List[int]]
"""
heights = collections.defaultdict(list)
def dfs(root):
if not root:
return -1
height = 1 + max(dfs(root.left), dfs(root.right))
heights[height].append(root.val)
return height
if not root:
return []
dfs(root)
return [heights[lvl] for lvl in range(0, len(heights))]