TTA(Test-Time Augmentation) 之Pytorch

TTA(Test-Time Augmentation) ,即测试时的数据增强

实现步骤如下:

  1. 将1个batch的数据通过flips, rotation, scale, etc.等操作生成batches
  2. 将各个batch分别输入网络
  3. 每个batch的masks/labels反向转换
  4. 通过mean, max, gmean, etc.合并各个batch预测的结果
  5. 最后输出最终的masks/labels
                   Input
             |           # input batch of images 
        / / /|\ \ \      # apply augmentations (flips, rotation, scale, etc.)
       | | | | | | |     # pass augmented batches through model
       | | | | | | |     # reverse transformations for each batch of masks/labels
        \ \ \ / / /      # merge predictions (mean, max, gmean, etc.)
             |           # output batch of masks/labels
           Output
      

安装

        $ pip install ttach
      

使用方法如下

        import ttach as tta
...
model.load_state_dict(torch.load('models/%s/model.pth' %args.name))

model.eval()
tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode='mean')
...
      

更多有关使用方法,可以看下面的参考链接

reference

https://github.com/qubvel/ttach

 


版权声明:本文为weixin_40519315原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。