## Two Sum IV - Input is a BST

• Time:O(n)
• Space:O(h) \to O(n)

## C++

class BSTIterator {
public:
BSTIterator(TreeNode* root, bool leftToRight) : leftToRight(leftToRight) {
pushUntilNull(root);
}

int next() {
TreeNode* root = stack.top();
stack.pop();
pushUntilNull(leftToRight ? root->right : root->left);
return root->val;
}

private:
stack<TreeNode*> stack;
bool leftToRight;

void pushUntilNull(TreeNode* root) {
while (root) {
stack.push(root);
root = leftToRight ? root->left : root->right;
}
}
};

class Solution {
public:
bool findTarget(TreeNode* root, int k) {
if (!root)
return false;

BSTIterator left(root, true);
BSTIterator right(root, false);

for (int l = left.next(), r = right.next(); l < r;) {
const int sum = l + r;
if (sum == k)
return true;
if (sum < k)
l = left.next();
else
r = right.next();
}

return false;
}
};


## JAVA

class BSTIterator {
public BSTIterator(TreeNode root, boolean leftToRight) {
this.leftToRight = leftToRight;
pushLeftsUntilNull(root);
}

public int next() {
TreeNode root = stack.pop();
pushLeftsUntilNull(leftToRight ? root.right : root.left);
return root.val;
}

public boolean hasNext() {
return !stack.isEmpty();
}

private Deque<TreeNode> stack = new ArrayDeque<>();
private boolean leftToRight;

private void pushLeftsUntilNull(TreeNode root) {
while (root != null) {
stack.push(root);
root = leftToRight ? root.left : root.right;
}
}
}

class Solution {
public boolean findTarget(TreeNode root, int k) {
if (root == null)
return false;

BSTIterator left = new BSTIterator(root, true);
BSTIterator right = new BSTIterator(root, false);

for (int l = left.next(), r = right.next(); l < r;) {
final int sum = l + r;
if (sum == k)
return true;
if (sum < k)
l = left.next();
else
r = right.next();
}

return false;
}
}


## Python

class BSTIterator:
def __init__(self, root: Optional[TreeNode], leftToRight: bool):
self.stack = []
self.leftToRight = leftToRight
self.pushUntilNone(root)

def next(self) -> int:
node = self.stack.pop()
if self.leftToRight:
self.pushUntilNone(node.right)
else:
self.pushUntilNone(node.left)
return node.val

def pushUntilNone(self, root: Optional[TreeNode]):
while root:
self.stack.append(root)
root = root.left if self.leftToRight else root.right

class Solution:
def findTarget(self, root: Optional[TreeNode], k: int) -> bool:
if not root:
return False

left = BSTIterator(root, True)
right = BSTIterator(root, False)

l = left.next()
r = right.next()

while l < r:
summ = l + r
if summ == k:
return True
if summ < k:
l = left.next()
else:
r = right.next()

return False