• Time:O(mn^2)
• Space:O(mn)

## C++

``````class Solution {
public:
int splitArray(vector<int>& nums, int m) {
const int n = nums.size();
// dp[i][k] := min of largest sum to split first i nums into k groups
dp.resize(n + 1, vector<int>(m + 1, INT_MAX));
prefix.resize(n + 1);

partial_sum(begin(nums), end(nums), begin(prefix) + 1);
return splitArray(nums, n, m);
}

private:
vector<vector<int>> dp;
vector<int> prefix;

int splitArray(const vector<int>& nums, int i, int k) {
if (k == 1)
return prefix[i];
if (dp[i][k] < INT_MAX)
return dp[i][k];

// try all possible partitions
for (int j = k - 1; j < i; ++j)
dp[i][k] =
min(dp[i][k], max(splitArray(nums, j, k - 1), prefix[i] - prefix[j]));

return dp[i][k];
}
};
``````

## JAVA

``````class Solution {
public int splitArray(int[] nums, int m) {
final int n = nums.length;
// dp[i][k] := min of largest sum to split first i nums into k groups
dp = new int[n + 1][m + 1];
prefix = new int[n + 1];

Arrays.stream(dp).forEach(A -> Arrays.fill(A, Integer.MAX_VALUE));

for (int i = 0; i < n; ++i)
prefix[i + 1] = nums[i] + prefix[i];

return splitArray(nums, n, m);
}

private int[][] dp;
private int[] prefix;

private int splitArray(int[] nums, int i, int k) {
if (k == 1)
return prefix[i];
if (dp[i][k] < Integer.MAX_VALUE)
return dp[i][k];

// try all possible partitions
for (int j = k - 1; j < i; ++j)
dp[i][k] = Math.min(dp[i][k], Math.max(splitArray(nums, j, k - 1), prefix[i] - prefix[j]));

return dp[i][k];
}
}
``````

## Python

``````class Solution:
def splitArray(self, nums: List[int], m: int) -> int:
n = len(nums)
prefix = [0] + list(accumulate(nums))

# dp(i, k) := min of largest sum to split first i nums into k groups
@lru_cache(None)
def dp(i: int, k: int) -> int:
if k == 1:
return prefix[i]

ans = math.inf

# try all possible partitions
for j in range(k - 1, i):
ans = min(ans, max(dp(j, k - 1), prefix[i] - prefix[j]))

return ans

return dp(n, m)
``````

• Time:O(mn^2)
• Space:O(mn)

## C++

``````class Solution {
public:
int splitArray(vector<int>& nums, int m) {
const int n = nums.size();
// dp[i][k] := min of largest sum to split first i nums into k groups
vector<vector<long>> dp(n + 1, vector<long>(m + 1, INT_MAX));
vector<long> prefix(n + 1);

partial_sum(begin(nums), end(nums), begin(prefix) + 1);

for (int i = 1; i <= n; ++i)
dp[i][1] = prefix[i];

for (int k = 2; k <= m; ++k)
for (int i = k; i <= n; ++i)
for (int j = k - 1; j < i; ++j)
dp[i][k] = min(dp[i][k], max(dp[j][k - 1], prefix[i] - prefix[j]));

return dp[n][m];
}
};
``````

## JAVA

``````class Solution {
public int splitArray(int[] nums, int m) {
final int n = nums.length;
// dp[i][k] := min of largest sum to split first i nums into k groups
int[][] dp = new int[n + 1][m + 1];
Arrays.stream(dp).forEach(A -> Arrays.fill(A, Integer.MAX_VALUE));
int[] prefix = new int[n + 1];

for (int i = 1; i <= n; ++i) {
prefix[i] = prefix[i - 1] + nums[i - 1];
dp[i][1] = prefix[i];
}

for (int k = 2; k <= m; ++k)
for (int i = k; i <= n; ++i)
for (int j = k - 1; j < i; ++j)
dp[i][k] = Math.min(dp[i][k], Math.max(dp[j][k - 1], prefix[i] - prefix[j]));

return dp[n][m];
}
}
``````

## Python

``````class Solution:
def splitArray(self, nums: List[int], m: int) -> int:
n = len(nums)
# dp[i][k] := min of largest sum to split first i nums into k groups
dp = [[math.inf] * (m + 1) for _ in range(n + 1)]
prefix = [0] + list(accumulate(nums))

for i in range(1, n + 1):
dp[i][1] = prefix[i]

for k in range(2, m + 1):
for i in range(k, n + 1):
for j in range(k - 1, i):
dp[i][k] = min(dp[i][k], max(dp[j][k - 1], prefix[i] - prefix[j]))

return dp[n][m]
``````
• Time:O(n\log(\Sigma |\texttt{nums}|))
• Space:O(1)

## C++

``````class Solution {
public:
int splitArray(vector<int>& nums, int m) {
int l = *max_element(begin(nums), end(nums));
int r = accumulate(begin(nums), end(nums), 0) + 1;

while (l < r) {
const int mid = (l + r) / 2;
if (numGroups(nums, mid) > m)
l = mid + 1;
else
r = mid;
}

return l;
}

private:
int numGroups(const vector<int>& nums, int maxSumInGroup) {
int groupCount = 1;
int sumInGroup = 0;

for (const int num : nums)
if (sumInGroup + num <= maxSumInGroup) {
sumInGroup += num;
} else {
++groupCount;
sumInGroup = num;
}

return groupCount;
}
};
``````

## JAVA

``````class Solution {
public int splitArray(int[] nums, int m) {
int l = Arrays.stream(nums).max().getAsInt();
int r = Arrays.stream(nums).sum() + 1;

while (l < r) {
final int mid = (l + r) / 2;
if (numGroups(nums, mid) > m)
l = mid + 1;
else
r = mid;
}

return l;
}

private int numGroups(int[] nums, int maxSumInGroup) {
int groupCount = 1;
int sumInGroup = 0;

for (final int num : nums)
if (sumInGroup + num <= maxSumInGroup) {
sumInGroup += num;
} else {
++groupCount;
sumInGroup = num;
}

return groupCount;
}
}
``````

## Python

``````class Solution:
def splitArray(self, nums: List[int], m: int) -> int:
l = max(nums)
r = sum(nums) + 1

def numGroups(maxSumInGroup: int) -> int:
groupCount = 1
sumInGroup = 0

for num in nums:
if sumInGroup + num <= maxSumInGroup:
sumInGroup += num
else:
groupCount += 1
sumInGroup = num

return groupCount

while l < r:
mid = (l + r) // 2
if numGroups(mid) > m:
l = mid + 1
else:
r = mid

return l
``````