563 Binary Tree Tilt

Given a binary tree, return the tilt of thewhole tree.

The tilt of atree nodeis defined as theabsolute differencebetween the sum of all left subtree node values and the sum of all right subtree node values. Null node has tilt 0.

The tilt of thewhole treeis defined as the sum of all nodes' tilt.

Example:

Input:

         1
       /   \
      2     3

Output:
 1

Explanation:

Tilt of node 2 : 0
Tilt of node 3 : 0
Tilt of node 1 : |2-3| = 1
Tilt of binary tree : 0 + 0 + 1 = 1

Note:

  1. The sum of node values in any subtree won't exceed the range of 32-bit integer.

  2. All the tilt values won't exceed the range of 32-bit integer.

Performance: O(N) time and space

The Idea: Because summations are have the property to be associative, we can use dynamic programming to first build a accumulation tree. With accumulation trees, every node is the summation of itself and its left and right subtrees. Using post order traversal, we can achieve this is O(N) time. Next, iterating through this accumation tree, we do another post order traversal that accumulations the absolute difference between the left and right subtrees. This also runs in O(N) time because the accumulation tree provides the summation of the left and right subtrees.

# Definition for a binary tree node.
# class TreeNode(object):
#     def __init__(self, x):
#         self.val = x
#         self.left = None
#         self.right = None

class Solution(object):


    def build_cum_tree(self, node):
        """
        :type node: TreeNode
        :rtype: void
        """
        # perform a postorder traversal through
        # the tree and accumulate nodes
        if (node != None):
            self.build_cum_tree(node.left)
            self.build_cum_tree(node.right)
            if (node.left != None):
                node.val += node.left.val
            if (node.right != None):
                node.val += node.right.val

    def _findTilt(self, node, accumulator):
        """
        :type root: TreeNode
        :rtype: void
        """
        if (node != None):
            self._findTilt(node.left, accumulator)
            self._findTilt(node.right, accumulator)
            temp_sum = 0
            if(node.left != None):
                temp_sum += node.left.val
            if(node.right != None):
                temp_sum -= node.right.val

            accumulator[0] += abs(temp_sum)


    def findTilt(self, root):
        """
        :type root: TreeNode
        :rtype: int
        """
        # modify tree into dp accumulated tree
        self.build_cum_tree(root)

        # now traverse through tree again and 
        # to find |left_sub - right_sub|
        #pre_order(root)
        #print()
        accumulator = [0]
        self._findTilt(root, accumulator)
        return accumulator[0]

Some helpful debugging code:

class TreeNode:
    def __init__(self, x):
        self.val = x
        self.left = None
        self.right = None

class BTree:
    def __init__(self):
        self.root = None

    def getRoot(self):
        return self.root

    def add(self, val):
        if(self.root == None):
            self.root = TreeNode(val)
        else:
            self._add(val, self.root)

    def _add(self, val, node):
        if(val < node.val):
            if(node.left != None):
                self._add(val, node.left)
            else:
                node.left = TreeNode(val)
        else:
            if(node.right != None):
                self._add(val, node.right)
            else:
                node.right = TreeNode(val)

    def find(self, val):
        if(self.root != None):
            return self._find(val, self.root)
        else:
            return None

    def _find(self, val, node):
        if(val == node.val):
            return node
        elif(val < node.val and node.left != None):
            self._find(val, node.l)
        elif(val > node.val and node.right != None):
            self._find(val, node.right)

    def deleteTree(self):
        # garbage collector will do this for us. 
        self.root = None

    def printTree(self):
        if(self.root != None):
            self._printTree(self.root)

    def _printTree(self, node):
        if(node != None):
            self._printTree(node.left)
            print(str(node.val) + ' ')
            self._printTree(node.right)


def pre_order(root):
    """
    :type root: TreeNode
    :rtype: void
    """
    if (root != None):
        pre_order(root.left)
        print(str(root.val) + ' ', end='', flush=True )
        pre_order(root.right)


if __name__ == '__main__':

    #     3
    # 0     4
    #   2      8
    tree = BTree()
    tree.add(3)
    tree.add(4)
    tree.add(0)
    tree.add(8)
    tree.add(2)
    tree.add(10)
    tree.add(-2)
    tree.printTree()

    print()

Last updated