JDK7的ConcurrentHashMap源码分析总结

JDK7的ConcurrentHashMap源码分析总结

@author:Jingdai
@date:2021.07.15

花了一天多时间把JDK7的ConcurrentHashMap源码研究了一下,现总结如下。由于水平有限,不免有错误,欢迎大佬指正。

整体思路

ConcurrentHashMap名字太长,后文用CHM代替。不同于HashTable用一把锁加锁所有的桶,CHM使用分段锁,每次加锁只锁整个Map的一部分,这样就大大提高了并发量,当不同的线程想同时修改CHM的不同部分时,不会阻塞。CHM中用Segment表示每个加锁单元,可以把这个Segment当成HashTable,即可以简单的理解为CHM等于多个HashTable,当然CHM做了很多优化。

在初始化CHM的时候,会指定Segment的数量,这个根据传入的 concurrencyLevel 计算得出(后面会讲),在算出这个Segment的数量后,Segment的数量就不会再发生变化了,随着元素的增多,需要rehash的时候也仅仅会在每个Segment中扩容,Segment并不会增多。

在找每个key对应的桶的时候,计算出hash后,利用hash的高位来计算这个key在哪一个Segment,再利用它的低位来计算它在Segment的哪一个桶中。

源码中利用了很多UNSAFE类的操作,这个我也不是很清楚,但是大概意思就是直接从主内存中去取最新的值,而不会从线程的工作线程中取,想要了解细节的可以看看别的博客。(狗头)CHM在读的时候不会加锁,直接利用UNSAFE去主内存中取值,而CHM在写的时候就是加锁去写。

主要类结构

CHM 内部有两个重要的类,Segment 和 HashEntry,下面分别来看一下。

ConcurrentHashMap 类

public class ConcurrentHashMap<K, V> extends AbstractMap<K, V>
        implements ConcurrentMap<K, V>, Serializable {
    
    // 默认初始容量是16, 和 HashMap 一样
    static final int DEFAULT_INITIAL_CAPACITY = 16;
    
    // 默认加载因子是 0.75, 和 HashMap 一样
    static final float DEFAULT_LOAD_FACTOR = 0.75f;
    
    // 默认并发级别时16,这个并发级别决定了Segment数组的长度
    static final int DEFAULT_CONCURRENCY_LEVEL = 16;
    
    // 每个Segment的table的最小容量
    static final int MIN_SEGMENT_TABLE_CAPACITY = 2;
    
    // 用于确定哪一个Segment的掩码值,hash的高几位用于选择Segment
    final int segmentMask;
    
    // 确定哪一个Segment的时候 hash偏移的位数
    final int segmentShift;
    
    // 最重要的属性,每个Segment可以看出一个Hash表,键值对是存在Segment中的
    final Segment<K,V>[] segments;
    
    // XXX 
}

上面仅仅列出了看源码比较重要的属性,其他属性没有列出。

Segment 类

static final class Segment<K,V> extends ReentrantLock implements Serializable {
    
    // 真正存储数据的数组,每个Segment独自有一个 table 存储键值对
    transient volatile HashEntry<K,V>[] table;
    
    // table 中元素的个数
    transient int count;
    
    // table 修改的次数
    transient int modCount;
    
    // 扩容阈值,它等于 capacity * loadFactor
    // capacity 就是数组 table 的长度
    transient int threshold;
    
    // 负载因子,每个Segment都相同
    final float loadFactor;
    
    // xxx
}   

从上面可以看出,Segment内部自己维护一个Hash表,它有自己的扩容阈值和负载因子,负载因子每个Segment 都相同,之后扩容的时候是每个Segment自己扩容,不会影响到 CHM 其他的 Segment。同时,可以看出Segment 是继承自 ReentrantLock接口的,所以相当于每个Segment自己有一把锁,想要对Segment进行修改的时候需要先得到这个锁。

HashEntry 类

static final class HashEntry<K,V> {
    final int hash;
    final K key;
    volatile V value;
    volatile HashEntry<K,V> next;

    // xxx
}

这个类就没什么要说的了,就是真正的 Entry对象,和HashMap没什么区别。

细节

  • segmentShift 和 segmentMask 作用?

    看构造函数。

    @SuppressWarnings("unchecked")
    public ConcurrentHashMap(int initialCapacity,
                             float loadFactor, int concurrencyLevel) {
    
        // xxx
        if (concurrencyLevel > MAX_SEGMENTS)
            concurrencyLevel = MAX_SEGMENTS;
        // Find power-of-two sizes best matching arguments
        int sshift = 0;
        int ssize = 1;
        while (ssize < concurrencyLevel) {
            ++sshift;
            ssize <<= 1;
        }
        this.segmentShift = 32 - sshift;
        this.segmentMask = ssize - 1;
        // xxx
        
        Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
       // xxx
    }
    

    为了看得更清楚,删除其他代码,这个 concurrencyLevel 是用来计算 segment 数组的长度ssize,因为求下标用的是 & 操作,所以数组长度必须是2的次幂,所以不能直接用 concurrencyLevel 来当数组长度,得到ssize 是一个大于 concurrencyLevel 的最小的2的次幂。这个 segmentMask 就简单一点了,就是ssize - 1,这样方便后面求元素在哪个segment中,即 x % 2n = x & (2n-1) = x & segmentMask,这点和HashMap思路一致。

    这个 segmentShift 就麻烦一点了,它等于 32 - sshift。假设现在算出来 ssize 是16,那么sshift 就是 4,那么segmentShift 就是28,segmentMask 就是就是 15,我们知道 int一共有32位,假设如下hash。

    hash:10101101  10101101 10101101 10101101
    

    我们要取高 sshift (4)位的值的话就是正好将 hash 右移segmentShift (28) 位再和 segmentMask 按位&。即

    高 sshift 位 % ssize = 高 sshift 位 & segmentMask = (hash >>> segmentShift) & segmentMask,后面会知道这个就是用来求一个 key 是属于哪一个segment的,即用 hash 的高 sshift 位来求 key 在哪一个segment中。

  • map 加入元素时放哪一个 segment?又在 segment 的 table 的哪一个index?

    @SuppressWarnings("unchecked")
    public V put(K key, V value) {
        Segment<K,V> s;
        if (value == null)
            throw new NullPointerException();
        int hash = hash(key);
        int j = (hash >>> segmentShift) & segmentMask;
        if ((s = (Segment<K,V>)UNSAFE.getObject          // nonvolatile; recheck
             (segments, (j << SSHIFT) + SBASE)) == null) //  in ensureSegment
            s = ensureSegment(j);
        return s.put(key, hash, value, false);
    }
    

    由前面讲的 int j = (hash >>> segmentShift) & segmentMask; 就是取高 sshift 位 % ssize,即由高sshift 位取余来确定在哪一个segment中。其中 UNSAFE.getObject(segments, (j << SSHIFT) + SBASE) 是去取segments的第 j 个元素。

    上面又会去调用 segment 的put 方法,如下。

    final V put(K key, int hash, V value, boolean onlyIfAbsent) {
        HashEntry<K,V> node = tryLock() ? null :
        scanAndLockForPut(key, hash, value);
        V oldValue;
        try {
            HashEntry<K,V>[] tab = table;
            int index = (tab.length - 1) & hash;
            // xxx
        } finally {
            unlock();
        }
        return oldValue;
    }
    

    删去了其它代码,可以看出 index 是通过 hash 对 table的length取余得到,这里并没有移位,所以是用hash的低位得到 key 在 table中的下标。

    即通过hash的高 sshift 位取余来确定在哪一个segment中,低位取余得到在segment 的 table的哪个下标中。

  • 初始Segment中的table的大小?

    看构造函数,去掉其它多余的代码。

    @SuppressWarnings("unchecked")
    public ConcurrentHashMap(int initialCapacity,
                             float loadFactor, int concurrencyLevel) {
        // xxx
        
        if (initialCapacity > MAXIMUM_CAPACITY)
            initialCapacity = MAXIMUM_CAPACITY;
        int c = initialCapacity / ssize;
        // 向上取整
        if (c * ssize < initialCapacity)
            ++c;
        int cap = MIN_SEGMENT_TABLE_CAPACITY;
        while (cap < c)
            cap <<= 1;
        // create segments and segments[0]
        Segment<K,V> s0 =
            new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
                             (HashEntry<K,V>[])new HashEntry[cap]);
        
        // xxx
    }
    

    前面已经分析过了,ssize就是segments数组的长度,就是有几个segment。而cap是初始的Segment中table的大小,现在就是求cap怎么得到。首先用 initialCapacity / ssize 得到每个segment中分到的元素个数 c,但是Java中整数除法都是截断后面的小数,所以如果截断了需要向上取整。然后找出大于 c 的最小的2的次幂就得到cap,即 table的大小。

  • 构造函数中初始化第0个segment的原因?

    把这个segment作为prototype,以后再创建segment时按这个segment的属性进行创建。代码如下。

    构造函数

    @SuppressWarnings("unchecked")
    public ConcurrentHashMap(int initialCapacity,
                             float loadFactor, int concurrencyLevel) {
       
        // xxx
        // create segments and segments[0]
        // 仅仅初始化了segments[0]
        Segment<K,V> s0 =
            new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
                             (HashEntry<K,V>[])new HashEntry[cap]);
        Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
        UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0]
        this.segments = ss;
    }
    

    ensureSegment函数

    @SuppressWarnings("unchecked")
    private Segment<K,V> ensureSegment(int k) {
        final Segment<K,V>[] ss = this.segments;
        long u = (k << SSHIFT) + SBASE; // raw offset
        Segment<K,V> seg;
        if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
            // 将 segment 0 作为模板来创建其他的 segment
            Segment<K,V> proto = ss[0]; // use segment 0 as prototype
            int cap = proto.table.length;
            float lf = proto.loadFactor;
            int threshold = (int)(cap * lf);
            HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
            if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
                == null) { // recheck
                Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
                while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
                       == null) {
                    if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
                        break;
                }
            }
        }
        return seg;
    }
    
  • 如何线程安全的创建segment?

    源码中使用CAS来保证安全的创建segment。使用的是 ensureSegment 函数。

    ensureSegment 函数

    @SuppressWarnings("unchecked")
    private Segment<K,V> ensureSegment(int k) {
        final Segment<K,V>[] ss = this.segments;
        long u = (k << SSHIFT) + SBASE; // raw offset
        Segment<K,V> seg;
        if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
            Segment<K,V> proto = ss[0]; // use segment 0 as prototype
            int cap = proto.table.length;
            float lf = proto.loadFactor;
            int threshold = (int)(cap * lf);
            HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
            if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
                == null) { // recheck
                Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
                // 不断重试
                while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
                       == null) {
                    // 利用CAS创建
                    if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
                        break;
                }
            }
        }
        return seg;
    }
    

    如上面代码,首先判断 k 位置的 segment 是否还是null,如果还是 null就根据第0个segment创建一个新的HashEntry数组,否则直接返回(说明其他线程已经创建了segment)。接着在此判断是否为 null,如果还是的话才真的创建一个segment,并利用自旋加CAS重试的方式设置这个 segment。

  • 如何线程安全的加入元素?

    源码通过加锁的方式线程安全的put元素。前面讲了通过CAS安全的创建segment,而 map 的 put 操作在得到对应的 segment 后,会调用 segment 的 put 方法去加入元素。加入时使用 ReentrantLock 加锁的方式保证线程安全。前面已经说了 Segment 继承了 ReentrantLock ,所以可以直接加锁。代码如下。

    final V put(K key, int hash, V value, boolean onlyIfAbsent) {
        // 尝试加锁,加锁成功直接 node 为 null
        // 否则调用 scanAndLockForPut 方法(后面会讲)
        // 这里知道要不加锁成功,要不调用 scanAndLockForPut
        // 方法使得加锁成功,所以执行完第一句一定会加锁成功
        HashEntry<K,V> node = tryLock() ? null :
        scanAndLockForPut(key, hash, value);
        V oldValue;
        try {
            HashEntry<K,V>[] tab = table;
            // 求下标
            int index = (tab.length - 1) & hash;
            HashEntry<K,V> first = entryAt(tab, index);
            for (HashEntry<K,V> e = first;;) {
                if (e != null) {
                    K k;
                    // 判断是否已经有了这个 key
                    if ((k = e.key) == key ||
                        (e.hash == hash && key.equals(k))) {
                        oldValue = e.value;
                        if (!onlyIfAbsent) {
                            e.value = value;
                            ++modCount;
                        }
                        break;
                    }
                    e = e.next;
                }
                else {
                    // 没有这个key,进行添加操作(头插法)
                    if (node != null)
                        node.setNext(first);
                    else
                        node = new HashEntry<K,V>(hash, key, value, first);
                    int c = count + 1;
                    if (c > threshold && tab.length < MAXIMUM_CAPACITY)
                        rehash(node);
                    else
                        setEntryAt(tab, index, node);
                    ++modCount;
                    count = c;
                    oldValue = null;
                    break;
                }
            }
        } finally {
            unlock();
        }
        return oldValue;
    }
    

    除了最开始的加锁操作,后面的基本和 JDK7 的 HashMap的put操作思路一致,都是头插法。下面看这个加锁的函数。

    Segment 的 scanAndLockForPut 方法

    private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
        HashEntry<K,V> first = entryForHash(this, hash);
        HashEntry<K,V> e = first;
        HashEntry<K,V> node = null;
        int retries = -1; // negative while locating node
        while (!tryLock()) {
            HashEntry<K,V> f; // to recheck first below
            // 首先retries = -1,必须要等遍历一遍这个链表,发现没有
            // 这个key,则创建这个node,或者遍历过程中发现有这个key,不需
            // 要创建这个node时,才将retries改为0,进到其他分支
            if (retries < 0) {
                if (e == null) {
                    if (node == null) // speculatively create node
                        node = new HashEntry<K,V>(hash, key, value, null);
                    retries = 0;
                }
                else if (key.equals(e.key))
                    retries = 0;
                else
                    e = e.next;
            }
            // 限制重试次数,当重试次数超了时阻塞获得锁
            else if (++retries > MAX_SCAN_RETRIES) {
                lock();
                break;
            }
            // 偶数时判断头结点是否变化,变化时重新赋值
            else if ((retries & 1) == 0 &&
                     (f = entryForHash(this, hash)) != first) {
                e = first = f; // re-traverse if entry changed
                retries = -1;
            }
        }
        return node;
    }
    

    其中 MAX_SCAN_RETRIES 值为。

    static final int MAX_SCAN_RETRIES =
        Runtime.getRuntime().availableProcessors() > 1 ? 64 : 1;
    

    根据注释,代码的目的就是为了获得锁,首先在尝试获得锁时,先判断这个链上有没有这个key,判断完之后再进行重试次数的累加,累加次数到了指定值后就会阻塞获得锁,同时在过程中如果头节点发生了变化会重新进行赋值,然后重试。在判断这个链上有没有这个key的同时,如果发现这个链上没有这个 key,会直接 new 出这个 node 供调用的 put 方法使用,相当于在尝试获取锁的时候做了一些其它事情,没有白白浪费CPU。

  • 如何安全的扩容?

    首先注意扩容不是 CHM 的 Segment[] 数组的扩容,Segment[] 数组的大小在初始化之后就不会再发生变化了,扩容是 Segment 内部的 HashEntry<K,V>[] table 的扩容。扩容的阈值和加载因子都和HashMap一致。由于 rehash 方法仅仅在 put 方法内部被调用,所以 rehash 执行的时候一定拥有 Segment 的锁。具体代码如下。

    Segment的 rehash 方法

    @SuppressWarnings("unchecked")
    private void rehash(HashEntry<K,V> node) {
    
        HashEntry<K,V>[] oldTable = table;
        int oldCapacity = oldTable.length;
        int newCapacity = oldCapacity << 1;
        threshold = (int)(newCapacity * loadFactor);
        HashEntry<K,V>[] newTable =
            (HashEntry<K,V>[]) new HashEntry[newCapacity];
        int sizeMask = newCapacity - 1;
        for (int i = 0; i < oldCapacity ; i++) {
            // e是链表头
            HashEntry<K,V> e = oldTable[i];
            if (e != null) {
                HashEntry<K,V> next = e.next;
                int idx = e.hash & sizeMask;
                if (next == null)   //  Single node on list
                    newTable[idx] = e;
                else { // Reuse consecutive sequence at same slot
                    HashEntry<K,V> lastRun = e;
                    int lastIdx = idx;
                    // 找到最后一串移动后下标相同的起始节点
                    // 这些节点不用复制,直接赋值
                    for (HashEntry<K,V> last = next;
                         last != null;
                         last = last.next) {
                        int k = last.hash & sizeMask;
                        if (k != lastIdx) {
                            lastIdx = k;
                            lastRun = last;
                        }
                    }
                    // 将最后那串相同下标节点移到新的桶中
                    newTable[lastIdx] = lastRun;
                    // Clone remaining nodes
                    // 浅克隆最后那串节点之前的每个节点,并头插到新的桶中
                    for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
                        V v = p.value;
                        int h = p.hash;
                        int k = h & sizeMask;
                        HashEntry<K,V> n = newTable[k];
                        newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
                    }
                }
            }
        }
        // 插入要头插的传入的node节点
        int nodeIndex = node.hash & sizeMask; // add the new node
        node.setNext(newTable[nodeIndex]);
        newTable[nodeIndex] = node;
        // 最后将 newTable 赋值给table
        table = newTable;
    }
    

    看注释,整个过程其实还算比较简单,和 JDK7 的HashMap非常类似。但是要注意,由于HashMap不用考虑线程安全问题,所以HashMap直接用的是原来的节点,不用复制,而 CHM 要考虑到可能还有线程在读(读不用获得锁),不能直接用原来的节点,所以代码中浅克隆了原来的节点并利用头插法插入新的table中。同时,为了减少克隆的节点数,作者做了一点点优化,找出原始链表最后一串会 rehash 到相同坐标的起始节点,直接将这一串复制过去,减少了克隆节点的个数。

  • 如何安全的获取元素?

    CHM 的 get 方法

    public V get(Object key) {
        Segment<K,V> s; // manually integrate access methods to reduce overhead
        HashEntry<K,V>[] tab;
        int h = hash(key);
        // 获取 segment 在segments数组中的下标
        long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
        if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
            (tab = s.table) != null) {
            // 获取 key 在 HashEntry数组 table 的下标,获取其头节点,然后遍历整个链表 
            for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
                 (tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
                 e != null; e = e.next) {
                K k;
                if ((k = e.key) == key || (e.hash == h && key.equals(k)))
                    return e.value;
            }
        }
        return null;
    }
    

    注意这里直接在 CHM 类中获取了 value,不同于 put 还要去 Segment 中的put方法中去put,而 Segment 类中没有 get 方法。这里get方法思路很简单,就是找到对应的segment,然后去segment的数组中的对应下标的链表中查找,只是这里利用的是UNSAFE类直接去获取主内存中的值,而不是从线程的工作内存中获取。

  • 如何获取CHM 的 size?

    CHM 的 size 方法

    public int size() {
        // Try a few times to get accurate count. On failure due to
        // continuous async changes in table, resort to locking.
        final Segment<K,V>[] segments = this.segments;
        int size;
        boolean overflow; // true if size overflows 32 bits
        long sum;         // sum of modCounts
        // 用于记录上一次 modCount 的和
        long last = 0L;   // previous sum
        int retries = -1; // first iteration isn't retry
        try {
            for (;;) {
                if (retries++ == RETRIES_BEFORE_LOCK) {
                    for (int j = 0; j < segments.length; ++j)
                        ensureSegment(j).lock(); // force creation
                }
                sum = 0L;
                size = 0;
                overflow = false;
                for (int j = 0; j < segments.length; ++j) {
                    Segment<K,V> seg = segmentAt(segments, j);
                    if (seg != null) {
                        sum += seg.modCount;
                        int c = seg.count;
                        if (c < 0 || (size += c) < 0)
                            overflow = true;
                    }
                }
                // 这次和上次一样就退出
                if (sum == last)
                    break;
                last = sum;
            }
        } finally {
            if (retries > RETRIES_BEFORE_LOCK) {
                for (int j = 0; j < segments.length; ++j)
                    segmentAt(segments, j).unlock();
            }
        }
        return overflow ? Integer.MAX_VALUE : size;
    }
    

    其中 RETRIES_BEFORE_LOCK 定义:

    static final int RETRIES_BEFORE_LOCK = 2;
    

    整个代码的思路就是记录本次的modCount和,如果和上次的一样就代表得到了准确的size,如果重试两次不成功就直接给所有Segment加锁,然后计算和。同时,加锁的时候如果Segment还没有创建,就会在这里强制创建这个Segment。但是这里应该不是完全的准确的size,如果 last 和 sum 两次一样,得到了size,跳出循环,这时失去了时间片,在结果返回前别的线程又增加了元素,这样size就不对了。

  • CHM 判断是否元素个数为0?

    CHM 的 isEmpty 方法

    public boolean isEmpty() {
    
        long sum = 0L;
        final Segment<K,V>[] segments = this.segments;
        for (int j = 0; j < segments.length; ++j) {
            Segment<K,V> seg = segmentAt(segments, j);
            if (seg != null) {
                if (seg.count != 0)
                    return false;
                sum += seg.modCount;
            }
        }
        if (sum != 0L) { // recheck unless no modifications
            for (int j = 0; j < segments.length; ++j) {
                Segment<K,V> seg = segmentAt(segments, j);
                if (seg != null) {
                    if (seg.count != 0)
                        return false;
                    sum -= seg.modCount;
                }
            }
            if (sum != 0L)
                return false;
        }
        return true;
    }
    

    这个方法也比较简单,就是循环两次判断,如果任何一次 segment 中的值不为0,直接返回 false。在第一次循环中记录modCount和,第二次再依次减去,如果最后等于0代表没有更改过,就返回true。

    其他方法如 remove 和 replace 思路都差不多,就不写了。


版权声明:本文为qq_41512783原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。