1

零 准备

0 FBI WARNING

文章异常啰嗦且绕弯。

1 TransmittableThreadLocal 是什么

当开发人员需要在线程池的线程中传递某些参数的时候,jdk 的 ThreadLocal 很难实现,静态变量则会面临不够灵活和出现线程安全等问题。
TransmittableThreadLocal 是阿里开源工具包,用于解决这一问题。

2 版本

  • jdk 版本
    Azul JDK 17.0.2
  • transmittable-thread-local

    <dependency>
      <groupId>com.alibaba</groupId>
      <artifactId>transmittable-thread-local</artifactId>
      <version>2.13.0-Beta1</version>
    </dependency>
  • junit-jupiter

    <dependency>
      <groupId>org.junit.jupiter</groupId>
      <artifactId>junit-jupiter</artifactId>
      <version>5.8.2</version>
    </dependency>

    一 Demo

    import com.alibaba.ttl.TransmittableThreadLocal;
    import com.alibaba.ttl.threadpool.TtlExecutors;
    import org.junit.jupiter.api.Test;
    
    import java.util.concurrent.Executor;
    import java.util.concurrent.ExecutorService;
    import java.util.concurrent.Executors;
    
    public class TreadLocalTest {
    
      @Test
      public void transmittableThreadLocal() {
    
          TransmittableThreadLocal<Integer> tl = new TransmittableThreadLocal<>();
    
          tl.set(6);
          System.out.println("父线程获取数据:" + tl.get()); // 第一次输出:6
    
          // 使用 jdk 的 Executors 工具创建一个线程池
          // 注意,这个线程池里只有一个线程
          Executor realPool = Executors.newFixedThreadPool(1);
          
          // 使用 TtlExecutors 创建一个 Ttl 框架封装的线程池
          Executor pool = TtlExecutors.getTtlExecutor(realPool);
    
          // 使用线程池跑一个任务
          pool.execute(() -> {
              Integer i = tl.get();
              System.out.println("第一次获取数据:" + i); // 第二次输出:6
          });
    
          // 修改一下 tl 里的值,并再跑一次任务
          tl.set(7);
          pool.execute(() -> {
              Integer i = tl.get();
              System.out.println("第二次获取数据:" + i); // 第三次输出:7
          });
      }
    }

    二 先从 InheritableThreadLocal 说起

    1 Thread

    InheritableThreadLocal 是 jdk 中自带的 ThreadLocal 的子类,在 jdk 的 Thread 对象中,会对它有单独的支持。
    首先来看 Thread 的构造方法:

    // java.lang.Thread 的核心构造方法
    private Thread(ThreadGroup g, Runnable target, String name,
                     long stackSize, AccessControlContext acc,
                     boolean inheritThreadLocals) {
      
      // 此处省略一大段无关代码...
      
      // inheritThreadLocals 是一个 boolean 类型的值,是一个 “是否启用 inheritableThreadLocals” 的开关
      // parent 是创造此线程的父线程
      if (inheritThreadLocals && parent.inheritableThreadLocals != null)
          // 如果父线程的 inheritableThreadLocals 存在,则此处会将它挪到当前线程里
          // ThreadLocal.createInheritedMap 是一个深拷贝,会创建新的 Entry
          this.inheritableThreadLocals = ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
      
      // 此处省略一大段无关代码...
    }

    2 ThreadLocalMap

    来看一下 ThreadLocal.createInheritedMap:

    // java.lang.ThreadLocal
    static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
      return new ThreadLocalMap(parentMap);
    }

    这个方法会创建一个 ThreadLocalMap,再来追踪一下 ThreadLocalMap 的构造器:
    (值得注意的是,ThreadLocalMap 是 ThreadLocal 的内部类,所以其实代码逻辑还是在 ThreadLocal.java 中)

    // java.lang.ThreadLocal
    private ThreadLocalMap(ThreadLocalMap parentMap) {
      Entry[] parentTable = parentMap.table;
      int len = parentTable.length;
      setThreshold(len);
      table = new Entry[len];
    
      // 此处把 ThreadLocalMap 里的元素都遍厉一遍
      // 然后都创建成新的 Entry 并塞到新的 ThreadLocalMap 里
      for (Entry e : parentTable) {
          if (e != null) {
              // 此处获取了 Entry 的 key,本质上就是 ThreadLocal 本身
              ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
              if (key != null) {
                  // 用 key 取获取 value,这行代码重点关注,下文会提到
                  Object value = key.childValue(e.value);
                  // 此处创建新的 Entry
                  Entry c = new Entry(key, value);
                  // 处理 hash 碰撞问题并存入
                  int h = key.threadLocalHashCode & (len - 1);
                  while (table[h] != null)
                      h = nextIndex(h, len);
                  table[h] = c;
                  size++;
              }
          }
      }
    }

    3 childValue

    这里需要重点关注一行代码:

    Object value = key.childValue(e.value);

    这个方法是 ThreadLocal 中的:

    // java.lang.ThreadLocal
    T childValue(T parentValue) {
      throw new UnsupportedOperationException();
    }

    由上文可见,这是一个没有被实现的预留模板方法。在 InheritableThreadLocal 中对其进行了实现:

    // java.lang.InheritableThreadLocal
    protected T childValue(T parentValue) {
      return parentValue;
    }

    5 initialValue

    initialValue 同样是 ThreadLocal 提供的一个空方法:

    // java.lang.ThreadLocal
    protected T initialValue() {
      return null;
    }

    这个方法会作用在 ThreadLocal 的 get() 方法里:

    // step 1
    // java.lang.ThreadLocal
    public T get() {
      Thread t = Thread.currentThread();
      ThreadLocalMap map = getMap(t);
      if (map != null) {
          ThreadLocalMap.Entry e = map.getEntry(this);
          if (e != null) {
              // 如果 Entry 存在,则此处会返回 Entry 的 value
              T result = (T)e.value;
              return result;
          }
      }
      // 如果 Entry 不存在,或者 ThreadLocalMap 不存在,会在这里初始化一个 value
      // 这个方法见 step 2
      return setInitialValue();
    }
    
    // step 2
    // java.lang.ThreadLocal
    private T setInitialValue() {
      // 这里初始化一个值
      T value = initialValue();
      Thread t = Thread.currentThread();
      ThreadLocalMap map = getMap(t);
      if (map != null) {
          // 将初始化出来的值存进去
          map.set(this, value);
      } else {
          // 初始化 ThreadLocalMap
          createMap(t, value);
      }
      
      // 此处忽略这段代码
      if (this instanceof TerminatingThreadLocal) {
          TerminatingThreadLocal.register((TerminatingThreadLocal<?>) this);
      }
      
      // 返回
      return value;
    }

    5 InheritableThreadLocal 的作用和问题

    假设 Thread A 是 Thread B 的父线程,由上述代码可知:

  • A 的 InheritableThreadLocal 内的数据可以被 B 继承
  • 继承方式是在创建 B 的时候,在构造方法里直接 copy 一份 InheritableThreadLocal 内的元素
  • copy 是一个快照机制,一旦结束,再去修改 A 中的 InheritableThreadLocal 中的元素,就不会同步给 B 了

那么问题来了:
如果系统中需要做到 A 和 B 的 InheritableThreadLocal 实时同步,应该如何解决?

三 TransmittableThreadLocal

先来看下列三行代码:

// 创建一个 TransmittableThreadLocal
ThreadLocal<Integer> tl = new TransmittableThreadLocal<>();

tl.set(6);

Integer i = tl.get();

1 构造器

TransmittableThreadLocal 的构造器非常简单。

// 是否要忽略 null value,如果这个参数为 false,则哪怕 value 是 null,也会存储下来
private final boolean disableIgnoreNullValueSemantics;

// 这个参数默认为 false
public TransmittableThreadLocal() {
    this(false);
}


public TransmittableThreadLocal(boolean disableIgnoreNullValueSemantics) {
    this.disableIgnoreNullValueSemantics = disableIgnoreNullValueSemantics;
}

2 holder

holder 是 TransmittableThreadLocal 的静态成员变量,是一个 InheritableThreadLocal。

// com.alibaba.ttl.TransmittableThreadLocal
private static final InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>> holder =
    new InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>>() {
    
    // 复写这个方法应该没有别的深意,只是为了防止在调用 holder.get().xxx() 的时候报空指针
    // 应该是开发人员觉得这样比较优雅
    @Override
    protected WeakHashMap<TransmittableThreadLocal<Object>, ?> initialValue() {
        return new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
    }

    // 这个方法实现了子线程和父线程之间的信息传递
    @Override
    protected WeakHashMap<TransmittableThreadLocal<Object>, ?> childValue(WeakHashMap<TransmittableThreadLocal<Object>, ?> parentValue) {
        return new WeakHashMap<TransmittableThreadLocal<Object>, Object>(parentValue);
    }
};

由上述可知:

  • holder 是一个记录的 value 是 WeakHashMap<TransmittableThreadLocal> 的 InheritableThreadLocal
  • WeakHashMap 的 value 并没有被使用到,可以将其视为一个 WeakHashSet
  • holder 复写了 initialValue 和 childValue 两个方法

holder 最重要的方法是 addThisToHolder:

// com.alibaba.ttl.TransmittableThreadLocal
// 如果当前 TransmittableThreadLocal 没有被记录在 holder 中,则会在此处 put 进去
private void addThisToHolder() {
    if (!holder.get().containsKey(this)) {
        holder.get().put((TransmittableThreadLocal<Object>) this, null); // WeakHashMap supports null value.
    }
}

同样还有移除方法:

// com.alibaba.ttl.TransmittableThreadLocal
private void removeThisFromHolder() {
    holder.get().remove(this);
}

3 set

存入 value 的方法。

// com.alibaba.ttl.TransmittableThreadLocal
@Override
public final void set(T value) {
    if (!disableIgnoreNullValueSemantics && null == value) {
        // 如果 value 是 null,且不忽略 null value,则此处进入删除逻辑
        remove();
    } else {
        // 存储逻辑
        super.set(value);
        // 将当前的 TransmittableThreadLocal 注册到 holder 里
        addThisToHolder();
    }
}

4 get

获取 value 的方法。

// com.alibaba.ttl.TransmittableThreadLocal
@Override
public final T get() {
    T value = super.get();
    // 尝试注册到 holder
    if (disableIgnoreNullValueSemantics || null != value) 
        addThisToHolder();
    return value;
}

5 Snapshot

Snapshot 是 TransmittableThreadLocal 的内部类,用来存放当前线程内的 ThreadLocal 和 TransmittableThreadLocal 数据。

// com.alibaba.ttl.TransmittableThreadLocal
private static class Snapshot {
    final HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value;
    final HashMap<ThreadLocal<Object>, Object> threadLocal2Value;

    private Snapshot(HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value, HashMap<ThreadLocal<Object>, Object> threadLocal2Value) {
        this.ttl2Value = ttl2Value;
        this.threadLocal2Value = threadLocal2Value;
    }
}

6 Transmitter

Transmitter 是 TransmittableThreadLocal 的内部类,本质上是一组静态工具。

6.1 获取一个快照

// com.alibaba.ttl.TransmittableThreadLocal.Transmitter
public static Object capture() {
    // captureTtlValues()  会将当前线程的 TransmittableThreadLocal 数据做成一个 HashMap
    // captureThreadLocalValues() 会将当前线程的 ThreadLocal 数据做成一个 HashMap
    return new Snapshot(captureTtlValues(), captureThreadLocalValues());
}
6.1.1 获取 holder 中所有的 TransmittableThreadLocal 数据
// com.alibaba.ttl.TransmittableThreadLocal.Transmitter
private static HashMap<TransmittableThreadLocal<Object>, Object> captureTtlValues() {
    
    HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new HashMap<TransmittableThreadLocal<Object>, Object>();
    
    for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
        ttl2Value.put(threadLocal, threadLocal.copyValue());
    }
    
    return ttl2Value;
}
6.1.2 获取 threadLocalHolder 中所有 ThreadLocal 数据
// com.alibaba.ttl.TransmittableThreadLocal.Transmitter
private static HashMap<ThreadLocal<Object>, Object> captureThreadLocalValues() {
    
    final HashMap<ThreadLocal<Object>, Object> threadLocal2Value = new HashMap<ThreadLocal<Object>, Object>();
    
    for (Map.Entry<ThreadLocal<Object>, TtlCopier<Object>> entry : threadLocalHolder.entrySet()) {
        final ThreadLocal<Object> threadLocal = entry.getKey();
        final TtlCopier<Object> copier = entry.getValue();

        threadLocal2Value.put(threadLocal, copier.copy(threadLocal.get()));
    }
    
    return threadLocal2Value;
}

6.2 重放

6.2.1 replay
// com.alibaba.ttl.TransmittableThreadLocal.Transmitter
// 本质上是对一个 snapshot 进行拷贝
public static Object replay(Object captured) {
    final Snapshot capturedSnapshot = (Snapshot) captured;
    return new Snapshot(replayTtlValues(capturedSnapshot.ttl2Value), replayThreadLocalValues(capturedSnapshot.threadLocal2Value));
}
6.2.2 replayTtlValues
// com.alibaba.ttl.TransmittableThreadLocal.Transmitter
// 本质上是对一个 map 进行深拷贝
private static HashMap<TransmittableThreadLocal<Object>, Object> replayTtlValues(HashMap<TransmittableThreadLocal<Object>, Object> captured) {
    
    // 创建一个新的 map
    HashMap<TransmittableThreadLocal<Object>, Object> backup = new HashMap<TransmittableThreadLocal<Object>, Object>();

    for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
        TransmittableThreadLocal<Object> threadLocal = iterator.next();

        // 将原来的 map 复制到新的 map 中
        backup.put(threadLocal, threadLocal.get());

        // 此处比较 holder 和 captured 的 key
        // 如果对应不一致,则将 holder 里的数据清空
        if (!captured.containsKey(threadLocal)) {
            iterator.remove();
            threadLocal.superRemove();
        }
    }

    // 将 value 和 key 对应起来
    // 这是一个保底纠错逻辑
    setTtlValuesTo(captured);

    // 这是一个暂时没有用的扩展方法
    doExecuteCallback(true);

    return backup;
}
6.2.3 replayThreadLocalValues
// com.alibaba.ttl.TransmittableThreadLocal.Transmitter
// 本质上是对一个 map 进行深拷贝
private static HashMap<ThreadLocal<Object>, Object> replayThreadLocalValues(HashMap<ThreadLocal<Object>, Object> captured) {
    final HashMap<ThreadLocal<Object>, Object> backup = new HashMap<ThreadLocal<Object>, Object>();

    for (Map.Entry<ThreadLocal<Object>, Object> entry : captured.entrySet()) {
        final ThreadLocal<Object> threadLocal = entry.getKey();
        backup.put(threadLocal, threadLocal.get());

        // threadLocalClearMark 是一个空对象,用于占位
        // 如果此处的 value 就是这个空对象,则此处代表这个 ttl 里的 value 已经被 clear 了
        final Object value = entry.getValue();
        if (value == threadLocalClearMark) 
            threadLocal.remove();
        else 
            threadLocal.set(value);
    }

    return backup;
}

6.3 恢复

6.3.1 restore
// com.alibaba.ttl.TransmittableThreadLocal.Transmitter
// 用快照来恢复当前线程的 ttl 数据
public static void restore(Object backup) {
    final Snapshot backupSnapshot = (Snapshot) backup;
    restoreTtlValues(backupSnapshot.ttl2Value);
    restoreThreadLocalValues(backupSnapshot.threadLocal2Value);
}
6.3.2 restoreTtlValues

这个方法与 replayTtlValues(...) 方法比较像

// com.alibaba.ttl.TransmittableThreadLocal.Transmitter
private static void restoreTtlValues(HashMap<TransmittableThreadLocal<Object>, Object> backup) {
    doExecuteCallback(false);

    for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
        TransmittableThreadLocal<Object> threadLocal = iterator.next();

        if (!backup.containsKey(threadLocal)) {
            iterator.remove();
            threadLocal.superRemove();
        }
    }

    setTtlValuesTo(backup);
}
6.3.3 restoreThreadLocalValues
// com.alibaba.ttl.TransmittableThreadLocal.Transmitter
private static void restoreThreadLocalValues(HashMap<ThreadLocal<Object>, Object> backup) {
    for (Map.Entry<ThreadLocal<Object>, Object> entry : backup.entrySet()) {
        final ThreadLocal<Object> threadLocal = entry.getKey();
        threadLocal.set(entry.getValue());
    }
}

四 ExecutorTtlWrapper

1 ExecutorTtlWrapper

ExecutorTtlWrapper 的代码非常少:

// com.alibaba.ttl.threadpool.ExecutorTtlWrapper
class ExecutorTtlWrapper implements Executor, TtlWrapper<Executor>, TtlEnhanced {
    
    // 这个变量代表了一个线程池
    private final Executor executor;
    // 这个变量是一个幂等标识符
    protected final boolean idempotent;

    ExecutorTtlWrapper(Executor executor, boolean idempotent) {
        this.executor = executor;
        this.idempotent = idempotent;
    }

    @Override
    public void execute(Runnable command) {
        executor.execute(TtlRunnable.get(command, false, idempotent));
    }

    @Overrid
    public Executor unwrap() {
        return executor;
    }

    // 其它方法不重要,这里省略...
}

ExecutorTtlWrapper 本质上是一个线程池的代理,在执行 execute(...) 方法的时候,会将 Runnable 任务包装成 TtlRunnable。

2 TtlEnhanced

// 这是一个单纯的空接口,用来标识一个类
public interface TtlEnhanced {
    
}

3 TtlWrapper

// TtlWrapper 用来标识一个包装类
// 需要实现获取被包装对象的 unwrap 方法
public interface TtlWrapper<T> extends TtlEnhanced {
    T unwrap();
}

4 TtlExecutors

TtlExecutors 是一个静态工具类,用来生成 ExecutorTtlWrapper。

// com.alibaba.ttl.threadpool.TtlExecutors
public static Executor getTtlExecutor(Executor executor) {
    // 如果已经包装过了,那么此处直接返回
    if (TtlAgent.isTtlAgentLoaded() || null == executor || executor instanceof TtlEnhanced) {
        return executor;
    }
    
    // 如果没有包装过,那么此处包装一下
    // 幂等标识符,此处默认为 true
    return new ExecutorTtlWrapper(executor, true);
}

TtlAgent 是对探针技术的应用,暂时不展开讲解。

五 TtlRunnable

1 TtlRunnable

首先来看一下 class:

// com.alibaba.ttl.TtlRunnable
public final class TtlRunnable implements Runnable, TtlWrapper<Runnable>, TtlEnhanced, TtlAttachments {
    
    private final AtomicReference<Object> capturedRef;
    private final Runnable runnable;
    private final boolean releaseTtlValueReferenceAfterRun;
    
    private TtlRunnable(Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
        // capture() 方法见上面 第三 part 的 Transmitter 部分
        // 本质上这是当前线程所存储的 TransmittableThreadLocal 和 ThreadLocal 的快照
        this.capturedRef = new AtomicReference<Object>(capture());
        // 真实的业务逻辑
        this.runnable = runnable;
        // 当前 TtlRunnable 是否可以重复执行
        // true 的情况下,只要执行完,就不能重复执行了
        this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun;
    }
    
    // 其它方法先省略...   
}

2 get

TtlRunnable.get(...) 是一个静态方法,用于创建一个 TtlRunnable 对象。

// com.alibaba.ttl.TtlRunnable
public static TtlRunnable get(Runnable runnable) {
    return get(runnable, false, false);
}

public static TtlRunnable get(Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
    return get(runnable, releaseTtlValueReferenceAfterRun, false);
}

public static TtlRunnable get(Runnable runnable, boolean releaseTtlValueReferenceAfterRun, boolean idempotent) {
    // 空判断
    if (null == runnable) 
        return null;

    // 如果当前为幂等,则此处复用
    if (runnable instanceof TtlEnhanced) {
        if (idempotent) 
            return (TtlRunnable) runnable;
        else 
            throw new IllegalStateException("Already TtlRunnable!");
    }
    
    // 创建对象
    return new TtlRunnable(runnable, releaseTtlValueReferenceAfterRun);
}

3 run

TtlRunnable.run() 是核心方法,是对业务逻辑的封装。

// com.alibaba.ttl.TtlRunnable
public void run() {
    
    // 获取当前快照
    final Object captured = capturedRef.get();
    
    // 有效性判断
    if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
        throw new IllegalStateException("TTL value reference is released after run!");
    }

    // replay 方法来自 Transmitter
    // 用于创建一个当前线程的 ThreadLocal 的备份
    final Object backup = replay(captured);
    try {
        runnable.run();
    } finally {
        // restore 方法来自 Transmitter
        // 使用备份来恢复当前线程的 ThreadLocal 数据
        restore(backup);
    }
}

captured 实际上是一个备忘录模式,用于确保子线程内的数据修改不影响到父线程。

六 一点唠叨

  • 封装的很有意思,但是很多细节还是没太看懂
  • 仅为个人的学习笔记,可能存在错误或者表述不清的地方,有缘补充

三流
57 声望16 粉丝

三流程序员一枚,立志做保姆级教程。