读模型中使用
tf.import_graph_def(od_graph_def, name='') 出现
ValueError: Input 0 of node batch_normalization_v1/cond/ReadVariableOp/Switch was passed float from batch_normalization_v1/moving_mean:0 incompatible with expected resource.这个问题困扰了我很久,出现在1.13.1版本的tensofrlow中,原因是使用tf.keras.layers.BatchNormalization模块后类型,读取模型类型的不匹配。据说使用res残差块也会出现该问题。
目前没有在该版本中比较有效的解决方法。
这里贴下我的解决方法:
2020.4.8更新:使用tensorflow1.14版本可以直接运行。
使用tensorflow版本来保存模型后,切回原版本读取,就不会出现读取模型出错。
下面代码中,model为tf.keras模型
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
full_model = tf.function(lambda x: model(x))
full_model = full_model.get_concrete_function(
tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))
# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
logdir="",
name="frozen_graph.pb",
as_text=False)
with tf.Graph().as_default():
output_graph_def = tf.compat.v1.GraphDef()
# 打开.pb模型
with open("frozen_graph.pb", "rb") as f:
output_graph_def.ParseFromString(f.read())
tensors = tf.import_graph_def(output_graph_def, name='')
# print("tensors:", tensors)
with tf.compat.v1.Session() as sess:
op = sess.graph.get_operations()
for i, m in enumerate(op):
print('op{}:'.format(i), m.values())
input_x = sess.graph.get_tensor_by_name("x:0") #可以看op的首末名input.name
print("input_X:", input_x)
out_softmax = sess.graph.get_tensor_by_name(
"Identity:0") #可以看op的首末名
print("Output:", out_softmax)
# 读入图片
img = cv2.imread("1.jpg")
img = cv2.resize(img, (128, 128))
img = img.astype(np.float32)
img = img / 255;
# img=np.reshape(img,(1,28,28,1))
print("img data type:", img.dtype)
img_out_softmax = sess.run(out_softmax,
feed_dict={input_x: np.reshape(img, (1, 128, 128, 3))})
print("img_out_softmax:", img_out_softmax)
for i, prob in enumerate(img_out_softmax[0]):
print('class {} prob:{}'.format(i, prob))
prediction_labels = np.argmax(img_out_softmax, axis=1)
print("Final class if:", prediction_labels)
print("prob of label:", img_out_softmax[0, prediction_labels])
版权声明:本文为a362682954原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。