Leetcode

Count Univalue Subtrees

  • Time:O(n)
  • Space:O(h)

C++

class Solution {
 public:
  int countUnivalSubtrees(TreeNode* root) {
    int ans = 0;
    isUnival(root, INT_MAX, ans);
    return ans;
  }

 private:
  bool isUnival(TreeNode* root, int val, int& ans) {
    if (!root)
      return true;

    if (isUnival(root->left, root->val, ans) &
        isUnival(root->right, root->val, ans)) {
      ++ans;
      return root->val == val;
    }

    return false;
  }
};

JAVA

class Solution {
  public int countUnivalSubtrees(TreeNode root) {
    isUnival(root, Integer.MAX_VALUE);
    return ans;
  }

  private int ans = 0;

  private boolean isUnival(TreeNode root, int val) {
    if (root == null)
      return true;

    if (isUnival(root.left, root.val) & isUnival(root.right, root.val)) {
      ++ans;
      return root.val == val;
    }

    return false;
  }
}

Python

class Solution:
  def countUnivalSubtrees(self, root: Optional[TreeNode]) -> int:
    ans = 0

    def isUnival(root: Optional[TreeNode], val: int) -> bool:
      nonlocal ans
      if not root:
        return True

      if isUnival(root.left, root.val) & isUnival(root.right, root.val):
        ans += 1
        return root.val == val

      return False

    isUnival(root, math.inf)
    return ans