对于keras回调函数简单整理
callback
回调函数是一个函数的合集,会在训练阶段中使用。你可以使用回调函数来查看训练模型的内在状态统计。你可以传递一个列表的回调函数(作为callback关键字参数)到Sequential或Model类型的.fit方法。在训练时,相应的回调函数的方法就会被在各自的阶段被调用。
可以通过扩展keras.callbacks.Callback基类来创建一个自定义的回调函数。通过类的属性self.model,回调函数可以获得它所联系的模型。
定义一个自己的回调函数
首先需要继承keras.callbacks.Callback类
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
|
from keras.callbacks import Callback
class LossHistory(Callback): # 继承自Callback类
def __init__(self, tt, path):
# 存储所消耗的总时间
self.countTimes = tt
# 数据保存的路径
self.path = path
"""
在模型开始的时候定义四个属性,每一个属性都是字典类型,存储相对应的值和epoch
"""
def on_train_begin(self, logs={}):
self.times = []
self.totalTime = time.time()
self.batch_losses = {"batch_loss": []}
self.batch_accuracys = {"batch_accuracy": []}
self.epoch_losses = {"epoch_loss":[]}
self.epoch_accuracys = {"epoch_accuracy":[]}
self.epoch_val_losses = {"epoch_val_loss": []}
self.epoch_val_accuracys = {"epoch_val_accuracy": []}
# 在每一个batch结束后记录相应的值
def on_train_end(self, logs={}):
self.totalTime = time.time() - self.totalTime
def on_batch_end(self, batch, logs={}):
self.batch_losses["batch_loss"].append(logs.get("loss"))
self.batch_accuracys["batch_accuracy"].append(logs.get("accuracy"))
# 在每一个epoch之后记录相应的值
def on_epoch_begin(self, batch, logs={}):
self.epoch_time_start = time.time()
def on_epoch_end(self, batch, logs={}):
self.times.append(time.time() - self.epoch_time_start)
self.epoch_losses["epoch_loss"].append(logs.get("loss"))
self.epoch_accuracys["epoch_accuracy"].append(logs.get("accuracy"))
self.epoch_val_losses["epoch_val_loss"].append(logs.get("val_loss"))
self.epoch_val_accuracys["epoch_val_accuracy"].append(logs.get("val_accuracy"))
def hist(self):
times = {
"epoch_Times":self.times
}
totalTime = {
"total_Time":self.totalTime
}
epoch_loss = self.epoch_losses
epoch_accuracy = self.epoch_accuracys
epoch_val_loss = self.epoch_val_losses
epoch_val_acc = self.epoch_val_accuracys
batch_loss = self.batch_losses
batch_accuracy = self.batch_accuracys
return [totalTime, times, epoch_loss, epoch_accuracy, epoch_val_loss, epoch_val_acc, batch_loss, batch_accuracy]
def saveHistory(self):
hisData = self.hist()
np.save(self.path+'\\'+str(self.countTimes)+"historyData.npy",hisData)
|
在上述创建好了一个类之后下面进行调用
1
2
3
4
5
6
|
import times
historys = LossHistory(tt=times, path="weights")
# 实例化你的model后
model.fit(x, y, epochs=10, batch_size=32, validation_data=(x_val, y_val), callbacks=[historys])
# 调用LossHistory类中的方法,进行数据保存
history.saveHistory()
|