1.概述
tf.metrics主要是实现了评估相关的各种指标函数,比如mean、precesion、auc、precession、mse、precesion_at_k等等。以tf1.15为例,主要包括:
accuracy(...): Calculates how often predictions matches labels.
auc(...): Computes the approximate AUC via a Riemann sum.
average_precision_at_k(...): Computes average precision@k of predictions with respect to sparse labels.
false_negatives(...): Computes the total number of false negatives.
false_negatives_at_thresholds(...): Computes false negatives at provided threshold values.
false_positives(...): Sum the weights of false positives.
false_positives_at_thresholds(...): Computes false positives at provided threshold values.
mean(...): Computes the (weighted) mean of the given values.
mean_absolute_error(...): Computes the mean absolute error between the labels and predictions.
mean_cosine_distance(...): Computes the cosine distance between the labels and predictions.
mean_iou(...): Calculate per-step mean Intersection-Over-Union (mIOU).
mean_per_class_accuracy(...): Calculates the mean of the per-class accuracies.
mean_relative_error(...): Computes the mean relative error by normalizing with the given values.
mean_squared_error(...): Computes the mean squared error between the labels and predictions.
mean_tensor(...): Computes the element-wise (weighted) mean of the given tensors.
percentage_below(...): Computes the percentage of values less than the given threshold.
precision(...): Computes the precision of the predictions with respect to the labels.
precision_at_k(...): Computes precision@k of the predictions with respect to sparse labels.
precision_at_thresholds(...): Computes precision values for different thresholds on predictions.
precision_at_top_k(...): Computes precision@k of the predictions with respect to sparse labels.
recall(...): Computes the recall of the predictions with respect to the labels.
recall_at_k(...): Computes recall@k of the predictions with respect to sparse labels.
recall_at_thresholds(...): Computes various recall values for different thresholds on predictions.
recall_at_top_k(...): Computes recall@k of top-k predictions with respect to sparse labels.
root_mean_squared_error(...): Computes the root mean squared error between the labels and predictions.
sensitivity_at_specificity(...): Computes the specificity at a given sensitivity.
sparse_average_precision_at_k(...): Renamed to average_precision_at_k, please use that method instead. (deprecated)
sparse_precision_at_k(...): Renamed to precision_at_k, please use that method instead. (deprecated)
specificity_at_sensitivity(...): Computes the specificity at a given sensitivity.
true_negatives(...): Sum the weights of true_negatives.
true_negatives_at_thresholds(...): Computes true negatives at provided threshold values.
true_positives(...): Sum the weights of true_positives.
true_positives_at_thresholds(...): Computes true positives at provided threshold values.
2.单个batch or 整体样本?
最开始用tf.metrics是计算auc,当时很好奇一个问题,每个batch的都会算一次auc,那么这个auc指标是只用这一个batch的样本算出来的吗?后来看了下源码才明白,其实每一次算auc都是从最开始到当前batch累计的auc。想象一下如果我们自己实现应该怎么弄,肯定是需要一个或几个存储累计值的变量(我们叫累计变量吧),然后每个batch的样本往累计变量上加,最后计算指标。这个累计变量,在tf.metrics里放在GraphKeys.(LOCAL|METRIC_VARIABLES)变量中(PS:Local变量指的是存在本地机器上的变量,如果是分布式计算,各个机器都会存各自的累积变量,最后通过merge_all来把各个机器的累积变量聚合)。
metircs里的各个指标函数的返回值一般是一个指标值和一个update_op,后者是一个op,就是用来对累积变量进行操作的。
3.以mean为例代码分析
@tf_export(v1=['metrics.mean']) # @tf_export这个特性相当于自定义一个引用路径,可以使用tf.meatrics.mean()来用
def mean(values,
weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
"""Computes the (weighted) mean of the given values.
函数创建两个本地变量`total`和`count`,他们用来计算均值values,该均值作为mean最后在函数返回值中返回,mean是一个幂等操作,简单地total/count.
为了在数据流上计算,函数创建一个`update_op` 操作,该操作用于更新变量total和count,并返回mean。weights的作用是对于计算total和count时进行加权,有时候可以用weight=0当作mask来计算部分value的均值。
For estimation of the metric over a stream of data, the function creates an `update_op` operation that updates these variables and returns the `mean`. `update_op` increments `total` with the reduced sum of the product of `values` and `weights`, and it increments `count` with the reduced sum of `weights`.
Args:
values: 任意维度的`Tensor`.
weights: 可选参数, 一个`Tensor`,它的秩(rank)要么是0, 要么跟values的秩一样,它必须可以广播给 `values` (i.e., 所有维度都是`1`, 或者跟`values` 维度一致).
metrics_collections: 可选参数,一个collections的list,`mean` 会加入其中.
updates_collections: 可选参数,一个collections的list,`update_op`会加入其中.
name: 可选参数,variable_scope名称.
Returns:
mean: 一个`Tensor` ,表示当前mean值, 即total/count.
update_op: 一个operation,每次执行都增加total和count.
Raises:
ValueError: 如果weights的形状不满足要求,或者metrics_collections和updates_collections不是list或者tuple.
RuntimeError: If eager execution is enabled.
"""
if context.executing_eagerly():
raise RuntimeError('tf.metrics.mean is not supported when eager execution '
'is enabled.')
with variable_scope.variable_scope(name, 'mean', (values, weights)):
values = math_ops.cast(values, dtypes.float32) # 把输入数据转float32
# metric_variable会在从ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES两个集合中获取或者创建名称为total和count的变量
# total和count是GraphKeys.LOCAL_VARIABLES仅仅表示对于分布式评估,每个worker都有自己的total和count,最后计算均值会把所有worker汇聚起来
total = metric_variable([], dtypes.float32, name='total')
count = metric_variable([], dtypes.float32, name='count')
# num_values表示有效计数
# 如果weights不存在,values的size就是计数个数
# 如果weights存在,values等于原始值乘以权重,num_values等于权重的和
if weights is None:
num_values = math_ops.cast(array_ops.size(values), dtypes.float32)
else:
values, _, weights = _remove_squeezable_dimensions(
predictions=values, labels=None, weights=weights)
weights = weights_broadcast_ops.broadcast_weights(
math_ops.cast(weights, dtypes.float32), values)
values = math_ops.multiply(values, weights)
num_values = math_ops.reduce_sum(weights)
# 在total加上values,计算累计值
update_total_op = state_ops.assign_add(total, math_ops.reduce_sum(values))
with ops.control_dependencies([values]):
# 当更新values之后,计算总的count
update_count_op = state_ops.assign_add(count, num_values)
def compute_mean(_, t, c):
return math_ops.div_no_nan(t, math_ops.maximum(c, 0), name='value')
# 分布式一个worker对应一个replica,这里是所有work一起计算均值mean_t
mean_t = _aggregate_across_replicas(
metrics_collections, compute_mean, total, count)
# total和count累计的操作update_op,包括update_total_op、update_count_op和方法div_no_nan
update_op = math_ops.div_no_nan(
update_total_op, math_ops.maximum(update_count_op, 0), name='update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
return mean_t, update_op
这里关键是理解返回值(mean_t, update_op),包括auc、precession等几乎所有的metrics函数返回都是形式(metric_value, update_op)的形式。直观上理解,metric_value是当前评估指标的值,update_op是计算这个值的算子。
3.参考:
tf.metrics官方文档
Avoiding headaches with tf.metrics——帮你搞懂metrics返回值含义
版权声明:本文为hongxingabc原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。