1

数据结构

JDK1.7 ConcurrentHashMap基于数组+链表,包括一个Segment数组,每个Segment中是又是一个数组+链表的数据结构(相当于一个HashMap),数组和链表存储的是一个个HashEntry对象
image.png

    static final class Segment<K,V> extends ReentrantLock implements Serializable {
        private static final long serialVersionUID = 2249069246763182397L;
        static final int MAX_SCAN_RETRIES =
            Runtime.getRuntime().availableProcessors() > 1 ? 64 : 1;
        transient volatile HashEntry<K,V>[] table;
        transient int count;
        transient int modCount;
        transient int threshold;
        final float loadFactor;
    }

    static final class HashEntry<K,V> {
        final int hash;
        final K key;
        volatile V value;
        volatile HashEntry<K,V> next;
    }

常用方法

使用

源码分析

主要属性

    //默认的容量大小,即HashEntry中数组的容量之和,初始化时会平均分配到每个Segment中的HashEntry数组
    static final int DEFAULT_INITIAL_CAPACITY = 16;
    //默认加载因子
    static final float DEFAULT_LOAD_FACTOR = 0.75f;
    //默认的并发级别,决定了Segment数组的长度
    static final int DEFAULT_CONCURRENCY_LEVEL = 16;
    //最大容量
    static final int MAXIMUM_CAPACITY = 1 << 30;
    //每个Segment中的HashEntry数组最小容量
    static final int MIN_SEGMENT_TABLE_CAPACITY = 2;
    //Segment的最大数量=65536
    static final int MAX_SEGMENTS = 1 << 16;
    //重试次数
    static final int RETRIES_BEFORE_LOCK = 2;

构造方法

    public ConcurrentHashMap(int initialCapacity, float loadFactor) {
        this(initialCapacity, loadFactor, DEFAULT_CONCURRENCY_LEVEL);
    }

    public ConcurrentHashMap(int initialCapacity) {
        this(initialCapacity, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
    }

    public ConcurrentHashMap() {
        this(DEFAULT_INITIAL_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
    }

    public ConcurrentHashMap(Map<? extends K, ? extends V> m) {
        this(Math.max((int) (m.size() / DEFAULT_LOAD_FACTOR) + 1,
                      DEFAULT_INITIAL_CAPACITY),
             DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
        putAll(m);
    }

    @SuppressWarnings("unchecked")
    public ConcurrentHashMap(int initialCapacity,
                             float loadFactor, int concurrencyLevel) {
        if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
            throw new IllegalArgumentException();
        if (concurrencyLevel > MAX_SEGMENTS)
            concurrencyLevel = MAX_SEGMENTS;
        // Find power-of-two sizes best matching arguments
        int sshift = 0;
        int ssize = 1;
        //ssize即为Segment数组的长度,默认concurrencyLevel=16,即ssize=Segment数组的长度=16
        while (ssize < concurrencyLevel) {
            ++sshift;
            ssize <<= 1; //乘以2
        }
        this.segmentShift = 32 - sshift;
        this.segmentMask = ssize - 1;
        if (initialCapacity > MAXIMUM_CAPACITY)
            initialCapacity = MAXIMUM_CAPACITY;
        int c = initialCapacity / ssize;
        if (c * ssize < initialCapacity)
            ++c;
        int cap = MIN_SEGMENT_TABLE_CAPACITY;
        //cap即每个Segment中的HashEntry数组的长度,即cap=每个Segment中的HashEntry数组的长度=2
        while (cap < c)
            cap <<= 1;
        //将Segment数组初始化长度为16并且只填充第0个元素,默认大小为2,负载因子 0.75,扩容阀值是 2*0.75=1.5,插入第二个值时才会进行扩容
        Segment<K,V> s0 =
            new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
                             (HashEntry<K,V>[])new HashEntry[cap]);
        Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
        UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0]
        this.segments = ss;
    }

put()方法

 public V put(K key, V value) {
        Segment<K,V> s;
        if (value == null)
            throw new NullPointerException();
        // 1. 根据key值,通过hash()计算出对应的hash值
        // 2. 根据hash值计算出对应的segment数组下标
        int hash = hash(key);
        int j = (hash >>> segmentShift) & segmentMask;
        if ((s = (Segment<K,V>)UNSAFE.getObject          // nonvolatile; recheck
             (segments, (j << SSHIFT) + SBASE)) == null) //  in ensureSegment
            //3.如果segment[j]==null,初始化segment[j]
            s = ensureSegment(j);
               //4.往segment[j]添加key-value
        return s.put(key, hash, value, false);
    }


        final V put(K key, int hash, V value, boolean onlyIfAbsent) {
            //tryLock尝试加锁,如果加锁成功,返回null,否则执行scanAndLockForPut尝试自旋加锁
            HashEntry<K,V> node = tryLock() ? null :
                scanAndLockForPut(key, hash, value);
            V oldValue;
            try {

                // 1. 根据key值,通过hash()计算出对应的hash值
                // 2. 根据hash值计算出对应的HashEntry数组下标
                HashEntry<K,V>[] tab = table;
                int index = (tab.length - 1) & hash;
                HashEntry<K,V> first = entryAt(tab, index);
                //通过遍历以该数组元素为头结点的链表
                for (HashEntry<K,V> e = first;;) {
                    //若头结点存在,遍历链表,若该key已存在,则用新value替换旧value
                    if (e != null) {
                        K k;
                        if ((k = e.key) == key ||
                            (e.hash == hash && key.equals(k))) {
                            oldValue = e.value;
                            if (!onlyIfAbsent) {
                                e.value = value;
                                ++modCount;
                            }
                            break;
                        }
                        e = e.next;
                    }
                    //若头节点不存在或已经遍历到了链表尾部
                    else {
                        //若node不为null,将node添加到HashEntry数组中,这里采用头插法
                        if (node != null)
                            node.setNext(first);
                        //若node为null,将node初始化后添加到HashEntry数组中,这里采用头插法
                        else
                            node = new HashEntry<K,V>(hash, key, value, first);
                        int c = count + 1;
                        //键值对数量size > 最大容量threshold
                        if (c > threshold && tab.length < MAXIMUM_CAPACITY)
                            //扩容
                            rehash(node);
                        else
                            setEntryAt(tab, index, node);
                        ++modCount;
                        count = c;
                        oldValue = null;
                        break;
                    }
                }
            } finally {
                //解锁
                unlock();
            }
            return oldValue;
        }

        //不断用tryLock()自旋进行加锁,若达到自旋次数则调用lock()阻塞获取锁
        private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
            HashEntry<K,V> first = entryForHash(this, hash);
            HashEntry<K,V> e = first;
            HashEntry<K,V> node = null;
            int retries = -1; // negative while locating node
            while (!tryLock()) {
                HashEntry<K,V> f; // to recheck first below
                if (retries < 0) {
                    if (e == null) {
                        if (node == null) // speculatively create node
                            node = new HashEntry<K,V>(hash, key, value, null);
                        retries = 0;
                    }
                    else if (key.equals(e.key))
                        retries = 0;
                    else
                        e = e.next;
                }
                else if (++retries > MAX_SCAN_RETRIES) {
                    lock();
                    break;
                }
                else if ((retries & 1) == 0 &&
                         (f = entryForHash(this, hash)) != first) {
                    e = first = f; // re-traverse if entry changed
                    retries = -1;
                }
            }
            return node;
        }

rehash()方法

        //HashEntry数组扩容为原来的两倍。老数组里的数据移动到新数组时,位置要么不变,要么变为 index+ oldSize,使用头插法插入到新数组
        private void rehash(HashEntry<K,V> node) {
            HashEntry<K,V>[] oldTable = table;
            int oldCapacity = oldTable.length;
            int newCapacity = oldCapacity << 1;
            threshold = (int)(newCapacity * loadFactor);
            HashEntry<K,V>[] newTable =
                (HashEntry<K,V>[]) new HashEntry[newCapacity];
            int sizeMask = newCapacity - 1;
            for (int i = 0; i < oldCapacity ; i++) {
                HashEntry<K,V> e = oldTable[i];
                if (e != null) {
                    HashEntry<K,V> next = e.next;
                    int idx = e.hash & sizeMask;
                    if (next == null)   //  Single node on list
                        newTable[idx] = e;
                    else { // Reuse consecutive sequence at same slot
                        HashEntry<K,V> lastRun = e;
                        int lastIdx = idx;
                        for (HashEntry<K,V> last = next;
                             last != null;
                             last = last.next) {
                            int k = last.hash & sizeMask;
                            if (k != lastIdx) {
                                lastIdx = k;
                                lastRun = last;
                            }
                        }
                        newTable[lastIdx] = lastRun;
                        // Clone remaining nodes
                        for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
                            V v = p.value;
                            int h = p.hash;
                            int k = h & sizeMask;
                            HashEntry<K,V> n = newTable[k];
                            newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
                        }
                    }
                }
            }
            int nodeIndex = node.hash & sizeMask; // add the new node
            node.setNext(newTable[nodeIndex]);
            newTable[nodeIndex] = node;
            table = newTable;
        }

get()方法

    //由于 HashEntry 中的 value 属性是用 volatile 关键词修饰的,保证了内存可见性,所以每次获取时都是最新值。ConcurrentHashMap 的 get 方法是非常高效的,因为整个过程都不需要加锁。
    public V get(Object key) {
        Segment<K,V> s; // manually integrate access methods to reduce overhead
        HashEntry<K,V>[] tab;
        // 1. 根据key值,通过hash()计算出对应的hash值
        int h = hash(key);
        // 2. 根据hash值计算出对应的segment数组下标,得到segment数组
        long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
        if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
            (tab = s.table) != null) {
            3. 根据hash值计算出对应的HashEntry数组下标,得到HashEntry数组,遍历数组
            for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
                     (tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
                 e != null; e = e.next) {
                K k;
                //4.找到对应的key,返回value
                if ((k = e.key) == key || (e.hash == h && key.equals(k)))
                    return e.value;
            }
        }
        return null;
    }

size()方法

    //计算两次,如果不变则返回计算结果,若不一致,则锁住所有的Segment求和
    public int size() {
        // Try a few times to get accurate count. On failure due to
        // continuous async changes in table, resort to locking.
        final Segment<K,V>[] segments = this.segments;
        int size;
        boolean overflow; // true if size overflows 32 bits
        long sum;         // sum of modCounts
        long last = 0L;   // previous sum
        int retries = -1; // first iteration isn't retry
        try {
            for (;;) {
                if (retries++ == RETRIES_BEFORE_LOCK) {
                    for (int j = 0; j < segments.length; ++j)
                        ensureSegment(j).lock(); // force creation
                }
                sum = 0L;
                size = 0;
                overflow = false;
                for (int j = 0; j < segments.length; ++j) {
                    Segment<K,V> seg = segmentAt(segments, j);
                    if (seg != null) {
                        sum += seg.modCount;
                        int c = seg.count;
                        if (c < 0 || (size += c) < 0)
                            overflow = true;
                    }
                }
                if (sum == last)
                    break;
                last = sum;
            }
        } finally {
            if (retries > RETRIES_BEFORE_LOCK) {
                for (int j = 0; j < segments.length; ++j)
                    segmentAt(segments, j).unlock();
            }
        }
        return overflow ? Integer.MAX_VALUE : size;
    }

总结

1.JDK1.7 ConcurrentHashMap基于数组+链表,包括一个Segment数组,每个Segment中是又是一个数组+链表的数据结构(相当于一个HashMap),数组和链表存储的是一个个HashEntry对象
2.Segment继承于ReentrantLock,理论上 ConcurrentHashMap支持CurrencyLevel(Segment数组数量)的线程并发。每当一个线程占用锁访问一个Segment时,不会影响到其他的Segment。
3.添加key-value时会根据key值计算出对应的hash值,根据hash值计算出对应的segment数组下标,对这个segment使用tryLock尝试加锁,如果加锁失败,执行scanAndLockForPut尝试自旋加锁直到成功;后续流程与HashMap相同。
4.由于HashEntry中的value属性是用volatile关键词修饰的,保证了内存可见性,所以每次获取时都是最新值。ConcurrentHashMap的get方法是非常高效的,因为整个过程都不需要加锁。


WillLiaowh
71 声望8 粉丝

世界上最伟大的力量是坚持。