保存权重的方式+设置学习率下降的方式+早停的方式+history的保存与加载

from keras.callbacks import TensorBoard, ModelCheckpoint, ReduceLROnPlateau, EarlyStopping

    # 保存的方式,3世代保存一次
    checkpoint_period = ModelCheckpoint(
                                    log_dir + 'ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5',
                                    monitor='val_loss', 
                                    save_weights_only=True, 
                                    save_best_only=True, 
                                    period=3
                                )
    # 学习率下降的方式,val_loss3次不下降就下降学习率继续训练
    reduce_lr = ReduceLROnPlateau(
                            monitor='val_loss', 
                            factor=0.5, 
                            patience=3, 
                            verbose=1
                        )
    # 是否需要早停,当val_loss一直不下降的时候意味着模型基本训练完毕,可以停止
    early_stopping = EarlyStopping(
                            monitor='val_loss', 
                            min_delta=0, 
                            patience=10, 
                            verbose=1
                        )

然后训练的时候:

    # 开始训练
    model.fit_generator(generate_arrays_from_file(lines[:num_train], batch_size),
            steps_per_epoch=max(1, num_train//batch_size),
            validation_data=generate_arrays_from_file(lines[num_train:], batch_size),
            validation_steps=max(1, num_val//batch_size),
            epochs=50,
            initial_epoch=0,
            callbacks=[checkpoint_period, reduce_lr,early_stopping])

    model.save_weights(log_dir+'last1.h5')

值得借鉴的地方有:

  1. 保存模型的时候(这里设置每三个epochs保存一次)
    log_dir + ‘ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5’
    这里给槽的方式并没有用.format这种落后的形式,而是用这种带冒号的形式

  2. 注意调用三者的是:callbacks=[checkpoint_period, reduce_lr,early_stopping])

最后补充保存和加载history的方式:

保存

# # history的保存
import pickle
with open('model_vi_bs_128_lrdefault_hist.pickle', 'wb') as file_pi:
        pickle.dump(history.history, log_dir+file_pi)

加载

import pickle
with open('logs/model_vi_bs_128_lrdefault_hist.pickle','rb') as fr:
    history = pickle.load(fr)

版权声明:本文为weixin_44684139原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。