参考 clip_gradient_norms() - 云+社区 - 腾讯云
def clip_gradient_norms(gradients_to_variables, max_norm):
clipped_grads_and_vars = []
for grad, var in gradients_to_variables:
if grad is not None:
if isinstance(grad, ops.IndexedSlices):
tmp = clip_ops.clip_by_norm(grad.values, max_norm)
grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape)
else:
grad = clip_ops.clip_by_norm(grad, max_norm)
clipped_grads_and_vars.append((grad, var))
return clipped_grads_and_vars
用给定的值剪辑渐变。
参数:
- gradients_to_variables:从渐变到变量对(元组)的列表
- max_norm:最大值
返回值:
- 变量对的剪切梯度列表
版权声明:本文为weixin_36670529原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。