有两个步骤可以微调新数据集上的模型。
1)添加新数据集的支持;
2)修改配置文件。
以对CIFAR10数据集的微调为例,用户需要修改配置中的五个部分。
1、继承基本配置
要在不同配置之间重用公共部分,我们支持从多个现有配置继承配置。要微调ResNet-50模型,需要继承新配置 _base_/models/resnet50.py以构建模型的基本结构。要使用CIFAR10数据集,新配置也可以简单地继承_base_/datasets/cifar10.py。对于运行时设置(例如训练计划),新配置需要继承_base_/default_runtime.py。
_base_ = [
'../_base_/models/resnet50.py',
'../_base_/datasets/cifar10.py',
'../_base_/default_runtime.py'
]此外,用户还可以选择编写全部内容,而不是使用继承,例如:configs/mnist/lenet5.py。
2、修改头
然后,新配置需要根据新数据集的类编号修改头。通过仅改变num_classes头部,除最终预测头部外,大部分已重用了预训练模型的权重。
_base_ = ['./resnet50.py']
model = dict(
pretrained=None,
head=dict(
type='LinearClsHead',
num_classes=10,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
))3、修改数据集
用户可能还需要准备数据集并编写有关数据集的配置。我们目前支持MNIST,CIFAR和ImageNet数据集。为了在CIFAR10上进行微调,其原始输入大小为32,因此我们应将其大小调整为224,以适应ImageNet上预训练的模型的输入大小。
_base_ = ['./cifar10.py']
img_norm_cfg = dict(
mean=[125.307, 122.961, 113.8575],
std=[51.5865, 50.847, 51.255],
to_rgb=True)
train_pipeline = [
dict(type='RandomCrop', size=32, padding=4),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Resize', size=224)
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='Resize', size=224)
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]4、修改训练计划
微调超参数与默认计划不同。它通常需要较小的学习率和较少的训练时间。
# optimizer
# lr is set for a batch size of 128
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(
policy='step',
step=[15])
total_epochs = 20
log_config = dict(interval=100)5、使用预训练模型
要使用预先训练的模型,新的配置会在中添加预先训练的模型的链接load_from。用户可能需要在训练之前下载模型权重,以避免训练期间的下载时间。
load_from = 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmclassification/models/tbd.pth' # noqa传送门:mmclassification项目阅读系列文章目录
教程文档翻译: