ValueError: Input 0 of node xxx was passed float from xxx 0 incompatible with expected resource.

读模型中使用

 tf.import_graph_def(od_graph_def, name='') 

出现

ValueError: Input 0 of node batch_normalization_v1/cond/ReadVariableOp/Switch was passed float from batch_normalization_v1/moving_mean:0 incompatible with expected resource.

这个问题困扰了我很久,出现在1.13.1版本的tensofrlow中,原因是使用tf.keras.layers.BatchNormalization模块后类型,读取模型类型的不匹配。据说使用res残差块也会出现该问题。

目前没有在该版本中比较有效的解决方法。

这里贴下我的解决方法:

2020.4.8更新:使用tensorflow1.14版本可以直接运行。

使用tensorflow版本来保存模型后,切回原版本读取,就不会出现读取模型出错。

下面代码中,model为tf.keras模型

    from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
    
    full_model = tf.function(lambda x: model(x))
    full_model = full_model.get_concrete_function(
        tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))
 
    # Get frozen ConcreteFunction
    frozen_func = convert_variables_to_constants_v2(full_model)
    frozen_func.graph.as_graph_def()
    tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                      logdir="",
                      name="frozen_graph.pb",
                      as_text=False)
 
 
    with tf.Graph().as_default():
        output_graph_def = tf.compat.v1.GraphDef()
 
        # 打开.pb模型
        with open("frozen_graph.pb", "rb") as f:
            output_graph_def.ParseFromString(f.read())
            tensors = tf.import_graph_def(output_graph_def, name='')
            # print("tensors:", tensors)
 
 
        with tf.compat.v1.Session() as sess:
            op = sess.graph.get_operations()
            for i, m in enumerate(op):
                print('op{}:'.format(i), m.values())
 
 
            input_x = sess.graph.get_tensor_by_name("x:0")  #可以看op的首末名input.name
            print("input_X:", input_x)
 
            out_softmax = sess.graph.get_tensor_by_name(
                "Identity:0") #可以看op的首末名
            print("Output:", out_softmax)
 
            # 读入图片
            img = cv2.imread("1.jpg")
            img = cv2.resize(img, (128, 128))
            img = img.astype(np.float32)
            img = img / 255;
            # img=np.reshape(img,(1,28,28,1))
            print("img data type:", img.dtype)
 
            img_out_softmax = sess.run(out_softmax,
                                       feed_dict={input_x: np.reshape(img, (1, 128, 128, 3))})
            print("img_out_softmax:", img_out_softmax)
            for i, prob in enumerate(img_out_softmax[0]):
                print('class {} prob:{}'.format(i, prob))
            prediction_labels = np.argmax(img_out_softmax, axis=1)
            print("Final class if:", prediction_labels)
            print("prob of label:", img_out_softmax[0, prediction_labels])

 


版权声明:本文为a362682954原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。