TensorRT运行TensorFlow pb模型 float16

目录

TensorFlow checkpoint to PB file

PB file to ONNX

ONNX to TensorRT plan (float16) and TensorRT inference script


​​​​​​​

TensorFlow checkpoint to PB file

以bert large为例

freeze_model.py

import tensorflow as tf
 
# input_checkpoint = './squad_large_output/model.ckpt-43799'
input_checkpoint = './github/bert/checkpoints_b8/model.ckpt-4'
 
output_node_names = "unstack"
output_graph = './data/bert_large_squad_b8.pb'
 
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
graph = tf.get_default_graph() # 获得默认的图
# tf.contrib.quantize.create_eval_graph(input_graph=graph)
input_graph_def = graph.as_graph_def()  # 返回一个序列化的图代表当前的图
 
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, input_checkpoint) #恢复图并得到数据
    output_graph_def = tf.graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
        sess=sess,
        input_graph_def=input_graph_def,# 等于:sess.graph_def
        output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
    with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
        f.write(output_graph_def.SerializeToString()) #序列化输出
    # print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点

fix_constant.py

import tensorflow as tf
 
input_model_path =  './data/bert_large_squad_b8.pb'
output_model_path = "./data/bert_large_squad_const_fixed_b8.pb"
 
tf.reset_default_graph()
sess = tf.Session()
with tf.gfile.FastGFile(input_model_path, "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
 
zero = tf.constant(0.0)
BATCH_SIZE = 8
MAX_SEQ_LENGTH = 384
HIDDEN_SIZE = 1024
input_map = {
    "bert/embeddings/dropout/random_uniform/RandomUniform:0": zero,
    #   'loss/dropout/random_uniform/RandomUniform:0': zero,
    "bert/embeddings/dropout/rate:0": zero,
    #   'loss/dropout/rate:0': zero,
}
for i in range(24):
    input_map["bert/encoder/layer_%d/attention/self/dropout/random_uniform/RandomUniform:0" % i] = zero
    input_map["bert/encoder/layer_%d/attention/output/dropout/random_uniform/RandomUniform:0" % i] = zero
    input_map["bert/encoder/layer_%d/output/dropout/random_uniform/RandomUniform:0" % i] = zero
    input_map["bert/encoder/layer_%d/attention/self/dropout/rate:0" % i] = zero
    input_map["bert/encoder/layer_%d/attention/output/dropout/rate:0" % i] = zero
    input_map["bert/encoder/layer_%d/output/dropout/rate:0" % i] = zero
input_map["bert/embeddings/GatherV2:0"] = tf.placeholder(
    "float32", name="embedding_lookup", shape=(BATCH_SIZE * MAX_SEQ_LENGTH, HIDDEN_SIZE)
)
input_map["bert/encoder/Cast:0"] = tf.placeholder(
    "float32", name="input_mask", shape=(BATCH_SIZE, 1, MAX_SEQ_LENGTH)
)
input_map["bert/embeddings/one_hot:0"] = tf.placeholder(
    "float32", name="one_hot", shape=(BATCH_SIZE * MAX_SEQ_LENGTH, 2)
)
tf.import_graph_def(graph_def, name="", input_map=input_map)
 
output_node_names = "unstack"
output_graph_def = tf.graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
    sess=sess,
    input_graph_def=tf.get_default_graph().as_graph_def(),# 等于:sess.graph_def
    output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
# with tf.gfile.GFile("../model_zoo/bert/bert_squad_simple.pb", "wb") as f: #保存模型
#     f.write(output_graph_def.SerializeToString()) #序列化输出
with tf.gfile.GFile(output_model_path, "wb") as f: #保存模型
    f.write(output_graph_def.SerializeToString()) #序列化输出
 
 
with open("graph_nodes_simple.log", 'w') as f:
    for node in output_graph_def.node:
        print(node.name, file=f)

PB file to ONNX

 安装 tf2onnx, 然后转换输出 onnx 模型。

pip install -U tf2onnx
python -m tf2onnx.convert --input <INPUT>.pb --output <OUTPUT>.onnx --inputs embedding_lookup:0,input_mask:0,one_hot:0 --outputs unstack:0,unstack:1

ONNX to TensorRT plan (float16) and TensorRT inference script

tensorrt_inference.py

import numpy as np
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import time
# from onnx import ModelProto
 
# onnx_path = "data/bert_large_squad_const_fixed.onnx"
onnx_path = "data/onnx_bert_large_squad_b8_fp32.onnx"
out_plan_file = "data/trt_bert_large_squad_b8_fp16.plan"
serialized_plan_file = "data/bert_large_squad_const_fixed_fp32_engine.trt"
serialized_plan_file = "data/bert_large_squad_const_fixed_engine.trt"
# serialized_plan_file = "data/trt_bert_large_fp16.plan"
 
 
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
trt_runtime = trt.Runtime(TRT_LOGGER)
 
 
def build_engine(onnx_path, fp16=True):
    """
    This is the function to create the TensorRT engine
    Args:
       onnx_path : Path to onnx_file.
       shape : Shape of the input of the ONNX file.
    """
    EXPLICIT_BATCH = 1 << (int)(
        trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    with trt.Builder(TRT_LOGGER) as builder, builder.create_network(
        EXPLICIT_BATCH
    ) as network, builder.create_builder_config() as config, trt.OnnxParser(network, TRT_LOGGER) as parser:
        config.max_workspace_size = 32 << 20
        with open(onnx_path, "rb") as model:
            parser.parse(model.read())
        if fp16:
            config.set_flag(trt.BuilderFlag.FP16)
            config.set_flag(trt.BuilderFlag.STRICT_TYPES)
        # network.get_input(0).shape = shape
        engine = builder.build_engine(network, config)
        return engine
 
 
def save_engine(engine, file_name):
    buf = engine.serialize()
    with open(file_name, "wb") as f:
        f.write(buf)
 
 
def load_engine(trt_runtime, plan_path):
    print("loading plan from ", plan_path, flush=True)
    with open(plan_path, "rb") as f:
        engine_data = f.read()
    engine = trt_runtime.deserialize_cuda_engine(engine_data)
    return engine
 
 
def allocate_buffers(engine):
    h_inputs = list()
    h_outputs = list()
    d_inputs = list()
    d_outputs = list()
    stream = cuda.Stream()
    for binding in engine:
        h_mem = cuda.pagelocked_empty(trt.volume(engine.get_binding_shape(
            binding)), trt.nptype(engine.get_binding_dtype(binding)))
        if engine.binding_is_input(binding):
            h_inputs.append(h_mem)
            d_inputs.append(cuda.mem_alloc(h_mem.nbytes))
        else:
            h_outputs.append(h_mem)
            d_outputs.append(cuda.mem_alloc(h_mem.nbytes))
    return h_inputs, d_inputs, h_outputs, d_outputs, stream
 
 
def rand_gen(shape, low=-0.2, high=0.2, dtype="float32"):
    return np.random.uniform(low=low, high=high, size=shape).astype(dtype)
 
 
def gen_input_data(engine):
    input_data = list()
    for binding in engine:
        if engine.binding_is_input(binding):
            shape = np.array(engine.get_binding_shape(binding))
            dtype = trt.nptype(engine.get_binding_dtype(binding))
            input_data.append(rand_gen(shape, dtype=dtype).ravel())
    return input_data
 
 
def do_inference(engine, input_data, h_i_mems, d_i_mems, h_o_mems, d_o_mems, stream, repeat_time=10):
    # load random data to page-locked buffer
    assert len(input_data) == len(h_i_mems)
    for i_data, i_mem in zip(input_data, h_i_mems):
        np.copyto(i_data, i_mem)
 
    with engine.create_execution_context() as context:
        start = time.time()
        for _ in range(repeat_time):
            for d_i_mem, h_i_mem in zip(d_i_mems, h_i_mems):
                cuda.memcpy_htod_async(d_i_mem, h_i_mem, stream)
 
            # Run inference
            # context.profiler = trt.Profiler()
            bindings = [int(x) for x in d_i_mems] + [int(x) for x in d_o_mems]
            context.execute_async_v2(
                bindings=bindings, stream_handle=stream.handle)
 
            # Transfer predictions back from the GPU.
            for h_o_mem, d_o_mem in zip(h_o_mems, d_o_mems):
                cuda.memcpy_dtoh_async(h_o_mem, d_o_mem, stream)
            # Synchronize the stream
            stream.synchronize()
        end = time.time()
    print("\nInference done.", flush=True)
    print(
        f"Run {repeat_time} times, total consumed: {end - start} seconds", flush=True)
    print(f"on average: {(end-start)/repeat_time} seconds", flush=True)
 
 
################################################################################
# engine = load_engine(trt_runtime, serialized_plan_file)
print("start building engine...", flush=True)
engine = build_engine(onnx_path)
print("building engine done!", flush=True)
save_engine(engine, out_plan_file)
 
# for binding in engine:
#     size = trt.volume(engine.get_binding_shape(binding)) * 1
#     dims = engine.get_binding_shape(binding)
#     # print(size)
#     print(dims)
#     print(binding)
#     print("input =", engine.binding_is_input(binding))
#     dtype = trt.nptype(engine.get_binding_dtype(binding))
#     print("dtype: ", dtype)
 
h_i_mems, d_i_mems, h_o_mems, d_o_mems, stream = allocate_buffers(engine)
print("allocate buffer done.")
 
input_data = gen_input_data(engine)
print("gen data done.")
repeat_times = 100
do_inference(engine, input_data, h_i_mems, d_i_mems,
             h_o_mems, d_o_mems, stream, repeat_times)


版权声明:本文为silverdemon原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。