clip_gradient_norms()

 参考   clip_gradient_norms() - 云+社区 - 腾讯云

def clip_gradient_norms(gradients_to_variables, max_norm):
  clipped_grads_and_vars = []
  for grad, var in gradients_to_variables:
    if grad is not None:
      if isinstance(grad, ops.IndexedSlices):
        tmp = clip_ops.clip_by_norm(grad.values, max_norm)
        grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape)
      else:
        grad = clip_ops.clip_by_norm(grad, max_norm)
    clipped_grads_and_vars.append((grad, var))
  return clipped_grads_and_vars

用给定的值剪辑渐变。

参数:

  • gradients_to_variables:从渐变到变量对(元组)的列表
  • max_norm:最大值

返回值:

  • 变量对的剪切梯度列表

版权声明:本文为weixin_36670529原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。