TensorFlow 中返回指定条件的元素 tf.where 的基本用法及实例代码

一、环境

TensorFlow API r1.12

CUDA 9.2 V9.2.148

cudnn64_7.dll

Python 3.6.3

Windows 10

 

二、官方说明

根据指定的条件返回输入张量 x 或 y 中指定元素

https://tensorflow.google.cn/api_docs/python/tf/where

如果 x 和 y 都是 None 时,该操作返回 condition 中为 True 的元素的索引坐标,该坐标以 2 维张量形式返回,第一个维度(行)表示 condition 中为 True 的元素的编号,第二个维度(列)表示 condition 中为 True 的元素的坐标。注意:输出张量的形状依赖数据张量中为 True 的值得多少。输出索引是行主序。

如果 x 和 y 都是 non-None,x 和 y 必须形状相同。如果 x 和 y 是张量,那么 condition 必须是标量形式的张量。如果 x 和 y 都是高阶向量,那么 condition 要么必须是与 x 的第一个维度相同尺度的向量,要么必须和 x 的形状完全相同。

condition 充当选择的掩码,根据其中每个元素的值来决定从 x ( if true ) 或 y ( if false) 取哪些相应的元素 / 行到输出中

如果 condition 是向量,同时 x 和 y 是更高阶的矩阵,然后 condition 决定从 x 和 y 中选择哪些行(外侧维度)。如果 condition 与 x 和 y 具有相同的形状,然后 condition 将决定从 x 和 y 中选择哪些元素

tf.where(
    condition,
    x=None,
    y=None,
    name=None
)

参数:

condition:布尔类型的张量

x:可以是与 condition 具有相同形状的张量。如果 condition 的阶为1,x 的阶可能更高,但是其第一个维度必须和 condition 的尺度相同

y:类型和形状跟 x 相同的张量

name:可选参数,操作的名称

返回:

在 x 和 y 都不是 None 的情况下,返回类型和形状跟 x 和 y 相同的张量,其形状为 (num_true, dim_size(condition))

 

三、实例

1、x 和 y 为 None 时

>>> import tensorflow as tf
>>> x = tf.constant(value=[1,2,3,4,5])
>>> y = tf.where(condition=x>1)
>>> y
# <tf.Tensor 'Where:0' shape=(?, 1) dtype=int64>
>>> sess = tf.InteractiveSession()
>>> y.eval()
# array([[1],
#        [2],
#        [3],
#        [4]])
>>> y.shape
# TensorShape([Dimension(None), Dimension(1)])
>>> tf.shape(y).eval()
# array([4, 1], dtype=int32)
>>> sess.close()

2、x 和 y 都是标量时,condition 也必须是标量

>>> x = tf.constant(value=1)
>>> x
<tf.Tensor 'Const_5:0' shape=() dtype=int32>
>>> y = tf.constant(value=2)
>>> y
<tf.Tensor 'Const_6:0' shape=() dtype=int32>
>>> condition = True    # x 和 y 是标量,condition 也必须是 布尔型标量,因此不能是 condition = [True]
>>> result = tf.where(condition=condition, x=x, y=y)
>>> result
<tf.Tensor 'Select_1:0' shape=() dtype=int32>
>>> sess = tf.InteractiveSession()
>>> y.eval()
2
>>> sess.close()

3、condition 是和 x 的第一个维度的尺度相同的标量

>>> import tensorflow as tf

>>> x = tf.constant(value=[[[1,2,3],[4,5,6]],[[-1,-2,-3],[-4,-5,-6]]])
>>> x
<tf.Tensor 'Const:0' shape=(2, 2, 3) dtype=int32>

>>> y = tf.constant(value=[[[11,12,13],[14,15,16]],[[-11,-12,-13],[-14,-15,-16]]])
>>> y
<tf.Tensor 'Const_3:0' shape=(2, 2, 3) dtype=int32>

>>> condition = tf.constant(value=[True, False])
>>> condition
<tf.Tensor 'Const_2:0' shape=(2,) dtype=bool>

>>> result = tf.where(condition=condition, x=x, y=y)
>>> result
<tf.Tensor 'Select_5:0' shape=(2, 2, 3) dtype=int32>

>>> sess = tf.InteractiveSession()
>>> result.eval()
array([[[  1,   2,   3],
        [  4,   5,   6]],

       [[-11, -12, -13],
        [-14, -15, -16]]], dtype=int32)
>>> sess.close()

4、condition 与 x 具有相同的维度

>>> import tensorflow as tf
>>> x = tf.constant(value=[[[1,2,3],[4,5,6]]])
>>> x
<tf.Tensor 'Const:0' shape=(1, 2, 3) dtype=int32>
>>> y = tf.constant(value=[[[-1,-2,-3],[-4,-5,-6]]])
>>> y
<tf.Tensor 'Const_1:0' shape=(1, 2, 3) dtype=int32>
>>> condition = [[[True, False, True],[False, True, False]]]
>>> result = tf.where(condition=condition, x=x, y=y)
>>> result
<tf.Tensor 'Select:0' shape=(1, 2, 3) dtype=int32>
>>> sess = tf.InteractiveSession()
>>> result.eval()
array([[[ 1, -2,  3],
        [-4,  5, -6]]], dtype=int32)
>>> sess.close()



>>> import tensorflow as tf
>>> x = tf.constant(value=[[[1,2,3],[4,5,6]]])
>>> x
<tf.Tensor 'Const:0' shape=(1, 2, 3) dtype=int32>
>>> y = tf.constant(value=[[[-1,-2,-3],[-4,-5,-6]]])
>>> y
<tf.Tensor 'Const_1:0' shape=(1, 2, 3) dtype=int32>
>>> condition = [[[True, False, True],[False, True, True]]]
>>> result = tf.where(condition=condition, x=x, y=y)
>>> result
<tf.Tensor 'Select_1:0' shape=(1, 2, 3) dtype=int32>
>>> result.eval()
array([[[ 1, -2,  3],
        [-4,  5,  6]]], dtype=int32)
>>> sess.close()

 


版权声明:本文为sdnuwjw原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。