CountDownLatch原理与使用

一个让候选人翻车的面试题

面试官问:"CountDownLatch是什么?能用来做什么?"

候选人小张说:"就是一个倒计时器,等倒数到0的时候,所有等待的线程就可以继续执行了。"

面试官追问:"那它和CyclicBarrier有什么区别?"

小张说:"呃...CyclicBarrier可以循环使用?"

面试官继续追问:"CountDownLatch的countDown()可以多次调用吗?"

小张愣了一下:"应该...可以?"

这个问题看似简单,但能回答好的人不多。CountDownLatch是JDK并发包中最常用的工具之一,理解它的原理对学习AQS、CyclicBarrier都有帮助。

今天这篇文章,把CountDownLatch讲透。

什么是CountDownLatch

基本概念

public class CountDownLatchDemo {
    public static void main(String[] args) throws InterruptedException {
        // 创建计数器,初始值为3
        CountDownLatch latch = new CountDownLatch(3);
        
        // 启动3个线程
        for (int i = 0; i < 3; i++) {
            new Thread(() -> {
                try {
                    // 模拟任务
                    Thread.sleep((long) (Math.random() * 1000));
                    System.out.println(Thread.currentThread().getName() + " 完成");
                } catch (InterruptedException e) {
                } finally {
                    latch.countDown();  // 计数减1
                }
            }).start();
        }
        
        // 主线程等待计数器归零
        latch.await();
        System.out.println("所有任务完成,主线程继续执行");
    }
}

使用场景

// 场景1:等待多个服务启动
public class ServiceStarter {
    public void startAll() throws InterruptedException {
        CountDownLatch latch = new CountDownLatch(3);
        
        new Thread(() -> startDatabase(latch), "DB").start();
        new Thread(() -> startCache(latch), "Cache").start();
        new Thread(() -> startMessageQueue(latch), "MQ").start();
        
        // 等待所有服务启动完成
        latch.await();
        System.out.println("所有服务启动完成,系统可以开始处理请求");
    }
    
    private void startDatabase(CountDownLatch latch) {
        // 启动数据库
        try {
            Thread.sleep(2000);  // 模拟启动时间
            System.out.println("数据库启动完成");
        } catch (InterruptedException e) {
        } finally {
            latch.countDown();
        }
    }
    
    // 其他服务类似...
}

CountDownLatch的原理

基于AQS的共享模式

// CountDownLatch使用AQS的共享模式
public class CountDownLatch {
    private final Sync sync;
    
    // 内部类:继承AQS
    private static final class Sync extends AbstractQueuedSynchronizer {
        Sync(int count) {
            setState(count);  // state = 计数器值
        }
        
        // 共享模式获取:tryAcquireShared
        // 返回负数:获取失败,需要等待
        // 返回正数或0:获取成功
        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }
        
        // 共享模式释放:tryReleaseShared
        protected boolean tryReleaseShared(int releases) {
            for (;;) {
                int c = getState();
                if (c == 0) {
                    return false;  // 已经是0,不能再减少
                }
                int nextc = c - 1;
                if (compareAndSetState(c, nextc)) {
                    return nextc == 0;  // 只有减到0时才返回true
                }
            }
        }
    }
}

await()的实现

public void await() throws InterruptedException {
    sync.acquireSharedInterruptibly(1);
}

// AQS的实现
public final void acquireSharedInterruptibly(int arg) {
    if (Thread.interrupted()) {
        throw new InterruptedException();
    }
    if (tryAcquireShared(arg) < 0) {
        // 获取失败,加入等待队列
        doAcquireSharedInterruptibly(arg);
    }
}

private void doAcquireSharedInterruptibly(int arg) {
    // 添加到队列尾部
    Node node = addWaiter(Node.SHARED);
    boolean failed = true;
    try {
        for (;;) {
            Node p = node.predecessor();
            if (p == head) {
                int r = tryAcquireShared(arg);
                if (r >= 0) {
                    setHeadAndPropagate(node, r);
                    p.next = null;  // help GC
                    failed = false;
                    return;
                }
            }
            // 检查并阻塞
            if (shouldParkAfterFailedAcquire(p, node)) {
                LockSupport.park(this);
                if (Thread.interrupted()) {
                    throw new InterruptedException();
                }
            }
        }
    } finally {
        if (failed) {
            cancelAcquire(node);
        }
    }
}

countDown()的实现

public void countDown() {
    sync.releaseShared(1);
}

// AQS的实现
public final boolean releaseShared(int arg) {
    if (tryReleaseShared(arg)) {
        // 释放成功,唤醒等待的线程
        doReleaseShared();
        return true;
    }
    return false;
}

private void doReleaseShared() {
    for (;;) {
        Node h = head;
        if (h != null && h != tail) {
            int ws = h.waitStatus;
            if (ws == Node.SIGNAL) {
                // 唤醒后继节点
                if (h.compareAndSetWaitStatus(ws, 0)) {
                    unpark(h.next);
                    break;
                }
            } else if (ws == 0 &&
                       !h.compareAndSetWaitStatus(0, Node.PROPAGATE)) {
                // 重置状态
            }
        }
        if (h == head) {
            break;
        }
    }
}

一次性 vs 可循环

一次性屏障

public class OneTimeBarrier {
    public void demo() throws InterruptedException {
        CountDownLatch latch = new CountDownLatch(1);
        
        // 只能使用一次
        latch.await();  // 第一次等待
        
        // latch.countDown();
        // latch.await();  // ❌ 如果countDown已调用,这里会立即返回
        // 但如果countDown未调用,这里会一直阻塞
    }
}

❌ 常见错误

public class CommonMistakes {
    public void mistake1_reuse() {
        // ❌ 错误:CountDownLatch不能重置
        CountDownLatch latch = new CountDownLatch(3);
        
        // 第一次使用
        for (int i = 0; i < 3; i++) {
            new Thread(() -> latch.countDown()).start();
        }
        latch.await();
        
        // 第二次使用 - 不work!
        latch.await();  // 立即返回,但计数已经是0了
    }
    
    public void mistake2_forgetCountDown() {
        // ❌ 错误:忘记调用countDown,导致永久等待
        CountDownLatch latch = new CountDownLatch(1);
        
        new Thread(() -> {
            try {
                Thread.sleep(1000);
                // 忘记调用 latch.countDown();
            } finally {
            }
        }).start();
        
        latch.await();  // 永远等待!
    }
}

正确的替代:CyclicBarrier

public class CyclicBarrierAlternative {
    public void correctReuse() {
        // ✅ 如果需要重置,使用CyclicBarrier
        CyclicBarrier barrier = new CyclicBarrier(3);
        
        // 第一次使用
        for (int i = 0; i < 3; i++) {
            new Thread(() -> {
                try {
                    barrier.await();
                } catch (Exception e) {}
            }).start();
        }
        
        // 可以再次使用
        for (int i = 0; i < 3; i++) {
            new Thread(() -> {
                try {
                    barrier.await();
                } catch (Exception e) {}
            }).start();
        }
    }
}

等待超时

await with timeout

public class TimedAwaitDemo {
    public void awaitWithTimeout() throws InterruptedException {
        CountDownLatch latch = new CountDownLatch(3);
        
        // 等待5秒
        boolean completed = latch.await(5, TimeUnit.SECONDS);
        
        if (completed) {
            System.out.println("所有任务在5秒内完成");
        } else {
            System.out.println("等待超时,但会继续执行");
        }
    }
}

// 实际应用:服务启动超时
public class ServiceStartup {
    public boolean startWithTimeout(long timeout, TimeUnit unit) 
            throws InterruptedException {
        CountDownLatch latch = new CountDownLatch(3);
        
        // 启动服务
        startServices(latch);
        
        // 等待启动完成,最多timeout
        return latch.await(timeout, unit);
    }
}

生产中的实际应用

场景1:并行计算

public class ParallelComputation {
    public int parallelSum(int[] array) throws InterruptedException {
        int n = Runtime.getRuntime().availableProcessors();
        int chunkSize = array.length / n;
        
        CountDownLatch latch = new CountDownLatch(n);
        int[] results = new int[n];
        
        for (int i = 0; i < n; i++) {
            final int start = i * chunkSize;
            final int end = (i == n - 1) ? array.length : start + chunkSize;
            final int index = i;
            
            new Thread(() -> {
                int sum = 0;
                for (int j = start; j < end; j++) {
                    sum += array[j];
                }
                results[index] = sum;
                latch.countDown();
            }).start();
        }
        
        latch.await();
        
        // 合并结果
        int total = 0;
        for (int result : results) {
            total += result;
        }
        return total;
    }
}

场景2:多数据源加载

public class MultiDataSourceLoader {
    private final Map<String, Object> data = new HashMap<>();
    private final CountDownLatch latch = new CountDownLatch(3);
    
    public void loadAll() throws InterruptedException {
        new Thread(() -> {
            data.put("users", loadUsers());
            latch.countDown();
        }, "LoadUsers").start();
        
        new Thread(() -> {
            data.put("products", loadProducts());
            latch.countDown();
        }, "LoadProducts").start();
        
        new Thread(() -> {
            data.put("orders", loadOrders());
            latch.countDown();
        }, "LoadOrders").start();
        
        latch.await();
        System.out.println("所有数据加载完成: " + data.keySet());
    }
    
    private List<User> loadUsers() {
        return new ArrayList<>();
    }
    
    private List<Product> loadProducts() {
        return new ArrayList<>();
    }
    
    private List<Order> loadOrders() {
        return new ArrayList<>();
    }
}

场景3:超时控制

public class TimeoutControl {
    public Response fetchWithTimeout(String url, long timeout) 
            throws Exception {
        CountDownLatch latch = new CountDownLatch(1);
        Response[] response = new Response[1];
        Exception[] error = new Exception[1];
        
        CompletableFuture.runAsync(() -> {
            try {
                response[0] = fetch(url);
            } catch (Exception e) {
                error[0] = e;
            } finally {
                latch.countDown();
            }
        });
        
        boolean completed = latch.await(timeout, TimeUnit.SECONDS);
        
        if (completed && error[0] == null) {
            return response[0];
        } else {
            throw new TimeoutException("Request timeout");
        }
    }
    
    private Response fetch(String url) {
        return new Response();
    }
}

CountDownLatch vs CyclicBarrier

对比表

特性CountDownLatchCyclicBarrier
是否可重置❌ 不可重置✅ 可重置
计数方向递减到0递增到parties
等待线程等待计数器归零等待所有线程到达
重用机制创建新实例调用reset()
典型场景主线程等待子任务多个线程互相等待

代码对比

// CountDownLatch:主线程等待
public class CountDownLatchUsage {
    public void demo() throws InterruptedException {
        CountDownLatch latch = new CountDownLatch(3);
        
        for (int i = 0; i < 3; i++) {
            new Thread(() -> {
                try {
                    Thread.sleep(100);
                    System.out.println("Task " + i + " done");
                } finally {
                    latch.countDown();
                }
            }).start();
        }
        
        latch.await();  // 主线程等待
        System.out.println("All done");
    }
}

// CyclicBarrier:线程间互相等待
public class CyclicBarrierUsage {
    public void demo() throws InterruptedException, BrokenBarrierException {
        CyclicBarrier barrier = new CyclicBarrier(3);
        
        for (int i = 0; i < 3; i++) {
            new Thread(() -> {
                try {
                    System.out.println("Task " + i + " ready");
                    barrier.await();  // 等待其他线程
                    System.out.println("Task " + i + " continue");
                } catch (Exception e) {}
            }).start();
        }
    }
}

面试中的高频追问

追问1:CountDownLatch的计数可以大于1吗?

可以。每次countDown()都会减1,可以一次减任意值:

CountDownLatch latch = new CountDownLatch(5);
latch.countDown();      // 减1
latch.countDown();      // 再减1
// 或者一次性减
// 没有直接方法一次减N,需要多次调用

追问2:countDown()可以在await()之前调用吗?

可以。如果计数已经是0了,await()会立即返回:

CountDownLatch latch = new CountDownLatch(1);
latch.countDown();  // 先调用
latch.await();      // 立即返回

追问3:为什么叫"倒计时锁"?

因为计数是倒着数的:从N倒数到0,像倒计时一样。

追问4:可以await()无限等待吗?

可以。直接调用await()会无限等待,直到计数归零。

latch.await();  // 无限等待
latch.await(5, TimeUnit.SECONDS);  // 等待5秒

【学习小结】

  1. CountDownLatch:一次性倒计时器,计数递减到0
  2. 不可重置:一旦计数归零,不能重置
  3. 基于AQS共享模式:state存储计数,tryAcquireShared检查是否归零
  4. 使用场景:等待多个任务完成、服务启动、并行计算
  5. countDown():在finally块中调用,避免任务失败导致永久等待
  6. 超时等待:使用await(timeout, unit)防止永久阻塞
  7. 区别CyclicBarrier:CDL=主线程等待,CYCB=线程互相等待

延伸阅读