Pytorch自定义损失函数losses.py

PyTorch已经有很多标准的损失函数,有时也需要创建自己的损失函数。为此,需要创建一个单独的文件 losses.py ,然后扩展 nn.Module类创建自定义损失函数:

class CustomLoss(torch.nn.Module):
    
    def __init__(self):
        super(CustomLoss,self).__init__()
        
    def forward(self,x,y):
        loss = torch.mean((x - y)**2)
        return loss

版权声明:本文为KyrieHe原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。