1. nn.DataParallel - 在模块级别实现数据并行
if len(args.gpu.split(','))>1:
model_ft = nn.DataParallel(model_ft,device_ids=[int(g) for g in args.gpu.split(',')])
model_ft = model_ft.to(device)
else:
model_ft = model_ft.to(device)
2. DistributedDataParallel - 基于 torch.distributed 包的数据并行
(试一下三年多不更的更新能否复原原力值, 未完 …请忽略)
版权声明:本文为wangweiwells原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。