Site Search:

BST.java


import java.util.LinkedList;
import java.util.Queue;

//space: 56N. time: insert(search) worst. O(N) avg. O(logN)
public class BST<Key extends Comparable<Key>, Value> {
    private Node root;
    private class Node {
        private Key key;
        private Value val;
        private Node left;
        private Node right;
        private int N;
        public Node(Key key, Value val, int N) {
            this.key = key;
            this.val = val;
            this.N = N;
        }
    }
    
    public int size() {
        return size(root);
    }
    
    private int size(Node x) {
        if(x == null) return 0;
        else return x.N;
    }
    
    public Value get(Key key) {
        return get(root, key);
    }
    
    private Value get(Node x, Key key) {
        if(x == null) return null;
        int cmp = key.compareTo(x.key);
        if(cmp < 0) {
            return get(x.left, key);
        } else if (cmp > 0) {
            return get(x.right, key);
        } else {
            return x.val;
        }
    }
    
    public void put(Key key, Value val) {
        root = put(root, key, val);
    }
    
    private Node put(Node x, Key key, Value val) {
        if(x == null) return new Node(key, val, 1);
        int cmp = key.compareTo(x.key);
        if(cmp < 0) {
            x.left = put(x.left, key, val);
        } else if(cmp > 0) {
            x.right = put(x.right, key, val);
        } else {
            x.val = val;
        }
        x.N = size(x.left) + size(x.right) + 1;
        return x;
    }
    
    public Key getMin() {
        Node x = getMin(root);
        if(x == null) return null;
        else return x.key;
    }
    
    private Node getMin(Node x) {
        if(x.left == null) return x;
        else return getMin(x.left);
    }
    
    public Key getMax() {
        Node x = getMax(root);
        if(x == null) return null;
        else return x.key;
    }
    
    private Node getMax(Node x) {
        if(x.right == null) return x;
        else return getMax(x.right);
    }
    
    public Key floor(Key key) { //Key smaller than or equal to the key
        Node x = floor(root, key);
        if(x == null) return null;
        else return x.key;
    }
    
    private Node floor(Node x, Key key) {
        if(x == null) return null;
        if(key.compareTo(x.key) < 0) return floor(x.left, key);
        else if(key.compareTo(x.key) == 0) return x;
        else {
            Node t = floor(x.right, key);
            if(t == null) return x;
            else return t;
        }
    }
    
    public int rank(Key key) { //number of nodes smaller than the key
        return rank(root, key);
    }
    
    private int rank(Node x, Key key) { 
        if(x == null) return 0;
        if(key.compareTo(x.key) == 0) return size(x.left);
        else if(key.compareTo(x.key) < 0) return rank(x.left, key);
        else return size(x.left) + 1 + rank(x.right, key);
    }
    
    public void deleteMin() {
        root = deleteMin(root);
    }
    
    private Node deleteMin(Node x) {
        if(x.left == null) return x.right;
        x.left = deleteMin(x.left);
        x.N = size(x.left) + 1 + size(x.right);
        return x;
    }
    
    public void delete(Key key) {
        root = delete(root, key);
    }
    
    private Node delete(Node x, Key key) {
        if(x == null) return null;
        int cmp = key.compareTo(x.key);
        if(cmp < 0) x.left = delete(x.left, key);
        else if(cmp > 0) x.right = delete(x.right, key);
        else {
            if(x.right == null) return x.left;
            if(x.left == null) return x.right;
            Node t = x;
            x = getMin(t.right);
            x.right = deleteMin(t.right);
            x.left = t.left;
        }
        x.N = size(x.left) + size(x.right) + 1;
        return x;
    }
    
    public Key select(int k) { //return key of rank k
        Node x = select(root, k);
        if(x == null) return null;
        else return x.key;
    }
    
    private Node select(Node x, int k) { 
        if(x == null) return null;
        int t = size(x.left);
        if(t > k) return select(x.left, k);
        else if(t < k) return select(x.right, k - t - 1);
        else return x;
    }
    
    public void print() {
        System.out.println();
        print(root);
        System.out.println();
    }
    
    private void print(Node x) {
        if(x == null) return;
        print(x.left);
        System.out.print(x.val + " ");
        print(x.right);
    }
    
    public Iterable<Key> keys() {
        return keys(getMin(), getMax());
    }
    
    private Iterable<Key> keys(Key lo, Key hi) {
        Queue<Key> queue = new LinkedList<>();
        keys(root, queue, lo, hi);
        return queue;
    }
    
    private void keys(Node x, Queue<Key> queue, Key lo, Key hi) {
        if(x == null) return;
        int cmplo = lo.compareTo(x.key);
        int cmphi = hi.compareTo(x.key);
        if(cmplo < 0) keys(x.left, queue, lo, hi);
        if(cmplo <= 0 && cmphi >= 0) queue.add(x.key);
        if(cmphi > 0) keys(x.right, queue, lo, hi);
    }
    
    public static void main(String...args) {
        BST<Integer, String> bst = new BST<>();
        bst.put(5, "five");
        bst.put(6, "six");
        bst.put(11, "eleven");
        bst.put(9, "nine");
        bst.put(4, "four");
        bst.put(7, "seven");
        System.out.println(bst.size());
        System.out.println(bst.get(6));
        System.out.println(bst.getMin());
        System.out.println(bst.getMax());
        System.out.println(bst.floor(9));
        System.out.println(bst.floor(10));
        System.out.println(bst.rank(9));
        System.out.println(bst.rank(10));
        bst.print();
        bst.deleteMin();
        bst.print();
        System.out.println(bst.size());
        System.out.println(bst.select(4));
        System.out.println(bst.select(1));
        bst.delete(7);
        bst.print();
        System.out.println(bst.size());
        for(Integer i : bst.keys()) {
            System.out.println(i);
        }
    }
}