问题:
1.训练好分类模型,比如训练保存了一个10分类的模型,但是实际用的时候呢,分类数可能会改变,但是还想继续使用前面保存的模型。那么相当于是只加载前几层的参数,最后一层做一些修改。
2.加载预训练模型时,预训练模型缺少网络中定义的变量
保存模型
saver = tf.train.Saver()
saver.save(sess,“model.ckpt”)
加载模型
saver.restore(sess,“model.ckpt”)
不传参数时,相当于是保存了所有的参数,然后加载所有的参数。
加载模型时变量缺失情况
我们可以先将模型定义的变量输出看一下,得到变量信息方式有很多
以resnet50为例
from tensorflow.contrib.slim.nets import resnet_v1
slim = tensorflow.contrib.slim
inputdata = tf.placeholder(tf.float32, shape=(1, 224, 224, 3), name='input')
net, end_points = resnet_v1.resnet_v1_50(inputdata, 1000, is_training=False)
我们在构建好网络后,加载网络中定义的变量
# way1
variables_to_resotre = tf.global_variables()
# way2
variables_to_resotre = slim.get_variables_to_restore()
输出这些变量的类型,可发现这些变量为列表类型。
那么我们只需要删除掉我们不需要的元素即可。
使用元素的name属性可得到其名称
删除样例如下:
#得到该网络中,所有可以加载的参数
variables = tf.slim.get_variables_to_restore()
#删除output层中的参数
variables_to_resotre = [v for v in varialbes if v.name !='output']
#构建这部分参数的saver
saver = tf.train.Saver(variables_to_restore)
saver.restore(sess,'model.ckpt')
版权声明:本文为Mr_WHITE2原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。