
torch代码计算
def paras_cnn(k,s,p,i=64): x=torch.ones(1,1,i,i) conv = torch.nn.Conv2d(1, 1, kernel_size=k, stride=s, padding=p) convt= torch.nn.ConvTranspose2d(1, 1, kernel_size=k, stride=s, padding=p) h1=conv(x) h2=convt(x) y=convt(h1) print("conv(x):{} \t convT(x):{} \t convT(conv(x)):{}".format((h1.shape[2],h1.shape[3]),(h2.shape[2],h2.shape[3]),(y.shape[2],y.shape[3]))) return h1.shape[2],h1.shape[3],h2.shape[2],h2.shape[3],y.shape[2],y.shape[3]
版权声明:本文为hechao3225原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。