U2NET目标显著性检测,抠图去背景效果倍儿棒

点击上方“AI搞事情”关注我们


论文:U^2-Net: Going Deeper with Nested U-Structure for Salient Object Detection
GIT:https://github.com/NathanUA/U-2-Net

U2Net用于显著目标检测(Salient Object Detection, SOD) ,目的是分割出图像中最具吸引力的目标。不同于图像识别,SOD更注重局部细节信息和全局对比信息,而不是深层语义信息,因此,主要的研究方向在于多层次与多尺度特征提取上。

网络结构

U2Net网络结构如下图,整体是一个编码-解码(Encoder-Decoder)结构的U-Net,其中,每个stage由新提出的RSU模块(residual U-block) 组成,即一个两层嵌套的U结构网络。

「其优势在于:」
1.RSU模块,融合了不同尺度感受野的特征,能够捕获更多不同尺度的上下文信息(contextual information)。
2.RSU模块的池化(pooling)操作,可以在不显著增加计算成本的情况下,加深网络结构的深度。

RSU,ReSidual Ublock, 用于捕获intra-stage的多尺度特征. 其结构如图(e)所示:

(a)-(c)显示了具有最小感受野的现有卷积块,但是1x1或者3x3的卷积核的感受野太小而无法捕捉全局信息,(d)通过利用空洞卷积增大感受野来获取全局信息,然而在前期大分辨率的输入特征图计算需要耗费大量的计算和内存资源。

残差模块与RSU模块的对比:主要设计区别在于,RSU用U-Net代替了普通的单流卷积,并用一个权重层构成的局部特征代替了原始特征:

损失函数

U2Net训练损失函数定义:

其中,M=6, 为U2Net 的 Sup1, Sup2, ..., Sup6 stage,为对应输出的显著图(saliency map)的损失函数;为最终融合输出的显著图的损失函数,为每个损失函数的权重。

对于每一项,使用标准二进制交叉熵来计算损失:

其中,(r,c)为像素坐标;(H, W) 为图像尺寸,height 和 width。分别表示 GT 像素值和预测的显著概率图(saliency probability map)。

结果可视化


搞事情

作者开源了代码,最近还公开了一些有趣的基于U2Net的应用,比如人像转素描,抠图、背景去除等。

我们可以根据说明进行一把尝试:

图像转素描

  • 下载源码
    git clone https://github.com/NathanUA/U-2-Net.git

  • 下载转素描模型u2net_portrait.pth放入到./saved_models/u2net_portrait/下面。

  • 执行脚本python u2net_portrait_test.py程序会读取U-2-Net/test_data/test_portrait_images/portrait_im路径下的照片进行转换,并把结果输出在U-2-Net/test_data/test_portrait_images/portrait_results路径下。
    若在CPU环境运行会提示torch.load使用参数map_location='cpu'即:net.load_state_dict(torch.load(model_dir, map_location='cpu'))

项目也提供了任意人脸图像转换的demo,区别在于增加了opencv的人脸检测,以及裁剪到输入的512x512大小,可以通过python u2net_portrait_demo.py执行,
图片放入路径./test_data/test_portrait_images/your_portrait_im/

结果在路径:./test_data/test_portrait_images/your_portrait_results/

抠图

通过U2Net,可以得到精细的前景alpha图像,通过简单的mask操作就可以将前景目标扣取出来。

# encoding=utf-8
import os
import cv2
import numpy as np
im1_path = '1/test.png'  # 原图
im2_path = '2/test_alpha.png'  # alpha图
img1 = cv2.imread(im1_path)
img2 = cv2.imread(im2_path, cv2.IMREAD_GRAYSCALE)
h, w, c = img1.shape
img3 = np.zeros((h, w, 4))
img3[:, :, 0:3] = img1
img3[:, :, 3] = img2
cv2.imwrite('res.png', img3)

有大佬将其做成了一个工具:www.remove.bg(50次免费试用),以及还有一个python库

参考:

1. Github 项目 - U2Net 网络及实现

2. U2Net论文解读及代码测试

原文链接下载模型:0rl5

往期推荐

长按二维码关注我们

有趣的灵魂在等你