语义分割之Fastscnn原理以及代码讲解

Fast-SCNN: Fast Semantic Segmentation Network

论文地址: https://arxiv.org/pdf/1902.04502.pdf
github: https://github.com/Tramac/Fast-SCNN-pytorch
Fastscnn也是目前用的较多的语义分割模型,在具有较好的实时性(论文介绍能做到123.5帧,但是我自己亲测P100上实测大概是24ms左右(256*256的图像),大概40帧上下)的同时还具有很好地分割精度(在cityscapes上得到68.0%的miou,在实际应用层面精度还是可以的。),下面就是针对Fastscnn的原理以及核心代码的介绍:

还是先讲Fastscnn的思路:
Fast SCNN 受 two-branch 结构和 encoder-decoder 网络启发,用于高分辨率(1024×2048)图像上的实时语义分割任务,

Fastscnn网络结构图如图所示:
Fastscnn网络结构图
可以看出整个Fastscnn和之前的语义分割模型整体来说还是基于一个encoder-decoder结构,作者通过Learning to Down-sample,Global Feature Extractor进行特征提取,在Feature Fusion阶段上面通过一个二次线性插值+DWConv+Conv进行一个上采样,最后和Learning to Down-sample输出的结果直接相加后通过Conv再次进行特征融合,最后进行像素点分类。

下面逐个对Fastscnn每个组件进行介绍:

首先是Learning to Down-sample(学习下采样模块)
在这里插入图片描述

class LearningToDownsample(nn.Module):
    """Learning to downsample module"""

    def __init__(self, dw_channels1=32, dw_channels2=48, out_channels=64, **kwargs):
        super(LearningToDownsample, self).__init__()
        self.conv = _ConvBNReLU(3, dw_channels1, 3, 2)
        self.dsconv1 = _DSConv(dw_channels1, dw_channels2, 2)
        self.dsconv2 = _DSConv(dw_channels2, out_channels, 2)

    def forward(self, x):
        x = self.conv(x)
        x = self.dsconv1(x)
        x = self.dsconv2(x)
        return x

Learning to Down-sample

class _DSConv(nn.Module):
    """Depthwise Separable Convolutions"""

    def __init__(self, dw_channels, out_channels, stride=1, **kwargs):
        super(_DSConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(dw_channels, dw_channels, 3, stride, 1, groups=dw_channels, bias=False),
            nn.BatchNorm2d(dw_channels),
            nn.ReLU(True),
            nn.Conv2d(dw_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True)
        )

    def forward(self, x):
        return self.conv(x)

深度可分卷积核心代码核心代码如上面所示
深度可分离卷积介绍看 这里

结构和代码如上图所示,可以很容易的发现Learning to Down-sample其实也是一个特征提取网络,它没有像一般的encoder-decoder分割网络那样用一个resnet或者vgg分类网络直接完成encoder过程,而是参考了mobilenet,使用了两个深度可分卷积(ds)完成初步的特征提取,目的还是为了快。
本质上就是一个卷积层+两个深度可分卷积。

再来看Global Feature Extractor模块
在这里插入图片描述

这里同样也是一个特征提取网络,通过一系列卷积层提取特征,再后面跟一个psp_pooling,关于pspnet中提出的psp_pooling可以参考我之前写的那一篇关于pspnet的博文: pspnet
上图全局特征提取模块(Global Feature Extractor)分为两部分,绿色的block是参考MobileNetV2做的Bottleneck,粉色的就是PyramidPooling,代码也很简单直观:

class LinearBottleneck(nn.Module):
    """LinearBottleneck used in MobileNetV2"""

    def __init__(self, in_channels, out_channels, t=6, stride=2, **kwargs):
        super(LinearBottleneck, self).__init__()
        self.use_shortcut = stride == 1 and in_channels == out_channels
        self.block = nn.Sequential(
            # pw
            _ConvBNReLU(in_channels, in_channels * t, 1),
            # dw
            _DWConv(in_channels * t, in_channels * t, stride),
            # pw-linear
            nn.Conv2d(in_channels * t, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        out = self.block(x)
        if self.use_shortcut:
            out = x + out
        return out

Bottleneck模块,这里注意,先做了一个CBR(Conv-bn-relu)后然后过的
DepthWiseConv,最后再过一个conv层一个bn层。

class _DWConv(nn.Module):
    def __init__(self, dw_channels, out_channels, stride=1, **kwargs):
        super(_DWConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(dw_channels, out_channels, 3, stride, 1, groups=dw_channels, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True)
        )

    def forward(self, x):
        return self.conv(x)

DepthWiseConv模块如上:

最后为了提取上下文特征,bottleneck后跟了一个PyramidPooling块
参考了pspnet,代码

class PyramidPooling(nn.Module):
    """Pyramid pooling module"""

    def __init__(self, in_channels, out_channels, **kwargs):
        super(PyramidPooling, self).__init__()
        inter_channels = int(in_channels / 4)
        self.conv1 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
        self.conv2 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
        self.conv3 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
        self.conv4 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
        self.out = _ConvBNReLU(in_channels * 2, out_channels, 1)

    def pool(self, x, size):
        avgpool = nn.AdaptiveAvgPool2d(size)
        return avgpool(x)

    def upsample(self, x, size):
        return F.interpolate(x, size, mode='bilinear', align_corners=True)

    def forward(self, x):
        size = x.size()[2:]
        feat1 = self.upsample(self.conv1(self.pool(x, 1)), size)
        feat2 = self.upsample(self.conv2(self.pool(x, 2)), size)
        feat3 = self.upsample(self.conv3(self.pool(x, 3)), size)
        feat4 = self.upsample(self.conv4(self.pool(x, 6)), size)
        x = torch.cat([x, feat1, feat2, feat3, feat4], dim=1)
        x = self.out(x)
        return x

注意这里PyramidPooling上采样同样用的先卷积再二次插值做的,最后cat一起后进行特征融合(Feature Fusion)
在这里插入图片描述
黄色是一个上采样,灰色是一个dw卷积最后跟一个conv层,然后和Learning to Down-sample出来的特征图( two-branch中的另一个branch,第一个branch是刚才废了老大劲做的Global Feature Extractor)直接相加得到最后的特征图,这一步本质和resnet的跳连接是一样的,都是为了使encoder出来的特征不仅具有深层特征,还需要有浅层特征增加小目标的表达能力,按照作者的话就是在低分辨率输入上学习全局信息,通过一个较浅的 branch 在高分辨率图像上学习细节信息,最后就是一串卷积后上采样到原图大小对每个像素点做一个分类,完成语义分割。

在这里插入图片描述

    def forward(self, x):
        size = x.size()[2:]
        higher_res_features = self.learning_to_downsample(x)
        x = self.global_feature_extractor(higher_res_features)
        x = self.feature_fusion(higher_res_features, x)
        x = self.classifier(x) #分类块之后直接插值上采样
        outputs = []
        x = F.interpolate(x, size, mode='bilinear', align_corners=True)
        outputs.append(x)
        if self.aux:
            auxout = self.auxlayer(higher_res_features)
            auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
            outputs.append(auxout)
        return tuple(outputs)

分类块本质其实就是一个小分类网络:

class Classifer(nn.Module):
    """Classifer"""

    def __init__(self, dw_channels, num_classes, stride=1, **kwargs):
        super(Classifer, self).__init__()
        self.dsconv1 = _DSConv(dw_channels, dw_channels, stride)
        self.dsconv2 = _DSConv(dw_channels, dw_channels, stride)
        self.conv = nn.Sequential(
            nn.Dropout(0.1),
            nn.Conv2d(dw_channels, num_classes, 1)
        )

    def forward(self, x):
        x = self.dsconv1(x)
        x = self.dsconv2(x)
        x = self.conv(x)
        return x

完了最后用一个softmax就得到最后结果

总结:

作者提出了一种高效的语义分割网络,参考了MobileNet,pspnet等网络,应用了ds,dw卷积,PyramidPooling方法,同时没有使用unet那样的dense skip connections,而是只用了一次skip connections显著降低了内存显存占用,提高了实时性(其实还是看项目,unet如果backbone也换一个轻量的然后把channel降低一倍或者更多,模型大小一样可以做到10mb以内,一样也能做到50帧以上。就是说分割精度可能有影响,关键还是看需求。),同时保证了较高的精度。


版权声明:本文为xuzz_498100208原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。