tf.strided_slice()对张量进行切片操作,就是从张量中提取一个片段.即从片段指定的位置begin开始,以stride步幅提取,直到所有维度都不小于end为止.
tf.strided_slice(
input_, #输入张量
begin, #切片起始处
end, #切片结束处
strides=None, #切片步长
begin_mask=0, #起始掩码
end_mask=0, #结束掩码
ellipsis_mask=0, #掩码
new_axis_mask=0, #掩码
shrink_axis_mask=0,#掩码
var=None, #与input_None对应的变量
name=None #操作的名称
)一维
tf.strided_slice()对一维张量/向量的切片操作与list/tuple/NumPy中的切片操作类似.
示例:
#!/usr/bin/python
# coding:utf-8
import tensorflow as tf
t = [1,2,3,4,5]
x = tf.strided_slice(t,[0],[3])
y = tf.strided_slice(t,[1],[-2])
with tf.Session() as sess:
print(sess.run(x))
print(sess.run(y))输出:
[1 2 3]
[2 3]多维
切片操作见示例.
#!/usr/bin/python
# coding:utf-8
import tensorflow as tf
t = tf.constant([[[1, 1, 1], [2, 2, 2], [7, 7, 7]],
[[3, 3, 3], [4, 4, 4], [8, 8, 8]],
[[5, 5, 5], [6, 6, 6], [9, 9, 9]]])
z1 = tf.strided_slice(t, [1], [-1], [1])
z2 = tf.strided_slice(t, [1, 0], [-1, 2], [1, 1])
z3 = tf.strided_slice(t, [1, 0, 1], [-1, 2, 3], [1, 1, 1])
with tf.Session() as sess:
print(sess.run(z1))
print(sess.run(z2))
print(sess.run(z3))z1 = tf.strided_slice(t, [1], [-1], [1])如图:
z2 = tf.strided_slice(t, [1, 0], [-1, 2], [1, 1])如图:
z3 = tf.strided_slice(t, [1, 0, 1], [-1, 2, 3], [1, 1, 1])如图:
版权声明:本文为akadiao原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。