首先准备好模型和数据,假设有四张显卡的话,编号:0,1,2,3
首先加载数据,一般加载到第一张显卡上:
if torch.cuda.is_available():
# model.cuda()
features = features.cuda()然后将模型加载进来,采用如下的方式:
model = nn.DataParallel(model.cuda(),device_ids=[0,1,2,3])版权声明:本文为weixin_61445075原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。