JDK7的ConcurrentHashMap源码分析总结
@author:Jingdai
@date:2021.07.15
花了一天多时间把JDK7的ConcurrentHashMap源码研究了一下,现总结如下。由于水平有限,不免有错误,欢迎大佬指正。
整体思路
ConcurrentHashMap名字太长,后文用CHM代替。不同于HashTable用一把锁加锁所有的桶,CHM使用分段锁,每次加锁只锁整个Map的一部分,这样就大大提高了并发量,当不同的线程想同时修改CHM的不同部分时,不会阻塞。CHM中用Segment表示每个加锁单元,可以把这个Segment当成HashTable,即可以简单的理解为CHM等于多个HashTable,当然CHM做了很多优化。
在初始化CHM的时候,会指定Segment的数量,这个根据传入的 concurrencyLevel 计算得出(后面会讲),在算出这个Segment的数量后,Segment的数量就不会再发生变化了,随着元素的增多,需要rehash的时候也仅仅会在每个Segment中扩容,Segment并不会增多。
在找每个key对应的桶的时候,计算出hash后,利用hash的高位来计算这个key在哪一个Segment,再利用它的低位来计算它在Segment的哪一个桶中。
源码中利用了很多UNSAFE类的操作,这个我也不是很清楚,但是大概意思就是直接从主内存中去取最新的值,而不会从线程的工作线程中取,想要了解细节的可以看看别的博客。(狗头)CHM在读的时候不会加锁,直接利用UNSAFE去主内存中取值,而CHM在写的时候就是加锁去写。
主要类结构
CHM 内部有两个重要的类,Segment 和 HashEntry,下面分别来看一下。
ConcurrentHashMap 类
public class ConcurrentHashMap<K, V> extends AbstractMap<K, V>
implements ConcurrentMap<K, V>, Serializable {
// 默认初始容量是16, 和 HashMap 一样
static final int DEFAULT_INITIAL_CAPACITY = 16;
// 默认加载因子是 0.75, 和 HashMap 一样
static final float DEFAULT_LOAD_FACTOR = 0.75f;
// 默认并发级别时16,这个并发级别决定了Segment数组的长度
static final int DEFAULT_CONCURRENCY_LEVEL = 16;
// 每个Segment的table的最小容量
static final int MIN_SEGMENT_TABLE_CAPACITY = 2;
// 用于确定哪一个Segment的掩码值,hash的高几位用于选择Segment
final int segmentMask;
// 确定哪一个Segment的时候 hash偏移的位数
final int segmentShift;
// 最重要的属性,每个Segment可以看出一个Hash表,键值对是存在Segment中的
final Segment<K,V>[] segments;
// XXX
}
上面仅仅列出了看源码比较重要的属性,其他属性没有列出。
Segment 类
static final class Segment<K,V> extends ReentrantLock implements Serializable {
// 真正存储数据的数组,每个Segment独自有一个 table 存储键值对
transient volatile HashEntry<K,V>[] table;
// table 中元素的个数
transient int count;
// table 修改的次数
transient int modCount;
// 扩容阈值,它等于 capacity * loadFactor
// capacity 就是数组 table 的长度
transient int threshold;
// 负载因子,每个Segment都相同
final float loadFactor;
// xxx
}
从上面可以看出,Segment内部自己维护一个Hash表,它有自己的扩容阈值和负载因子,负载因子每个Segment 都相同,之后扩容的时候是每个Segment自己扩容,不会影响到 CHM 其他的 Segment。同时,可以看出Segment 是继承自 ReentrantLock接口的,所以相当于每个Segment自己有一把锁,想要对Segment进行修改的时候需要先得到这个锁。
HashEntry 类
static final class HashEntry<K,V> {
final int hash;
final K key;
volatile V value;
volatile HashEntry<K,V> next;
// xxx
}
这个类就没什么要说的了,就是真正的 Entry对象,和HashMap没什么区别。
细节
segmentShift 和 segmentMask 作用?
看构造函数。
@SuppressWarnings("unchecked") public ConcurrentHashMap(int initialCapacity, float loadFactor, int concurrencyLevel) { // xxx if (concurrencyLevel > MAX_SEGMENTS) concurrencyLevel = MAX_SEGMENTS; // Find power-of-two sizes best matching arguments int sshift = 0; int ssize = 1; while (ssize < concurrencyLevel) { ++sshift; ssize <<= 1; } this.segmentShift = 32 - sshift; this.segmentMask = ssize - 1; // xxx Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize]; // xxx }为了看得更清楚,删除其他代码,这个 concurrencyLevel 是用来计算 segment 数组的长度ssize,因为求下标用的是 & 操作,所以数组长度必须是2的次幂,所以不能直接用 concurrencyLevel 来当数组长度,得到ssize 是一个大于 concurrencyLevel 的最小的2的次幂。这个 segmentMask 就简单一点了,就是ssize - 1,这样方便后面求元素在哪个segment中,即 x % 2n = x & (2n-1) = x & segmentMask,这点和HashMap思路一致。
这个 segmentShift 就麻烦一点了,它等于 32 - sshift。假设现在算出来 ssize 是16,那么sshift 就是 4,那么segmentShift 就是28,segmentMask 就是就是 15,我们知道 int一共有32位,假设如下hash。
hash:10101101 10101101 10101101 10101101我们要取高 sshift (4)位的值的话就是正好将 hash 右移segmentShift (28) 位再和 segmentMask 按位&。即
高 sshift 位 % ssize = 高 sshift 位 & segmentMask = (hash >>> segmentShift) & segmentMask,后面会知道这个就是用来求一个 key 是属于哪一个segment的,即用 hash 的高 sshift 位来求 key 在哪一个segment中。
map 加入元素时放哪一个 segment?又在 segment 的 table 的哪一个index?
@SuppressWarnings("unchecked") public V put(K key, V value) { Segment<K,V> s; if (value == null) throw new NullPointerException(); int hash = hash(key); int j = (hash >>> segmentShift) & segmentMask; if ((s = (Segment<K,V>)UNSAFE.getObject // nonvolatile; recheck (segments, (j << SSHIFT) + SBASE)) == null) // in ensureSegment s = ensureSegment(j); return s.put(key, hash, value, false); }由前面讲的
int j = (hash >>> segmentShift) & segmentMask;就是取高 sshift 位 % ssize,即由高sshift 位取余来确定在哪一个segment中。其中UNSAFE.getObject(segments, (j << SSHIFT) + SBASE)是去取segments的第 j 个元素。上面又会去调用 segment 的put 方法,如下。
final V put(K key, int hash, V value, boolean onlyIfAbsent) { HashEntry<K,V> node = tryLock() ? null : scanAndLockForPut(key, hash, value); V oldValue; try { HashEntry<K,V>[] tab = table; int index = (tab.length - 1) & hash; // xxx } finally { unlock(); } return oldValue; }删去了其它代码,可以看出 index 是通过 hash 对 table的length取余得到,这里并没有移位,所以是用hash的低位得到 key 在 table中的下标。
即通过hash的高 sshift 位取余来确定在哪一个segment中,低位取余得到在segment 的 table的哪个下标中。
初始Segment中的table的大小?
看构造函数,去掉其它多余的代码。
@SuppressWarnings("unchecked") public ConcurrentHashMap(int initialCapacity, float loadFactor, int concurrencyLevel) { // xxx if (initialCapacity > MAXIMUM_CAPACITY) initialCapacity = MAXIMUM_CAPACITY; int c = initialCapacity / ssize; // 向上取整 if (c * ssize < initialCapacity) ++c; int cap = MIN_SEGMENT_TABLE_CAPACITY; while (cap < c) cap <<= 1; // create segments and segments[0] Segment<K,V> s0 = new Segment<K,V>(loadFactor, (int)(cap * loadFactor), (HashEntry<K,V>[])new HashEntry[cap]); // xxx }前面已经分析过了,ssize就是segments数组的长度,就是有几个segment。而cap是初始的Segment中table的大小,现在就是求cap怎么得到。首先用 initialCapacity / ssize 得到每个segment中分到的元素个数 c,但是Java中整数除法都是截断后面的小数,所以如果截断了需要向上取整。然后找出大于 c 的最小的2的次幂就得到cap,即 table的大小。
构造函数中初始化第0个segment的原因?
把这个segment作为prototype,以后再创建segment时按这个segment的属性进行创建。代码如下。
构造函数
@SuppressWarnings("unchecked") public ConcurrentHashMap(int initialCapacity, float loadFactor, int concurrencyLevel) { // xxx // create segments and segments[0] // 仅仅初始化了segments[0] Segment<K,V> s0 = new Segment<K,V>(loadFactor, (int)(cap * loadFactor), (HashEntry<K,V>[])new HashEntry[cap]); Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize]; UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0] this.segments = ss; }ensureSegment函数
@SuppressWarnings("unchecked") private Segment<K,V> ensureSegment(int k) { final Segment<K,V>[] ss = this.segments; long u = (k << SSHIFT) + SBASE; // raw offset Segment<K,V> seg; if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) { // 将 segment 0 作为模板来创建其他的 segment Segment<K,V> proto = ss[0]; // use segment 0 as prototype int cap = proto.table.length; float lf = proto.loadFactor; int threshold = (int)(cap * lf); HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap]; if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) { // recheck Segment<K,V> s = new Segment<K,V>(lf, threshold, tab); while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) { if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s)) break; } } } return seg; }如何线程安全的创建segment?
源码中使用CAS来保证安全的创建segment。使用的是 ensureSegment 函数。
ensureSegment 函数
@SuppressWarnings("unchecked") private Segment<K,V> ensureSegment(int k) { final Segment<K,V>[] ss = this.segments; long u = (k << SSHIFT) + SBASE; // raw offset Segment<K,V> seg; if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) { Segment<K,V> proto = ss[0]; // use segment 0 as prototype int cap = proto.table.length; float lf = proto.loadFactor; int threshold = (int)(cap * lf); HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap]; if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) { // recheck Segment<K,V> s = new Segment<K,V>(lf, threshold, tab); // 不断重试 while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) { // 利用CAS创建 if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s)) break; } } } return seg; }如上面代码,首先判断 k 位置的 segment 是否还是null,如果还是 null就根据第0个segment创建一个新的HashEntry数组,否则直接返回(说明其他线程已经创建了segment)。接着在此判断是否为 null,如果还是的话才真的创建一个segment,并利用自旋加CAS重试的方式设置这个 segment。
如何线程安全的加入元素?
源码通过加锁的方式线程安全的put元素。前面讲了通过CAS安全的创建segment,而 map 的 put 操作在得到对应的 segment 后,会调用 segment 的 put 方法去加入元素。加入时使用 ReentrantLock 加锁的方式保证线程安全。前面已经说了 Segment 继承了 ReentrantLock ,所以可以直接加锁。代码如下。
final V put(K key, int hash, V value, boolean onlyIfAbsent) { // 尝试加锁,加锁成功直接 node 为 null // 否则调用 scanAndLockForPut 方法(后面会讲) // 这里知道要不加锁成功,要不调用 scanAndLockForPut // 方法使得加锁成功,所以执行完第一句一定会加锁成功 HashEntry<K,V> node = tryLock() ? null : scanAndLockForPut(key, hash, value); V oldValue; try { HashEntry<K,V>[] tab = table; // 求下标 int index = (tab.length - 1) & hash; HashEntry<K,V> first = entryAt(tab, index); for (HashEntry<K,V> e = first;;) { if (e != null) { K k; // 判断是否已经有了这个 key if ((k = e.key) == key || (e.hash == hash && key.equals(k))) { oldValue = e.value; if (!onlyIfAbsent) { e.value = value; ++modCount; } break; } e = e.next; } else { // 没有这个key,进行添加操作(头插法) if (node != null) node.setNext(first); else node = new HashEntry<K,V>(hash, key, value, first); int c = count + 1; if (c > threshold && tab.length < MAXIMUM_CAPACITY) rehash(node); else setEntryAt(tab, index, node); ++modCount; count = c; oldValue = null; break; } } } finally { unlock(); } return oldValue; }除了最开始的加锁操作,后面的基本和 JDK7 的 HashMap的put操作思路一致,都是头插法。下面看这个加锁的函数。
Segment 的 scanAndLockForPut 方法
private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) { HashEntry<K,V> first = entryForHash(this, hash); HashEntry<K,V> e = first; HashEntry<K,V> node = null; int retries = -1; // negative while locating node while (!tryLock()) { HashEntry<K,V> f; // to recheck first below // 首先retries = -1,必须要等遍历一遍这个链表,发现没有 // 这个key,则创建这个node,或者遍历过程中发现有这个key,不需 // 要创建这个node时,才将retries改为0,进到其他分支 if (retries < 0) { if (e == null) { if (node == null) // speculatively create node node = new HashEntry<K,V>(hash, key, value, null); retries = 0; } else if (key.equals(e.key)) retries = 0; else e = e.next; } // 限制重试次数,当重试次数超了时阻塞获得锁 else if (++retries > MAX_SCAN_RETRIES) { lock(); break; } // 偶数时判断头结点是否变化,变化时重新赋值 else if ((retries & 1) == 0 && (f = entryForHash(this, hash)) != first) { e = first = f; // re-traverse if entry changed retries = -1; } } return node; }其中 MAX_SCAN_RETRIES 值为。
static final int MAX_SCAN_RETRIES = Runtime.getRuntime().availableProcessors() > 1 ? 64 : 1;根据注释,代码的目的就是为了获得锁,首先在尝试获得锁时,先判断这个链上有没有这个key,判断完之后再进行重试次数的累加,累加次数到了指定值后就会阻塞获得锁,同时在过程中如果头节点发生了变化会重新进行赋值,然后重试。在判断这个链上有没有这个key的同时,如果发现这个链上没有这个 key,会直接 new 出这个 node 供调用的 put 方法使用,相当于在尝试获取锁的时候做了一些其它事情,没有白白浪费CPU。
如何安全的扩容?
首先注意扩容不是 CHM 的 Segment[] 数组的扩容,Segment[] 数组的大小在初始化之后就不会再发生变化了,扩容是 Segment 内部的 HashEntry<K,V>[] table 的扩容。扩容的阈值和加载因子都和HashMap一致。由于 rehash 方法仅仅在 put 方法内部被调用,所以 rehash 执行的时候一定拥有 Segment 的锁。具体代码如下。
Segment的 rehash 方法
@SuppressWarnings("unchecked") private void rehash(HashEntry<K,V> node) { HashEntry<K,V>[] oldTable = table; int oldCapacity = oldTable.length; int newCapacity = oldCapacity << 1; threshold = (int)(newCapacity * loadFactor); HashEntry<K,V>[] newTable = (HashEntry<K,V>[]) new HashEntry[newCapacity]; int sizeMask = newCapacity - 1; for (int i = 0; i < oldCapacity ; i++) { // e是链表头 HashEntry<K,V> e = oldTable[i]; if (e != null) { HashEntry<K,V> next = e.next; int idx = e.hash & sizeMask; if (next == null) // Single node on list newTable[idx] = e; else { // Reuse consecutive sequence at same slot HashEntry<K,V> lastRun = e; int lastIdx = idx; // 找到最后一串移动后下标相同的起始节点 // 这些节点不用复制,直接赋值 for (HashEntry<K,V> last = next; last != null; last = last.next) { int k = last.hash & sizeMask; if (k != lastIdx) { lastIdx = k; lastRun = last; } } // 将最后那串相同下标节点移到新的桶中 newTable[lastIdx] = lastRun; // Clone remaining nodes // 浅克隆最后那串节点之前的每个节点,并头插到新的桶中 for (HashEntry<K,V> p = e; p != lastRun; p = p.next) { V v = p.value; int h = p.hash; int k = h & sizeMask; HashEntry<K,V> n = newTable[k]; newTable[k] = new HashEntry<K,V>(h, p.key, v, n); } } } } // 插入要头插的传入的node节点 int nodeIndex = node.hash & sizeMask; // add the new node node.setNext(newTable[nodeIndex]); newTable[nodeIndex] = node; // 最后将 newTable 赋值给table table = newTable; }看注释,整个过程其实还算比较简单,和 JDK7 的HashMap非常类似。但是要注意,由于HashMap不用考虑线程安全问题,所以HashMap直接用的是原来的节点,不用复制,而 CHM 要考虑到可能还有线程在读(读不用获得锁),不能直接用原来的节点,所以代码中浅克隆了原来的节点并利用头插法插入新的table中。同时,为了减少克隆的节点数,作者做了一点点优化,找出原始链表最后一串会 rehash 到相同坐标的起始节点,直接将这一串复制过去,减少了克隆节点的个数。
如何安全的获取元素?
CHM 的 get 方法
public V get(Object key) { Segment<K,V> s; // manually integrate access methods to reduce overhead HashEntry<K,V>[] tab; int h = hash(key); // 获取 segment 在segments数组中的下标 long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE; if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null && (tab = s.table) != null) { // 获取 key 在 HashEntry数组 table 的下标,获取其头节点,然后遍历整个链表 for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile (tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE); e != null; e = e.next) { K k; if ((k = e.key) == key || (e.hash == h && key.equals(k))) return e.value; } } return null; }注意这里直接在 CHM 类中获取了 value,不同于 put 还要去 Segment 中的put方法中去put,而 Segment 类中没有 get 方法。这里get方法思路很简单,就是找到对应的segment,然后去segment的数组中的对应下标的链表中查找,只是这里利用的是UNSAFE类直接去获取主内存中的值,而不是从线程的工作内存中获取。
如何获取CHM 的 size?
CHM 的 size 方法
public int size() { // Try a few times to get accurate count. On failure due to // continuous async changes in table, resort to locking. final Segment<K,V>[] segments = this.segments; int size; boolean overflow; // true if size overflows 32 bits long sum; // sum of modCounts // 用于记录上一次 modCount 的和 long last = 0L; // previous sum int retries = -1; // first iteration isn't retry try { for (;;) { if (retries++ == RETRIES_BEFORE_LOCK) { for (int j = 0; j < segments.length; ++j) ensureSegment(j).lock(); // force creation } sum = 0L; size = 0; overflow = false; for (int j = 0; j < segments.length; ++j) { Segment<K,V> seg = segmentAt(segments, j); if (seg != null) { sum += seg.modCount; int c = seg.count; if (c < 0 || (size += c) < 0) overflow = true; } } // 这次和上次一样就退出 if (sum == last) break; last = sum; } } finally { if (retries > RETRIES_BEFORE_LOCK) { for (int j = 0; j < segments.length; ++j) segmentAt(segments, j).unlock(); } } return overflow ? Integer.MAX_VALUE : size; }其中 RETRIES_BEFORE_LOCK 定义:
static final int RETRIES_BEFORE_LOCK = 2;整个代码的思路就是记录本次的modCount和,如果和上次的一样就代表得到了准确的size,如果重试两次不成功就直接给所有Segment加锁,然后计算和。同时,加锁的时候如果Segment还没有创建,就会在这里强制创建这个Segment。但是这里应该不是完全的准确的size,如果 last 和 sum 两次一样,得到了size,跳出循环,这时失去了时间片,在结果返回前别的线程又增加了元素,这样size就不对了。
CHM 判断是否元素个数为0?
CHM 的 isEmpty 方法
public boolean isEmpty() { long sum = 0L; final Segment<K,V>[] segments = this.segments; for (int j = 0; j < segments.length; ++j) { Segment<K,V> seg = segmentAt(segments, j); if (seg != null) { if (seg.count != 0) return false; sum += seg.modCount; } } if (sum != 0L) { // recheck unless no modifications for (int j = 0; j < segments.length; ++j) { Segment<K,V> seg = segmentAt(segments, j); if (seg != null) { if (seg.count != 0) return false; sum -= seg.modCount; } } if (sum != 0L) return false; } return true; }这个方法也比较简单,就是循环两次判断,如果任何一次 segment 中的值不为0,直接返回 false。在第一次循环中记录modCount和,第二次再依次减去,如果最后等于0代表没有更改过,就返回true。
其他方法如 remove 和 replace 思路都差不多,就不写了。