Problem
Given an array of integers, find all the triplets that sum to zero.
For example, given array {2, 1, -1, 3, 0, -2, -3, -1}, a possible output should be:
{-3,0,3},{-3,1,2},{-2,-1,3},{-2,0,2},{-1,-1,2},{-1,0,1}
Solution
brutal force approach is to find all the triplets and test if they sum to zero. We need one loop for each number, the time complexity is O(N^3).
approach 2.
We know a fast way to solve 2 sum. Given a + b = sum, we can iterate through the array, store the element value and its index into a HashMap. Later when we meet a value a, we can check the HashMap to find out if (sum - a) is in the HashMap, if yes, we pair the a's index and (sum - a)'s index to make a two sum pair. Since map put and get only cost 1, the time cost is O(N) for two sum, the extra space cost is the HashMap, which is O(N) as well.
Similarly, we can solve three sum by converting it to two sum problem.
We can iterate through the array, for each element a, we are looking for a two sum pair that sum up to (-a). In order to find that two sum pair, we can iterate through the rest of the elements, find all the two sums and pair them with a to get triplets. One problem is there could be duplicates. We can solve it by sorting the triplets, then add them into a Set in oder to remove duplicates. The time cost will be O(N^2), one N is for looping through the array, the other N is for two sum calculation. Space cost is bound by the HashMap for two sum and the Set, which is O(N).
approach 3.
If the array is sorted, we can do better. We need 3 pointers, pointer i to the smallest element in the triplet, pointer lo = i + 1 and pointer hi = arr.length - 1 initially.
if arr[i] + arr[lo] + a[hi] < 0,
we increase lo
else if arr[i] + arr[lo] + a[hi] > 0
we decrease hi
else
we found a three sum.
We will continue to increase lo and hi until they cross.
In this way, we can found all the three sum pairs with a for loop and a while loop, the cost will be N^2, we also need to sort the array, which cost NlogN, so the time is bounded by O(N^2). Besides a few pointers, no extra space is needed, the space cost is O(1).
-3, -2, -1, -1, 0, 1, 2, 3
i
l
h
import java.util.*;
class ThreeSum {
private static int TARGET = 0;
public static List<String> getThreeSum(int[] arr) {
List<String> sums = new ArrayList<>();
Arrays.sort(arr); //3 way quick sort by jvm, O(NlogN)
//-3, -2, -1, -1, 0, 1, 2, 3
// i
// lo
// hi
//{-3,0,3},{-3,1,2},{-2,-1,3},{-2,0,2},{-1,-1,2},{-1,0,1}
int N = arr.length;
for(int i = 0; i < N; i++) {
int lo = i + 1;
int hi = N - 1;
if(i > 0 && arr[i] == arr[i - 1]) //prevent duplicate
continue;
if(arr[i] > TARGET) break;
while(lo < hi) {
int cmp = arr[i] + arr[lo] + arr[hi]; //0 + 1 + 3 = 4
if(cmp == TARGET) {
sums.add(String.format("[%d, %d, %d]", arr[i], arr[lo], arr[hi]));
while(lo < hi && arr[lo] == arr[lo + 1]) lo++; //skip duplicate
while(lo < hi && arr[hi] == arr[hi - 1]) hi--;
lo++;
hi--;
} else if(cmp < TARGET) {
lo++;
} else {
hi--;
}
}
}
return sums;
}
public static void main(String...args) {
int[] arr = new int[] {2, 1, -1, 3, 0, -2, -3, -1};
List<String> sums = getThreeSum(arr);
sums.forEach(System.out::println);
}
}
class ThreeSum {
private static int TARGET = 0;
public static List<String> getThreeSum(int[] arr) {
List<String> sums = new ArrayList<>();
Arrays.sort(arr); //3 way quick sort by jvm, O(NlogN)
//-3, -2, -1, -1, 0, 1, 2, 3
// i
// lo
// hi
//{-3,0,3},{-3,1,2},{-2,-1,3},{-2,0,2},{-1,-1,2},{-1,0,1}
int N = arr.length;
for(int i = 0; i < N; i++) {
int lo = i + 1;
int hi = N - 1;
if(i > 0 && arr[i] == arr[i - 1]) //prevent duplicate
continue;
if(arr[i] > TARGET) break;
while(lo < hi) {
int cmp = arr[i] + arr[lo] + arr[hi]; //0 + 1 + 3 = 4
if(cmp == TARGET) {
sums.add(String.format("[%d, %d, %d]", arr[i], arr[lo], arr[hi]));
while(lo < hi && arr[lo] == arr[lo + 1]) lo++; //skip duplicate
while(lo < hi && arr[hi] == arr[hi - 1]) hi--;
lo++;
hi--;
} else if(cmp < TARGET) {
lo++;
} else {
hi--;
}
}
}
return sums;
}
public static void main(String...args) {
int[] arr = new int[] {2, 1, -1, 3, 0, -2, -3, -1};
List<String> sums = getThreeSum(arr);
sums.forEach(System.out::println);
}
}
The time cost is O(N^2), the space cost is O(1).