# 参考资料:https://blog.csdn.net/qq_38101208/article/details/110481390
#%%
import torch
import torch.nn as nn
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
class CNN(nn.Module):
def __init__(self):
super(CNN,self).__init__()
pass
#网络
model = CNN().to(device)
#训练时(输入训练数据,标签)
x = torch.arange(4)
y = torch.arange(4)
x,y = x.to(device),y.to(device)
#预测时(输入训练数据)
#输出结果如果需要用numpy 进行处理,需要讲结果载入到CPU上
out = model(x).cpu().numpy()
版权声明:本文为xingghaoyuxitong原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。