这是一个再简单不过的组合问题:
编号 0-9 的 10 个球里面拿取任意 5 个,有多少种不同的组合?
答案是可以用公式算出来的,也就是 (10!) / ((5!) ^ 2) = 252
个。但是如果要把它们全部遍历出来呢?
下面是一种效率比较高的遍历方式,原理是将所有结果集看作是树节点(准确的说是叶子节点),然后去遍历这棵树即可。树的生成规则是:
- 一级节点的值是第一个球的编号,二级节点是第二个球的编号,依此类推;
- 任何一级节点的值必须大于上级节点的值。
这样能做到所有的叶子节点刚好覆盖所有的解,没有多余没有缺失。
如何用多线程遍历这棵树呢?按一级节点不同的值,分别放到线程里面遍历即可。每个节点代表一个子树,先计算该树的起始和终止节点,作为解空间的边界,然后从起始节点开始遍历直到终止节点为止即可。
代码如下:
import java.util.Arrays;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.IntStream;
/**
* 多线程遍历组合树
*/
public class CombinationIterator {
public static void main(String[] args) throws Exception {
int itemCount = 50;
int pickCount = 10;
AtomicLong answerCount = new AtomicLong();
long start = System.currentTimeMillis();
// 根据一级节点拆分解空间
int[] level1Values = IntStream.range(0, itemCount - pickCount + 1).toArray();
ExecutorService threadPool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
// 对拆分的解空间多线程遍历
for (int level1Value : level1Values) {
// 第一个解和最后一个解
int[] firstPicks = IntStream.range(level1Value, pickCount + level1Value).toArray();
int[] lastPicks =
IntStream.concat(
IntStream.range(level1Value, level1Value + 1),
IntStream.range(itemCount - pickCount + 1, itemCount)
).toArray();
// 遍历解区间
threadPool.submit(() -> answerCount.addAndGet(iterateSubTree(firstPicks, lastPicks)));
}
threadPool.shutdown();
threadPool.awaitTermination(1, TimeUnit.HOURS);
System.out.printf("%d 取 %d 的组合遍历完成,共有 %d 个解。%n", itemCount, pickCount, answerCount.get());
System.out.printf("执行时间 %dms", (System.currentTimeMillis() - start));
}
/**
* 遍历区间的组合
*
* @param firstPicks 区间的第一个解
* @param lastPicks 区间的最后一个解
*
* @return 区间的解数量
*/
private static long iterateSubTree(int[] firstPicks, int[] lastPicks) {
long counter = 0;
int[] picks = firstPicks;
do {
if (picks != null) {
// System.out.println(Arrays.toString(picks));
counter++;
}
picks = getNextPick(picks, lastPicks);
} while (picks != null);
System.out.println("区间 " + Arrays.toString(lastPicks) + " 遍历完成,共 " + counter + " 个解");
return counter;
}
/**
* 根据当前解计算下一个解,直到遇到最终解,则返回 null
*
* @param picks 当前解
* @param lastPicks 最终解
*
* @return 下一个解或 null
*/
private static int[] getNextPick(int[] picks, int[] lastPicks) {
if (Arrays.equals(picks, lastPicks)) {
return null;
}
int[] nextPick = Arrays.copyOf(picks, picks.length);
int movable = findMovable(nextPick, lastPicks);
nextPick[movable]++;
if (movable != nextPick.length - 1) {
// 将 movable 后面的点移回到贴紧 movable 的右侧
partialReset(nextPick, movable);
}
return nextPick;
}
// 在 picks 中寻找第一个可以右移的点
private static int findMovable(int[] picks, int[] lastPicks) {
for (int i = picks.length - 1; i >= 0; i--) {
if (picks[i] < lastPicks[i]) {
return i;
}
}
return -1; // 实际上不会返回 -1
}
// 指定位置后面的点都移回到贴紧该位置的右侧
private static void partialReset(int[] picks, int resetStart) {
for (int i = resetStart + 1; i < picks.length; i++) {
picks[i] = picks[i - 1] + 1;
}
}
}
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。