目录
TensorFlow checkpoint to PB file
ONNX to TensorRT plan (float16) and TensorRT inference script
TensorFlow checkpoint to PB file
以bert large为例
freeze_model.py
import tensorflow as tf
# input_checkpoint = './squad_large_output/model.ckpt-43799'
input_checkpoint = './github/bert/checkpoints_b8/model.ckpt-4'
output_node_names = "unstack"
output_graph = './data/bert_large_squad_b8.pb'
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
graph = tf.get_default_graph() # 获得默认的图
# tf.contrib.quantize.create_eval_graph(input_graph=graph)
input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, input_checkpoint) #恢复图并得到数据
output_graph_def = tf.graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
sess=sess,
input_graph_def=input_graph_def,# 等于:sess.graph_def
output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
f.write(output_graph_def.SerializeToString()) #序列化输出
# print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
fix_constant.py
import tensorflow as tf
input_model_path = './data/bert_large_squad_b8.pb'
output_model_path = "./data/bert_large_squad_const_fixed_b8.pb"
tf.reset_default_graph()
sess = tf.Session()
with tf.gfile.FastGFile(input_model_path, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
zero = tf.constant(0.0)
BATCH_SIZE = 8
MAX_SEQ_LENGTH = 384
HIDDEN_SIZE = 1024
input_map = {
"bert/embeddings/dropout/random_uniform/RandomUniform:0": zero,
# 'loss/dropout/random_uniform/RandomUniform:0': zero,
"bert/embeddings/dropout/rate:0": zero,
# 'loss/dropout/rate:0': zero,
}
for i in range(24):
input_map["bert/encoder/layer_%d/attention/self/dropout/random_uniform/RandomUniform:0" % i] = zero
input_map["bert/encoder/layer_%d/attention/output/dropout/random_uniform/RandomUniform:0" % i] = zero
input_map["bert/encoder/layer_%d/output/dropout/random_uniform/RandomUniform:0" % i] = zero
input_map["bert/encoder/layer_%d/attention/self/dropout/rate:0" % i] = zero
input_map["bert/encoder/layer_%d/attention/output/dropout/rate:0" % i] = zero
input_map["bert/encoder/layer_%d/output/dropout/rate:0" % i] = zero
input_map["bert/embeddings/GatherV2:0"] = tf.placeholder(
"float32", name="embedding_lookup", shape=(BATCH_SIZE * MAX_SEQ_LENGTH, HIDDEN_SIZE)
)
input_map["bert/encoder/Cast:0"] = tf.placeholder(
"float32", name="input_mask", shape=(BATCH_SIZE, 1, MAX_SEQ_LENGTH)
)
input_map["bert/embeddings/one_hot:0"] = tf.placeholder(
"float32", name="one_hot", shape=(BATCH_SIZE * MAX_SEQ_LENGTH, 2)
)
tf.import_graph_def(graph_def, name="", input_map=input_map)
output_node_names = "unstack"
output_graph_def = tf.graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
sess=sess,
input_graph_def=tf.get_default_graph().as_graph_def(),# 等于:sess.graph_def
output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
# with tf.gfile.GFile("../model_zoo/bert/bert_squad_simple.pb", "wb") as f: #保存模型
# f.write(output_graph_def.SerializeToString()) #序列化输出
with tf.gfile.GFile(output_model_path, "wb") as f: #保存模型
f.write(output_graph_def.SerializeToString()) #序列化输出
with open("graph_nodes_simple.log", 'w') as f:
for node in output_graph_def.node:
print(node.name, file=f)
PB file to ONNX
安装 tf2onnx, 然后转换输出 onnx 模型。
pip install -U tf2onnx
python -m tf2onnx.convert --input <INPUT>.pb --output <OUTPUT>.onnx --inputs embedding_lookup:0,input_mask:0,one_hot:0 --outputs unstack:0,unstack:1
ONNX to TensorRT plan (float16) and TensorRT inference script
tensorrt_inference.py
import numpy as np
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import time
# from onnx import ModelProto
# onnx_path = "data/bert_large_squad_const_fixed.onnx"
onnx_path = "data/onnx_bert_large_squad_b8_fp32.onnx"
out_plan_file = "data/trt_bert_large_squad_b8_fp16.plan"
serialized_plan_file = "data/bert_large_squad_const_fixed_fp32_engine.trt"
serialized_plan_file = "data/bert_large_squad_const_fixed_engine.trt"
# serialized_plan_file = "data/trt_bert_large_fp16.plan"
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
trt_runtime = trt.Runtime(TRT_LOGGER)
def build_engine(onnx_path, fp16=True):
"""
This is the function to create the TensorRT engine
Args:
onnx_path : Path to onnx_file.
shape : Shape of the input of the ONNX file.
"""
EXPLICIT_BATCH = 1 << (int)(
trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(
EXPLICIT_BATCH
) as network, builder.create_builder_config() as config, trt.OnnxParser(network, TRT_LOGGER) as parser:
config.max_workspace_size = 32 << 20
with open(onnx_path, "rb") as model:
parser.parse(model.read())
if fp16:
config.set_flag(trt.BuilderFlag.FP16)
config.set_flag(trt.BuilderFlag.STRICT_TYPES)
# network.get_input(0).shape = shape
engine = builder.build_engine(network, config)
return engine
def save_engine(engine, file_name):
buf = engine.serialize()
with open(file_name, "wb") as f:
f.write(buf)
def load_engine(trt_runtime, plan_path):
print("loading plan from ", plan_path, flush=True)
with open(plan_path, "rb") as f:
engine_data = f.read()
engine = trt_runtime.deserialize_cuda_engine(engine_data)
return engine
def allocate_buffers(engine):
h_inputs = list()
h_outputs = list()
d_inputs = list()
d_outputs = list()
stream = cuda.Stream()
for binding in engine:
h_mem = cuda.pagelocked_empty(trt.volume(engine.get_binding_shape(
binding)), trt.nptype(engine.get_binding_dtype(binding)))
if engine.binding_is_input(binding):
h_inputs.append(h_mem)
d_inputs.append(cuda.mem_alloc(h_mem.nbytes))
else:
h_outputs.append(h_mem)
d_outputs.append(cuda.mem_alloc(h_mem.nbytes))
return h_inputs, d_inputs, h_outputs, d_outputs, stream
def rand_gen(shape, low=-0.2, high=0.2, dtype="float32"):
return np.random.uniform(low=low, high=high, size=shape).astype(dtype)
def gen_input_data(engine):
input_data = list()
for binding in engine:
if engine.binding_is_input(binding):
shape = np.array(engine.get_binding_shape(binding))
dtype = trt.nptype(engine.get_binding_dtype(binding))
input_data.append(rand_gen(shape, dtype=dtype).ravel())
return input_data
def do_inference(engine, input_data, h_i_mems, d_i_mems, h_o_mems, d_o_mems, stream, repeat_time=10):
# load random data to page-locked buffer
assert len(input_data) == len(h_i_mems)
for i_data, i_mem in zip(input_data, h_i_mems):
np.copyto(i_data, i_mem)
with engine.create_execution_context() as context:
start = time.time()
for _ in range(repeat_time):
for d_i_mem, h_i_mem in zip(d_i_mems, h_i_mems):
cuda.memcpy_htod_async(d_i_mem, h_i_mem, stream)
# Run inference
# context.profiler = trt.Profiler()
bindings = [int(x) for x in d_i_mems] + [int(x) for x in d_o_mems]
context.execute_async_v2(
bindings=bindings, stream_handle=stream.handle)
# Transfer predictions back from the GPU.
for h_o_mem, d_o_mem in zip(h_o_mems, d_o_mems):
cuda.memcpy_dtoh_async(h_o_mem, d_o_mem, stream)
# Synchronize the stream
stream.synchronize()
end = time.time()
print("\nInference done.", flush=True)
print(
f"Run {repeat_time} times, total consumed: {end - start} seconds", flush=True)
print(f"on average: {(end-start)/repeat_time} seconds", flush=True)
################################################################################
# engine = load_engine(trt_runtime, serialized_plan_file)
print("start building engine...", flush=True)
engine = build_engine(onnx_path)
print("building engine done!", flush=True)
save_engine(engine, out_plan_file)
# for binding in engine:
# size = trt.volume(engine.get_binding_shape(binding)) * 1
# dims = engine.get_binding_shape(binding)
# # print(size)
# print(dims)
# print(binding)
# print("input =", engine.binding_is_input(binding))
# dtype = trt.nptype(engine.get_binding_dtype(binding))
# print("dtype: ", dtype)
h_i_mems, d_i_mems, h_o_mems, d_o_mems, stream = allocate_buffers(engine)
print("allocate buffer done.")
input_data = gen_input_data(engine)
print("gen data done.")
repeat_times = 100
do_inference(engine, input_data, h_i_mems, d_i_mems,
h_o_mems, d_o_mems, stream, repeat_times)
版权声明:本文为silverdemon原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。