mmclassification的使用问题

一、概要

    mmclassification是一个分类框架,可以通过配置的方式快速的对一些idea进行验证,本文通过一个案例对mmclassification的使用过程进行一个简单说明。

二、编写配置

    关于mmclassification的配置文件组成及其基本的编写流程请参考官方文档,里面有比较详实的介绍。这里以本人实际编写的配置为例,对几个特殊的点重点进行下说明,这几点在文档里面笔墨不是很多,希望能帮到路过的朋友。

_base_ = [
    '../_base_/models/resnet18.py',
    '../_base_/datasets/imagenet_bs32.py',
    '../_base_/schedules/imagenet_bs2048_AdamW.py',
    '../_base_/default_runtime.py'
]

img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

# 修改输出分类数量
model = dict(head=dict(num_classes=2))
optimizer = dict(type='AdamW', lr=0.0015, weight_decay=0.001)
lr_config = dict(min_lr=0.0001, warmup_iters=5, warmup_ratio=0.2, by_epoch=True)

# dataset settings
dataset_type = 'demo_dataset'
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='RandomFlip', flip_prob=0.7, direction='horizontal'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='ToTensor', keys=['img']),
    dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', size=(600, -1)),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='Collect', keys=['img'])
]

data = dict(
    # samples_per_gpu = batch_size / 训练模型使用的GPU个数
    samples_per_gpu=64,
    # 每个卡上面有多少个线程负责加载数据
    workers_per_gpu=4,
    train=dict(
        type=dataset_type,
        """
        图片路径以data_prefix + filename的方式表示,如果data_prefix为空,
        那么anno_file的第一列一定要是文件路径,否则将找不到图片
        """
        data_prefix='',
        ann_file='data/train.txt',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        data_prefix='',
        ann_file='data/val.txt',
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        data_prefix='',
        ann_file='data/test.txt',
        pipeline=test_pipeline))

evaluation = dict(
    # 单个指标的值是字符串形式, 如'accuracy', 多个指标是列表形式,如下所示
    metric=['accuracy','precision','recall'],
    """ 
    如果不需要topk, 则topk一定要设为1, 否则代码内部会默认设为(1,5), 
    也可以设置个人希望设置的数值, 比如topk=(1,3)
    """
    metric_options=dict(topk=1),
    # 仅保存最优检查点
    save_best='auto',
    )

# 如果想使用预训练权重, load_from设置成预训练权重路径
load_from = '/mnt/home/pth/resnet18-5c106cde.pth'

# 日志配置中的interval项可以控制多少个batch输出一次日志信息
log_config = dict(interval=20, hooks=[dict(type='TextLoggerHook')])

三、运行

    bash tools/dist_train.sh 配置文件路径 卡的数量,例如:bash tools/dist_train.sh configs/tutorial/resnet50_finetune_cifar.py 8,建议修改dist_train.sh内部代码如下:nohup python3 -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} > log.txt 2>&1 &,这样不至于阻塞。


版权声明:本文为gaoxueyi551原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。