TensoFlow中的segment_sum函数

今天看TensorFlow自然语言处理这本书时,碰到了下面这个语句

data = tf.constant([1,2,3,4,5,6,7,8,9,10],dtype=tf.float32)
segment_ids = tf.constant([0,0,0,1,1,2,2,2,2,2],dtype=tf.int32)
x_seg_sum = tf.segment_sum(data,segment_ids)

然后不理解这个方法的含义,就去官方文档里看了下在这里插入图片描述这个其实描述的就挺明显了,意思是TensorFlow提供一些操作在张量分割上。

tf.segment_sum(data, segment_ids, name=None)

方法中的data就是你需要操作的张量,segment_ids是你给的对应的标号,标号给定的原则要与data的第一方向(first dimension)一致,这个第一方向类似于execl中的行数,给定的segment_ids的中元素的个数必须与和data第一方向上的个数d0相同,而且元素的值属于0到d0之间。
官方也给出了例子。在这里插入图片描述那么segment_也就很好理解了,就是根据给定的下标对张量进行分组,而下滑线后接的就是分组后的操作,比如segment_sumsegment_minsegmen_max


版权声明:本文为qq_43735982原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。