《python深度学习》笔记(十四):指数移动平均值EMA
定義
指數(shù)移動(dòng)平均(Exponential Moving Average)也叫權(quán)重移動(dòng)平均(Weighted Moving Average),是一種給予近期數(shù)據(jù)更高權(quán)重平均的方法。
?
作用
給W和b使用EMA,就是防止訓(xùn)練過程遇到異常數(shù)據(jù)或者隨機(jī)跳躍(畢竟是隨機(jī)批量,數(shù)據(jù)不確定)影響訓(xùn)練效果的,讓W(xué)和b維持相對穩(wěn)定。
普通的參數(shù)權(quán)重相當(dāng)于一直累積更新整個(gè)訓(xùn)練過程的梯度,使用EMA的參數(shù)權(quán)重相當(dāng)于使用訓(xùn)練過程梯度的加權(quán)平均(剛開始的梯度權(quán)值很小)。由于剛開始訓(xùn)練不穩(wěn)定,得到的梯度給更小的權(quán)值更為合理,所以EMA會(huì)有效。
啥時(shí)使用
EMA在數(shù)據(jù)量小或者數(shù)據(jù)不穩(wěn)定或者batch_size小的情況下尤其有用
比如回歸問題的波士頓房價(jià)數(shù)據(jù)集,還有使用預(yù)訓(xùn)練的CNN中batch_size=20比較小。
或者看曲線,那種包含噪聲,波動(dòng)很大,或者縱軸范圍較大,數(shù)據(jù)方差較大的圖像,為了使曲線變得平滑,更具可讀性,所以使用EMA
?代碼實(shí)現(xiàn)
class EMA():def __init__(self, model, decay):self.model = modelself.decay = decayself.shadow = {}self.backup = {}def register(self):for name, param in self.model.named_parameters():if param.requires_grad:self.shadow[name] = param.data.clone()def update(self):for name, param in self.model.named_parameters():if param.requires_grad:assert name in self.shadownew_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]self.shadow[name] = new_average.clone()def apply_shadow(self):for name, param in self.model.named_parameters():if param.requires_grad:assert name in self.shadowself.backup[name] = param.dataparam.data = self.shadow[name]def restore(self):for name, param in self.model.named_parameters():if param.requires_grad:assert name in self.backupparam.data = self.backup[name]self.backup = {}# 初始化 ema = EMA(model, 0.999) ema.register()# 訓(xùn)練過程中,更新完參數(shù)后,同步update shadow weights def train():optimizer.step()ema.update()# eval前,apply shadow weights;eval之后,恢復(fù)原來模型的參數(shù) def evaluate():ema.apply_shadow()# evaluate ema.restore() import matplotlib.pyplot as pltpoints = [1, 5, 3, 9, 4] def smooth_curve(points, factor=0.9):smoothed_points =[] # 數(shù)據(jù)點(diǎn),權(quán)重系數(shù)for point in points: # 遍歷所有的數(shù)據(jù)點(diǎn)if smoothed_points: # 如果列表中有數(shù)據(jù),則執(zhí)行下面步驟previous = smoothed_points[-1]smoothed_points.append(previous * factor + point * (1 - factor))# 指數(shù)移動(dòng)平均值EMA,前一個(gè)數(shù)據(jù)點(diǎn)*加權(quán)系數(shù)+當(dāng)前數(shù)據(jù)點(diǎn)*(1-加權(quán)系數(shù))else:smoothed_points.append(point) # append添加到列表中最后面return smoothed_pointsresults = smooth_curve(points) print(results) plt.plot(range(1, len(points) + 1), results) plt.show()[1, 1.4, 1.56, 2.304, 2.4736]?
總結(jié)
以上是生活随笔為你收集整理的《python深度学习》笔记(十四):指数移动平均值EMA的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python自动化测试之Appium自动
- 下一篇: 基于Python深度图生成3D点云