Given 2 nodes on a binary search tree, find the distance of the two nodes. The distance is the minimum number of edges from one node to the other.
For example: in the following binary search tree
/ \
5 10
/ \ \
3 6 13
The distance of node 3 and 6 is 2, the distance of node 3 and 10 is 3.
From root, we walk to node 3, the distance is 2, the distance from root to node 6 is 2 as well. The distance from root to 5 is 1, so the formula is: dist(7, 3) + dist(7, 6) - 2*dist(7,5).
The key is to find the common parent node of 3 and 6, which can be found during post order traversal.
class NodeDistance {
private static class Node {
int value;
Node left;
Node right;
public Node(int value, Node left, Node right) {
this.value = value;
this.left = left;
this.right = right;
public static void main(String...args) {
Node root = new Node(7, null, null);
root.left = new Node(5, new Node(3, null, null), new Node(6, null, null));
root.right = new Node(10, null, new Node(13, null, null));
NodeDistance nd = new NodeDistance();
System.out.println(nd.findDistance(root, 3, 6));
public int findDistance(Node root, int w, int v) {//13, 6
Node common = findCommonParent(root, w, v);
if(common == null) return 0;
return dist(common, w) + dist(common, v);
/ \
5 10
/ \ \
3 6 13
private int dist(Node root, int target) {//5, 3
if(root.value == target) return 0;
else if(root.value < target) {
return 1 + dist(root.right, target);
} else {
return 1 + dist(root.left, target);
private Node findCommonParent(Node root, int w, int v) {//3, 6
if(root == null || root.value == w || root.value == v) {
return root;
Node lchild = findCommonParent(root.left, w, v);
Node rchild = findCommonParent(root.right, w, v);
if(lchild == null) {
return rchild;
} else if(rchild == null) {
return lchild;
} else {
return root;
private static class Node {
int value;
Node left;
Node right;
public Node(int value, Node left, Node right) {
this.value = value;
this.left = left;
this.right = right;
public static void main(String...args) {
Node root = new Node(7, null, null);
root.left = new Node(5, new Node(3, null, null), new Node(6, null, null));
root.right = new Node(10, null, new Node(13, null, null));
NodeDistance nd = new NodeDistance();
System.out.println(nd.findDistance(root, 3, 6));
public int findDistance(Node root, int w, int v) {//13, 6
Node common = findCommonParent(root, w, v);
if(common == null) return 0;
return dist(common, w) + dist(common, v);
/ \
5 10
/ \ \
3 6 13
private int dist(Node root, int target) {//5, 3
if(root.value == target) return 0;
else if(root.value < target) {
return 1 + dist(root.right, target);
} else {
return 1 + dist(root.left, target);
private Node findCommonParent(Node root, int w, int v) {//3, 6
if(root == null || root.value == w || root.value == v) {
return root;
Node lchild = findCommonParent(root.left, w, v);
Node rchild = findCommonParent(root.right, w, v);
if(lchild == null) {
return rchild;
} else if(rchild == null) {
return lchild;
} else {
return root;
The time complexity is O(logN) and the space complexity is O(1)