2

内容目录

深度分析ConcurrentHashMap中的并发扩容机制

说到扩容,相比各位读者都不陌生,无非就是创建一个扩容目标大小的数组,把原来老数组中的数据迁移到新数组中来即可,这种方式比较适合在没有多线程并发的场景中完成,但是在ConcurrentHashMap中并没有那么简单,因为在多线程环境下进行扩容时,会存在其他线程同时往集合中添加元素。

可能有些读者会想,这个很简单,把整个扩容过程加一把同步锁,保证扩容过程中不存在其他线程对数据进行操作。很显然,这种方式对性能的损耗非常大,特别是如果涉及到数据量比较多的扩容时,会导致非常多的线程被阻塞。

ConcurrentHashMap中扩容部分的设计非常巧妙,它通过使用CAS机制实现无锁的并发同步策略,同时对于同步锁synchronized,也只把粒度控制到了单个数据节点做数据迁移的这个范围,并且利用多个线程来进行并行扩容,大大提高了数据迁移的效率。

多线程并发扩容原理图解

首先,如下图所示,通过一个简略图来整体了解一下并发扩容是怎么一回事,当存在多个线程并行进行扩容以及数据迁移时,默认情况下会给每个线程分配一个区间,这个区间默认值是16。每个线程负责自己区间的数据迁移工作。需要注意的是,在下图所示中有一个transferIndex的属性,这个是一个转移索引,如果当前只有两个线程要对64位长度的数组做数据迁移,意味着每个线程需要做多次迁移,而这个过程就依赖于transferIndex来更新每个线程迁移的数据区间。

在这里插入图片描述

transfer数据迁移

transfer这个方法的代码非常多,代码如下。

private final void transfer(Node<K,V>[] tab, Node<K,V>[] nextTab) {
    int n = tab.length, stride;
    if ((stride = (NCPU > 1) ? (n >>> 3) / NCPU : n) < MIN_TRANSFER_STRIDE)
        stride = MIN_TRANSFER_STRIDE; // subdivide range
    if (nextTab == null) {            // initiating
        try {
            @SuppressWarnings("unchecked")
            Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n << 1];
            nextTab = nt;
        } catch (Throwable ex) {      // try to cope with OOME
            sizeCtl = Integer.MAX_VALUE;
            return;
        }
        nextTable = nextTab;
        transferIndex = n;
    }
    int nextn = nextTab.length;
    ForwardingNode<K,V> fwd = new ForwardingNode<K,V>(nextTab);
    boolean advance = true;
    boolean finishing = false; // to ensure sweep before committing nextTab
    for (int i = 0, bound = 0;;) {
        Node<K,V> f; int fh;
        while (advance) {
            int nextIndex, nextBound;
            if (--i >= bound || finishing)
                advance = false;
            else if ((nextIndex = transferIndex) <= 0) {
                i = -1;
                advance = false;
            }
            else if (U.compareAndSwapInt
                     (this, TRANSFERINDEX, nextIndex,
                      nextBound = (nextIndex > stride ?
                                   nextIndex - stride : 0))) {
                bound = nextBound;
                i = nextIndex - 1;
                advance = false;
            }
        }
        if (i < 0 || i >= n || i + n >= nextn) {
            int sc;
            if (finishing) {
                nextTable = null;
                table = nextTab;
                sizeCtl = (n << 1) - (n >>> 1);
                return;
            }
            if (U.compareAndSwapInt(this, SIZECTL, sc = sizeCtl, sc - 1)) {
                if ((sc - 2) != resizeStamp(n) << RESIZE_STAMP_SHIFT)
                    return;
                finishing = advance = true;
                i = n; // recheck before commit
            }
        }
        else if ((f = tabAt(tab, i)) == null)
            advance = casTabAt(tab, i, null, fwd);
        else if ((fh = f.hash) == MOVED)
            advance = true; // already processed
        else {
            synchronized (f) {
                if (tabAt(tab, i) == f) {
                    Node<K,V> ln, hn;
                    if (fh >= 0) {
                        int runBit = fh & n;
                        Node<K,V> lastRun = f;
                        for (Node<K,V> p = f.next; p != null; p = p.next) {
                            int b = p.hash & n;
                            if (b != runBit) {
                                runBit = b;
                                lastRun = p;
                            }
                        }
                        if (runBit == 0) {
                            ln = lastRun;
                            hn = null;
                        }
                        else {
                            hn = lastRun;
                            ln = null;
                        }
                        for (Node<K,V> p = f; p != lastRun; p = p.next) {
                            int ph = p.hash; K pk = p.key; V pv = p.val;
                            if ((ph & n) == 0)
                                ln = new Node<K,V>(ph, pk, pv, ln);
                            else
                                hn = new Node<K,V>(ph, pk, pv, hn);
                        }
                        setTabAt(nextTab, i, ln);
                        setTabAt(nextTab, i + n, hn);
                        setTabAt(tab, i, fwd);
                        advance = true;
                    }
                    else if (f instanceof TreeBin) {
                        TreeBin<K,V> t = (TreeBin<K,V>)f;
                        TreeNode<K,V> lo = null, loTail = null;
                        TreeNode<K,V> hi = null, hiTail = null;
                        int lc = 0, hc = 0;
                        for (Node<K,V> e = t.first; e != null; e = e.next) {
                            int h = e.hash;
                            TreeNode<K,V> p = new TreeNode<K,V>
                                (h, e.key, e.val, null, null);
                            if ((h & n) == 0) {
                                if ((p.prev = loTail) == null)
                                    lo = p;
                                else
                                    loTail.next = p;
                                loTail = p;
                                ++lc;
                            }
                            else {
                                if ((p.prev = hiTail) == null)
                                    hi = p;
                                else
                                    hiTail.next = p;
                                hiTail = p;
                                ++hc;
                            }
                        }
                        ln = (lc <= UNTREEIFY_THRESHOLD) ? untreeify(lo) :
                        (hc != 0) ? new TreeBin<K,V>(lo) : t;
                        hn = (hc <= UNTREEIFY_THRESHOLD) ? untreeify(hi) :
                        (lc != 0) ? new TreeBin<K,V>(hi) : t;
                        setTabAt(nextTab, i, ln);
                        setTabAt(nextTab, i + n, hn);
                        setTabAt(tab, i, fwd);
                        advance = true;
                    }
                }
            }
        }
    }
}

为了更清晰的理解transfer方法的代码,我们把它分成五个部分去解读。

第一个部分,创建扩容后的数组

这部分代码主要做两个事情。

  • 计算每个线程处理的区间大小,默认是16。(NCPU > 1) ? (n >>> 3) / NCPU : n) < MIN_TRANSFER_STRIDE这段代码的目的是让每个CPU处理的数据区间大小相同,避免出现数据转移任务分配不均匀的现象。如果数组的长度比较小的话,默认一个CPU处理的长度是16。
  • 初始化一个新的数组nt赋值给nextTab,该数组的长度是原来长度的n << 1,并且初始化一个transferIndex,默认值为老的数组长度。
private final void transfer(Node<K,V>[] tab, Node<K,V>[] nextTab) {
    int n = tab.length, stride;
    if ((stride = (NCPU > 1) ? (n >>> 3) / NCPU : n) < MIN_TRANSFER_STRIDE)
        stride = MIN_TRANSFER_STRIDE; // subdivide range
    if (nextTab == null) {            // initiating
        try {
            @SuppressWarnings("unchecked")
            Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n << 1];
            nextTab = nt;
        } catch (Throwable ex) {      // try to cope with OOME
            sizeCtl = Integer.MAX_VALUE;
            return;
        }
        nextTable = nextTab;
        transferIndex = n;
    }
}
第二个部分,数据迁移区间计算

这部分代码,通过while(advance)循环计算每个线程需要进行数据迁移的数组区间。笔者在前面提到过,如果根据数组长度计算出来的每个CPU处理的区间数小于16的情况下,会设置默认的区间是16,假设数组长度是64,但是只有两个线程在并行做数据迁移时,那这两个线程就需要执行多次区间迁移。

private final void transfer(Node<K,V>[] tab, Node<K,V>[] nextTab) {
    //省略部分代码....
    int nextn = nextTab.length;
    ForwardingNode<K,V> fwd = new ForwardingNode<K,V>(nextTab);
    boolean advance = true;
    boolean finishing = false; // to ensure sweep before committing nextTab
    for (int i = 0, bound = 0;;) {
        Node<K,V> f; int fh;
        while (advance) {
            int nextIndex, nextBound;
            if (--i >= bound || finishing)
                advance = false;
            else if ((nextIndex = transferIndex) <= 0) {
                i = -1;
                advance = false;
            }
            else if (U.compareAndSwapInt
                     (this, TRANSFERINDEX, nextIndex,
                      nextBound = (nextIndex > stride ?
                                   nextIndex - stride : 0))) {
                bound = nextBound;
                i = nextIndex - 1;
                advance = false;
            }
        }
        //省略部分代码....
    }
}

上面这段代码,有一些关键的东西需要简单分析一下。

  • ForwardingNode这个表示一个正在被迁移的Node,当原数组中位置x节点的数据完成迁移后,会对x位置设置一个ForwardingNode表示该位置已经处理过了。
  • advance字段是用来判断是否还有待处理的数据迁移工作。
  • while循环中的方法就是用来计算区间,假设当前数组长度是32位,需要扩容到64位,此时transferIndex=32nextn=64 n=32

    • 第一次循环,i=0nextIndex=32。进入到U.CompareAndSwapInt方法,修改transferIndex的值,如果transferIndex==nextIndex, 则把transferIndex修改为16nextBound=16. 此时bound=16. i=31,当前线程负责迁移的数组区间为[16,31]
    • 第二次循环,--i=30nextIndex=16transferIndex=16,进入到U.compareAndSwapIndex,修改transferIndex的值为0nextBound=0bound=0i=nextIndex-1=15,当前线程负责迁移的数组区间为[0,15]。

    每次循环,都是通过if (--i >= bound || finishing)来判断数组区间是否分配完成,也就是说,数组从高往低进行迁移,比如第一次循环,处理的区间是[16,31], 那么就会从31位开始往前进行遍历,对每个链表进行数据转移。

第三个部分,更新扩容标记

这部分主要是判断逻辑,有两个点。

  • 如果i所在位置的Node为空,说明当前没有数据,不需要迁移,直接通过casTabAt修改成fwd占位即可。
  • 如果i位置所在的Node数据的hash值为MOVED,说明当前节点已经被迁移过了,继续往下遍历。
private final void transfer(Node<K,V>[] tab, Node<K,V>[] nextTab) {
    //省略部分代码....
    else if ((f = tabAt(tab, i)) == null)
        advance = casTabAt(tab, i, null, fwd);
    else if ((fh = f.hash) == MOVED)
        advance = true; // already processed
    //省略部分代码....
}
第四个部分,开始数据迁移和扩容

这部分内容就是真正实现数据迁移的逻辑,代码比较长,从大的层面来说就两块。

  • 首先对当前要迁移的节点f增加同步锁synchronized,避免多线程竞争。
  • fh>=0表示f节点为链表或者普通节点,则按照链表或者普通节点的方式来进行数据迁移。
  • f instanceof TreeBin表示f节点为红黑树,按照红黑树的规则进行数据迁移,这里需要注意的是,数据迁移之后可能会存在红黑树转化成链表的情况,就是当链表长度小于等于6的时候,就会转化为链表。
private final void transfer(Node<K,V>[] tab, Node<K,V>[] nextTab) {
    //省略部分代码....
    for (int i = 0, bound = 0;;) {
        //省略部分代码....
        synchronized (f) {
            if (tabAt(tab, i) == f) {
                Node<K,V> ln, hn;
                if (fh >= 0) {
                    int runBit = fh & n;
                    Node<K,V> lastRun = f;
                    for (Node<K,V> p = f.next; p != null; p = p.next) {
                        int b = p.hash & n;
                        if (b != runBit) {
                            runBit = b;
                            lastRun = p;
                        }
                    }
                    if (runBit == 0) {
                        ln = lastRun;
                        hn = null;
                    }
                    else {
                        hn = lastRun;
                        ln = null;
                    }
                    for (Node<K,V> p = f; p != lastRun; p = p.next) {
                        int ph = p.hash; K pk = p.key; V pv = p.val;
                        if ((ph & n) == 0)
                            ln = new Node<K,V>(ph, pk, pv, ln);
                        else
                            hn = new Node<K,V>(ph, pk, pv, hn);
                    }
                    setTabAt(nextTab, i, ln);
                    setTabAt(nextTab, i + n, hn);
                    setTabAt(tab, i, fwd);
                    advance = true;
                }
                else if (f instanceof TreeBin) {
                    //如果当前节点是红黑树,则按照红黑树的处理逻辑进行迁移。
                }
            }
        }
    }
}

上述代码其实也包含一个比较有意思的设计,就是用到了高低位整体迁移的方式,来提升迁移效率,在分析上述代码之前,先来了解一下什么是高低位迁移。

假设存在这样一个数据存储的结构,如下图所示,在数组下标为4的位置,存在一条由链表组成的节点,其中节点上这些数字表示的是key对应的hash码。

在这里插入图片描述

上述这些hash值4、20、52、68、84、100,他们是怎么计算并且放在数组下标4的位置呢?我们回到putVal方法上可以看到,当前key是通过这个方法tabAt(tab, i = (n - 1) & hash)去数组中查找的,关键的逻辑是(n-1)&hash

final V putVal(K key, V value, boolean onlyIfAbsent) {
    //省略代码....
    int hash = spread(key.hashCode());
    int binCount = 0;
    for (Node<K,V>[] tab = table;;) {
        Node<K,V> f; int n, i, fh;
        if (tab == null || (n = tab.length) == 0)
            tab = initTable();
        else if ((f = tabAt(tab, i = (n - 1) & hash)) == null) { //查找逻辑
            if (casTabAt(tab, i, null,
                         new Node<K,V>(hash, key, value, null)))
                break;                   // no lock when adding to empty bin
        }
        //省略代码....
    }
    //省略代码....
}

我们仔细观察(n-1)&hash这个逻辑,它有一个动态变化的因素n(数组长度),也就是说,随着n的值的变化,原本存储在数组下标4位置的key,在扩容之后计算的下标位置也会变化。

举例来说,在如上图所示的链表中,4、20、52、68、84、100这些hash值,在数组长度为16位的情况下,通过(n-1)&hash得到的下标位置都是4。但是当数组长度扩容到32位时,再通过(n-1)&hash来计算,发现20、52、84这三个hash值对应的下标位置都变成了20,其他值4、68、100计算得到的数组下标位置仍然是4。这就意味着,由一个链表组成的节点中,有可能存在一部分节点在扩容后不需要迁移,一部分节点在扩容后需要迁移的情况。

因此,所谓的高低位迁移,表示的就是上述这种情况,而所谓的低位就是指不需要迁移的元素、高位是表示需要迁移的元素。

继续回到transfer代码的高低位迁移逻辑中来,这里有一个比较有意思的设计,就是通过一定的规则计算出两条链ln(低位链)hn(高位链),然后把这两条链表一次性迁移到新的数组中,这样的方式减少了数据迁移次数。

if (tabAt(tab, i) == f) {
    Node<K,V> ln, hn;
    if (fh >= 0) {
        int runBit = fh & n;
        Node<K,V> lastRun = f;
        for (Node<K,V> p = f.next; p != null; p = p.next) {
            int b = p.hash & n;
            if (b != runBit) {
                runBit = b;
                lastRun = p;
            }
        }
        if (runBit == 0) {
            ln = lastRun;
            hn = null;
        }
        else {
            hn = lastRun;
            ln = null;
        }
        for (Node<K,V> p = f; p != lastRun; p = p.next) {
            int ph = p.hash; K pk = p.key; V pv = p.val;
            if ((ph & n) == 0)
                ln = new Node<K,V>(ph, pk, pv, ln);
            else
                hn = new Node<K,V>(ph, pk, pv, hn);
        }
        setTabAt(nextTab, i, ln);
        setTabAt(nextTab, i + n, hn);
        setTabAt(tab, i, fwd);
        advance = true;
    }

上述代码中,主要分析一下高低位链路的计算方法:

  • 通过for循环遍历当前节点链表计算出当前链表最后一个需要迁移或者不需要迁移的节点位置。遍历每一个节点通过p.hash&n计算一个值,这个值有两个结果,一个是等于0,表示需要迁移的数据,一个是大于0,表示不需要迁移的数据。

    for (Node<K,V> p = f.next; p != null; p = p.next) {
        int b = p.hash & n;
        if (b != runBit) {
            runBit = b;
            lastRun = p;
        }
    }

    为了更好的帮助大家理解,我把前面的那个链表通过上面的计算用图形的方式表达如下,runBit针对头部节点计算得到的值是0,根据不断循环计算最终找到最后高位或者低位的位置所在的节点是100

    需要注意,这里说的最后一位,不是指真正意义上的最后一位,而是指节点中后续不存在高低位变化的节点的最早一个节点。假设在下图中100这个节点后面还存在runBit=0的节点,此时返回的lastRun仍然是100对应的节点。之所以这么设计是因为后续如果不存在需要迁移的节点时,那么它本身就是一个链,不需要再次遍历处理,减少遍历次数。

    在这里插入图片描述

  • 通过runBit进行判断,当前链表中最后一个节点是属于高位还是低位,如果runBit==0表示低位,则把lastRun赋值给ln低位链。否则,赋值给hn高位链。

    if (runBit == 0) {
        ln = lastRun;
        hn = null;
    }
    else {
        hn = lastRun;
        ln = null;
    }

    此时, ln=lastRun=hash值100对应的节点,hn=null。

  • 再一次遍历整个链表,把原本的链表构建出高低链。

    for (Node<K,V> p = f; p != lastRun; p = p.next) {
        int ph = p.hash; K pk = p.key; V pv = p.val;
        if ((ph & n) == 0)
            ln = new Node<K,V>(ph, pk, pv, ln);
        else
            hn = new Node<K,V>(ph, pk, pv, hn);
    }

    通过上述代码执行之后,高低位拆分情况如下图所示。

    在这里插入图片描述

  • 最后,把低位链设置到扩容后的数组i位置,高位链设置到i+n的位置。

    setTabAt(nextTab, i, ln);
    setTabAt(nextTab, i + n, hn);
    setTabAt(tab, i, fwd);
    advance = true;

至此,就完成了扩容以及基于链表结构下的数据迁移工作,整体原理如下图所示。

在这里插入图片描述

第五部分,完成迁移后的判断

在transfer方法中还有一部分代码,就是用来判断是否完成扩容,以及扩容完成之后的后置处理,代码如下。

if (i < 0 || i >= n || i + n >= nextn) {
    int sc;
    if (finishing) {
        nextTable = null;
        table = nextTab;
        sizeCtl = (n << 1) - (n >>> 1);
        return;
    }
    if (U.compareAndSwapInt(this, SIZECTL, sc = sizeCtl, sc - 1)) {
        if ((sc - 2) != resizeStamp(n) << RESIZE_STAMP_SHIFT)
            return;
        finishing = advance = true;
        i = n; // recheck before commit
    }
}

这部分代码有两个逻辑。

  • 如果数据迁移工作完成,则把扩容后的数组赋值给table
  • 如果还未完成,说明还有其他线程正在执行中,所以当前线程通过U.compareAndSwapInt(this, SIZECTL, sc = sizeCtl, sc - 1)修改并发扩容的线程数量(这部分代码在前面章节中分析过了,sizeCtl低16位会记录并发扩容线程数量),如果(sc - 2) != resizeStamp(n) << RESIZE_STAMP_SHIFT满足,说明没有线程在协助扩容,也就是说扩容结束了。

跟着Mic学架构
810 声望1.1k 粉丝

《Spring Cloud Alibaba 微服务原理与实战》、《Java并发编程深度理解及实战》作者。 咕泡教育联合创始人,12年开发架构经验,对分布式微服务、高并发领域有非常丰富的实战经验。