3
大家好,我是半夏之沫 😁😁 一名金融科技领域的JAVA系统研发😊😊
我希望将自己工作和学习中的经验以最朴实最严谨的方式分享给大家,共同进步👉💓👈
👉👉👉👉👉👉👉👉💓写作不易,期待大家的关注和点赞💓👈👈👈👈👈👈👈👈
👉👉👉👉👉👉👉👉💓关注微信公众号【技术探界】 💓👈👈👈👈👈👈👈👈

前言

ThreadLocal用于多线程环境下每个线程存储和获取线程的局部变量,这些局部变量与线程绑定,线程之间互不影响。本篇文章将对ThreadLocal的使用和原理进行学习。

正文

一. ThreadLocal的使用

以一个简单例子对ThreadLocal的使用进行说明。

通常,ThreadLocal的使用是将其声明为类的私有静态字段,如下所示。

public class ThreadLocalLearn {

    private static final ThreadLocal<String> threadLocal = new ThreadLocal<>();

    public void setThreadName(String threadName) {
        threadLocal.set(threadName);
    }

    public String getThreadName() {
        return threadLocal.get();
    }

}

ThreadLocalLearn类具有一个声明为private staticThreadLocal对象字段,每个线程通过ThreadLocalLearn提供的
setThreadName()方法存放线程名,通过getThreadName()方法获取线程名。

编写一个测试程序如下所示。

class ThreadLocalLearnTest {

    private ThreadLocalLearn threadLocalLearn;

    @BeforeEach
    public void setUp() {
        threadLocalLearn = new ThreadLocalLearn();
    }

    @Test
    void givenMultiThreads_whenSetThreadNameToThreadLocal_thenCanGetThreadNameFromThreadLocal() {
        Thread threadApple = new Thread(() -> {
            threadLocalLearn.setThreadName("Thread-Apple");
            System.out.println(Thread.currentThread().getName() + ": " + threadLocalLearn.getThreadName());
        }, "Thread-Apple");

        Thread threadPeach = new Thread(() -> {
            threadLocalLearn.setThreadName("Thread-Peach");
            System.out.println(Thread.currentThread().getName() + ": " + threadLocalLearn.getThreadName());
        }, "Thread-Peach");

        threadApple.start();
        threadPeach.start();
    }

}

测试程序中启用了两个线程,两个线程执行了同样的操作,即将线程的名字通过编写的ThreadLocalLearn存放,然后又取出。打印结果如下所示。

Thread-Apple: Thread-Apple
Thread-Peach: Thread-Peach

打印结果显示,两个线程的局部变量与线程绑定,线程之间互不影响。

二. ThreadLocal的原理。

首先分析一下ThreadLocalset()方法,其源码如下所示。

public void set(T value) {
    // 获取当前线程
    Thread t = Thread.currentThread();
    // 获取当前线程的ThreadLocalMap
    ThreadLocalMap map = getMap(t);
    if (map != null)
        // 以ThreadLocal对象为键,将value存到当前线程的ThreadLocalMap中
        map.set(this, value);
    else
        // 如果当前线程没有ThreadLocalMap,则先创建,再存值
        createMap(t, value);
}

ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

由上面源码可知,ThreadLocalset()方法实际上是ThreadLocal以自身对象为键,将value存放到当前线程的ThreadLocalMap中。每个线程对象都有一个叫做threadLocals的字段,该字段是一个ThreadLocalMap类型的对象。ThreadLocalMap类是ThreadLocal类的一个静态内部类,用于线程对象存储线程独享的变量副本。

ThreadLocalMap实际上并不是一个Map,关于ThreadLocalMap是如何存储线程独享的变量副本,将在后一小节进行分析。下面再看一下ThreadLocalget()方法。

public T get() {
    // 获取当前线程
    Thread t = Thread.currentThread();
    // 获取当前线程的ThreadLocalMap
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        // 以ThreadLocal对象为键,从当前线程的ThreadLocalMap中获取value
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            T result = (T)e.value;
            return result;
        }
    }
    // 如果当前线程没有ThreadLocalMap,则创建ThreadLocalMap,并以ThreadLocal对象为键存入一个初始值到创建的ThreadLocalMap中
    // 如果有ThreadLocalMap,但获取不到value,则以ThreadLocal对象为键存入一个初始值到ThreadLocalMap中
    // 返回初始值,且初始值一般为null
    return setInitialValue();
}

private T setInitialValue() {
    T value = initialValue();
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
    return value;
}

由上面源码可知,ThreadLocalget()方法实际上是ThreadLocal以自身对象为键,从当前线程的ThreadLocalMap中获取value

通过分析ThreadLocalset()get()方法可知,ThreadLocal能够在多线程环境下存储和获取线程的局部变量,实质是将局部变量值存放在每个线程对象的ThreadLocalMap中,因此线程之间互不影响。

三. ThreadLocalMap的原理

ThreadLocalMap本身不是Map,但是可以实现以key-value的形式存储线程的局部变量。与Map类似,ThreadLocalMap中将键值对的关系封装为了一个Entry对象,EntryThreadLocalMap的静态内部类,源码如下所示。

static class Entry extends WeakReference<ThreadLocal<?>> {
    Object value;

    Entry(ThreadLocal<?> k, Object v) {
        super(k);
        value = v;
    }
}

Entry继承于WeakReference,因此Entry是一个弱引用对象,而作为键的ThreadLocal对象是被弱引用的对象。

首先分析ThreadLocalMap的构造函数。ThreadLocalMap有两个构造函数,这里只分析签名为ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue)的构造函数。

ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
    // 创建一个容量为16的Entry数组
    table = new Entry[INITIAL_CAPACITY];
    // 使用散列算法计算第一个键值对在数组中的索引
    int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
    // 创建Entry对象,并存放在Entry数组的索引对应位置
    table[i] = new Entry(firstKey, firstValue);
    size = 1;
    // 根据Entry数组初始容量大小设置扩容阈值
    setThreshold(INITIAL_CAPACITY);
}

ThreadLocalMap的散列算法为将ThreadLocal的哈希码与Entry数组长度减一做相与操作,由于Entry数组长度为2的幂次方,因此上述散列算法实质是ThreadLocal的哈希码对Entry数组长度取模。通过散列算法计算得到初始键值对在Entry数组中的位置后,会创建一个Entry对象并存放在数组的对应位置。最后根据公式:len * 2 / 3计算扩容阈值。

由上述分析可知,创建ThreadLocalMap对象时便会初始化存放键值对关系的Entry数组。现在看一下ThreadLocalMapset()方法。

// 调用set()方法时会传入一对键值对
private void set(ThreadLocal<?> key, Object value) {
    Entry[] tab = table;
    int len = tab.length;
    // 通过散列算法计算键值对的索引位置
    int i = key.threadLocalHashCode & (len-1);

    // 遍历Entry数组
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        // 获取当前Entry的键
        ThreadLocal<?> k = e.get();

        // 当前Entry的键与键值对的键相等(即指向同一个ThreadLocal对象),则更新当前Entry的value为键值对的值
        if (k == key) {
            e.value = value;
            return;
        }

        // 当前Entry的键被垃圾回收了,这样的Entry称为陈旧项,则根据键值对创建Entry并替换陈旧项
        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }

    // 此时i表示遍历Entry数组时遇到的第一个空槽的索引
    // 程序运行到这里,说明遍历Entry数组时,在遇到第一个空槽前,遍历过的Entry的键与键值对的键均不相等,同时也没有陈旧项
    // 此时根据键值对创建Entry对象并存放在索引为i的位置(即空槽的位置)
    tab[i] = new Entry(key, value);
    int sz = ++size;
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        // Entry数组中键值对数量大于等于阈值,则触发rehash()
        // rehash()会先遍历Entry数组并删除陈旧项,如果删除陈旧项之后,键值对数量还大于等于阈值的3/4,则进行扩容
        // 扩容后,Entry数组长度应该为扩容前的两倍
        rehash();
}

private static int nextIndex(int i, int len) {
    return ((i + 1 < len) ? i + 1 : 0);
}

set()方法中,首先通过散列算法计算键值对的索引位置,然后从计算得到的索引位置开始往后遍历Entry数组,一直遍历到第一个空槽为止。在遍历的过程中,如果遍历到某个Entry的键与键值对的键相等,则更新这个Entry的值为键值对的值;如果遍历到某个Entry并且这个Entry被判定为陈旧项(键被垃圾回收的Entry对象),那么执行清除陈旧项的逻辑;如果遍历遇到空槽了,但没有发现有键与键值对的键相等的Entry,也没有陈旧项,则根据键值对生成Entry对象并存放在空槽的位置。

set()方法中,需要清除陈旧项时调用了replaceStaleEntry()方法,该方法会根据键值对创建Entry对象并替换陈旧项,同时触发一次清除陈旧项的逻辑。replaceStaleEntry()方法的实现如下所示。

// table[staleSlot]为陈旧项
// 该方法实际就是从索引staleSlot开始向后遍历Entry数组直到遇到空槽,如果找到某一个Entry的键与键值对的键相等,那么将这个Entry的值更新为键值对的值,并将这个Entry与陈旧项互换位置
// 如果遇到空槽也没有找到键与键值对的键相等的Entry,则直接将陈旧项清除,然后根据键值对创建一个Entry对象存放在索引为staleSlot的位置
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                               int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    Entry e;

    int slotToExpunge = staleSlot;
    // 从索引为staleSlot的槽位向前遍历Entry数组直到遇到空槽,并记录遍历时遇到的最后一个陈旧项的索引,用slotToExpunge表示
    for (int i = prevIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = prevIndex(i, len))
        if (e.get() == null)
            slotToExpunge = i;

    // 从索引为staleSlot的槽位向后遍历Entry数组
    for (int i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();

        // 如果遍历到某个Entry的键与键值对的键相等
        if (k == key) {
            // 将遍历到的Entry的值更新
            e.value = value;

            // 将更新后的Entry与索引为staleSlot的陈旧项互换位置
            tab[i] = tab[staleSlot];
            tab[staleSlot] = e;

            // 如果向前遍历Entry数组时没有发现陈旧项,那么这里将slotToExpunge的值更新为陈旧项的新位置的索引
            if (slotToExpunge == staleSlot)
                slotToExpunge = i;
            // expungeStaleEntry(int i)能够清除i位置的陈旧项,以及从i位置的槽位到下一个空槽之间的所有陈旧项
            // cleanSomeSlots(int i, int n)可以从i位置开始向后扫描log2(n)个槽位,如果发现了陈旧项,则清除陈旧项,并再向后扫描log2(table.length)个槽位
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            return;
        }

        // 如果遍历到的Entry是陈旧项,并且向前遍历Entry数组时没有发现陈旧项,则将slotToExpunge的值更新为当前遍历到的陈旧项的索引
        if (k == null && slotToExpunge == staleSlot)
            slotToExpunge = i;
    }

    // 从索引为staleSlot的槽位向后遍历Entry数组时,直到遇到了空槽也没有找到键与键值对的键相等的Entry
    // 此时将staleSlot位置的陈旧项直接清除,并根据键值对创建一个Entry对象存放在索引为staleSlot的位置
    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);

    // 一开始时,staleSlot与slotToExpunge是相等的,一旦staleSlot与slotToExpunge不相等,表明从staleSlot位置向前或向后遍历Entry数组时,发现了除staleSlot位置的陈旧项之外的陈旧项
    // 此时需要清除这些陈旧项
    if (slotToExpunge != staleSlot)
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

replaceStaleEntry()中调用了两个关键方法,expungeStaleEntry(int i)能够清除i位置的陈旧项,以及从i位置的槽位到下一个空槽之间的所有陈旧项;cleanSomeSlots(int i, int n)可以从i位置开始向后扫描log2(n)个槽位,如果发现了陈旧项,则清除陈旧项,并再向后扫描log2(table.length)个槽位。其实现如下。

private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;

    // 删除staleSlot位置的陈旧项
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    size--;

    // 从staleSlot位置开始往后遍历Entry数组,直到遍历到空槽
    // 如果遍历到陈旧项,则清除陈旧项
    // 如果遍历到非陈旧项,则将该Entry重新通过散列算法计算索引位置
    Entry e;
    int i;
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        if (k == null) {
            e.value = null;
            tab[i] = null;
            size--;
        } else {
            int h = k.threadLocalHashCode & (len - 1);
            if (h != i) {
                tab[i] = null;
                while (tab[h] != null)
                    h = nextIndex(h, len);
                tab[h] = e;
            }
        }
    }
    // 返回遍历到的空槽的索引
    return i;
}

private boolean cleanSomeSlots(int i, int n) {
    boolean removed = false;
    Entry[] tab = table;
    int len = tab.length;
    do {
        i = nextIndex(i, len);
        Entry e = tab[i];
        if (e != null && e.get() == null) {
            // 一旦扫描到陈旧项,则重置n为Entry数组长度,然后清除扫描到的陈旧项到下一个空槽之间的所有陈旧项,最后从空槽的位置向后再扫描log2(table.length)个槽位
            n = len;
            removed = true;
            i = expungeStaleEntry(i);
        }
    } while ( (n >>>= 1) != 0);
    return removed;
}

由于replaceStaleEntry()方法中对应了很多种情况,因此单纯根据代码不能很直观的了解ThreadLocalMap是如何清除陈旧项的,所以下面结合图进行学习。这里默认Entry数组长度为16。

场景一:Entry数组槽位分布如下所示。

staleSlot向前遍历时,会将slotToExpunge值置为2,从staleSlot向后遍历时,由于索引为6的Entry对象的键与键值对的键相等,因此会更新这个Entry对象的值,并与staleSlot位置(索引为4)的陈旧项互换位置。互换位置后,Entry数组槽位分布如下所示。

因此最后会触发一次清除陈旧项的逻辑。先清除slotToExpunge到下一个空槽之间的所有陈旧项,即索引2和索引6的槽位的陈旧项会被清除;然后从空槽的下一个槽位,往后扫描log2(16) = 4个槽位,即依次扫描索引为8,9,10,11的槽位,并在扫描到索引为10的槽位时发现陈旧项,此时清除索引10槽位到下一个空槽之间的所有陈旧项,即索引10槽位的陈旧项会被清除,再然后从空槽的下一个槽位往后扫描log2(16) = 4个槽位,即依次扫描索引为13,14,15,0的槽位,没有发现陈旧项,扫描结束,并返回true,表示扫描到了陈旧项并清除了。

场景二:Entry数组槽位分布如下所示。

staleSlot向前遍历时,直到遇到空槽为止,也没有陈旧项,因此向前遍历结束后,slotToExpungestaleSlot相等。向后遍历到索引5的槽位时,发现了陈旧项,由于此时slotToExpungestaleSlot相等,因此将slotToExpunge置为5。继续向后遍历,由于索引为6的Entry对象的键与键值对的键相等,因此会更新这个Entry对象的值,并与staleSlot位置(索引为4)的陈旧项互换位置。互换位置后,Entry数组槽位分布如下所示。

因此最后会触发一次清除陈旧项的逻辑,清除逻辑与场景一相同,这里不再赘述。

场景三:Entry数组槽位分布如下所示。

staleSlot向前遍历时,会将slotToExpunge值置为2,从staleSlot向后遍历时,直到遇到空槽为止,也没有发现键与键值对的键相等的Entry,因此会将索引为staleSlot的槽位的陈旧项直接清除,并根据键值对创建一个Entry对象存放在索引为staleSlot的位置。staleSlot槽位的陈旧项被清除后的槽位分布如下所示。

之后清除陈旧项的逻辑与场景一相同,这里不再赘述。

实际场景下可能不会出现上述的槽位分布,这里只是举个例子,对replaceStaleEntry()方法的执行流程进行说明。

下面再看一下getEntry()方法。

private Entry getEntry(ThreadLocal<?> key) {
    // 使用散列算法计算索引
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    if (e != null && e.get() == key)
        // 如果Entry数组索引位置的Entry的键与key相等,则返回这个Entry
        return e;
    else
        // 没有找到key对应的Entry时会执行getEntryAfterMiss()方法
        return getEntryAfterMiss(key, i, e);
}

// 该方法一边遍历Entry数组寻找键与key相等的Entry,一边清除陈旧项
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;

    while (e != null) {
        ThreadLocal<?> k = e.get();
        if (k == key)
            return e;
        if (k == null)
            expungeStaleEntry(i);
        else
            i = nextIndex(i, len);
        e = tab[i];
    }
    return null;
}

无论是set()还是getEntry()方法,一旦发现了陈旧项,便会触发清除Entry数组中的陈旧项的逻辑,这是ThreadLocal为了防止发生内存泄漏的保护机制。

ThreadLocal如何防止内存泄漏?
已知,每个线程有一个ThreadLocalMap字段,ThreadLocalMap中将键值对的关系封装为了一个Entry对象,EntryThreadLocalMap的静态内部类,其实现如下。

static class Entry extends WeakReference<ThreadLocal<?>> {
    Object value;

    Entry(ThreadLocal<?> k, Object v) {
        super(k);
        value = v;
    }
}

当正常使用ThreadLocal时,虚拟机栈和堆上对象的引用关系可以用下图表示。

因此Entry是一个弱引用对象,key引用的ThreadLocal为被弱引用的对象,value引用的对象(上图中的Object)为被强引用的对象,那么在这种情况下,key引用的ThreadLocal不存在其它引用后,在下一次垃圾回收时key引用的ThreadLocal会被回收,防止了ThreadLocal对象的内存泄漏。key引用的ThreadLocal被回收后,此时这个Entry就成为了一个陈旧项,如果不对陈旧项做清除,那么陈旧项的value引用的对象就永远不会被回收,也会产生内存泄漏,所以ThreadLocal采用了线性探测来清除陈旧项,从而防止了内存泄漏。

总结

合理使用ThreadLocal可以在多线程环境下存储和获取线程的局部变量,并且将ThreadLocalMap中的Entry设计成了一个弱引用对象,可以防止ThreadLocal对象的内存泄漏,同时也采用了线性探测方法来清除陈旧项,防止了Entry中的值的内存泄漏,不过还是建议在每次使用完ThreadLocal后,及时调用ThreadLocalremove()方法,及时释放内存。


大家好,我是半夏之沫 😁😁 一名金融科技领域的JAVA系统研发😊😊
我希望将自己工作和学习中的经验以最朴实最严谨的方式分享给大家,共同进步👉💓👈
👉👉👉👉👉👉👉👉💓写作不易,期待大家的关注和点赞💓👈👈👈👈👈👈👈👈
👉👉👉👉👉👉👉👉💓关注微信公众号【技术探界】 💓👈👈👈👈👈👈👈👈

半夏之沫
65 声望32 粉丝