mirror of
https://gitee.com/dromara/hutool.git
synced 2026-02-09 09:16:26 +08:00
feat(Combination): add BigInteger-based combination calculation methods and deprecate old count method
This commit is contained in:
@@ -1,13 +1,13 @@
|
||||
package cn.hutool.core.math;
|
||||
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.math.BigInteger;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import cn.hutool.core.util.NumberUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
|
||||
/**
|
||||
* 组合,即C(n, m)<br>
|
||||
* 排列组合相关类 参考:http://cgs1999.iteye.com/blog/2327664
|
||||
@@ -32,18 +32,71 @@ public class Combination implements Serializable {
|
||||
|
||||
/**
|
||||
* 计算组合数,即C(n, m) = n!/((n-m)! * m!)
|
||||
*
|
||||
* <p>注意:此方法内部使用 BigInteger 修复了旧版 factorial 的计算错误,
|
||||
* 但最终仍以 long 返回,因此当结果超过 long 范围时仍会溢出。</p>
|
||||
* <p>建议使用 {@link #countBig(int, int)} 获取精确结果,或使用
|
||||
* {@link #countSafe(int, int)} 获取安全 long 版本。</p>
|
||||
* @param n 总数
|
||||
* @param m 选择的个数
|
||||
* @return 组合数
|
||||
*/
|
||||
@Deprecated
|
||||
public static long count(int n, int m) {
|
||||
if (0 == m || n == m) {
|
||||
return 1;
|
||||
}
|
||||
return (n > m) ? NumberUtil.factorial(n, n - m) / NumberUtil.factorial(m) : 0;
|
||||
BigInteger big = countBig(n, m);
|
||||
return big.longValue();
|
||||
}
|
||||
|
||||
/**
|
||||
* 计算组合数 C(n, m) 的 BigInteger 精确版本。
|
||||
* 使用逐步累乘除法(非阶乘)保证不溢出、性能好。
|
||||
* <p>
|
||||
* 数学定义:
|
||||
* C(n, m) = n! / (m! (n - m)!)
|
||||
* <p>
|
||||
* 优化方式:
|
||||
* 1. 利用对称性 m = min(m, n-m)
|
||||
* 2. 每一步先乘 BigInteger,再除以当前 i,保证数值不暴涨
|
||||
*
|
||||
* @param n 总数 n(必须 >= 0)
|
||||
* @param m 取出 m(必须 >= 0)
|
||||
* @return C(n, m) 的 BigInteger 精确值;当 m > n 时返回 BigInteger.ZERO
|
||||
*/
|
||||
public static BigInteger countBig(int n, int m) {
|
||||
if (n < 0 || m < 0) {
|
||||
throw new IllegalArgumentException("n and m must be non-negative. got n=" + n + ", m=" + m);
|
||||
}
|
||||
if (m > n) {
|
||||
return BigInteger.ZERO;
|
||||
}
|
||||
if (m == 0 || n == m) {
|
||||
return BigInteger.ONE;
|
||||
}
|
||||
// 使用对称性:C(n, m) = C(n, n-m)
|
||||
m = Math.min(m, n - m);
|
||||
BigInteger result = BigInteger.ONE;
|
||||
// 从 1 → m 累乘
|
||||
for (int i = 1; i <= m; i++) {
|
||||
int numerator = n - m + i;
|
||||
result = result.multiply(BigInteger.valueOf(numerator))
|
||||
.divide(BigInteger.valueOf(i));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* 安全组合数 long 版本。
|
||||
*
|
||||
* @param n 总数 n(必须 >= 0)
|
||||
* @param m 取出 m(必须 >= 0)
|
||||
* <p>若结果超出 long 范围,会抛 ArithmeticException,而非溢出。</p>
|
||||
*/
|
||||
public static long countSafe(int n, int m) {
|
||||
BigInteger big = countBig(n, m);
|
||||
return big.longValueExact();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 计算组合总数,即C(n, 1) + C(n, 2) + C(n, 3)...
|
||||
*
|
||||
@@ -104,4 +157,5 @@ public class Combination implements Serializable {
|
||||
select(i + 1, resultList, resultIndex + 1, result);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package cn.hutool.core.math;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.math.BigInteger;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
/**
|
||||
* 组合单元测试
|
||||
*
|
||||
@@ -51,4 +53,96 @@ public class CombinationTest {
|
||||
List<String[]> list2 = combination.select(0);
|
||||
assertEquals(1, list2.size());
|
||||
}
|
||||
|
||||
|
||||
// -----------------------------
|
||||
// countBig() 正确性测试
|
||||
// -----------------------------
|
||||
@Test
|
||||
void testCountBig_basicCases() {
|
||||
assertEquals(BigInteger.ONE, Combination.countBig(5, 0));
|
||||
assertEquals(BigInteger.ONE, Combination.countBig(5, 5));
|
||||
assertEquals(BigInteger.valueOf(10), Combination.countBig(5, 3));
|
||||
assertEquals(BigInteger.valueOf(10), Combination.countBig(5, 2));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testCountBig_mGreaterThanN() {
|
||||
assertEquals(BigInteger.ZERO, Combination.countBig(5, 6));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testCountBig_negativeInput() {
|
||||
assertThrows(IllegalArgumentException.class, () -> Combination.countBig(-1, 3));
|
||||
assertThrows(IllegalArgumentException.class, () -> Combination.countBig(5, -2));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testCountBig_symmetry() {
|
||||
assertEquals(Combination.countBig(20, 3), Combination.countBig(20, 17));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testCountBig_largeNumbers() {
|
||||
// C(50, 3) = 19600
|
||||
assertEquals(new BigInteger("19600"), Combination.countBig(50, 3));
|
||||
|
||||
// C(100, 50) 的确切值(重要测试)
|
||||
BigInteger expected = new BigInteger(
|
||||
"100891344545564193334812497256"
|
||||
);
|
||||
assertEquals(expected, Combination.countBig(100, 50));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testCountBig_veryLargeCombination() {
|
||||
// 不比较具体值,只断言不要抛错
|
||||
BigInteger result = Combination.countBig(2000, 1000);
|
||||
assertTrue(result.signum() > 0);
|
||||
}
|
||||
|
||||
// -----------------------------
|
||||
// count(long) 兼容性测试
|
||||
// -----------------------------
|
||||
@Test
|
||||
void testCount_basic() {
|
||||
assertEquals(10L, Combination.count(5, 3));
|
||||
assertEquals(1L, Combination.count(5, 0));
|
||||
assertEquals(0L, Combination.count(5, 6));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testCount_overflowBehavior() {
|
||||
// C(100, 50) 远超 long 范围,但旧版行为要求不抛异常
|
||||
long r = Combination.count(100, 50);
|
||||
|
||||
// longValue() 不抛异常,并且可能溢出
|
||||
assertNotNull(r);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testCount_noException() {
|
||||
assertDoesNotThrow(() -> Combination.count(5000, 2500));
|
||||
}
|
||||
|
||||
// -----------------------------
|
||||
// countSafe() 安全 long 版本测试
|
||||
// -----------------------------
|
||||
@Test
|
||||
void testCountSafe_exactFitsLong() {
|
||||
// C(50, 3) = 19600 fits long
|
||||
assertEquals(19600L, Combination.countSafe(50, 3));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testCountSafe_overflowThrows() {
|
||||
// C(100, 50) 超出 long → 应抛 ArithmeticException
|
||||
assertThrows(ArithmeticException.class, () -> Combination.countSafe(100, 50));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testCountSafe_invalidInput() {
|
||||
assertThrows(IllegalArgumentException.class, () -> Combination.countSafe(-1, 3));
|
||||
assertThrows(IllegalArgumentException.class, () -> Combination.countSafe(3, -1));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user