Java 多线程(7): ThreadLocal 的应用及原理

2

在涉及到多线程需要共享变量的时候,一般有两种方法:其一就是使用互斥锁,使得在每个时刻只能有一个线程访问该变量,好处就是便于编码(直接使用 synchronized 关键字进行同步访问),缺点在于这增加了线程间的竞争,降低了效率;其二就是使用本文要讲的 ThreadLocal。如果说 synchronized 是以“时间换空间”,那么 ThreadLocal 就是 “以空间换时间” —— 因为 ThreadLocal 的原理就是为每个线程都提供一个这样的变量,使得这些变量是线程级别的变量,不同线程之间互不影响,从而达到可以并发访问而不出现并发问题的目的。


首先我们来看一个客观的事实:当一个可变对象被多个线程访问时,可能会得到非预期的结果 —— 所以先让我们来看一个例子。在讲到并发访问的问题的时候,SimpleDateFormat 总是会被拿来当成一个绝好的例子(从这点看感谢 JDK 提供了这么一个有设计缺陷的类方便我们当成反面教材 :) )。因为 SimpleDateFormatformatparse 方法共享从父类 DateFormat 继承而来的 Calendar 对象:
DateFormat 的 Calendar 对象

并且在 formatparse 方法中都会改变这个 Calendar 对象:

  • format 方法片段:

format 方法片段

  • parse 方法片段:

parse 方法片段

就拿 format 方法来说,考虑如下的并发情景:

  • 线程A 此时调用 calendar.setTime(date1),然后 线程A 被中断;
  • 接着 线程B 执行,然后调用 calendar.setTime(date2),然后 线程B 被中断;
  • 接着又是 线程A 执行,但是此时的 calendar 已经和之前的不一致了,所以便导致了并发问题。

所以因为这个共享的 calendar 对象,SimpleDateFormat 并不是一个线程安全的类,我们写一段代码来测试下。

(1)定义 DateFormatWrapper 类,来包装对 SimpleDateFormat 的调用:

public class DateFormatWrapper {

    private static final SimpleDateFormat SDF = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");

    public static String format(Date date) {
        return SDF.format(date);
    }

    public static Date parse(String str) throws ParseException {
        return SDF.parse(str);
    }
    
}

(2)然后写一个 DateFormatTest,开启多个线程来使用 DateFormatWrapper

public class DateFormatTest {

    public static void main(String[] args) throws Exception {
        ExecutorService threadPool = Executors.newCachedThreadPool(); // 创建无大小限制的线程池

        List<Future<?>> futures = new ArrayList<>();

        for (int i = 0; i < 9; i++) {
            DateFormatTask task = new DateFormatTask();
            Future<?> future = threadPool.submit(task); // 将任务提交到线程池

            futures.add(future);
        }

        for (Future<?> future : futures) {
            try {
                future.get();
            } catch (ExecutionException ex) { // 运行时如果出现异常则进入 catch 块
                System.err.println("执行时出现异常:" + ex.getMessage());
            }
        }

        threadPool.shutdown();
    }

    static class DateFormatTask implements Callable<Void> {

        @Override
        public Void call() throws Exception {
            String str = DateFormatWrapper.format(
                    DateFormatWrapper.parse("2017-07-17 16:54:54"));
            System.out.printf("Thread(%s) -> %s\n", Thread.currentThread().getName(), str);

            return null;
        }

    }
}

某次运行的结果:
某次运行的结果

可以发现,SimpleDateFormat 在多线程共享的情况下,不仅可能会出现结果错误的情况,还可能会由于并发访问导致运行异常。当然,我们肯定有解决的办法:

  1. DateFormatWrapperformatparse 方法加上 synchronized 关键字,坏处就是前面提到的这会加大线程间的竞争和切换而降低效率;
  2. 不使用全局的 SimpleDateFormat 对象,而是每次使用 formatparse 方法都新建一个 SimpleDateFormat 对象,坏处也很明显,每次调用 format 或者 parse 方法都要新建一个 SimpleDateFormat,这会加大 GC 的负担;
  3. 使用 ThreadLocalThreadLocal<SimpleDateFormat> 可以为每个线程提供一个独立的 SimpleDateFormat 对象,创建的 SimpleDateFormat 对象个数最多和线程个数相同,相比于 (1),使用ThreadLocal不存在线程间的竞争;相比于 (2),使用ThreadLocal创建的 SimpleDateFormat 对象个数也更加合理(不会超过线程的数量)。

我们使用 ThreadLocal 来对 DateFormatWrapper 进行修改,使得每个线程使用单独的 SimpleDateFormat

public class DateFormatWrapper {

    private static final ThreadLocal<SimpleDateFormat> SDF = new ThreadLocal<SimpleDateFormat>() {
        @Override
        protected SimpleDateFormat initialValue() {
            return new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
        }

    };

    public static String format(Date date) {
        return SDF.get().format(date);
    }

    public static Date parse(String str) throws ParseException {
        return SDF.get().parse(str);
    }

}

如果使用 Java8,则初始化 ThreadLocal 对象的代码可以改为:

private static final ThreadLocal<SimpleDateFormat> SDF
            = ThreadLocal.withInitial(() -> new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"));

然后再运行 DateFormatTest,便始终是预期的结果:
正确的结果


我们已经看到了 ThreadLocal 的功能,那 ThreadLocal 是如何实现为每个线程提供一份共享变量的拷贝呢?

在使用 ThreadLocal 时,当前线程访问 ThreadLocal 中包含的变量是通过 get() 方法,所以首先来看这个方法的实现:

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}

通过代码可以猜测:

  • 在某个地方(其实就是在 ThreadLocal 的内部),JDK 实现了一个类似于 HashMap 的类,叫 ThreadLocalMap,该 “Map” 的键类型为 ThreadLocal<T>,值类型为 T
  • 然后每个线程都关联着一个 ThreadLocalMap 对象,并且可以通过 getMap(Thread t) 方法来获得 线程t 关联的 ThreadLocalMap 对象;
  • ThreadLocalMap 类有个以 ThreadLocal 对象为参数的 getEntry(ThreadLocal) 的方法,用来获得当前 ThreadLocal 对象关联的 Entry 对象。一个 Entry 对象就是一个键值对,键(key)是 ThreadLocal 对象,值(value)是该 ThreadLocal 对象包含的变量(即 T)。

查看 getMap(Thread) 方法:
getMap(Thread)

直接返回的就是 t.threadLocals,原来在 Thread 类中有一个就叫 threadLocalsThreadLocalMap 的变量:
 Thread 的 threadLocals 变量

所以每个 Thread 都会拥有一个 ThreadLocalMap 变量,来存放属于该 Thread 的所有 ThreadLocal 变量。这样来看的话,ThreadLocal就相当于一个调度器,每次调用 get 方法的时候,都会先找到当前线程的 ThreadLocalMap,然后再在这个 ThreadLocalMap 中找到对应的线程本地变量。

ThreadLocal 的 get() 方法的流程

然后我们来看看当 mapnull(即第一次调用 get())时调用的 setInitialValue() 方法:

private T setInitialValue() {
    T value = initialValue();
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
    return value;
}

该方法首先会调用 initialValue() 方法来获得该 ThreadLocal 对象中需要包含的变量 —— 所以这就是为什么使用 ThreadLocal 是需要继承 ThreadLocal 时并覆写 initialValue() 方法,因为这样才能让 setInitialValue() 调用 initialValue() 从而得到 ThreadLocal 包含的初始变量;然后就是当 map 不为 null 的时候,将该变量(value)与当前ThreadLocal对象(this)在 map 中进行关联;如果 mapnull,则调用 createMap 方法:

void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

createMap 会调用 ThreadLocalMap 的构造方法来创建一个 ThreadLocalMap 对象:
ThreadLocalMap 的构造方法

可以看到该方法通过一个 ThreadLocal 对象(firstKey)和该 ThreadLocal 包含的对象(firstValue)构造了一个 ThreadLocalMap 对象,使得该 map 在构造完毕时候就包含了这样一个键值对(firstKey -> firstValue)。


为啥需要使用 Map 呢?因为一个线程可能有多个 ThreadLocal 对象,可能是包含 SimpleDateFormat,也可能是包含一个数据库连接 Connection,所以不同的变量需要通过对应的 ThreadLocal 对象来快速查找 —— 那么 Map 当然是最好的方式。


ThreadLocal 还提供了修改和删除当前包含对象的方法,修改的方法为 set,删除的方法为 remove

public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}

很好理解,如果当前 ThredLocal 还没有包含值,那么就调用 createMap 来初始化当前线程的 ThreadLocalMap 对象,否则直接在 map 中修改当前 ThreadLocalthis)包含的值。

public void remove() {
    ThreadLocalMap m = getMap(Thread.currentThread());
    if (m != null)
        m.remove(this);
}

remove 方法就是获得当前线程的 ThreadLocalMap 对象,然后调用这个 mapremove(ThreadLocal) 方法。查看 ThreadLocalMapremove(ThreadLocal) 方法的实现:
remove(ThreadLocal)

逻辑就是先找到参数(ThreadLocal对象)对应的 Entry,然后调用 Entryclear() 方法,再调用 expungeStaleEntry(i)i 为该 EntrymapEntry 数组中的索引。

(1)首先来看看 e.clear() 做了什么。

查看 ThreadLocalMap 的源代码,我们可以发现这个 “Map” 的 Entry 的实现如下:
Entry 的实现

可以看到,该 Entry 类继承自 WeakReference<ThreadLocal<?>>,所以 Entry 是一个 WeakReference(弱引用),而且该 WeakReference 包含的是一个 ThreadLocal 对象 —— 因而每个 Entry 是一个弱引用的 ThreadLocal 对象(又因为 Entry 包括了一个 value 变量,所以该 Entry 构成了一个 ThreadLocal -> Object 的键值对),而 Entryclear() 方法,是继承自 WeakReference,作用就是将 WeakReference 包含的对象的引用设置为 null

clear() 方法

我们知道对于一个弱引用的对象,一旦该对象不再被其他对象引用(比如像 clear() 方法那样将对象引用直接设置为 null),那么在 GC 发生的时候,该对象便会被 GC 回收。所以让 Entry 作为一个 WeakReference,配合 ThreadLocalremove 方法,可以及时清除某个 Entry 中的 ThreadLocalEntrykey)。

(2)expungeStaleEntry(i)的作用

先来看 expungeStaleEntry 的前一半代码:

expungeStaleEntry 的前一半代码

expungeStaleEntry 这部分代码的作用就是将 i 位置上的 Entryvalue 设置为 null,以及将 Entry 的引用设置为 null。为什么要这做呢?因为前面调用 e.clear(),只是将 Entrykey 设置为 null 并且可以使其在 GC 是被快速回收,但是 Entryvalue 在调用 e.clear() 后并不会为 null —— 所以如果不对 value 也进行清除,那么就可能会导致内存泄漏了。因此expungeStaleEntry 方法的一个作用在于可以把需要清除的 Entry 彻底的从 ThreadLocalMap 中清除(keyvalueEntry 全部设置为 null)。但是 expungeStaleEntry 还有另外的功能:看 expungeStaleEntry 的后一半代码:

expungeStaleEntry 的后一半代码

作用就是扫描位置 staleSlot 之后的 Entry 数组(直到某一个为 null 的位置),清除每个 keyThreadLocal) 为 nullEntry,所以使用 expungeStaleEntry 可以降低内存泄漏的概率。但是如果某些 ThreadLocal 变量不需要使用但是却没有调用到 expungeStaleEntry 方法,那么就会导致这些 ThreadLocal 变量长期的贮存在内存中,引起内存浪费或者泄露 —— 所以,如果确定某个 ThreadLocal 变量已经不需要使用,需要及时的使用 ThreadLocalremove() 方法(ThreadLocalgetset 方法也会调用到 expungeStaleEntry),将其从内存中清除。

你可能感兴趣的

scort · 4月27日

写得很好,有一个疑问在expungeStaleEntry方法的后半段的循环条件,如果下一个slot里entry是null的话,就不会继续下去,如果再下一个slot里的entry有key是null的情况呢?这个问题在源码中几处处理key是null的代码中都是存在的,还是我理解错了?

回复

0

你的理解没有错,就是 expungeStaleEntry(i) 每次只扫描 [i, j) 这个范围的数据位置,j 是 i 之后第一个为 null 的 slot,该方法的 javadoc 的注释也是这么说明的。至于为什么这么做,javadoc 有提到一个算法,Knuth 6.4 Algorithm R,应该是 Knuth 写的程序设计的那个书的 6.4 节,可能和这个有关。我的想法是,估计是想在性能和内存之间做个取舍,即每次只扫描底层数组的一个区域,就去掉这个区域中那些已经用不到的 ThreadLocal,而不是每次都扫描整个数组,因为下一次扫描又会扫描到其他区域 —— 可能有概率性的东西在这里面。

mizhoux 作者 · 4月27日
载入中...