直接上源代码吧,代码是阿波罗团队写的源代码,我这边给加了注释。想要看源代码请看:apollo/cyber/base at master · ApolloAuto/apollo · GitHubAn open autonomous driving platform. Contribute to ApolloAuto/apollo development by creating an account on GitHub.https://github.com/ApolloAuto/apollo/tree/master/cyber/base
/******************************************************************************
* Copyright 2018 The Apollo Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*****************************************************************************/
#ifndef CYBER_BASE_ATOMIC_HASH_MAP_H_
#define CYBER_BASE_ATOMIC_HASH_MAP_H_
#include <atomic>
#include <cstdint>
#include <type_traits>
#include <utility>
//整体数据结构采用固定数组+单链表方式实现。一个Bucket代表一个链表,有TableSize个这样的链表。为什么要这样设计?
//因为整个数据按照key值来算出数据存放数据位置,此哈希表是不允许key值一样的,但是即使这样两个key也可能存放相同的位置
//就比如哈希表是128大小,那么但是如果存在两个key分别是1和129那么他们两个将会在同一个位置存储。用链表存储他们。
namespace apollo {
namespace cyber {
namespace base {
/**
* @brief A implementation of lock-free fixed size hash map
*
* @tparam K Type of key, must be integral
* @tparam V Type of value
* @tparam 128 Size of hash table
* @tparam 0 Type traits, use for checking types of key & value
*/
//AtomicHashMap是一个模板类型,下面这个template就是他的规则,看到他有4个模板形参,
//第一个为key的泛型,第二个为value泛型,第三个为map的容量大小,默认为128可以不传递
//注意这个大小并不代表容器真正能够存储的key-value个数,只是会影响map的存取效率
//第四个写法可以总结为整个模板的“约束”,专业点讲就是模板偏特化,模板实例化的时候是不
//需要传这个参的,他的作用就是进行编译时期的匹配约束,符合条件的才能匹配这个类型,
//下面这个约束的大概意思是:TableSize必须是整数类型,并且他的值必须是2的N次方,
//否则编译时期匹配失败。
template <typename K, typename V, std::size_t TableSize = 128,
typename std::enable_if<std::is_integral<K>::value &&
(TableSize & (TableSize - 1)) == 0,
int>::type = 0>
class AtomicHashMap {
public:
//capacity_初始化为TableSize很容易理解,mode_num_初始化为capacity_ - 1一开始还是比较难理解的
//比如此值默认是127,二级制就是:1111111,此值为固定值,初始化之后不再更改,主要用于k存放位置的计算
AtomicHashMap() : capacity_(TableSize), mode_num_(capacity_ - 1) {}
//禁用拷贝构造和赋值构造函数,增强程序执行效率
AtomicHashMap(const AtomicHashMap &other) = delete;
AtomicHashMap &operator=(const AtomicHashMap &other) = delete;
bool Has(K key) {
//根据key值确定该值坐落的位置,和mode_num_进行与的操作是处理整数保证其在0 -(TableSize-1)之间
uint64_t index = key & mode_num_;
return table_[index].Has(key);
}
bool Get(K key, V **value) {
uint64_t index = key & mode_num_;
return table_[index].Get(key, value);
}
bool Get(K key, V *value) {
uint64_t index = key & mode_num_;
V *val = nullptr;
bool res = table_[index].Get(key, &val);
if (res) {
*value = *val;
}
return res;
}
void Set(K key) {
uint64_t index = key & mode_num_;
table_[index].Insert(key);
}
void Set(K key, const V &value) {
uint64_t index = key & mode_num_;
table_[index].Insert(key, value);
}
void Set(K key, V &&value) {
//std::cout << "&& value " << value << std::endl;
uint64_t index = key & mode_num_;
table_[index].Insert(key, std::forward<V>(value));
}
private:
//该类名为Entry,其实他表达的意思是一组key-value数据的节点,注意他是struct不是class,意为所有成员public
struct Entry {
Entry() {}
explicit Entry(K key) : key(key) {
value_ptr.store(new V(), std::memory_order_release);
}
Entry(K key, const V &value) : key(key) {
value_ptr.store(new V(value), std::memory_order_release);
}
Entry(K key, V &&value) : key(key) {
value_ptr.store(new V(std::forward<V>(value)), std::memory_order_release);
}
~Entry() { delete value_ptr.load(std::memory_order_acquire); }
K key = 0;
std::atomic<V *> value_ptr = {nullptr};
std::atomic<Entry *> next = {nullptr};
};
//由Entry多个节点组成的单链表,改链表里面的数据是有序排列的,按照key值由小变大
class Bucket {
public:
Bucket() : head_(new Entry()) {}
~Bucket() {
Entry *ite = head_;
while (ite) {
auto tmp = ite->next.load(std::memory_order_acquire);
delete ite;
ite = tmp;
}
}
bool Has(K key) {
Entry *m_target = head_->next.load(std::memory_order_acquire);
while (Entry *target = m_target) {
if (target->key < key) {
m_target = target->next.load(std::memory_order_acquire);
continue;
} else {
return target->key == key;
}
}
return false;
}
//这个Find函数是一个很重要的接口下面的Insert和Get接口都依赖此接口,乍一看Find接口和上面的Has接口不是重复了吗?
//其实不然,这个Find接口不仅会判断存不存在还会找这个key应该存放的位置(即使不存在),这个寻找位置算法其实就是key从小到大的排序
//参数1:需要的key。参数2:所在位置的上个节点。参数3:寻找到的节点,有可能并不是key一样的,target的key值大于参数里面的key,这样链表插入
//的时候下一个节点是target(有可能是空)。参数2和3其实是返回值类型
//为什么这么麻烦事儿又是前节点,又是后节点,主要是单链表插入操作就是这样,要用变量记录,否则链表就断了
bool Find(K key, Entry **prev_ptr, Entry **target_ptr) {
Entry *prev = head_;
Entry *m_target = head_->next.load(std::memory_order_acquire);
while (Entry *target = m_target) { //注意这里不仅仅是一个赋值操作,也是一个判空操作
if (target->key == key) { //这是找到了,确认存在该key值
*prev_ptr = prev;
*target_ptr = target;
return true;
} else if (target->key > key) { //这是没找到key,链表里面接下来的key都比寻找的key值大,没有找下去的意义
*prev_ptr = prev;
*target_ptr = target;
return false;
} else { //还没找到,换下个节点
prev = target;
m_target = target->next.load(std::memory_order_acquire);
}
}
*prev_ptr = prev;
*target_ptr = nullptr;
return false;
}
void Insert(K key, const V &value) {
Entry *prev = nullptr;
Entry *target = nullptr;
Entry *new_entry = nullptr;
V *new_value = nullptr;
//这个循环是有点难理解的,存在的意义是什么?从代码上来讲只要compare_exchange_strong返回false就会再次进入循环
//这个操作主要是保证多线程安全吧,下面两次compare_exchange_strong一般来讲都是返回true的,如果出现多线程问题可能返回false
while (true) {
if (Find(key, &prev, &target)) { //注意:这里无论返回false还是true,prev和target都是被成功赋值了的
// key exists, update value
if (!new_value) {
new_value = new V(value); //存在节点,只需更新value值就好,所以new一个value值
}
//实事求是讲还是没读懂下面两行代码的意义,先读一下,再确认一下,有意义吗?不懂,
//又仔细思考了一下,注意,比较的target->value_ptr的值,target是本接口的一个临时变量,不会出现多线程读写问题,但是
//target->value_ptr就不一样了,他是链表里面的一个值,是会出现数据竞争问题的。所以采用compare_exchange_strong,
//用来保证,在我改变数值的时候其他人没有同时改
auto old_val_ptr = target->value_ptr.load(std::memory_order_acquire);
if (target->value_ptr.compare_exchange_strong(
old_val_ptr, new_value, std::memory_order_acq_rel,
std::memory_order_relaxed)) {
delete old_val_ptr;
if (new_entry) {
delete new_entry;
new_entry = nullptr;
}
return;
}
continue;
} else {
if (!new_entry) {
new_entry = new Entry(key, value); //存在节点,所以要新建一个节点
}
new_entry->next.store(target, std::memory_order_release);
//这个compare_exchange_strong还是有意义的,他大概的作用就是多线程安全,改之前比较一下,是不是被人改过了,如果都被人改过了
//证明发生了多线程写时间,此次Find不作数,重新Find一次。
if (prev->next.compare_exchange_strong(target, new_entry,
std::memory_order_acq_rel,
std::memory_order_relaxed)) {
if (new_value) {
delete new_value;
new_value = nullptr;
}
return;
}
// another entry has been inserted, retry
}
}
}
void Insert(K key, V &&value) {
Entry *prev = nullptr;
Entry *target = nullptr;
Entry *new_entry = nullptr;
V *new_value = nullptr;
while (true) {
if (Find(key, &prev, &target)) {
// key exists, update value
if (!new_value) {
new_value = new V(std::forward<V>(value));
}
auto old_val_ptr = target->value_ptr.load(std::memory_order_acquire);
if (target->value_ptr.compare_exchange_strong(
old_val_ptr, new_value, std::memory_order_acq_rel,
std::memory_order_relaxed)) {
delete old_val_ptr;
if (new_entry) {
delete new_entry;
new_entry = nullptr;
}
return;
}
continue;
} else {
if (!new_entry) {
new_entry = new Entry(key, value);
}
new_entry->next.store(target, std::memory_order_release);
if (prev->next.compare_exchange_strong(target, new_entry,
std::memory_order_acq_rel,
std::memory_order_relaxed)) {
// Insert success
if (new_value) {
delete new_value;
new_value = nullptr;
}
return;
}
// another entry has been inserted, retry
}
}
}
void Insert(K key) {
Entry *prev = nullptr;
Entry *target = nullptr;
Entry *new_entry = nullptr;
V *new_value = nullptr;
while (true) {
if (Find(key, &prev, &target)) {
// key exists, update value
if (!new_value) {
new_value = new V();
}
auto old_val_ptr = target->value_ptr.load(std::memory_order_acquire);
if (target->value_ptr.compare_exchange_strong(
old_val_ptr, new_value, std::memory_order_acq_rel,
std::memory_order_relaxed)) {
delete old_val_ptr;
if (new_entry) {
delete new_entry;
new_entry = nullptr;
}
return;
}
continue;
} else {
if (!new_entry) {
new_entry = new Entry(key);
}
new_entry->next.store(target, std::memory_order_release);
if (prev->next.compare_exchange_strong(target, new_entry,
std::memory_order_acq_rel,
std::memory_order_relaxed)) {
// Insert success
if (new_value) {
delete new_value;
new_value = nullptr;
}
return;
}
// another entry has been inserted, retry
}
}
}
bool Get(K key, V **value) {
Entry *prev = nullptr;
Entry *target = nullptr;
if (Find(key, &prev, &target)) {
*value = target->value_ptr.load(std::memory_order_acquire);
return true;
}
return false;
}
Entry *head_;
};
private:
Bucket table_[TableSize];
uint64_t capacity_;
uint64_t mode_num_;
};
} // namespace base
} // namespace cyber
} // namespace apollo
#endif // CYBER_BASE_ATOMIC_HASH_MAP_H_
简单写了个main.cpp对一些功能进行测试。
#include <iostream>
#include <thread>
#include "atomic_hash_map.h"
using namespace std;
using namespace apollo::cyber::base;
AtomicHashMap<int, int64_t> map;
void fun() {
int64_t value;
map.Get(100, &value);
map.Set(100, value + 1);
}
int main() {
// 模板声明测试
AtomicHashMap<int, string> map1;
AtomicHashMap<char, string> map2;
// AtomicHashMap<string, string> map3; //编译不过,应为key不是整形
AtomicHashMap<char, string, 1024> map4;
// AtomicHashMap<char, string, 84> map5; //编译不过,第三个参数不是2的N次方
// 模板第三个参数测试
AtomicHashMap<int, string, 2> map6;
map6.Set(1, "11111");
map6.Set(2, "22222");
map6.Set(3, "33333");
map6.Set(4, "44444");
std::string v1;
map6.Get(1, &v1);
std::cout << v1 << std::endl;
std::string v4;
map6.Get(4, &v4);
std::cout << v4 << std::endl;
// 模板右值引用测试,std::move用途为将值改为右值,一般用于常规代码使用,std::forward用途为探索真实值为左值还是右值,一般用于模板类的编写
AtomicHashMap<int, string> map7;
map7.Set(9, std::string("aaaaaaa"));
map7.Set(10, std::move(std::string("bbbbbbb")));
std::string t1("cccccccc");
map7.Set(11, t1);
std::string t2("dddddddd");
map7.Set(11, std::move(t2));
// 多线程测试
// 这是个反面教材实例,无锁map是指他的Get,Set操作都是原子操作,而fun函数不是原子操作,显然得不到想要的结果,可以多线程同时Set或者Get,但是同一个线程
// 有个多个Get,Set操作,加上多线程就不行
map.Set(100, 0);
std::thread t[2];
for (size_t i = 0; i < 2; i++) {
t[i] = std::thread([&] {
for (size_t j = 0; j < 100; j++) {
fun();
}
});
}
t[0].join();
t[1].join();
int64_t value;
map.Get(100, &value);
std::cout << "value : " <<value << std::endl;
return 0;
}
版权声明:本文为liujiayu2原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。