import java.util.LinkedList;
import java.util.Queue;
//time: logN/logR, space: (8R + 56)N to (8R +56)Nw
public class Trie<Value> {
private Node root;
private static int R = 256;
private static class Node {
private Object val;
private Node[] next = new Node[R];
}
public Value get(String key) {
Node x = get(root, key, 0);
if(x == null) return null;
return (Value)x.val;
}
private Node get(Node x, String key, int d) {
if(x == null) return null;
if(d == key.length()) return x;
char c = key.charAt(d);
return get(x.next[c], key, d+1);
}
public void put(String key, Value val) {
root = put(root, key, val, 0);
}
private Node put(Node x, String key, Value val, int d) {
if(x == null) {x = new Node();}
if(key.length() == d) {
x.val = val;
return x;
}
char c = key.charAt(d);
x.next[c] = put(x.next[c], key, val, d+1);
return x;
}
public int size() {return size(root);}
private int size(Node x) {
if(x == null) return 0;
int cnt = 0;
if(x.val != null) cnt++;
for(char c = 0; c < R; c++) {
cnt += size(x.next[c]);
}
return cnt;
}
public Iterable<String> keys() {
return keysWithPrefix("");
}
public Iterable<String> keysWithPrefix(String pre) {
Queue<String> queue = new LinkedList<>();
collect(get(root, pre, 0), pre, queue);
return queue;
}
private void collect(Node x, String pre, Queue<String> queue) {
if(x == null) return;
if(x.val != null) queue.add(pre);
for(char c = 0; c < R; c++) {
collect(x.next[c], pre + c, queue);
}
}
public Iterable<String> keysThatMatch(String pat) {
Queue<String> queue = new LinkedList<>();
collect(root, pat, "", 0, queue);
return queue;
}
private void collect(Node x, String pat, String pre, int d, Queue<String> queue) {
if(x == null) return;
if(x.val != null && pat.length() == d) queue.add(pre);
if(pat.length() == d) return;
for(char c = 0; c < R; c++) {
if(pat.charAt(d) == c || pat.charAt(d) == '.') {
collect(x.next[c], pat, pre + c, d+1, queue);
}
}
}
public String longestPrefixOf(String s) {
int length = search(root, s, 0, 0);
return s.substring(0, length);
}
private int search(Node x, String s, int d, int length) {
if(x == null) return length;
if(x.val != null) length = d;
if(d == s.length()) return length;
char c = s.charAt(d);
return search(x.next[c], s, d+1, length);
}
public static void main(String...args) {
Trie<Integer> trie = new Trie<>();
trie.put("apple", 9);
trie.put("pear", 2);
trie.put("cherry", 5);
trie.put("churry", 1);
trie.put("orchard", 20);
trie.put("orange", 4);
trie.put("avocado", 0);
System.out.println(trie.get("cherry"));
System.out.println(trie.get("cucumber"));
System.out.println("=========size()========");
System.out.println(trie.size());
System.out.println("======keys()=======");
for(String s : trie.keys()) {
System.out.println(s);
}
System.out.println("======keysWithPrefix(\"or\")====");
for(String s : trie.keysWithPrefix("or")) {
System.out.println(s);
}
System.out.println("======keysThatMatch=======");
for(String s: trie.keysThatMatch("ch.rry")) {
System.out.println(s);
}
for(String s: trie.keysThatMatch("orange")) {
System.out.println(s);
}
System.out.println("======longestPrefixOf====");
System.out.println(trie.longestPrefixOf("organic"));
System.out.println(trie.longestPrefixOf("avocadojam"));
}
}