pytorch创建的变量默认是存放在cpu上的,如下列代码所示:
>>> import torch
>>> R = torch.eye(3)
>>> R.device
# 输出为:device(type='cpu')
torch变量在不同device之间的转换方法有两种,如下:
# cpu -> cuda
>>> R = R.cuda(0) # 这里0表示第一个cuda
>>> R.device
# 输出为:device(type='cuda', index=0)
还一种方式是:
>>> device = torch.device('cuda:0')
>>> R = R.to(device)
>>> R.device
# 输出为:device(type='cuda', index=0)
将cuda上的变量转换到cpu上的方法是一样的,使用的函数为:
>>> R = R.cpu
# 或
>>> device = torch.device('cpu')
>>> R = R.to(device)
版权声明:本文为weixin_44120025原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。