一、概要
mmclassification是一个分类框架,可以通过配置的方式快速的对一些idea进行验证,本文通过一个案例对mmclassification的使用过程进行一个简单说明。
二、编写配置
关于mmclassification的配置文件组成及其基本的编写流程请参考官方文档,里面有比较详实的介绍。这里以本人实际编写的配置为例,对几个特殊的点重点进行下说明,这几点在文档里面笔墨不是很多,希望能帮到路过的朋友。
_base_ = [
'../_base_/models/resnet18.py',
'../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs2048_AdamW.py',
'../_base_/default_runtime.py'
]
img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
# 修改输出分类数量
model = dict(head=dict(num_classes=2))
optimizer = dict(type='AdamW', lr=0.0015, weight_decay=0.001)
lr_config = dict(min_lr=0.0001, warmup_iters=5, warmup_ratio=0.2, by_epoch=True)
# dataset settings
dataset_type = 'demo_dataset'
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomFlip', flip_prob=0.7, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['img']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', size=(600, -1)),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
# samples_per_gpu = batch_size / 训练模型使用的GPU个数
samples_per_gpu=64,
# 每个卡上面有多少个线程负责加载数据
workers_per_gpu=4,
train=dict(
type=dataset_type,
"""
图片路径以data_prefix + filename的方式表示,如果data_prefix为空,
那么anno_file的第一列一定要是文件路径,否则将找不到图片
"""
data_prefix='',
ann_file='data/train.txt',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_prefix='',
ann_file='data/val.txt',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_prefix='',
ann_file='data/test.txt',
pipeline=test_pipeline))
evaluation = dict(
# 单个指标的值是字符串形式, 如'accuracy', 多个指标是列表形式,如下所示
metric=['accuracy','precision','recall'],
"""
如果不需要topk, 则topk一定要设为1, 否则代码内部会默认设为(1,5),
也可以设置个人希望设置的数值, 比如topk=(1,3)
"""
metric_options=dict(topk=1),
# 仅保存最优检查点
save_best='auto',
)
# 如果想使用预训练权重, load_from设置成预训练权重路径
load_from = '/mnt/home/pth/resnet18-5c106cde.pth'
# 日志配置中的interval项可以控制多少个batch输出一次日志信息
log_config = dict(interval=20, hooks=[dict(type='TextLoggerHook')])三、运行
bash tools/dist_train.sh 配置文件路径 卡的数量,例如:bash tools/dist_train.sh configs/tutorial/resnet50_finetune_cifar.py 8,建议修改dist_train.sh内部代码如下:nohup python3 -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} > log.txt 2>&1 &,这样不至于阻塞。
版权声明:本文为gaoxueyi551原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。