【one way的pytorch学习笔记】(七) Multi GPU 训练(未完..请忽略)


1. nn.DataParallel - 在模块级别实现数据并行

if len(args.gpu.split(','))>1:
    model_ft = nn.DataParallel(model_ft,device_ids=[int(g) for g in args.gpu.split(',')])
    model_ft = model_ft.to(device)
else:
    model_ft = model_ft.to(device)

2. DistributedDataParallel - 基于 torch.distributed 包的数据并行


(试一下三年多不更的更新能否复原原力值, 未完 …请忽略)


版权声明:本文为wangweiwells原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。