A Generic Solution to K-Sum Problems

A Generic Solution to K-Sum Problems

K-sum problems is a type of algorithm problems I encountered when preparing for job interviews. 

Problem statement:

Given an array of integers, return all unique combinations of k numbers such that they add up to a specific target. (Assuming k < number of elements in the given integer array)

(

The signature of the function will be something like (in Java for example):

List<List<Integer>> kSum(int[] nums, int k, int target)

The problem is actually not as difficult as it seems. My approach was to use recursion to reduce the K-sum problem to (K-1) sum, and all the way to 2 sum.

First of all, a sort is needed.

Arrays.sort(nums);

Start from 1 and 2

Let’s start from base cases. If k = 1, it becomes a simple search.

if (k == 1) {
  List<List<Integer>> res = new ArrayList<>();
  if (Arrays.binarySearch(nums, target) >= 0) {
    res.add(Collections.singletonList(target));
    return res;
  }
}

If k = 2, take 2-pointers approach. Start from the first and the last element in the sorted array, and move towards each other. At each step, examine the sum of two pointers nums[i] + nums[j]  (assume i < j ) and the target value:

  • if nums[i] + nums[j] == target, great, add these elements to our results, and move both pointers (++i, --j);
  • if nums[i] + nums[j] > target, then we want a smaller sum, therefore we move the right pointer (--j);
  • similarly, if nums[i] + nums[j] < target, move the left pointer rightward (++i).

Repeat this process until two pointers meet. Pseudo code:

# 2 Sum
List<List<Integer>> res = new ArrayList<>();
int left = 0, right = nums.length - 1;
while (left < right) {
  int sum = nums[left] + nums[right];
  if (sum == target) {
    res.add(new ArrayList<>(Arrays.asList(nums[left], nums[right])));
    # de-dup
    while (left < right && nums[left] == nums[left + 1]) ++left;
    while (left < right && nums[right] == nums[right - 1]) --right;
    ++left;
    --right;
  } else if (sum < target) {
    ++left;
  } else {
    --right;
  }
}
return res;

Recursion

Now that 2 sum problem is solved, we can move forward to tackle recursive cases.

If k=3, we can iterate each element, and leverage the 2 Sum solution.  For each number num(index i) in the array, assume there is a combination {num, m, n} that add up to target, and num is the smallest element in the combination. Therefore, for each num, we will try to find solution to a 2 Sum problem, where the new target equals to target - num, and the source array is a sub array (starting element i+1) of the original one.  Let’s define a helper function for recursion, with an additional parameter beginIdx that specifies the left bound of the array to search:

private List<List<Integer>> kSumHelper(int[] nums, int k, int target, int beginIdx)

Pseudo code of 3 Sum will be like:

List<List<Integer>> res = new ArrayList<>();
for (int i = 0; i <= nums.length - 3; i++) {
  // Get 2-Sum result for each element
  List<List<Integer>> subResults = kSumHelper(nums, 2, target - nums[i], i + 1);
  for (List<Integer> list : sub) {
    list.add(0, nums[i]);
  }
res.addAll(sub);
return res;

Similarly, we can always reduce K Sum problem to K-1. After some trimming and optimization, a complete KSum utility class will look like:

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

/**
 * General K-Sum problem.
 * Created by Sky on 014, 7, 14, 2016.
 */
public class KSum {
    /**
     * Find all k elements groups that adding up to given target.
     *
     * @param nums   input array
     * @param k      k
     * @param target target value
     * @return all groups
     */
    public List<List<Integer>> kSum(int[] nums, int k, int target) {
        Arrays.sort(nums);
        if (k <= 1) {
            List<List<Integer>> res = new ArrayList<>();
            if (k == 1 && Arrays.binarySearch(nums, target) >= 0) {
                res.add(Collections.singletonList(target));
            }
            return res;
        }
        return kSum(nums, k, target, 0);
    }

    /**
     * Helper function for K-sum.
     *
     * @param nums     input array
     * @param k        k
     * @param target   target value
     * @param beginIdx begin index
     * @return all groups of k elements from beginIdx that adding up to target.
     */
    private List<List<Integer>> kSum(int[] nums, int k, int target, int beginIdx) {
        int len = nums.length;
        if (k == 2) {
            List<List<Integer>> res = new ArrayList<>();
            int left = beginIdx, right = len - 1;
            while (left < right) {
                int sum = nums[left] + nums[right];
                if (sum == target) {
                    res.add(new ArrayList<>(Arrays.asList(nums[left], nums[right])));
                    while (left < right && nums[left] == nums[left + 1]) {
                        ++left;
                    }
                    while (left < right && nums[right] == nums[right - 1]) {
                        --right;
                    }
                    ++left;
                    --right;
                } else if (sum < target) {
                    ++left;
                } else {
                    --right;
                }
            }
            return res;
        }
        List<List<Integer>> res = new ArrayList<>();
        for (int i = beginIdx; i <= len - k; i++) {
            if (i > beginIdx && nums[i] == nums[i - 1] || nums[i] + (k - 1) * nums[len - 1] < target) {
                continue;
            }
            if (nums[i] + (k - 1) * nums[i + 1] > target) {
                break;
            }
            List<List<Integer>> sub = kSum(nums, k - 1, target - nums[i], i + 1);
            for (List<Integer> list : sub) {
                list.add(0, nums[i]);
            }
            res.addAll(sub);
        }
        return res;
    }
}

Above code has not been fully tested, but passed all test cases of LeetCode 2 Sum, 3 Sum and 4 Sum problems.

Reference: 

8 Replies to “A Generic Solution to K-Sum Problems”

    1. Right, it has not been fully tested or proven yet. I’m not an expert in algorithm and am also looking for a simpler solution. Let me know if you find one!

  1. this solution is quite inefficient and very verbose. This is an optimisation problem and as such can be resolved with DP in a simpler way

        1. I’ll bite:

          public static int kSum(int[] nums, int target, int numberOfItemsToTake) {
          int n = nums.length;
          int dp[][][] = new int[n + 1][target + 1][numberOfItemsToTake + 1];
          dp[n][0][0] = 1;

          int count = 0;
          for(int i = n - 1; i >= 0; i--) {
          int num = nums[i];
          for (int j = 0; j <= target; j++) {
          for (int k = 0; k <= numberOfItemsToTake; k++) {
          int max = dp[i + 1][j][k];
          if (j >= num && k > 0) {
          max = Math.max(dp[i + 1][j - num][k - 1], max);
          }
          dp[i][j][k] += max;
          if (j == target && k == numberOfItemsToTake && dp[i][j][k] > 0) {
          count++;
          }
          }
          }
          }

          return count;
          }

          Of course here I’m just returning the count, if you need to return a list of lists you can traverse the dp table and construct it. But this is more optimal since it is O(NSK) where N is the number of items, S is the target amount, and K is the number of items to take. This is pseudo polynomial time since the runtime depends on how large target is and it’s the best known complexity we can achieve on this problem (same as Knapsack except we have the K factor in there as well)

  2. This code is passed and beat 99% in the leetcode test, awesome answer especially for the pruning part, thank you!

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.