对比学习系列(三)-----SimCLR

SimCLR

SimCLR通过隐藏空间的对比损失最大化相同数据在不同增广下的一致性来学习表达。SimCLR框架有四个主要的组件,分别是:数据增广,encode网络,projection head网络和对比学习函数。

在这里插入图片描述
对于数据x xx,从同一个数据增广族中抽取两个独立的数据增广算子(t ∼ T t \sim TtTt ′ ∼ T {t}' \sim TtT),以获得两个相关的视图x ^ i \hat{x}_{i}x^ix ^ j \hat{x}_{j}x^jx ^ i \hat{x}_{i}x^ix ^ j \hat{x}_{j}x^j是一对正样本,然后一个神经网络编码器f ( ⋅ ) f\left( \cdot \right)f()从增广的数据中提取特征h i = f ( x ^ i ) , h j = f ( x ^ j ) , h_{i}=f\left( \hat{x}_{i} \right), h_{j}=f\left( \hat{x}_{j} \right),hi=f(x^i),hj=f(x^j),。再然后一个小的神经网络project head g ( ⋅ ) g\left( \cdot \right)g()将特征映射到对比损失的空间。project head采用带有一个隐含层的MLP获取z i = g ( h i ) = W ( 2 ) σ ( W ( 1 ) h i ) z_{i} = g\left( h_{i} \right) = W^{\left( 2 \right)} \sigma \left( W^{\left( 1 \right)} h_{i}\right)zi=g(hi)=W(2)σ(W(1)hi)

对于包含一对正样本x ^ i \hat{x}_{i}x^ix ^ j \hat{x}_{j}x^j的集合{ x ^ k } \{ \hat{x}_{k} \}{x^k},对比预测任务目的是对于给定的x ^ i \hat{x}_{i}x^i{ x ^ } k ≠ i \{ \hat{x} \}_{k \neq i}{x^}k=i中识别出x ^ j \hat{x}_{j}x^j。随机挑选N NN个样本组成一个minibatch,这个minibatch中则有2 N 2N2N个数据样本,将其他2 ( N − 1 ) 2\left( N - 1\right)2(N1)个扩增的样本作为这个minibatch中的负样本,设s i m ( u , v ) = u T v / ∥ u ∥ ∥ v ∥ sim\left( u, v\right) = u^{T}v / \| u\| \| v\|sim(u,v)=uTv/∥u∥∥v表示l 2 l_{2}l2正则化后你的u uuv vv的点积,那么对一对正样本( i , j ) \left( i, j \right)(i,j),损失函数如下定义:

l i , j = − l o g e x p ( s i m ( z i , z j ) / τ ) ∑ k = 1 2 N 1 [ k ≠ i ] e x p ( s i m ( z i , z k ) / τ ) l_{i,j} = - log \frac{exp\left( sim \left( z_{i}, z_{j}\right) / \tau \right)}{\sum_{k=1}^{2N} \mathbb{1}_{[ k \neq i]} exp\left( sim \left( z_{i}, z_{k}\right) / \tau \right)}li,j=logk=12N1[k=i]exp(sim(zi,zk)/τ)exp(sim(zi,zj)/τ)

最后的损失函数计算一个minibatch中的所有的正样本对,包括( i , j ) \left( i, j \right)(i,j)( j , i ) \left( j,i \right)(j,i)。下面是SimCLR的伪代码。从伪代码中可以看出,编码器f ( ⋅ ) f\left( \cdot \right)f()和project head g ( ⋅ ) g\left( \cdot \right)g() 在训练时都会被更新参数,但是只有编码器f ( ⋅ ) f\left( \cdot \right)f()用于下游任务。
在这里插入图片描述
simCLR不采用memory bank的形式进行训练,而是加大batchsize,bacth size为8192,对于每一个正样本,将会有16382张负样本实例。增大batch size其实相当于每个minibatch时动态生成一个memory bank。论文中发现使用标准的SGD/Momentum,大batch size训练时是不稳定的,论文中采用LARS优化器。

参考

  1. The Illustrated SimCLR Framework
  2. SimCLR

版权声明:本文为weixin_42111770原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。