diff --git a/hutool-core/src/main/java/cn/hutool/core/math/Combination.java b/hutool-core/src/main/java/cn/hutool/core/math/Combination.java index 7da8d3b42a..868ac83ae5 100644 --- a/hutool-core/src/main/java/cn/hutool/core/math/Combination.java +++ b/hutool-core/src/main/java/cn/hutool/core/math/Combination.java @@ -1,13 +1,13 @@ package cn.hutool.core.math; +import cn.hutool.core.util.StrUtil; + import java.io.Serializable; +import java.math.BigInteger; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import cn.hutool.core.util.NumberUtil; -import cn.hutool.core.util.StrUtil; - /** * 组合,即C(n, m)
* 排列组合相关类 参考:http://cgs1999.iteye.com/blog/2327664 @@ -32,18 +32,71 @@ public class Combination implements Serializable { /** * 计算组合数,即C(n, m) = n!/((n-m)! * m!) - * + *

注意:此方法内部使用 BigInteger 修复了旧版 factorial 的计算错误, + * 但最终仍以 long 返回,因此当结果超过 long 范围时仍会溢出。

+ *

建议使用 {@link #countBig(int, int)} 获取精确结果,或使用 + * {@link #countSafe(int, int)} 获取安全 long 版本。

* @param n 总数 * @param m 选择的个数 * @return 组合数 */ + @Deprecated public static long count(int n, int m) { - if (0 == m || n == m) { - return 1; - } - return (n > m) ? NumberUtil.factorial(n, n - m) / NumberUtil.factorial(m) : 0; + BigInteger big = countBig(n, m); + return big.longValue(); } + /** + * 计算组合数 C(n, m) 的 BigInteger 精确版本。 + * 使用逐步累乘除法(非阶乘)保证不溢出、性能好。 + *

+ * 数学定义: + * C(n, m) = n! / (m! (n - m)!) + *

+ * 优化方式: + * 1. 利用对称性 m = min(m, n-m) + * 2. 每一步先乘 BigInteger,再除以当前 i,保证数值不暴涨 + * + * @param n 总数 n(必须 >= 0) + * @param m 取出 m(必须 >= 0) + * @return C(n, m) 的 BigInteger 精确值;当 m > n 时返回 BigInteger.ZERO + */ + public static BigInteger countBig(int n, int m) { + if (n < 0 || m < 0) { + throw new IllegalArgumentException("n and m must be non-negative. got n=" + n + ", m=" + m); + } + if (m > n) { + return BigInteger.ZERO; + } + if (m == 0 || n == m) { + return BigInteger.ONE; + } + // 使用对称性:C(n, m) = C(n, n-m) + m = Math.min(m, n - m); + BigInteger result = BigInteger.ONE; + // 从 1 → m 累乘 + for (int i = 1; i <= m; i++) { + int numerator = n - m + i; + result = result.multiply(BigInteger.valueOf(numerator)) + .divide(BigInteger.valueOf(i)); + } + + return result; + } + + /** + * 安全组合数 long 版本。 + * + * @param n 总数 n(必须 >= 0) + * @param m 取出 m(必须 >= 0) + *

若结果超出 long 范围,会抛 ArithmeticException,而非溢出。

+ */ + public static long countSafe(int n, int m) { + BigInteger big = countBig(n, m); + return big.longValueExact(); + } + + /** * 计算组合总数,即C(n, 1) + C(n, 2) + C(n, 3)... * @@ -104,4 +157,5 @@ public class Combination implements Serializable { select(i + 1, resultList, resultIndex + 1, result); } } + } diff --git a/hutool-core/src/test/java/cn/hutool/core/math/CombinationTest.java b/hutool-core/src/test/java/cn/hutool/core/math/CombinationTest.java index 3cc6d84613..134bcb39a0 100644 --- a/hutool-core/src/test/java/cn/hutool/core/math/CombinationTest.java +++ b/hutool-core/src/test/java/cn/hutool/core/math/CombinationTest.java @@ -1,10 +1,12 @@ package cn.hutool.core.math; -import static org.junit.jupiter.api.Assertions.*; import org.junit.jupiter.api.Test; +import java.math.BigInteger; import java.util.List; +import static org.junit.jupiter.api.Assertions.*; + /** * 组合单元测试 * @@ -51,4 +53,96 @@ public class CombinationTest { List list2 = combination.select(0); assertEquals(1, list2.size()); } + + + // ----------------------------- + // countBig() 正确性测试 + // ----------------------------- + @Test + void testCountBig_basicCases() { + assertEquals(BigInteger.ONE, Combination.countBig(5, 0)); + assertEquals(BigInteger.ONE, Combination.countBig(5, 5)); + assertEquals(BigInteger.valueOf(10), Combination.countBig(5, 3)); + assertEquals(BigInteger.valueOf(10), Combination.countBig(5, 2)); + } + + @Test + void testCountBig_mGreaterThanN() { + assertEquals(BigInteger.ZERO, Combination.countBig(5, 6)); + } + + @Test + void testCountBig_negativeInput() { + assertThrows(IllegalArgumentException.class, () -> Combination.countBig(-1, 3)); + assertThrows(IllegalArgumentException.class, () -> Combination.countBig(5, -2)); + } + + @Test + void testCountBig_symmetry() { + assertEquals(Combination.countBig(20, 3), Combination.countBig(20, 17)); + } + + @Test + void testCountBig_largeNumbers() { + // C(50, 3) = 19600 + assertEquals(new BigInteger("19600"), Combination.countBig(50, 3)); + + // C(100, 50) 的确切值(重要测试) + BigInteger expected = new BigInteger( + "100891344545564193334812497256" + ); + assertEquals(expected, Combination.countBig(100, 50)); + } + + @Test + void testCountBig_veryLargeCombination() { + // 不比较具体值,只断言不要抛错 + BigInteger result = Combination.countBig(2000, 1000); + assertTrue(result.signum() > 0); + } + + // ----------------------------- + // count(long) 兼容性测试 + // ----------------------------- + @Test + void testCount_basic() { + assertEquals(10L, Combination.count(5, 3)); + assertEquals(1L, Combination.count(5, 0)); + assertEquals(0L, Combination.count(5, 6)); + } + + @Test + void testCount_overflowBehavior() { + // C(100, 50) 远超 long 范围,但旧版行为要求不抛异常 + long r = Combination.count(100, 50); + + // longValue() 不抛异常,并且可能溢出 + assertNotNull(r); + } + + @Test + void testCount_noException() { + assertDoesNotThrow(() -> Combination.count(5000, 2500)); + } + + // ----------------------------- + // countSafe() 安全 long 版本测试 + // ----------------------------- + @Test + void testCountSafe_exactFitsLong() { + // C(50, 3) = 19600 fits long + assertEquals(19600L, Combination.countSafe(50, 3)); + } + + @Test + void testCountSafe_overflowThrows() { + // C(100, 50) 超出 long → 应抛 ArithmeticException + assertThrows(ArithmeticException.class, () -> Combination.countSafe(100, 50)); + } + + @Test + void testCountSafe_invalidInput() { + assertThrows(IllegalArgumentException.class, () -> Combination.countSafe(-1, 3)); + assertThrows(IllegalArgumentException.class, () -> Combination.countSafe(3, -1)); + } }