今天开始更新学习 FaceBook 的深度学习框架 PyTorch !
PyTorch 底层优化的非常好,而且与 Numpy 无缝对接,用起来很清爽,不像 TensorFlow 那么“反 Python”~
先看了 Deep Learning with PyTorch: A 60 Minute Blitz ,题目说是“一小时搞定”,但就我这个上了岁数的人来讲,花了一晚上才把一整套流程跑了一遍。。。
接下来看看踩得第一个坑,在使用交叉熵损失函数 (Cross Entropy) 时抛出异常:
RuntimeError: multi-target not supported at …\aten\src\THNN/generic/ClassNLLCriterion.c:20
程序运行的经过为:
cross_entropy = nn.CrossEntropyLoss()
loss = cross_entropy(predicts, labels)
导致的原因为,函数的参数
predicts
的 shape
为 n*m
,n
为 batch_data
的样本数,m
为 模型输出层的维数;
labels
的 shape
为 1*n
,是一个一维 Tensor
(在四分类任务中,形如 [1, 0, 3, 1, 2]
),其中每个元素为类别的类别号(例如 一个样本的类别为飞机
,其对应的类别标签为 1
,则对应的 label
则为 1
)
看一下另一个损失函数的坑,MSE(Mean Square Error):
RuntimeError: Expected object of scalar type Float but got scalar type Long for argument #2 'target’
程序运行的经过为:
mse = nn.MSELoss()
loss = mse(result, batch_label)
导致原因为:
mse
方法要求 target
数据类型为 float
,所以要对 batch_label 做数据类型转换:
# 可以直接在全部数据中转换
batch_label.float()
版权声明:本文为weixin_37352167原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。