【神经网络笔记】——TensorFlow2.x自定义回调函数Callbacks

【神经网络笔记】——TensorFlow2.x自定义回调函数Callbacks

背景

在未知训练迭代次数的情况下,动态确定每个训练需要的迭代次数,需要设置自定义回调函数,配置acc或mse损失到达一定程度后停止训练

实例

import tensorflow as tf


class earlyStop(tf.keras.callbacks.Callback):
    def __init__(self,mode,acc_threshold=0.95,loss_threshold=0.025):
        super().__init__()
        self.mode = mode
        self.acc_threshold = acc_threshold
        self.loss_threshold = loss_threshold
    def on_epoch_end(self, epoch, logs=None):
        if self.mode=='a':#二分类
            if float(logs['acc'][-1])>self.acc_threshold:
                self.model.stop_training = True
                print("训练Early Stopping 迭代次数共计",epoch)
        elif self.mode=='b':#值回归
            if float(logs['loss'][-1])<self.loss_threshold:
                self.model.stop_training = True
                print("训练Early Stopping 迭代次数共计",epoch)

关键点


 1. 终止训练 
self.model.stop_training = True
 2. 回调函数调用
        history = model.fit(features,
                            to_categorical(labels,2),
                            epochs=epochs,  # 迭代次数
                            batch_size=batch_size,
                            verbose=2,  # 该方法训练不动 异常占比过小
                            # validation_split=0.1  # 没必要,数据集过度不平衡,参考价值不大
                            callbacks=[earlyStop(mode,acc_threshold,loss_threshold)]
                            )

总结

对于迭代次数不定的场景、需要按照条件停止训练的场景,都可以自己编写回调函数,不用官方提供的earlystopping!自力更生,慢却也最快。


版权声明:本文为BlackValley原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。