# A Generic Solution to K-Sum Problems

K-sum problems is a type of algorithm problems I encountered when preparing for job interviews.

**Problem statement:**

Given an array of integers, return all unique combinations of k numbers such that they add up to a specific target. (Assuming k < number of elements in the given integer array)

(

The signature of the function will be something like (in `Java`

for example):

`List<List<Integer>> kSum(int[] nums, int k, int target)`

The problem is actually not as difficult as it seems. My approach was to use recursion to reduce the K-sum problem to (K-1) sum, and all the way to 2 sum.

First of all, a sort is needed.

`Arrays.sort(nums);`

## Start from 1 and 2

Let’s start from base cases. If k = 1, it becomes a simple search.

```
if (k == 1) {
List<List<Integer>> res = new ArrayList<>();
if (Arrays.binarySearch(nums, target) >= 0) {
res.add(Collections.singletonList(target));
return res;
}
}
```

If k = 2, take 2-pointers approach. Start from the first and the last element in the sorted array, and move towards each other. At each step, examine the sum of two pointers `nums[i] + nums[j]`

(assume i < j ) and the `target`

value:

- if
`nums[i] + nums[j] == target`

, great, add these elements to our results, and move both pointers (`++i`

,`--j`

); - if
`nums[i] + nums[j] > target`

, then we want a smaller sum, therefore we move the right pointer (`--j`

); - similarly, if
`nums[i] + nums[j] < target`

, move the left pointer rightward (`++i`

).

Repeat this process until two pointers meet. Pseudo code:

```
# 2 Sum
List<List<Integer>> res = new ArrayList<>();
int left = 0, right = nums.length - 1;
while (left < right) {
int sum = nums[left] + nums[right];
if (sum == target) {
res.add(new ArrayList<>(Arrays.asList(nums[left], nums[right])));
# de-dup
while (left < right && nums[left] == nums[left + 1]) ++left;
while (left < right && nums[right] == nums[right - 1]) --right;
++left;
--right;
} else if (sum < target) {
++left;
} else {
--right;
}
}
return res;
```

## Recursion

Now that 2 sum problem is solved, we can move forward to tackle recursive cases.

If k=3, we can iterate each element, and leverage the 2 Sum solution. For each number `num`

(index `i`

) in the array, assume there is a combination {`num`

, `m`

, `n`

} that add up to `target`

, and `num`

is the smallest element in the combination. Therefore, for each `num`

, we will try to find solution to a 2 Sum problem, where the new target equals to `target - num`

, and the source array is a sub array (starting element `i+1`

) of the original one. Let’s define a helper function for recursion, with an additional parameter `beginIdx`

that specifies the left bound of the array to search:

`private List<List<Integer>> kSumHelper(int[] nums, int k, int target, int beginIdx)`

Pseudo code of 3 Sum will be like:

```
List<List<Integer>> res = new ArrayList<>();
for (int i = 0; i <= nums.length - 3; i++) {
// Get 2-Sum result for each element
List<List<Integer>> subResults = kSumHelper(nums, 2, target - nums[i], i + 1);
for (List<Integer> list : sub) {
list.add(0, nums[i]);
}
res.addAll(sub);
return res;
```

Similarly, we can always reduce K Sum problem to K-1. After some trimming and optimization, a complete KSum utility class will look like:

```
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
/**
* General K-Sum problem.
* Created by Sky on 014, 7, 14, 2016.
*/
public class KSum {
/**
* Find all k elements groups that adding up to given target.
*
* @param nums input array
* @param k k
* @param target target value
* @return all groups
*/
public List<List<Integer>> kSum(int[] nums, int k, int target) {
Arrays.sort(nums);
if (k <= 1) {
List<List<Integer>> res = new ArrayList<>();
if (k == 1 && Arrays.binarySearch(nums, target) >= 0) {
res.add(Collections.singletonList(target));
}
return res;
}
return kSum(nums, k, target, 0);
}
/**
* Helper function for K-sum.
*
* @param nums input array
* @param k k
* @param target target value
* @param beginIdx begin index
* @return all groups of k elements from beginIdx that adding up to target.
*/
private List<List<Integer>> kSum(int[] nums, int k, int target, int beginIdx) {
int len = nums.length;
if (k == 2) {
List<List<Integer>> res = new ArrayList<>();
int left = beginIdx, right = len - 1;
while (left < right) {
int sum = nums[left] + nums[right];
if (sum == target) {
res.add(new ArrayList<>(Arrays.asList(nums[left], nums[right])));
while (left < right && nums[left] == nums[left + 1]) {
++left;
}
while (left < right && nums[right] == nums[right - 1]) {
--right;
}
++left;
--right;
} else if (sum < target) {
++left;
} else {
--right;
}
}
return res;
}
List<List<Integer>> res = new ArrayList<>();
for (int i = beginIdx; i <= len - k; i++) {
if (i > beginIdx && nums[i] == nums[i - 1] || nums[i] + (k - 1) * nums[len - 1] < target) {
continue;
}
if (nums[i] + (k - 1) * nums[i + 1] > target) {
break;
}
List<List<Integer>> sub = kSum(nums, k - 1, target - nums[i], i + 1);
for (List<Integer> list : sub) {
list.add(0, nums[i]);
}
res.addAll(sub);
}
return res;
}
}
```

Above code has not been fully tested, but passed all test cases of LeetCode 2 Sum, 3 Sum and 4 Sum problems.

Reference:

- K Sum Problem Analysis: Recursive Implementation and Lower Bound
- My KSum java code on GitHub

## 8 Replies to “A Generic Solution to K-Sum Problems”

have to prove both the correctness and the efficiency?

Right, it has not been fully tested or proven yet. I’m not an expert in algorithm and am also looking for a simpler solution. Let me know if you find one!

this solution is quite inefficient and very verbose. This is an optimisation problem and as such can be resolved with DP in a simpler way

I’m more than happy to learn a better solution! please let me know

How to do DP here?

I’ll bite:

public static int kSum(int[] nums, int target, int numberOfItemsToTake) {

int n = nums.length;

int dp[][][] = new int[n + 1][target + 1][numberOfItemsToTake + 1];

dp[n][0][0] = 1;

`int count = 0;`

for(int i = n - 1; i >= 0; i--) {

int num = nums[i];

for (int j = 0; j <= target; j++) {

for (int k = 0; k <= numberOfItemsToTake; k++) {

int max = dp[i + 1][j][k];

if (j >= num && k > 0) {

max = Math.max(dp[i + 1][j - num][k - 1], max);

}

dp[i][j][k] += max;

if (j == target && k == numberOfItemsToTake && dp[i][j][k] > 0) {

count++;

}

}

}

}

`return count;`

}

Of course here I’m just returning the count, if you need to return a list of lists you can traverse the dp table and construct it. But this is more optimal since it is O(N

SK) where N is the number of items, S is the target amount, and K is the number of items to take. This is pseudo polynomial time since the runtime depends on how large target is and it’s the best known complexity we can achieve on this problem (same as Knapsack except we have the K factor in there as well)I should also mention, you can optimize the memory of this solution by one dimension

This code is passed and beat 99% in the leetcode test, awesome answer especially for the pruning part, thank you!