Callback

对于keras回调函数简单整理

callback

回调函数是一个函数的合集,会在训练阶段中使用。你可以使用回调函数来查看训练模型的内在状态统计。你可以传递一个列表的回调函数(作为callback关键字参数)到SequentialModel类型的.fit方法。在训练时,相应的回调函数的方法就会被在各自的阶段被调用。

可以通过扩展keras.callbacks.Callback基类来创建一个自定义的回调函数。通过类的属性self.model,回调函数可以获得它所联系的模型。

定义一个自己的回调函数

首先需要继承keras.callbacks.Callback

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from keras.callbacks import Callback
class LossHistory(Callback):  # 继承自Callback类
    def __init__(self, tt, path):
        # 存储所消耗的总时间
        self.countTimes = tt
        # 数据保存的路径
        self.path = path

    """
    在模型开始的时候定义四个属性,每一个属性都是字典类型,存储相对应的值和epoch
    """

    def on_train_begin(self, logs={}):
        self.times = []
        self.totalTime = time.time()

        self.batch_losses = {"batch_loss": []}
        self.batch_accuracys = {"batch_accuracy": []}
        self.epoch_losses = {"epoch_loss":[]}
        self.epoch_accuracys = {"epoch_accuracy":[]}
        self.epoch_val_losses = {"epoch_val_loss": []}
        self.epoch_val_accuracys = {"epoch_val_accuracy": []}
    # 在每一个batch结束后记录相应的值
    def on_train_end(self, logs={}):
        self.totalTime = time.time() - self.totalTime

    def on_batch_end(self, batch, logs={}):
        self.batch_losses["batch_loss"].append(logs.get("loss"))
        self.batch_accuracys["batch_accuracy"].append(logs.get("accuracy"))
    # 在每一个epoch之后记录相应的值
    def on_epoch_begin(self, batch, logs={}):
        self.epoch_time_start = time.time()

    def on_epoch_end(self, batch, logs={}):
        self.times.append(time.time() - self.epoch_time_start)
        self.epoch_losses["epoch_loss"].append(logs.get("loss"))
        self.epoch_accuracys["epoch_accuracy"].append(logs.get("accuracy"))
        self.epoch_val_losses["epoch_val_loss"].append(logs.get("val_loss"))
        self.epoch_val_accuracys["epoch_val_accuracy"].append(logs.get("val_accuracy"))

    def hist(self):
        times = {
            "epoch_Times":self.times
        }
        totalTime = {
            "total_Time":self.totalTime
        }
        epoch_loss = self.epoch_losses
        epoch_accuracy = self.epoch_accuracys
        epoch_val_loss = self.epoch_val_losses
        epoch_val_acc = self.epoch_val_accuracys
        batch_loss = self.batch_losses
        batch_accuracy = self.batch_accuracys
        return [totalTime, times, epoch_loss, epoch_accuracy, epoch_val_loss, epoch_val_acc, batch_loss, batch_accuracy]
    def saveHistory(self):
        hisData = self.hist()
        np.save(self.path+'\\'+str(self.countTimes)+"historyData.npy",hisData)

在上述创建好了一个类之后下面进行调用

1
2
3
4
5
6
import times
historys = LossHistory(tt=times, path="weights")
# 实例化你的model后
model.fit(x, y, epochs=10, batch_size=32, validation_data=(x_val, y_val), callbacks=[historys])
# 调用LossHistory类中的方法,进行数据保存
history.saveHistory()