diff --git a/hutool-core/src/main/java/cn/hutool/core/math/Combination.java b/hutool-core/src/main/java/cn/hutool/core/math/Combination.java
index 7da8d3b42a..868ac83ae5 100644
--- a/hutool-core/src/main/java/cn/hutool/core/math/Combination.java
+++ b/hutool-core/src/main/java/cn/hutool/core/math/Combination.java
@@ -1,13 +1,13 @@
package cn.hutool.core.math;
+import cn.hutool.core.util.StrUtil;
+
import java.io.Serializable;
+import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
-import cn.hutool.core.util.NumberUtil;
-import cn.hutool.core.util.StrUtil;
-
/**
* 组合,即C(n, m)
* 排列组合相关类 参考:http://cgs1999.iteye.com/blog/2327664
@@ -32,18 +32,71 @@ public class Combination implements Serializable {
/**
* 计算组合数,即C(n, m) = n!/((n-m)! * m!)
- *
+ *
注意:此方法内部使用 BigInteger 修复了旧版 factorial 的计算错误, + * 但最终仍以 long 返回,因此当结果超过 long 范围时仍会溢出。
+ *建议使用 {@link #countBig(int, int)} 获取精确结果,或使用 + * {@link #countSafe(int, int)} 获取安全 long 版本。
* @param n 总数 * @param m 选择的个数 * @return 组合数 */ + @Deprecated public static long count(int n, int m) { - if (0 == m || n == m) { - return 1; - } - return (n > m) ? NumberUtil.factorial(n, n - m) / NumberUtil.factorial(m) : 0; + BigInteger big = countBig(n, m); + return big.longValue(); } + /** + * 计算组合数 C(n, m) 的 BigInteger 精确版本。 + * 使用逐步累乘除法(非阶乘)保证不溢出、性能好。 + *+ * 数学定义: + * C(n, m) = n! / (m! (n - m)!) + *
+ * 优化方式: + * 1. 利用对称性 m = min(m, n-m) + * 2. 每一步先乘 BigInteger,再除以当前 i,保证数值不暴涨 + * + * @param n 总数 n(必须 >= 0) + * @param m 取出 m(必须 >= 0) + * @return C(n, m) 的 BigInteger 精确值;当 m > n 时返回 BigInteger.ZERO + */ + public static BigInteger countBig(int n, int m) { + if (n < 0 || m < 0) { + throw new IllegalArgumentException("n and m must be non-negative. got n=" + n + ", m=" + m); + } + if (m > n) { + return BigInteger.ZERO; + } + if (m == 0 || n == m) { + return BigInteger.ONE; + } + // 使用对称性:C(n, m) = C(n, n-m) + m = Math.min(m, n - m); + BigInteger result = BigInteger.ONE; + // 从 1 → m 累乘 + for (int i = 1; i <= m; i++) { + int numerator = n - m + i; + result = result.multiply(BigInteger.valueOf(numerator)) + .divide(BigInteger.valueOf(i)); + } + + return result; + } + + /** + * 安全组合数 long 版本。 + * + * @param n 总数 n(必须 >= 0) + * @param m 取出 m(必须 >= 0) + *
若结果超出 long 范围,会抛 ArithmeticException,而非溢出。
+ */ + public static long countSafe(int n, int m) { + BigInteger big = countBig(n, m); + return big.longValueExact(); + } + + /** * 计算组合总数,即C(n, 1) + C(n, 2) + C(n, 3)... * @@ -104,4 +157,5 @@ public class Combination implements Serializable { select(i + 1, resultList, resultIndex + 1, result); } } + } diff --git a/hutool-core/src/test/java/cn/hutool/core/math/CombinationTest.java b/hutool-core/src/test/java/cn/hutool/core/math/CombinationTest.java index 3cc6d84613..134bcb39a0 100644 --- a/hutool-core/src/test/java/cn/hutool/core/math/CombinationTest.java +++ b/hutool-core/src/test/java/cn/hutool/core/math/CombinationTest.java @@ -1,10 +1,12 @@ package cn.hutool.core.math; -import static org.junit.jupiter.api.Assertions.*; import org.junit.jupiter.api.Test; +import java.math.BigInteger; import java.util.List; +import static org.junit.jupiter.api.Assertions.*; + /** * 组合单元测试 * @@ -51,4 +53,96 @@ public class CombinationTest { List