TensorFlow学习--tf.strided_slice/张量切片

tf.strided_slice()对张量进行切片操作,就是从张量中提取一个片段.即从片段指定的位置begin开始,以stride步幅提取,直到所有维度都不小于end为止.

tf.strided_slice(
    input_,          #输入张量
    begin,             #切片起始处
    end,               #切片结束处
    strides=None,      #切片步长
    begin_mask=0,      #起始掩码
    end_mask=0,        #结束掩码
    ellipsis_mask=0,   #掩码
    new_axis_mask=0,   #掩码
    shrink_axis_mask=0,#掩码
    var=None,          #与input_None对应的变量
    name=None          #操作的名称
)

一维

tf.strided_slice()对一维张量/向量的切片操作与list/tuple/NumPy中的切片操作类似.

示例:

#!/usr/bin/python
# coding:utf-8
import tensorflow as tf
t = [1,2,3,4,5]
x = tf.strided_slice(t,[0],[3])
y = tf.strided_slice(t,[1],[-2])
with tf.Session() as sess:
    print(sess.run(x))
    print(sess.run(y))

输出:

[1 2 3]
[2 3]

多维

切片操作见示例.

#!/usr/bin/python
# coding:utf-8

import tensorflow as tf
t = tf.constant([[[1, 1, 1], [2, 2, 2], [7, 7, 7]],
                 [[3, 3, 3], [4, 4, 4], [8, 8, 8]],
                 [[5, 5, 5], [6, 6, 6], [9, 9, 9]]])

z1 = tf.strided_slice(t, [1], [-1], [1])
z2 = tf.strided_slice(t, [1, 0], [-1, 2], [1, 1])
z3 = tf.strided_slice(t, [1, 0, 1], [-1, 2, 3], [1, 1, 1])

with tf.Session() as sess:
    print(sess.run(z1))
    print(sess.run(z2))
    print(sess.run(z3))

z1 = tf.strided_slice(t, [1], [-1], [1])如图:
这里写图片描述
z2 = tf.strided_slice(t, [1, 0], [-1, 2], [1, 1])如图:
这里写图片描述
z3 = tf.strided_slice(t, [1, 0, 1], [-1, 2, 3], [1, 1, 1])如图:
这里写图片描述



版权声明:本文为akadiao原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。