Type like pro

Index

Find kth smallest element in a binary search tree

Find kth smallest element in a Binary search tree

Problem


Given a binary search tree find the kth smallest element. 

Solution


We will augment the binary search tree by storing the weight of left subtree rooted at a node. So any node will keep an extra count which is equal to the number of nodes in its left subtree. This augmentation is called order statistics tree. With the help of this augmentation we can find the kth smallest element in O(log n) expected complexity for a balanced binary search tree. Suppose we try to find the 6th smallest element, we start at root. If root has leftWeight value 3, that means there are only 3 elements that are smaller than root. So 6th smallest element cannot be on the left side of root. So we try to find the element in right subtree. While going to right subtree we now try to find 6-4=2nd smallest element, because we already had 3 smaller element in root's left subtree and root itself is smaller than the right subtree. So we call the recursive function on root.right. If the value of k is less than the leftWeight then we just go to the left subtree with the value k.


Code

public class KthLargestOnline
{
 public static void main(String[] args)
 {
  BST bst = new BST();
  int[] arr =
  { 12, 4, 5, 6, 2, 7, 8, 11, 2, 3 };
  for (int num : arr)
   bst.add(num);
  System.out.println(bst.getOrdered(4));
  arr = new int[]
  { 12, 1, 9, 14, 25 };
  for (int num : arr)
   bst.add(num);
  System.out.println(bst.getOrdered(6));

 }

 private static class BST
 {
  Node root;

  public void add(int num)
  {
   if (root == null)
   {
    root = new Node(num);
   } else
    add(root, num);
  }

  private void add(Node root, int num)
  {
   Node node = new Node(num);
   if (node.value < root.value)
   {
    root.leftWeight++;
    if (root.left == null)
     root.left = node;
    else
     add(root.left, num);
   } else
   {
    if (root.right == null)
     root.right = node;
    else
     add(root.right, num);
   }
  }

  public int getOrdered(int k)
  {
   return getOrdered(root, k);
  }

  private Integer getOrdered(Node root, int k)
  {
   if (root == null)
    return null;
   if (root.leftWeight > k)
   {
    return getOrdered(root.left, k);
   } else if (root.leftWeight < k)
   {
    return getOrdered(root.right, k - root.leftWeight);
   } else
   {
    return root.value;
   }
  }
 }

 private static class Node
 {
  int value;
  int leftWeight;
  Node left;
  Node right;

  public Node(int value)
  {
   this.value = value;
   this.leftWeight = 1;
  }
 }

}