feat(Combination): add BigInteger-based combination calculation methods and deprecate old count method

This commit is contained in:
yulin
2025-11-23 22:41:39 +08:00
parent ff32ed0872
commit 5e1110426b
2 changed files with 157 additions and 9 deletions

View File

@@ -1,13 +1,13 @@
package cn.hutool.core.math;
import cn.hutool.core.util.StrUtil;
import java.io.Serializable;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import cn.hutool.core.util.NumberUtil;
import cn.hutool.core.util.StrUtil;
/**
* 组合即C(n, m)<br>
* 排列组合相关类 参考http://cgs1999.iteye.com/blog/2327664
@@ -32,18 +32,71 @@ public class Combination implements Serializable {
/**
* 计算组合数即C(n, m) = n!/((n-m)! * m!)
*
* <p>注意:此方法内部使用 BigInteger 修复了旧版 factorial 的计算错误,
* 但最终仍以 long 返回,因此当结果超过 long 范围时仍会溢出。</p>
* <p>建议使用 {@link #countBig(int, int)} 获取精确结果,或使用
* {@link #countSafe(int, int)} 获取安全 long 版本。</p>
* @param n 总数
* @param m 选择的个数
* @return 组合数
*/
@Deprecated
public static long count(int n, int m) {
if (0 == m || n == m) {
return 1;
}
return (n > m) ? NumberUtil.factorial(n, n - m) / NumberUtil.factorial(m) : 0;
BigInteger big = countBig(n, m);
return big.longValue();
}
/**
* 计算组合数 C(n, m) 的 BigInteger 精确版本。
* 使用逐步累乘除法(非阶乘)保证不溢出、性能好。
* <p>
* 数学定义:
* C(n, m) = n! / (m! (n - m)!)
* <p>
* 优化方式:
* 1. 利用对称性 m = min(m, n-m)
* 2. 每一步先乘 BigInteger再除以当前 i保证数值不暴涨
*
* @param n 总数 n必须 >= 0
* @param m 取出 m必须 >= 0
* @return C(n, m) 的 BigInteger 精确值;当 m > n 时返回 BigInteger.ZERO
*/
public static BigInteger countBig(int n, int m) {
if (n < 0 || m < 0) {
throw new IllegalArgumentException("n and m must be non-negative. got n=" + n + ", m=" + m);
}
if (m > n) {
return BigInteger.ZERO;
}
if (m == 0 || n == m) {
return BigInteger.ONE;
}
// 使用对称性C(n, m) = C(n, n-m)
m = Math.min(m, n - m);
BigInteger result = BigInteger.ONE;
// 从 1 → m 累乘
for (int i = 1; i <= m; i++) {
int numerator = n - m + i;
result = result.multiply(BigInteger.valueOf(numerator))
.divide(BigInteger.valueOf(i));
}
return result;
}
/**
* 安全组合数 long 版本。
*
* @param n 总数 n必须 >= 0
* @param m 取出 m必须 >= 0
* <p>若结果超出 long 范围,会抛 ArithmeticException而非溢出。</p>
*/
public static long countSafe(int n, int m) {
BigInteger big = countBig(n, m);
return big.longValueExact();
}
/**
* 计算组合总数即C(n, 1) + C(n, 2) + C(n, 3)...
*
@@ -104,4 +157,5 @@ public class Combination implements Serializable {
select(i + 1, resultList, resultIndex + 1, result);
}
}
}

View File

@@ -1,10 +1,12 @@
package cn.hutool.core.math;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.Test;
import java.math.BigInteger;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
/**
* 组合单元测试
*
@@ -51,4 +53,96 @@ public class CombinationTest {
List<String[]> list2 = combination.select(0);
assertEquals(1, list2.size());
}
// -----------------------------
// countBig() 正确性测试
// -----------------------------
@Test
void testCountBig_basicCases() {
assertEquals(BigInteger.ONE, Combination.countBig(5, 0));
assertEquals(BigInteger.ONE, Combination.countBig(5, 5));
assertEquals(BigInteger.valueOf(10), Combination.countBig(5, 3));
assertEquals(BigInteger.valueOf(10), Combination.countBig(5, 2));
}
@Test
void testCountBig_mGreaterThanN() {
assertEquals(BigInteger.ZERO, Combination.countBig(5, 6));
}
@Test
void testCountBig_negativeInput() {
assertThrows(IllegalArgumentException.class, () -> Combination.countBig(-1, 3));
assertThrows(IllegalArgumentException.class, () -> Combination.countBig(5, -2));
}
@Test
void testCountBig_symmetry() {
assertEquals(Combination.countBig(20, 3), Combination.countBig(20, 17));
}
@Test
void testCountBig_largeNumbers() {
// C(50, 3) = 19600
assertEquals(new BigInteger("19600"), Combination.countBig(50, 3));
// C(100, 50) 的确切值(重要测试)
BigInteger expected = new BigInteger(
"100891344545564193334812497256"
);
assertEquals(expected, Combination.countBig(100, 50));
}
@Test
void testCountBig_veryLargeCombination() {
// 不比较具体值,只断言不要抛错
BigInteger result = Combination.countBig(2000, 1000);
assertTrue(result.signum() > 0);
}
// -----------------------------
// count(long) 兼容性测试
// -----------------------------
@Test
void testCount_basic() {
assertEquals(10L, Combination.count(5, 3));
assertEquals(1L, Combination.count(5, 0));
assertEquals(0L, Combination.count(5, 6));
}
@Test
void testCount_overflowBehavior() {
// C(100, 50) 远超 long 范围,但旧版行为要求不抛异常
long r = Combination.count(100, 50);
// longValue() 不抛异常,并且可能溢出
assertNotNull(r);
}
@Test
void testCount_noException() {
assertDoesNotThrow(() -> Combination.count(5000, 2500));
}
// -----------------------------
// countSafe() 安全 long 版本测试
// -----------------------------
@Test
void testCountSafe_exactFitsLong() {
// C(50, 3) = 19600 fits long
assertEquals(19600L, Combination.countSafe(50, 3));
}
@Test
void testCountSafe_overflowThrows() {
// C(100, 50) 超出 long → 应抛 ArithmeticException
assertThrows(ArithmeticException.class, () -> Combination.countSafe(100, 50));
}
@Test
void testCountSafe_invalidInput() {
assertThrows(IllegalArgumentException.class, () -> Combination.countSafe(-1, 3));
assertThrows(IllegalArgumentException.class, () -> Combination.countSafe(3, -1));
}
}