• Time:O(n)
• Space:O(h)

C++

``````class Solution {
public:
int averageOfSubtree(TreeNode* root) {
int ans = 0;
dfs(root, ans);
return ans;
}

private:
pair<int, int> dfs(TreeNode* root, int& ans) {
if (!root)
return {0, 0};
const auto [leftSum, leftCount] = dfs(root->left, ans);
const auto [rightSum, rightCount] = dfs(root->right, ans);
const int sum = root->val + leftSum + rightSum;
const int count = 1 + leftCount + rightCount;
if (sum / count == root->val)
++ans;
return {sum, count};
}
};
``````

JAVA

``````class Solution {
public int averageOfSubtree(TreeNode root) {
dfs(root);
return ans;
}

private int ans = 0;

private Pair<Integer, Integer> dfs(TreeNode root) {
if (root == null)
return new Pair<>(0, 0);
var left = dfs(root.left);
var right = dfs(root.right);
final int sum = root.val + left.getKey() + right.getKey();
final int count = 1 + left.getValue() + right.getValue();
if (sum / count == root.val)
++ans;
return new Pair<>(sum, count);
}
}
``````

Python

``````class Solution:
def averageOfSubtree(self, root: Optional[TreeNode]) -> int:
ans = 0

def dfs(root: Optional[TreeNode]) -> Tuple[int, int]:
nonlocal ans
if not root:
return (0, 0)
leftSum, leftCount = dfs(root.left)
rightSum, rightCount = dfs(root.right)
summ = root.val + leftSum + rightSum
count = 1 + leftCount + rightCount
if summ // count == root.val:
ans += 1
return (summ, count)

dfs(root)
return ans
``````