keras维度转换问题

一般如果只涉及卷积和全连接不需要考虑维度转换的问题,但是当使用inception模块,或者使用RNN和CNN结合的时候需要考虑维度转换的问题。
1.在keras中使用layers.Reshape函数,flatten()的源代码就是使用reshape函数,对于四维的输入,(?,f1,f2,通道),如果要转换成三维,在keras里面可以不用考虑第一维的batch,直接layers.Reshape((f1*f2,通道))就可以
2.用lambda函数进行切片操作,可以降低维度,比如对于一个(a,b,c,d)的四维tensor,要取(a,b,c,0)表示第一个通道的三维数据(a,b,c)可以使用:

def slice(x,index):
    return x[:,:,:,index]
x=layers.Lambda(slice,arguments={'index':0})(input_datas)

3.另外,还需注意,在keras内使用x.shape[0],比如返回的是32,这个32不是整数类型的,这里是dimension类型的,要使用这个32,需要int(x.shape[0]),比如1中的layers.Reshape((int(x.shape[1]*x.shape[2]),int(x.shape[3])))


版权声明:本文为qq_26064989原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。