3

前言

ThreadLocal用于多线程环境下每个线程存储和获取线程的局部变量,这些局部变量与线程绑定,线程之间互不影响。本篇文章将对ThreadLocal的使用和原理进行学习。

正文

一. ThreadLocal的使用

以一个简单例子对ThreadLocal的使用进行说明。

通常,ThreadLocal的使用是将其声明为类的私有静态字段,如下所示。

public class ThreadLocalLearn {

    private static final ThreadLocal<String> threadLocal = new ThreadLocal<>();

    public void setThreadName(String threadName) {
        threadLocal.set(threadName);
    }

    public String getThreadName() {
        return threadLocal.get();
    }

}

ThreadLocalLearn类具有一个声明为private staticThreadLocal对象字段,每个线程通过ThreadLocalLearn提供的
setThreadName()方法存放线程名,通过getThreadName()方法获取线程名。

编写一个测试程序如下所示。

class ThreadLocalLearnTest {

    private ThreadLocalLearn threadLocalLearn;

    @BeforeEach
    public void setUp() {
        threadLocalLearn = new ThreadLocalLearn();
    }

    @Test
    void givenMultiThreads_whenSetThreadNameToThreadLocal_thenCanGetThreadNameFromThreadLocal() {
        Thread threadApple = new Thread(() -> {
            threadLocalLearn.setThreadName("Thread-Apple");
            System.out.println(Thread.currentThread().getName() + ": " + threadLocalLearn.getThreadName());
        }, "Thread-Apple");

        Thread threadPeach = new Thread(() -> {
            threadLocalLearn.setThreadName("Thread-Peach");
            System.out.println(Thread.currentThread().getName() + ": " + threadLocalLearn.getThreadName());
        }, "Thread-Peach");

        threadApple.start();
        threadPeach.start();
    }

}

测试程序中启用了两个线程,两个线程执行了同样的操作,即将线程的名字通过编写的ThreadLocalLearn存放,然后又取出。打印结果如下所示。

Thread-Apple: Thread-Apple
Thread-Peach: Thread-Peach

打印结果显示,两个线程的局部变量与线程绑定,线程之间互不影响。

二. ThreadLocal的原理。

首先分析一下ThreadLocalset()方法,其源码如下所示。

public void set(T value) {
    // 获取当前线程
    Thread t = Thread.currentThread();
    // 获取当前线程的ThreadLocalMap
    ThreadLocalMap map = getMap(t);
    if (map != null)
        // 以ThreadLocal对象为键,将value存到当前线程的ThreadLocalMap中
        map.set(this, value);
    else
        // 如果当前线程没有ThreadLocalMap,则先创建,再存值
        createMap(t, value);
}

ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

由上面源码可知,ThreadLocalset()方法实际上是ThreadLocal以自身对象为键,将value存放到当前线程的ThreadLocalMap中。每个线程对象都有一个叫做threadLocals的字段,该字段是一个ThreadLocalMap类型的对象。ThreadLocalMap类是ThreadLocal类的一个静态内部类,用于线程对象存储线程独享的变量副本。

ThreadLocalMap实际上并不是一个Map,关于ThreadLocalMap是如何存储线程独享的变量副本,将在后一小节进行分析。下面再看一下ThreadLocalget()方法。

public T get() {
    // 获取当前线程
    Thread t = Thread.currentThread();
    // 获取当前线程的ThreadLocalMap
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        // 以ThreadLocal对象为键,从当前线程的ThreadLocalMap中获取value
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            T result = (T)e.value;
            return result;
        }
    }
    // 如果当前线程没有ThreadLocalMap,则创建ThreadLocalMap,并以ThreadLocal对象为键存入一个初始值到创建的ThreadLocalMap中
    // 如果有ThreadLocalMap,但获取不到value,则以ThreadLocal对象为键存入一个初始值到ThreadLocalMap中
    // 返回初始值,且初始值一般为null
    return setInitialValue();
}

private T setInitialValue() {
    T value = initialValue();
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
    return value;
}

由上面源码可知,ThreadLocalget()方法实际上是ThreadLocal以自身对象为键,从当前线程的ThreadLocalMap中获取value

通过分析ThreadLocalset()get()方法可知,ThreadLocal能够在多线程环境下存储和获取线程的局部变量,实质是将局部变量值存放在每个线程对象的ThreadLocalMap中,因此线程之间互不影响。

三. ThreadLocalMap的原理

ThreadLocalMap本身不是Map,但是可以实现以key-value的形式存储线程的局部变量。与Map类似,ThreadLocalMap中将键值对的关系封装为了一个Entry对象,EntryThreadLocalMap的静态内部类,源码如下所示。

static class Entry extends WeakReference<ThreadLocal<?>> {
    Object value;

    Entry(ThreadLocal<?> k, Object v) {
        super(k);
        value = v;
    }
}

Entry继承于WeakReference,因此Entry是一个弱引用对象,而作为键的ThreadLocal对象是被弱引用的对象。

首先分析ThreadLocalMap的构造函数。ThreadLocalMap有两个构造函数,这里只分析签名为ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue)的构造函数。

ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
    // 创建一个容量为16的Entry数组
    table = new Entry[INITIAL_CAPACITY];
    // 使用散列算法计算第一个键值对在数组中的索引
    int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
    // 创建Entry对象,并存放在Entry数组的索引对应位置
    table[i] = new Entry(firstKey, firstValue);
    size = 1;
    // 根据Entry数组初始容量大小设置扩容阈值
    setThreshold(INITIAL_CAPACITY);
}

ThreadLocalMap的散列算法为将ThreadLocal的哈希码与Entry数组长度减一做相与操作,由于Entry数组长度为2的幂次方,因此上述散列算法实质是ThreadLocal的哈希码对Entry数组长度取模。通过散列算法计算得到初始键值对在Entry数组中的位置后,会创建一个Entry对象并存放在数组的对应位置。最后根据公式:len * 2 / 3计算扩容阈值。

由上述分析可知,创建ThreadLocalMap对象时便会初始化存放键值对关系的Entry数组。现在看一下ThreadLocalMapset()方法。

// 调用set()方法时会传入一对键值对
private void set(ThreadLocal<?> key, Object value) {
    Entry[] tab = table;
    int len = tab.length;
    // 通过散列算法计算键值对的索引位置
    int i = key.threadLocalHashCode & (len-1);

    // 遍历Entry数组
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        // 获取当前Entry的键
        ThreadLocal<?> k = e.get();

        // 当前Entry的键与键值对的键相等(即指向同一个ThreadLocal对象),则更新当前Entry的value为键值对的值
        if (k == key) {
            e.value = value;
            return;
        }

        // 当前Entry的键被垃圾回收了,这样的Entry称为陈旧项,则根据键值对创建Entry并替换陈旧项
        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }

    // 此时i表示遍历Entry数组时遇到的第一个空槽的索引
    // 程序运行到这里,说明遍历Entry数组时,在遇到第一个空槽前,遍历过的Entry的键与键值对的键均不相等,同时也没有陈旧项
    // 此时根据键值对创建Entry对象并存放在索引为i的位置(即空槽的位置)
    tab[i] = new Entry(key, value);
    int sz = ++size;
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        // Entry数组中键值对数量大于等于阈值,则触发rehash()
        // rehash()会先遍历Entry数组并删除陈旧项,如果删除陈旧项之后,键值对数量还大于等于阈值的3/4,则进行扩容
        // 扩容后,Entry数组长度应该为扩容前的两倍
        rehash();
}

private static int nextIndex(int i, int len) {
    return ((i + 1 < len) ? i + 1 : 0);
}

set()方法中,首先通过散列算法计算键值对的索引位置,然后从计算得到的索引位置开始往后遍历Entry数组,一直遍历到第一个空槽为止。在遍历的过程中,如果遍历到某个Entry的键与键值对的键相等,则更新这个Entry的值为键值对的值;如果遍历到某个Entry并且这个Entry被判定为陈旧项(键被垃圾回收的Entry对象),那么执行清除陈旧项的逻辑;如果遍历遇到空槽了,但没有发现有键与键值对的键相等的Entry,也没有陈旧项,则根据键值对生成Entry对象并存放在空槽的位置。

set()方法中,需要清除陈旧项时调用了replaceStaleEntry()方法,该方法会根据键值对创建Entry对象并替换陈旧项,同时触发一次清除陈旧项的逻辑。replaceStaleEntry()方法的实现如下所示。

// table[staleSlot]为陈旧项
// 该方法实际就是从索引staleSlot开始向后遍历Entry数组直到遇到空槽,如果找到某一个Entry的键与键值对的键相等,那么将这个Entry的值更新为键值对的值,并将这个Entry与陈旧项互换位置
// 如果遇到空槽也没有找到键与键值对的键相等的Entry,则直接将陈旧项清除,然后根据键值对创建一个Entry对象存放在索引为staleSlot的位置
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                               int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    Entry e;

    int slotToExpunge = staleSlot;
    // 从索引为staleSlot的槽位向前遍历Entry数组直到遇到空槽,并记录遍历时遇到的最后一个陈旧项的索引,用slotToExpunge表示
    for (int i = prevIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = prevIndex(i, len))
        if (e.get() == null)
            slotToExpunge = i;

    // 从索引为staleSlot的槽位向后遍历Entry数组
    for (int i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();

        // 如果遍历到某个Entry的键与键值对的键相等
        if (k == key) {
            // 将遍历到的Entry的值更新
            e.value = value;

            // 将更新后的Entry与索引为staleSlot的陈旧项互换位置
            tab[i] = tab[staleSlot];
            tab[staleSlot] = e;

            // 如果向前遍历Entry数组时没有发现陈旧项,那么这里将slotToExpunge的值更新为陈旧项的新位置的索引
            if (slotToExpunge == staleSlot)
                slotToExpunge = i;
            // expungeStaleEntry(int i)能够清除i位置的陈旧项,以及从i位置的槽位到下一个空槽之间的所有陈旧项
            // cleanSomeSlots(int i, int n)可以从i位置开始向后扫描log2(n)个槽位,如果发现了陈旧项,则清除陈旧项,并再向后扫描log2(table.length)个槽位
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            return;
        }

        // 如果遍历到的Entry是陈旧项,并且向前遍历Entry数组时没有发现陈旧项,则将slotToExpunge的值更新为当前遍历到的陈旧项的索引
        if (k == null && slotToExpunge == staleSlot)
            slotToExpunge = i;
    }

    // 从索引为staleSlot的槽位向后遍历Entry数组时,直到遇到了空槽也没有找到键与键值对的键相等的Entry
    // 此时将staleSlot位置的陈旧项直接清除,并根据键值对创建一个Entry对象存放在索引为staleSlot的位置
    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);

    // 一开始时,staleSlot与slotToExpunge是相等的,一旦staleSlot与slotToExpunge不相等,表明从staleSlot位置向前或向后遍历Entry数组时,发现了除staleSlot位置的陈旧项之外的陈旧项
    // 此时需要清除这些陈旧项
    if (slotToExpunge != staleSlot)
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

replaceStaleEntry()中调用了两个关键方法,expungeStaleEntry(int i)能够清除i位置的陈旧项,以及从i位置的槽位到下一个空槽之间的所有陈旧项;cleanSomeSlots(int i, int n)可以从i位置开始向后扫描log2(n)个槽位,如果发现了陈旧项,则清除陈旧项,并再向后扫描log2(table.length)个槽位。其实现如下。

private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;

    // 删除staleSlot位置的陈旧项
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    size--;

    // 从staleSlot位置开始往后遍历Entry数组,直到遍历到空槽
    // 如果遍历到陈旧项,则清除陈旧项
    // 如果遍历到非陈旧项,则将该Entry重新通过散列算法计算索引位置
    Entry e;
    int i;
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        if (k == null) {
            e.value = null;
            tab[i] = null;
            size--;
        } else {
            int h = k.threadLocalHashCode & (len - 1);
            if (h != i) {
                tab[i] = null;
                while (tab[h] != null)
                    h = nextIndex(h, len);
                tab[h] = e;
            }
        }
    }
    // 返回遍历到的空槽的索引
    return i;
}

private boolean cleanSomeSlots(int i, int n) {
    boolean removed = false;
    Entry[] tab = table;
    int len = tab.length;
    do {
        i = nextIndex(i, len);
        Entry e = tab[i];
        if (e != null && e.get() == null) {
            // 一旦扫描到陈旧项,则重置n为Entry数组长度,然后清除扫描到的陈旧项到下一个空槽之间的所有陈旧项,最后从空槽的位置向后再扫描log2(table.length)个槽位
            n = len;
            removed = true;
            i = expungeStaleEntry(i);
        }
    } while ( (n >>>= 1) != 0);
    return removed;
}

由于replaceStaleEntry()方法中对应了很多种情况,因此单纯根据代码不能很直观的了解ThreadLocalMap是如何清除陈旧项的,所以下面结合图进行学习。这里默认Entry数组长度为16。

场景一:Entry数组槽位分布如下所示。

staleSlot向前遍历时,会将slotToExpunge值置为2,从staleSlot向后遍历时,由于索引为6的Entry对象的键与键值对的键相等,因此会更新这个Entry对象的值,并与staleSlot位置(索引为4)的陈旧项互换位置。互换位置后,Entry数组槽位分布如下所示。

因此最后会触发一次清除陈旧项的逻辑。先清除slotToExpunge到下一个空槽之间的所有陈旧项,即索引2和索引6的槽位的陈旧项会被清除;然后从空槽的下一个槽位,往后扫描log2(16) = 4个槽位,即依次扫描索引为8,9,10,11的槽位,并在扫描到索引为10的槽位时发现陈旧项,此时清除索引10槽位到下一个空槽之间的所有陈旧项,即索引10槽位的陈旧项会被清除,再然后从空槽的下一个槽位往后扫描log2(16) = 4个槽位,即依次扫描索引为13,14,15,0的槽位,没有发现陈旧项,扫描结束,并返回true,表示扫描到了陈旧项并清除了。

场景二:Entry数组槽位分布如下所示。

staleSlot向前遍历时,直到遇到空槽为止,也没有陈旧项,因此向前遍历结束后,slotToExpungestaleSlot相等。向后遍历到索引5的槽位时,发现了陈旧项,由于此时slotToExpungestaleSlot相等,因此将slotToExpunge置为5。继续向后遍历,由于索引为6的Entry对象的键与键值对的键相等,因此会更新这个Entry对象的值,并与staleSlot位置(索引为4)的陈旧项互换位置。互换位置后,Entry数组槽位分布如下所示。

因此最后会触发一次清除陈旧项的逻辑,清除逻辑与场景一相同,这里不再赘述。

场景三:Entry数组槽位分布如下所示。

staleSlot向前遍历时,会将slotToExpunge值置为2,从staleSlot向后遍历时,直到遇到空槽为止,也没有发现键与键值对的键相等的Entry,因此会将索引为staleSlot的槽位的陈旧项直接清除,并根据键值对创建一个Entry对象存放在索引为staleSlot的位置。staleSlot槽位的陈旧项被清除后的槽位分布如下所示。

之后清除陈旧项的逻辑与场景一相同,这里不再赘述。

实际场景下可能不会出现上述的槽位分布,这里只是举个例子,对replaceStaleEntry()方法的执行流程进行说明。

下面再看一下getEntry()方法。

private Entry getEntry(ThreadLocal<?> key) {
    // 使用散列算法计算索引
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    if (e != null && e.get() == key)
        // 如果Entry数组索引位置的Entry的键与key相等,则返回这个Entry
        return e;
    else
        // 没有找到key对应的Entry时会执行getEntryAfterMiss()方法
        return getEntryAfterMiss(key, i, e);
}

// 该方法一边遍历Entry数组寻找键与key相等的Entry,一边清除陈旧项
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;

    while (e != null) {
        ThreadLocal<?> k = e.get();
        if (k == key)
            return e;
        if (k == null)
            expungeStaleEntry(i);
        else
            i = nextIndex(i, len);
        e = tab[i];
    }
    return null;
}

无论是set()还是getEntry()方法,一旦发现了陈旧项,便会触发清除Entry数组中的陈旧项的逻辑,这是ThreadLocal为了防止发生内存泄漏的保护机制。

ThreadLocal如何防止内存泄漏?
已知,每个线程有一个ThreadLocalMap字段,ThreadLocalMap中将键值对的关系封装为了一个Entry对象,EntryThreadLocalMap的静态内部类,其实现如下。

static class Entry extends WeakReference<ThreadLocal<?>> {
    Object value;

    Entry(ThreadLocal<?> k, Object v) {
        super(k);
        value = v;
    }
}

当正常使用ThreadLocal时,虚拟机栈和堆上对象的引用关系可以用下图表示。

因此Entry是一个弱引用对象,key引用的ThreadLocal为被弱引用的对象,value引用的对象(上图中的Object)为被强引用的对象,那么在这种情况下,key引用的ThreadLocal不存在其它引用后,在下一次垃圾回收时key引用的ThreadLocal会被回收,防止了ThreadLocal对象的内存泄漏。key引用的ThreadLocal被回收后,此时这个Entry就成为了一个陈旧项,如果不对陈旧项做清除,那么陈旧项的value引用的对象就永远不会被回收,也会产生内存泄漏,所以ThreadLocal采用了线性探测来清除陈旧项,从而防止了内存泄漏。

总结

合理使用ThreadLocal可以在多线程环境下存储和获取线程的局部变量,并且将ThreadLocalMap中的Entry设计成了一个弱引用对象,可以防止ThreadLocal对象的内存泄漏,同时也采用了线性探测方法来清除陈旧项,防止了Entry中的值的内存泄漏,不过还是建议在每次使用完ThreadLocal后,及时调用ThreadLocalremove()方法,及时释放内存。


半夏之沫
65 声望32 粉丝