estimator 模型保存与使用

1:estimator 是tensorflow的高级封装库,但是tensorflow 分为两个版本,1.X与2.X,本次文章两个版本都会说明,方便大家进行判断

1.0保存与读取

output_dir=’../outer‘
def serving_input_fn():
    label_ids = tf.placeholder(tf.int32, [None, max_seq_length], name='label_ids')
    input_ids = tf.placeholder(tf.int32, [None, max_seq_length], name='input_ids')
    input_mask = tf.placeholder(tf.int32, [None, max_seq_length], name='input_mask')
    segment_ids = tf.placeholder(tf.int32, [None, max_seq_length], name='segment_ids')
    input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
        'label_ids': label_ids,
        'input_ids': input_ids,
        'input_mask': input_mask,
        'segment_ids': segment_ids,
    })()
    return input_fn
estimator.export_savedmodel(output_dir, serving_input_fn)
predict_fn = tf.contrib.predictor.from_saved_model(output_dirs)
# 注意这个地方传入的一般为numpy的格式,具体还要看报错是啥
print(predict_fn({'input_ids': input_ids,
                      'segment_ids': segment_ids,
                      'label_ids': label_ids,
                      'input_mask': input_mask}))

2.0保存与读取

tf.saved_model.load(self.model_path)
model = self.predict_fn.signatures["serving_default"]
ret = model(input_ids=tf.constant(input_feature['input_ids']),
input_mask=tf.constant(input_feature['input_mask']),
label_ids=tf.constant(input_feature['label_ids']),
segment_ids=tf.constant(input_feature['segment_ids']))

2.0的读取方式变换了,没有之前的tf.contrib这个库了,所以方法变为tf.saved_model.load这种,而且要用signatures指名输出参数,这个地方不建议修改,主要是后面的参数,必须要和模型对应,不像之前是字典模式,如果你的输入参数无法进行这样写,建议用**传入

   map_dict = {'Input-Token': tf.constant(input_feature['Input-Token'], dtype=tf.float32),
               'Input-Segment': tf.constant(input_feature['Input-Segment'], dtype=tf.float32)}
   ret = model(**map_dict)

20220411
一般需要keras版本,这里新增一个版本对应
在这里插入图片描述


版权声明:本文为qq_39309652原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。