import java.util.LinkedList;
import java.util.Queue;
//space: 56N. time: insert(search) worst. O(N) avg. O(logN)
public class BST<Key extends Comparable<Key>, Value> {
private Node root;
private class Node {
private Key key;
private Value val;
private Node left;
private Node right;
private int N;
public Node(Key key, Value val, int N) {
this.key = key;
this.val = val;
this.N = N;
}
}
public int size() {
return size(root);
}
private int size(Node x) {
if(x == null) return 0;
else return x.N;
}
public Value get(Key key) {
return get(root, key);
}
private Value get(Node x, Key key) {
if(x == null) return null;
int cmp = key.compareTo(x.key);
if(cmp < 0) {
return get(x.left, key);
} else if (cmp > 0) {
return get(x.right, key);
} else {
return x.val;
}
}
public void put(Key key, Value val) {
root = put(root, key, val);
}
private Node put(Node x, Key key, Value val) {
if(x == null) return new Node(key, val, 1);
int cmp = key.compareTo(x.key);
if(cmp < 0) {
x.left = put(x.left, key, val);
} else if(cmp > 0) {
x.right = put(x.right, key, val);
} else {
x.val = val;
}
x.N = size(x.left) + size(x.right) + 1;
return x;
}
public Key getMin() {
Node x = getMin(root);
if(x == null) return null;
else return x.key;
}
private Node getMin(Node x) {
if(x.left == null) return x;
else return getMin(x.left);
}
public Key getMax() {
Node x = getMax(root);
if(x == null) return null;
else return x.key;
}
private Node getMax(Node x) {
if(x.right == null) return x;
else return getMax(x.right);
}
public Key floor(Key key) { //Key smaller than or equal to the key
Node x = floor(root, key);
if(x == null) return null;
else return x.key;
}
private Node floor(Node x, Key key) {
if(x == null) return null;
if(key.compareTo(x.key) < 0) return floor(x.left, key);
else if(key.compareTo(x.key) == 0) return x;
else {
Node t = floor(x.right, key);
if(t == null) return x;
else return t;
}
}
public int rank(Key key) { //number of nodes smaller than the key
return rank(root, key);
}
private int rank(Node x, Key key) {
if(x == null) return 0;
if(key.compareTo(x.key) == 0) return size(x.left);
else if(key.compareTo(x.key) < 0) return rank(x.left, key);
else return size(x.left) + 1 + rank(x.right, key);
}
public void deleteMin() {
root = deleteMin(root);
}
private Node deleteMin(Node x) {
if(x.left == null) return x.right;
x.left = deleteMin(x.left);
x.N = size(x.left) + 1 + size(x.right);
return x;
}
public void delete(Key key) {
root = delete(root, key);
}
private Node delete(Node x, Key key) {
if(x == null) return null;
int cmp = key.compareTo(x.key);
if(cmp < 0) x.left = delete(x.left, key);
else if(cmp > 0) x.right = delete(x.right, key);
else {
if(x.right == null) return x.left;
if(x.left == null) return x.right;
Node t = x;
x = getMin(t.right);
x.right = deleteMin(t.right);
x.left = t.left;
}
x.N = size(x.left) + size(x.right) + 1;
return x;
}
public Key select(int k) { //return key of rank k
Node x = select(root, k);
if(x == null) return null;
else return x.key;
}
private Node select(Node x, int k) {
if(x == null) return null;
int t = size(x.left);
if(t > k) return select(x.left, k);
else if(t < k) return select(x.right, k - t - 1);
else return x;
}
public void print() {
System.out.println();
print(root);
System.out.println();
}
private void print(Node x) {
if(x == null) return;
print(x.left);
System.out.print(x.val + " ");
print(x.right);
}
public Iterable<Key> keys() {
return keys(getMin(), getMax());
}
private Iterable<Key> keys(Key lo, Key hi) {
Queue<Key> queue = new LinkedList<>();
keys(root, queue, lo, hi);
return queue;
}
private void keys(Node x, Queue<Key> queue, Key lo, Key hi) {
if(x == null) return;
int cmplo = lo.compareTo(x.key);
int cmphi = hi.compareTo(x.key);
if(cmplo < 0) keys(x.left, queue, lo, hi);
if(cmplo <= 0 && cmphi >= 0) queue.add(x.key);
if(cmphi > 0) keys(x.right, queue, lo, hi);
}
public static void main(String...args) {
BST<Integer, String> bst = new BST<>();
bst.put(5, "five");
bst.put(6, "six");
bst.put(11, "eleven");
bst.put(9, "nine");
bst.put(4, "four");
bst.put(7, "seven");
System.out.println(bst.size());
System.out.println(bst.get(6));
System.out.println(bst.getMin());
System.out.println(bst.getMax());
System.out.println(bst.floor(9));
System.out.println(bst.floor(10));
System.out.println(bst.rank(9));
System.out.println(bst.rank(10));
bst.print();
bst.deleteMin();
bst.print();
System.out.println(bst.size());
System.out.println(bst.select(4));
System.out.println(bst.select(1));
bst.delete(7);
bst.print();
System.out.println(bst.size());
for(Integer i : bst.keys()) {
System.out.println(i);
}
}
}