mmclassification-模型微调(二)

有两个步骤可以微调新数据集上的模型。

1)添加新数据集的支持;

2)修改配置文件。

以对CIFAR10数据集的微调为例,用户需要修改配置中的五个部分。

1、继承基本配置

要在不同配置之间重用公共部分,我们支持从多个现有配置继承配置。要微调ResNet-50模型,需要继承新配置 _base_/models/resnet50.py以构建模型的基本结构。要使用CIFAR10数据集,新配置也可以简单地继承_base_/datasets/cifar10.py。对于运行时设置(例如训练计划),新配置需要继承_base_/default_runtime.py。

_base_ = [
    '../_base_/models/resnet50.py',
    '../_base_/datasets/cifar10.py', 
    '../_base_/default_runtime.py'
]

此外,用户还可以选择编写全部内容,而不是使用继承,例如:configs/mnist/lenet5.py。

2、修改头

然后,新配置需要根据新数据集的类编号修改头。通过仅改变num_classes头部,除最终预测头部外,大部分已重用了预训练模型的权重。

_base_ = ['./resnet50.py']
model = dict(
    pretrained=None,
    head=dict(
        type='LinearClsHead',
        num_classes=10,
        in_channels=2048,
        loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
    ))

3、修改数据集

用户可能还需要准备数据集并编写有关数据集的配置。我们目前支持MNIST,CIFAR和ImageNet数据集。为了在CIFAR10上进行微调,其原始输入大小为32,因此我们应将其大小调整为224,以适应ImageNet上预训练的模型的输入大小。

_base_ = ['./cifar10.py']
img_norm_cfg = dict(
     mean=[125.307, 122.961, 113.8575],
     std=[51.5865, 50.847, 51.255],
     to_rgb=True)
     
train_pipeline = [
    dict(type='RandomCrop', size=32, padding=4),
    dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
    dict(type='Resize', size=224) 
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='ToTensor', keys=['gt_label']),
    dict(type='Collect', keys=['img', 'gt_label'])
 ]
 
 test_pipeline = [
    dict(type='Resize', size=224)
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='ToTensor', keys=['gt_label']),
    dict(type='Collect', keys=['img', 'gt_label'])
 ]

4、修改训练计划

微调超参数与默认计划不同。它通常需要较小的学习率和较少的训练时间。

# optimizer
# lr is set for a batch size of 128
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)

# learning policy
lr_config = dict(
    policy='step',
    step=[15])
total_epochs = 20
log_config = dict(interval=100)

5、使用预训练模型

要使用预先训练的模型,新的配置会在中添加预先训练的模型的链接load_from。用户可能需要在训练之前下载模型权重,以避免训练期间的下载时间。

load_from = 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmclassification/models/tbd.pth' # noqa

传送门:mmclassification项目阅读系列文章目录

教程文档翻译:

mmclassification-安装使用(一)

mmclassification-模型微调(二)

mmclassification-添加新数据集(三)

mmclassification-自定义数据管道(四)

mmclassification-添加新模块(五)