tensorflow加载与保存模型

问题:

1.训练好分类模型,比如训练保存了一个10分类的模型,但是实际用的时候呢,分类数可能会改变,但是还想继续使用前面保存的模型。那么相当于是只加载前几层的参数,最后一层做一些修改。
2.加载预训练模型时,预训练模型缺少网络中定义的变量

保存模型

saver = tf.train.Saver()
saver.save(sess,“model.ckpt”)

加载模型

saver.restore(sess,“model.ckpt”)

不传参数时,相当于是保存了所有的参数,然后加载所有的参数。

加载模型时变量缺失情况

我们可以先将模型定义的变量输出看一下,得到变量信息方式有很多
以resnet50为例

    from tensorflow.contrib.slim.nets import resnet_v1
	slim = tensorflow.contrib.slim
	
    inputdata = tf.placeholder(tf.float32, shape=(1, 224, 224, 3), name='input')
    net, end_points = resnet_v1.resnet_v1_50(inputdata, 1000, is_training=False)

我们在构建好网络后,加载网络中定义的变量

# way1
variables_to_resotre = tf.global_variables()
# way2
variables_to_resotre = slim.get_variables_to_restore()

输出这些变量的类型,可发现这些变量为列表类型。
那么我们只需要删除掉我们不需要的元素即可。
使用元素的name属性可得到其名称
在这里插入图片描述

删除样例如下:

#得到该网络中,所有可以加载的参数
variables = tf.slim.get_variables_to_restore()
#删除output层中的参数
variables_to_resotre = [v for v in varialbes if v.name !='output']
#构建这部分参数的saver
saver = tf.train.Saver(variables_to_restore)
saver.restore(sess,'model.ckpt')

版权声明:本文为Mr_WHITE2原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。