序
本文主要来展示一下简版的work stealing线程池的实现。
Executors
Executors默认提供了几个工厂方法
/**
* Creates a thread pool that maintains enough threads to support
* the given parallelism level, and may use multiple queues to
* reduce contention. The parallelism level corresponds to the
* maximum number of threads actively engaged in, or available to
* engage in, task processing. The actual number of threads may
* grow and shrink dynamically. A work-stealing pool makes no
* guarantees about the order in which submitted tasks are
* executed.
*
* @param parallelism the targeted parallelism level
* @return the newly created thread pool
* @throws IllegalArgumentException if {@code parallelism <= 0}
* @since 1.8
*/
public static ExecutorService newWorkStealingPool(int parallelism) {
return new ForkJoinPool
(parallelism,
ForkJoinPool.defaultForkJoinWorkerThreadFactory,
null, true);
}
/**
* Creates a work-stealing thread pool using all
* {@link Runtime#availableProcessors available processors}
* as its target parallelism level.
* @return the newly created thread pool
* @see #newWorkStealingPool(int)
* @since 1.8
*/
public static ExecutorService newWorkStealingPool() {
return new ForkJoinPool
(Runtime.getRuntime().availableProcessors(),
ForkJoinPool.defaultForkJoinWorkerThreadFactory,
null, true);
}
思路
ForkJoinPool主要用到的是双端队列,不过这里我们粗糙的实现的话,也可以不用到deque。
public class WorkStealingChannel<T> {
private static final Logger LOGGER = LoggerFactory.getLogger(WorkStealingChannel.class);
BlockingDeque<T>[] managedQueues;
AtomicLongMap<Integer> stat = AtomicLongMap.create();
public WorkStealingChannel() {
int nCPU = Runtime.getRuntime().availableProcessors();
int queueCount = nCPU / 2 + 1;
managedQueues = new LinkedBlockingDeque[queueCount];
for(int i=0;i<queueCount;i++){
managedQueues[i] = new LinkedBlockingDeque<T>();
}
}
public void put(T item) throws InterruptedException {
int targetIndex = Math.abs(item.hashCode() % managedQueues.length);
BlockingQueue<T> targetQueue = managedQueues[targetIndex];
targetQueue.put(item);
}
public T take() throws InterruptedException {
int rdnIdx = ThreadLocalRandom.current().nextInt(managedQueues.length);
int idx = rdnIdx;
while (true){
idx = idx % managedQueues.length;
T item = null;
if(idx == rdnIdx){
item = managedQueues[idx].poll();
}else{
item = managedQueues[idx].pollLast();
}
if(item != null){
LOGGER.info("take ele from queue {}",idx);
stat.addAndGet(idx,1);
return item;
}
idx++;
if(idx == rdnIdx){
break;
}
}
//走完一轮没有,则随机取一个等待
LOGGER.info("wait for queue:{}",rdnIdx);
stat.addAndGet(rdnIdx,1);
return managedQueues[rdnIdx].take();
}
public AtomicLongMap<Integer> getStat() {
return stat;
}
}
这里根据cpu的数量建立了几个deque,然后每次put的时候,根据hashcode取模放到对应的队列。然后获取的时候,先从随机一个队列取,没有的话,再robbin round取其他队列的,还没有的话,则阻塞等待指定队列的元素。
测试实例
public class WorkStealingDemo {
static final WorkStealingChannel<String> channel = new WorkStealingChannel<>();
static volatile boolean running = true;
static class Producer extends Thread{
@Override
public void run() {
while(running){
try {
channel.put(UUID.randomUUID().toString());
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
}
static class Consumer extends Thread{
@Override
public void run() {
while(running){
try {
String value = channel.take();
System.out.println(value);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
}
public static void stop(){
running = false;
System.out.println(channel.getStat());
}
public static void main(String[] args) throws InterruptedException {
int nCPU = Runtime.getRuntime().availableProcessors();
int consumerCount = nCPU / 2 + 1;
for (int i = 0; i < nCPU; i++) {
new Producer().start();
}
for (int i = 0; i < consumerCount; i++) {
new Consumer().start();
}
Thread.sleep(30*1000);
stop();
}
}
输出
{0=660972, 1=660613, 2=661537, 3=659846, 4=659918}
从数据来看,还是相对均匀的。
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。