CountDownLatch使用原理

先看一个例子,来说明下CountDownLatch运用场景

public class SpiderDemo {
	private static Logger logger = LoggerFactory.getLogger(SpiderDemo.class);
	private static ExecutorService spiderPool = Executors.newFixedThreadPool(5);

	public static void main(String[] args) {

		CountDownLatch countDown = new CountDownLatch(5);
		for (int i=0;i<5;i++) {
			spiderPool.execute(new SpiderThread(countDown, i));
		}

		try {
			countDown.await();
		} catch (InterruptedException e) {
			e.printStackTrace();
		}

		logger.info("spider end");
		spiderPool.shutdown();
	}

	static class SpiderThread implements Runnable{
		private CountDownLatch countDown;
		private int index;

		public SpiderThread(CountDownLatch countDown, int index) {
			this.countDown = countDown;
			this.index = index;
		}

		@Override
		public void run() {
			// TODO Auto-generated method stub
			logger.info("spider"+index+" data begin.. ");
			try {
				Thread.sleep(3000);
			} catch (InterruptedException e) {
				e.printStackTrace();
			}

			logger.info("spider"+index+" data complete.. ");

			countDown.countDown();
		}

	}

}

我们来看看上面的例子,这里创建了一个CountDownLatch对象,传入一个数值,使用5个线程去执行SpiderThread,结束后使用countDown()方法将计数器递减,最后使用await()方法堵塞,直到所有的线程执行完成。

所以CountDownLatch的使用场景是可以异步去执行多个任务,主线程可以等待所有任务执行完成后,再继续执行。

下面从源码方面分析下CountDownLatch的实现原理

首先查看构造方法CountDownLatch(int count )

    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }

	Sync(int count) {
            setState(count);
        }

当count小于0的时候,直接抛出异常,大于0则创建了一个Sync对象。Sync构造函数则是直接设置了一个变量值state来保存count数组

再看看countDown()方法,很显然它的目的是将count数组-1

    public void countDown() {
        sync.releaseShared(1);
    }

    public final boolean releaseShared(int arg) {
        if (tryReleaseShared(arg)) {
            doReleaseShared();
            return true;
        }
        return false;
    }

它里面是直接调用了releaseShared(1)方法,而releaseShared又调用了tryReleaseShared方法,如果尝试执行成功,则再调用doReleaseShared()方法

	protected boolean tryReleaseShared(int releases) {
        // Decrement count; signal when transition to zero
        for (;;) {
            int c = getState();
            if (c == 0)
                return false;
            int nextc = c-1;
            if (compareAndSetState(c, nextc))
                return nextc == 0;
        }
    }

tryReleaseShared是干什么的呢?从上面的例子看它的作用就是将count值-1,如果当前count值就是0了,那就啥都不干,直接返回false了,否则使用CAS算法尝试将count值-1,最后判断当前的count值是否为0,若为0则返回true,返回true就需要执行doReleaseShared方法了

    private void doReleaseShared() {
        for (;;) {
            Node h = head;
            if (h != null && h != tail) {
                int ws = h.waitStatus;
                if (ws == Node.SIGNAL) {
                    if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
                        continue;            // loop to recheck cases
                    unparkSuccessor(h);
                }
                else if (ws == 0 &&
                         !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
                    continue;                // loop on failed CAS
            }
            if (h == head)                   // loop if head changed
                break;
        }
    }

doReleaseShared又干了些啥?doReleaseShared的作用就是唤醒主线程继续执行,这里调用了unparkSuccessor方法

最后来看下await()方法

    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }
    public final void acquireSharedInterruptibly(int arg)
            throws InterruptedException {
        if (Thread.interrupted())
            throw new InterruptedException();
        if (tryAcquireShared(arg) < 0)
            doAcquireSharedInterruptibly(arg);
    }

    private void doAcquireSharedInterruptibly(int arg)
        throws InterruptedException {
        final Node node = addWaiter(Node.SHARED);
        boolean failed = true;
        try {
            for (;;) {
                final Node p = node.predecessor();
                if (p == head) {
                    int r = tryAcquireShared(arg);
                    if (r >= 0) {
                        setHeadAndPropagate(node, r);
                        p.next = null; // help GC
                        failed = false;
                        return;
                    }
                }
                if (shouldParkAfterFailedAcquire(p, node) &&
                    parkAndCheckInterrupt())
                    throw new InterruptedException();
            }
        } finally {
            if (failed)
                cancelAcquire(node);
        }
    }

可以看到最终会调用doAcquireSharedInterruptibly(若中途线程有被打断,则抛出异常了)

首先会添加一个Node节点(Node.SHARED看源码实质就是new Node()),这个节点就是保存所有调用await()方法的线程(因为可能会有多个线程调用了await()方法),所以这里每调用一次就保存下来一个节点,最后一个一个的释放,如果这里判断count不为0,则会使用park方法将线程挂起。

 

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注