From wiki, A binary search tree is a rooted binary tree, whose internal nodes each store a key (and optionally, an associated value) and each have two distinguished sub-trees, commonly denoted left and right. The tree additionally satisfies the binary search tree property, which states that the key in each node must be greater than all keys stored in the left sub-tree, and smaller than all keys in right sub-tree.[1] (The leaves (final nodes) of the tree contain no key and have no structure to distinguish them from one another. Leaves are commonly represented by a special leaf or nil symbol, a NULL pointer, etc.)
package test; public class TreeNode { int key; int height; int size; TreeNode left; TreeNode right; TreeNode parent; Object data = null; public TreeNode(final int key) { this.key = key; this.size = 1; this.height = 1; this.left = null; this.right = null; } public TreeNode(final int key, Object val) { this.key = key; this.size = 1; this.height = 1; this.left = null; this.right = null; this.data = val; } //insert an element - iterative public TreeNode insert(TreeNode root, int key, int val){ TreeNode node = new TreeNode(key, val); TreeNode parent = null; if(root == null){ return node; } while(root != null){ parent = root; if(key < root.key){ root = root.left; } else{ root = root.right; } parent.size++; } if(key < parent.key){ parent.left = node; node.parent = parent; } else{ parent.right = node; node.parent = parent; } return node; } //insert an element - recursive public static TreeNode insert2(TreeNode root, int key, int val){ if(root == null){ return new TreeNode(key, val); } if(root.key > key){ root.left = insert2(root.left, key, val); } else if(root.key < key){ root.right = insert2(root.right, key, val); } else{ root.data = val; } root.size = 1+size(root.left)+size(root.right); return root; } //max key element of the tree public static TreeNode max(TreeNode root){ if(root == null){ return null; } while(root.right != null){ root = root.right; } return root; } //min key element of the tree public static TreeNode min(TreeNode root){ if(root == null){ return null; } while(root.left != null){ root = root.left; } return root; } //inorder successor using parent node public static TreeNode successor(TreeNode node){ if(node.right != null){ return min(node.right); } else{ TreeNode parent = node.parent; while(parent != null){ //without using key comparison -- only using left, right pointer //if(node == parent.right){ // node = parent; //} //else break; if(parent.key > node.key){ break; } parent = parent.parent; } return parent; } } //inorder successor without using parent node public static TreeNode successor2(TreeNode root, TreeNode node){ if(node.right != null){ return min(node.right); } else{ TreeNode successor = null; while(root != null){ if(root.key > node.key){ successor = root; root = root.left; } else if(root.key < node.key){ root = root.right; } else{ break; } } return successor; } } //inorder predecessor using parent node public static TreeNode predecessor(TreeNode node){ if(node.left != null){ return max(node.left); } else{ TreeNode parent = node.parent; while(parent != null){ //without using key comparison -- only using left, right pointer //if(node == parent.left){ // node = parent; //} //else break; if(parent.key < node.key){ break; } parent = parent.parent; } return parent; } } //inorder predecessor without using parent node public static TreeNode predecessor2(TreeNode root, TreeNode node){ if(node.left != null){ return max(node.left); } else{ TreeNode pred = null; while(root != null){ if(root.key > node.key){ root = root.left; } else if(root.key < node.key){ pred = root; root = root.right; } else{ break; } } return pred; } } //delete without using parent public static TreeNode delete(TreeNode root, int key){ if(root == null){ return root; } if(key < root.key){ root.left = delete(root.left, key); } else if(key > root.key){ root.right = delete(root.right, key); } else{ if(root.left == null){ TreeNode temp = root.right; root = null; return temp; } else if(root.right == null){ TreeNode temp = root.left; root = null; return temp; } TreeNode successor = min(root.right); root.key = successor.key; root.right = delete(root.right, successor.key); } root.size = size(root.left)+size(root.right)+1; return root; } //delete using parent public static void delete2(TreeNode node){ if(node == null){ return; } if(node.left == null && node.right == null){ if(node == node.parent.left){ node.parent.left = null; } else{ node.parent.right = null; } node = null; } else if(node.left == null || node.right == null){ TreeNode parent = node.parent; node = node.left == null ? node.right : node.left; node.parent = parent; } else{ TreeNode successor = successor(node); node.key = successor.key; delete2(successor); } } private static int size(TreeNode node){ return node == null ? 0 : node.size; } //largest key less than equal to given key public static TreeNode floor(TreeNode root, int key){ if(root == null){ return root; } if(root.key > key){ return floor(root.left, key); } else if(root.key < key){ TreeNode floor = floor(root.right, key); if(floor == null){ return root; } else{ return floor; } } else{ return root; } } //smallest key greater than equal to given key public static TreeNode ceiling(TreeNode root, int key){ if(root == null){ return root; } if(root.key < key){ return ceiling(root.right, key); } else if(root.key > key){ TreeNode floor = ceiling(root.left, key); if(floor == null){ return root; } else{ return floor; } } else{ return root; } } //select kth smallest element in the BST public static TreeNode select(TreeNode root, int k){ if(root == null){ return root; } int n = size(root); if(n > k){ return select(root.left, k); } else if(n < k){ return select(root.right, k-n-1); } else{ return root; } } //rank of a given key : number of nodes in the subtree less than the key public static int rank(TreeNode root, int key){ if(root == null){ return 0; } if(root.key > key){ return rank(root.left, key); } else if(root.key < key){ return 1+size(root.left)+rank(root.right, key); } else{ return size(root.left); } } //check if a tree is a BST public static boolean isBST(TreeNode node){ return isBST(node, Integer.MAX_VALUE, Integer.MIN_VALUE); } private static boolean isBST(TreeNode node, int max, int min){ if(node == null){ return true; } if(node.key >= max || node.key <= min){ return false; } return isBST(node.left, node.key, min) && isBST(node.right, max, node.key); } //height of the subtree rooted at given node public static int height(TreeNode node){ if(node == null){ return -1; } return 1+Math.max(height(node.left), height(node.right)); } //binary search -- iterative public static TreeNode search(TreeNode root, int key){ if(root == null){ return null; } while(root != null){ if(root.key == key){ return root; } else if(root.key > key){ root = root.left; } else{ root = root.right; } } return root; } //binary search -- recursive public static TreeNode search2(TreeNode root, int key){ if(root == null){ return root; } if(root.key > key){ return search(root.left, key); } else if (root.key > key){ return search(root.right, key); } else{ return root; } } public static void PrintTreeInorder(TreeNode root){ if(root == null){ return; } PrintTreeInorder(root.left); System.out.print(" "+root.key); PrintTreeInorder(root.right); } }
How to find kth smallest without changing the original data structure?
We can do inorder traversal (recursive or iterative) and each time we find a root we decrease count that start from k. When we reach k=0, we have the kth smallest element.
public int kthSmallest(TreeNode root, int k) { TreeNode kth = MorrisInorderTraversal(root, k); return kth != null ? kth.val : -1; } public static TreeNode MorrisInorderTraversal(TreeNode root, int k){ if(root == null){ return null; } TreeNode cur = root; TreeNode pre = null; while(cur != null){ //if no left subtree the visit right subtree right away after printing current node if(cur.left == null){ k--; if(k == 0){ return cur; } cur = cur.right; } else{ //otherwise we will traverse the left subtree and come back to current //node by using threaded pointer from predecessor of current node //first find the predecessor of cur pre = cur.left; while(pre.right != null && pre.right != cur){ pre = pre.right; } //threaded pointer not added - add it and go to left subtree to traverse if(pre.right == null){ pre.right = cur; cur = cur.left; } else{ //we traversed left subtree through threaded pointer and reached cur again //so revert the threaded pointer and print out current node before traversing right subtree pre.right = null; k--; if(k == 0){ return cur; } //now traverse right subtree cur = cur.right; } } } return null; }