apollo源码解读:无锁哈希表AtomicHashMap

直接上源代码吧,代码是阿波罗团队写的源代码,我这边给加了注释。想要看源代码请看:apollo/cyber/base at master · ApolloAuto/apollo · GitHubAn open autonomous driving platform. Contribute to ApolloAuto/apollo development by creating an account on GitHub.https://github.com/ApolloAuto/apollo/tree/master/cyber/base

/******************************************************************************
 * Copyright 2018 The Apollo Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *****************************************************************************/

#ifndef CYBER_BASE_ATOMIC_HASH_MAP_H_
#define CYBER_BASE_ATOMIC_HASH_MAP_H_

#include <atomic>
#include <cstdint>
#include <type_traits>
#include <utility>

//整体数据结构采用固定数组+单链表方式实现。一个Bucket代表一个链表,有TableSize个这样的链表。为什么要这样设计?
//因为整个数据按照key值来算出数据存放数据位置,此哈希表是不允许key值一样的,但是即使这样两个key也可能存放相同的位置
//就比如哈希表是128大小,那么但是如果存在两个key分别是1和129那么他们两个将会在同一个位置存储。用链表存储他们。

namespace apollo {
namespace cyber {
namespace base {
/**
 * @brief A implementation of lock-free fixed size hash map
 *
 * @tparam K Type of key, must be integral
 * @tparam V Type of value
 * @tparam 128 Size of hash table
 * @tparam 0 Type traits, use for checking types of key & value
 */
 //AtomicHashMap是一个模板类型,下面这个template就是他的规则,看到他有4个模板形参,
 //第一个为key的泛型,第二个为value泛型,第三个为map的容量大小,默认为128可以不传递
 //注意这个大小并不代表容器真正能够存储的key-value个数,只是会影响map的存取效率
 //第四个写法可以总结为整个模板的“约束”,专业点讲就是模板偏特化,模板实例化的时候是不
 //需要传这个参的,他的作用就是进行编译时期的匹配约束,符合条件的才能匹配这个类型,
 //下面这个约束的大概意思是:TableSize必须是整数类型,并且他的值必须是2的N次方,
 //否则编译时期匹配失败。
template <typename K, typename V, std::size_t TableSize = 128,
          typename std::enable_if<std::is_integral<K>::value &&
                                      (TableSize & (TableSize - 1)) == 0,
                                  int>::type = 0>
class AtomicHashMap {
 public:
  //capacity_初始化为TableSize很容易理解,mode_num_初始化为capacity_ - 1一开始还是比较难理解的
  //比如此值默认是127,二级制就是:1111111,此值为固定值,初始化之后不再更改,主要用于k存放位置的计算
  AtomicHashMap() : capacity_(TableSize), mode_num_(capacity_ - 1) {}
  //禁用拷贝构造和赋值构造函数,增强程序执行效率
  AtomicHashMap(const AtomicHashMap &other) = delete;
  AtomicHashMap &operator=(const AtomicHashMap &other) = delete;

  bool Has(K key) {
    //根据key值确定该值坐落的位置,和mode_num_进行与的操作是处理整数保证其在0 -(TableSize-1)之间
    uint64_t index = key & mode_num_; 
    return table_[index].Has(key);
  }

  bool Get(K key, V **value) {
    uint64_t index = key & mode_num_;
    return table_[index].Get(key, value);
  }

  bool Get(K key, V *value) {
    uint64_t index = key & mode_num_;
    V *val = nullptr;
    bool res = table_[index].Get(key, &val);
    if (res) {
      *value = *val;
    }
    return res;
  }

  void Set(K key) {
    uint64_t index = key & mode_num_;
    table_[index].Insert(key);
  }

  void Set(K key, const V &value) {
    uint64_t index = key & mode_num_;
    table_[index].Insert(key, value);
  }

  void Set(K key, V &&value) {
    //std::cout << "&& value " << value << std::endl;
    uint64_t index = key & mode_num_;
    table_[index].Insert(key, std::forward<V>(value));
  }

 private:
  //该类名为Entry,其实他表达的意思是一组key-value数据的节点,注意他是struct不是class,意为所有成员public
  struct Entry {
    Entry() {}
    explicit Entry(K key) : key(key) {
      value_ptr.store(new V(), std::memory_order_release);
    }
    Entry(K key, const V &value) : key(key) {
      value_ptr.store(new V(value), std::memory_order_release);
    }
    Entry(K key, V &&value) : key(key) {
      value_ptr.store(new V(std::forward<V>(value)), std::memory_order_release);
    }
    ~Entry() { delete value_ptr.load(std::memory_order_acquire); }

    K key = 0;
    std::atomic<V *> value_ptr = {nullptr};
    std::atomic<Entry *> next = {nullptr};
  };

  //由Entry多个节点组成的单链表,改链表里面的数据是有序排列的,按照key值由小变大
  class Bucket {
   public:
    Bucket() : head_(new Entry()) {}
    ~Bucket() {
      Entry *ite = head_;
      while (ite) {
        auto tmp = ite->next.load(std::memory_order_acquire);
        delete ite;
        ite = tmp;
      }
    }

    bool Has(K key) {
      Entry *m_target = head_->next.load(std::memory_order_acquire);
      while (Entry *target = m_target) {
        if (target->key < key) {
          m_target = target->next.load(std::memory_order_acquire);
          continue;
        } else {
          return target->key == key;
        }
      }
      return false;
    }
    
    //这个Find函数是一个很重要的接口下面的Insert和Get接口都依赖此接口,乍一看Find接口和上面的Has接口不是重复了吗?
    //其实不然,这个Find接口不仅会判断存不存在还会找这个key应该存放的位置(即使不存在),这个寻找位置算法其实就是key从小到大的排序
    //参数1:需要的key。参数2:所在位置的上个节点。参数3:寻找到的节点,有可能并不是key一样的,target的key值大于参数里面的key,这样链表插入
    //的时候下一个节点是target(有可能是空)。参数2和3其实是返回值类型
    //为什么这么麻烦事儿又是前节点,又是后节点,主要是单链表插入操作就是这样,要用变量记录,否则链表就断了
    bool Find(K key, Entry **prev_ptr, Entry **target_ptr) {
      Entry *prev = head_;
      Entry *m_target = head_->next.load(std::memory_order_acquire);
      while (Entry *target = m_target) { //注意这里不仅仅是一个赋值操作,也是一个判空操作
        if (target->key == key) { //这是找到了,确认存在该key值
          *prev_ptr = prev;
          *target_ptr = target;
          return true;
        } else if (target->key > key) { //这是没找到key,链表里面接下来的key都比寻找的key值大,没有找下去的意义
          *prev_ptr = prev;
          *target_ptr = target;
          return false; 
        } else { //还没找到,换下个节点
          prev = target;
          m_target = target->next.load(std::memory_order_acquire);
        }
      }
      *prev_ptr = prev;
      *target_ptr = nullptr;
      return false;
    }

    void Insert(K key, const V &value) {
      Entry *prev = nullptr;
      Entry *target = nullptr;
      Entry *new_entry = nullptr;
      V *new_value = nullptr;
      //这个循环是有点难理解的,存在的意义是什么?从代码上来讲只要compare_exchange_strong返回false就会再次进入循环
      //这个操作主要是保证多线程安全吧,下面两次compare_exchange_strong一般来讲都是返回true的,如果出现多线程问题可能返回false
      while (true) { 
        if (Find(key, &prev, &target)) { //注意:这里无论返回false还是true,prev和target都是被成功赋值了的
          // key exists, update value
          if (!new_value) {  
            new_value = new V(value); //存在节点,只需更新value值就好,所以new一个value值
          }
          //实事求是讲还是没读懂下面两行代码的意义,先读一下,再确认一下,有意义吗?不懂,
          //又仔细思考了一下,注意,比较的target->value_ptr的值,target是本接口的一个临时变量,不会出现多线程读写问题,但是
          //target->value_ptr就不一样了,他是链表里面的一个值,是会出现数据竞争问题的。所以采用compare_exchange_strong,
          //用来保证,在我改变数值的时候其他人没有同时改
          auto old_val_ptr = target->value_ptr.load(std::memory_order_acquire);
          if (target->value_ptr.compare_exchange_strong(
                  old_val_ptr, new_value, std::memory_order_acq_rel,
                  std::memory_order_relaxed)) {
            delete old_val_ptr;
            if (new_entry) {
              delete new_entry;
              new_entry = nullptr;
            }
            return;
          }
          continue;
        } else {
          if (!new_entry) {
            new_entry = new Entry(key, value); //存在节点,所以要新建一个节点
          }
          new_entry->next.store(target, std::memory_order_release);
          //这个compare_exchange_strong还是有意义的,他大概的作用就是多线程安全,改之前比较一下,是不是被人改过了,如果都被人改过了
          //证明发生了多线程写时间,此次Find不作数,重新Find一次。
          if (prev->next.compare_exchange_strong(target, new_entry,
                                                 std::memory_order_acq_rel,
                                                 std::memory_order_relaxed)) {
            if (new_value) {
              delete new_value;
              new_value = nullptr;
            }
            return;
          }
          // another entry has been inserted, retry
        }
      }
    }

    void Insert(K key, V &&value) {
      Entry *prev = nullptr;
      Entry *target = nullptr;
      Entry *new_entry = nullptr;
      V *new_value = nullptr;
      while (true) {
        if (Find(key, &prev, &target)) {
          // key exists, update value
          if (!new_value) {
            new_value = new V(std::forward<V>(value));
          }
          auto old_val_ptr = target->value_ptr.load(std::memory_order_acquire);
          if (target->value_ptr.compare_exchange_strong(
                  old_val_ptr, new_value, std::memory_order_acq_rel,
                  std::memory_order_relaxed)) {
            delete old_val_ptr;
            if (new_entry) {
              delete new_entry;
              new_entry = nullptr;
            }
            return;
          }
          continue;
        } else {
          if (!new_entry) {
            new_entry = new Entry(key, value);
          }
          new_entry->next.store(target, std::memory_order_release);
          if (prev->next.compare_exchange_strong(target, new_entry,
                                                 std::memory_order_acq_rel,
                                                 std::memory_order_relaxed)) {
            // Insert success
            if (new_value) {
              delete new_value;
              new_value = nullptr;
            }
            return;
          }
          // another entry has been inserted, retry
        }
      }
    }

    void Insert(K key) {
      Entry *prev = nullptr;
      Entry *target = nullptr;
      Entry *new_entry = nullptr;
      V *new_value = nullptr;
      while (true) {
        if (Find(key, &prev, &target)) {
          // key exists, update value
          if (!new_value) {
            new_value = new V();
          }
          auto old_val_ptr = target->value_ptr.load(std::memory_order_acquire);
          if (target->value_ptr.compare_exchange_strong(
                  old_val_ptr, new_value, std::memory_order_acq_rel,
                  std::memory_order_relaxed)) {
            delete old_val_ptr;
            if (new_entry) {
              delete new_entry;
              new_entry = nullptr;
            }
            return;
          }
          continue;
        } else {
          if (!new_entry) {
            new_entry = new Entry(key);
          }
          new_entry->next.store(target, std::memory_order_release);
          if (prev->next.compare_exchange_strong(target, new_entry,
                                                 std::memory_order_acq_rel,
                                                 std::memory_order_relaxed)) {
            // Insert success
            if (new_value) {
              delete new_value;
              new_value = nullptr;
            }
            return;
          }
          // another entry has been inserted, retry
        }
      }
    }

    bool Get(K key, V **value) {
      Entry *prev = nullptr;
      Entry *target = nullptr;
      if (Find(key, &prev, &target)) {
        *value = target->value_ptr.load(std::memory_order_acquire);
        return true;
      }
      return false;
    }

    Entry *head_;
  };

 private:
  Bucket table_[TableSize];
  uint64_t capacity_;
  uint64_t mode_num_;
};

}  // namespace base
}  // namespace cyber
}  // namespace apollo

#endif  // CYBER_BASE_ATOMIC_HASH_MAP_H_

简单写了个main.cpp对一些功能进行测试。

#include <iostream>
#include <thread>
#include "atomic_hash_map.h"

using namespace std;
using namespace apollo::cyber::base;

AtomicHashMap<int, int64_t> map;

void fun() {
    int64_t value;
    map.Get(100, &value);
    map.Set(100, value + 1);
}

int main() {
   // 模板声明测试
   AtomicHashMap<int, string> map1;
   AtomicHashMap<char, string> map2;
   // AtomicHashMap<string, string> map3; //编译不过,应为key不是整形
   AtomicHashMap<char, string, 1024> map4;
   // AtomicHashMap<char, string, 84> map5; //编译不过,第三个参数不是2的N次方

   // 模板第三个参数测试
   AtomicHashMap<int, string, 2> map6;
   map6.Set(1, "11111");
   map6.Set(2, "22222");
   map6.Set(3, "33333");
   map6.Set(4, "44444");
   std::string v1;
   map6.Get(1, &v1);
   std::cout << v1 << std::endl;
   std::string v4;
   map6.Get(4, &v4);
   std::cout << v4 << std::endl;

   // 模板右值引用测试,std::move用途为将值改为右值,一般用于常规代码使用,std::forward用途为探索真实值为左值还是右值,一般用于模板类的编写
   AtomicHashMap<int, string> map7;
   map7.Set(9, std::string("aaaaaaa"));
   map7.Set(10, std::move(std::string("bbbbbbb")));
   std::string t1("cccccccc");
   map7.Set(11, t1);
   std::string t2("dddddddd");
   map7.Set(11, std::move(t2));

   // 多线程测试
   // 这是个反面教材实例,无锁map是指他的Get,Set操作都是原子操作,而fun函数不是原子操作,显然得不到想要的结果,可以多线程同时Set或者Get,但是同一个线程
   // 有个多个Get,Set操作,加上多线程就不行
   map.Set(100, 0);
   std::thread t[2];
   for (size_t i = 0; i < 2; i++) {
       t[i] = std::thread([&] {
           for (size_t j = 0; j < 100; j++) {
               fun();
           }
       });
   }
   t[0].join();
   t[1].join();
   int64_t value;
   map.Get(100, &value);
   std::cout << "value : " <<value << std::endl;


    return 0;
}


版权声明:本文为liujiayu2原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。