Reference: LeetCode
Difficulty: Easy

Problem

You are given a binary tree in which each node contains an integer value. Find the number of paths that sum to a given value.

The path does not need to start or end at the root or a leaf, but it must go downwards (traveling only from parent nodes to child nodes).

The tree has no more than 1,000 nodes and the values are in the range -1,000,000 to 1,000,000.

Example:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
root = [10,5,-3,3,2,null,11,3,-2,null,1], sum = 8

10
/ \
5 -3
/ \ \
3 2 11
/ \ \
3 -2 1

Return 3. The paths that sum to 8 are:

1. 5 -> 3
2. 5 -> 2 -> 1
3. -3 -> 11

Follow up:

  • Reduce the time complexity.

Analysis

By fudonglai:
写递归的技巧是:明白一个函数的作用并相信它能完成这个任务,千万不要跳进这个函数里面企图探究更多细节,否则就会陷入无穷的细节无法自拔。你就算浑身是铁,能压几个栈?more

按照前面说的技巧,先来定义清楚每个递归函数应该做的事:

pathSum 函数:给他一个节点和一个目标值,他返回以这个节点为根的树中,和为目标值的路径总数。
count 函数:给他一个节点和一个目标值,他返回以这个节点为根的树中,能凑出几个以该节点为路径开头,和为目标值的路径总数。

Methods:

  1. Typical Recursion
    • First, we should define each recursive function’s job.
    • pathSum(x, target): Return the number of paths that satisfy target requirement in a tree whose root is x.
    • count(x, target): Return the number of paths that satisfy target requirement in a tree whose root must be x, the head of those paths.
    • Time: 11 ms
      • In each level in tree, as you go down, $N$ is halved, and finally you call count(x, target) for the leaves whose cost is $T(1)$ (not all the same $N$). Each layer does $O(N)$, and in the best case there are $\log{N}$ layers.
      • According to the master theorem,
        • pathSum(): $T(N) = 2T(N/2) + O(N) = O(N\log{N})$
        • count(): $T(N) = 2T(N/2) + O(1)$
      • Best: $O(N\log{N})$
      • Worst: $O(N^2)$
    • Space: $O(h)$
  2. HashMap + Prefix Sum (tankztc & kekezi)
    • Time: $O(N)$ 4 ms
    • Space: $O(N)$
    • The idea is similar as Two Sum, using HashMap to store keys (the prefix sum) and values (how many ways to get to this prefix sum).
    • The prefix sum stores the sum from the root to the current node in the recursion.
    • Whenever we reach a node, we check if prefix sum - target exists in the HashMap.
      • If it does, we added up the ways of prefix sum - target into result.
      • For example:
        • In one path we have: [1, 2, -1, -1, 2],
        • then the prefix sum will be: [1, 3, 2, 1, 3].
        • Let’s say we want to find target sum $2$, then we have $4$ paths which are 2, 1,2,-1, 2,-1,-1,2, and 2.
  • The sum from any node in the middle of the path to the current node
    • $=$ (the sum from the root to the current node) $-$ (the prefix sum of the node in the middle)
  • We want the difference above equal to the target value. In addition, we need to know how many differences are equal to the target value.
  • Use HashMap. The value of map stores the frequency of all possible sum in the path to the current node. If the difference we want exists, there must exist a node in the middle of the path, such that from this node to the current node, the sum is equal to the target value.
  • There might be multiple nodes in the middle that satisfy the requirement. In each recursion, the map stores all information we need to calculate the number of ranges that sum up to target (start from a middle node and end by the current node).
  • To get the total number of a path count, we add up the number of valid paths ended by each node in the tree.
  • Each recursion returns the total count of valid paths in the subtree rooted at the current node. And this sum can be divided into three parts:
    • The total number of valid paths in the subtree rooted at the current node’s left child.
    • The total number of valid paths in the subtree rooted at the current node’s right child.
    • The number of valid paths ended by the current node.
  • The interesting part of this solution is that the prefix is counted from root to leaves, and the result of total count is calculated from the bottom to the top.

Code

Typical Recursion

Incorrect Version

First, let’s examine an incorrect version:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
public int pathSum(TreeNode root, int sum) {
return pathSum(root, sum, sum);
}

private int pathSum(TreeNode root, int sum, int oldSum) {
if (root == null) {
return 0;
}

// root
int r0 = root.val == sum ? 1 : 0;
// left
int r1 = pathSum(root.left, sum - root.val, oldSum);
int r2 = pathSum(root.left, oldSum, oldSum); // start from left
// right
int r3 = pathSum(root.right, sum - root.val, oldSum);
int r4 = pathSum(root.right, oldSum, oldSum);

return r0 + r1 + r2 + r3 + r4;
}

Example:

1
2
3
4
5
6
7
Input = [1,null,2,null,3]
Output = 2 (the code outputs 3)
1
\
2
\
3

Call Stack:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
p(1, 3, 3)
r0=0, r1=0, r2=0
r3 = p(2, 2, 3) // continue with node 1
r0=1, r1=0, r2=0
r3 = p(3, 0, 3)
r0=0, r1=0, r2=0, r3=0
= 0
r4 = p(3, 3, 3) // here
r0=1, r1=0, r2=0, r3=0
= 1
= r0 + r4 = 2
r4 = p(2, 3, 3) // start over at node 2
r0=0, r1=0, r2=0
r3 = p(3, 1, 3)
r0=0,r1=0,r2=0,r3=0
= 0
r4 = p(3, 3, 3) // here
r0=1, r1=0, r2=0, r3=0
= 1
= r4 = 1
r3 + r4 = 3
= 3

We can see that there are repeated calculations.

  • p(3, 3, 3) called by p(2, 2, 3) called by p(1, 3, 3).
  • p(3, 3, 3) called by p(2, 3, 3) called by p(1, 3, 3).

Correct Version

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
// pathSum(): How many paths in a tree whose head is root satisfy?
public int pathSum(TreeNode root, int sum) {
if (root == null) {
return 0;
}
int rootNum = count(root, sum); // root as a leading head of any paths
int leftNum = pathSum(root.left, sum); // don't need to know the details
int rightNum = pathSum(root.right, sum);
return rootNum + leftNum + rightNum;
}

// count(): How many paths whose head is root satisfy?
private int count(TreeNode root, int sum) {
if (root == null) {
return 0;
}
int rootNum = (root.val == sum) ? 1 : 0;
int leftNum = count(root.left, sum - root.val);
int rightNum = count(root.right, sum - root.val);
return rootNum + leftNum + rightNum;
}

HashMap + Prefix Sum

Note:

  • I am dead.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
public int pathSum(TreeNode root, int sum) {
if (root == null) {
return 0;
}
HashMap<Integer, Integer> map = new HashMap<>();
map.put(0, 1);
return findPathSum(root, 0, sum, map);
}

private int findPathSum(TreeNode curr, int sum, int target, Map<Integer, Integer> map) {
if (curr == null) {
return 0;
}
// update the prefix sum by adding the current value
sum += curr.val;
// get the number of valid paths, ended by the current node
int numPathToCurr = map.getOrDefault(sum - target, 0);
// update the map with the current sum, so the map is good to be passed to the next recursion
map.put(sum, map.getOrDefault(sum, 0) + 1); // prefix sum
// add the 3 parts
int res = numPathToCurr + findPathSum(curr.left, sum, target, map) + findPathSum(curr.right, sum, target, map);
// restore the map, as the recursion goes from the bottom to the top
map.put(sum, map.get(sum) - 1);
return res;
}