• Time:O(n)
• Space:O(h)

## C++

``````class Solution {
public:
int diameterOfBinaryTree(TreeNode* root) {
int ans = 0;
maxDepth(root, ans);
return ans;
}

private:
int maxDepth(TreeNode* root, int& ans) {
if (!root)
return 0;

const int l = maxDepth(root->left, ans);
const int r = maxDepth(root->right, ans);
ans = max(ans, l + r);
return 1 + max(l, r);
}
};
``````

## JAVA

``````class Solution {
public int diameterOfBinaryTree(TreeNode root) {
maxDepth(root);
return ans;
}

private int ans = 0;

int maxDepth(TreeNode root) {
if (root == null)
return 0;

final int l = maxDepth(root.left);
final int r = maxDepth(root.right);
ans = Math.max(ans, l + r);
return 1 + Math.max(l, r);
}
}
``````

## Python

``````class Solution:
def diameterOfBinaryTree(self, root: Optional[TreeNode]) -> int:
ans = 0

def maxDepth(root: Optional[TreeNode]) -> int:
nonlocal ans
if not root:
return 0

l = maxDepth(root.left)
r = maxDepth(root.right)
ans = max(ans, l + r)
return 1 + max(l, r)

maxDepth(root)
return ans
``````