PyTorch框架学习十三——优化器
PyTorch框架學(xué)習(xí)十三——優(yōu)化器
- 一、優(yōu)化器
- 二、Optimizer類
- 1.基本屬性
- 2.基本方法
- 三、學(xué)習(xí)率與動(dòng)量
- 1.學(xué)習(xí)率learning rate
- 2.動(dòng)量、沖量Momentum
- 四、十種常見的優(yōu)化器(簡單羅列)
上次筆記簡單介紹了一下?lián)p失函數(shù)的概念以及18種常用的損失函數(shù),這次筆記介紹優(yōu)化器的相關(guān)知識(shí)以及PyTorch中的使用。
一、優(yōu)化器
PyTorch中的優(yōu)化器:管理并更新模型中可學(xué)習(xí)參數(shù)的值,使得模型輸出更接近真實(shí)標(biāo)簽。
導(dǎo)數(shù):函數(shù)在指定坐標(biāo)軸上的變化率。
方向?qū)?shù):指定方向上的變化率。
梯度:一個(gè)向量,方向?yàn)榉较驅(qū)?shù)取得最大值的方向。
二、Optimizer類
1.基本屬性
2.基本方法
(1)zero_grad():清空所管理的參數(shù)的梯度。因?yàn)镻yTorch中張量梯度不會(huì)自動(dòng)清零。
weight = torch.randn((2, 2), requires_grad=True) weight.grad = torch.ones((2, 2))optimizer = optim.SGD([weight], lr=0.1)print("weight before step:{}".format(weight.data)) optimizer.step() # 修改lr=1 0.1觀察結(jié)果 print("weight after step:{}".format(weight.data))print("weight in optimizer:{}\nweight in weight:{}\n".format(id(optimizer.param_groups[0]['params'][0]), id(weight)))print("weight.grad is {}\n".format(weight.grad)) optimizer.zero_grad() print("after optimizer.zero_grad(), weight.grad is\n{}".format(weight.grad))結(jié)果如下:
weight before step:tensor([[0.6614, 0.2669],[0.0617, 0.6213]]) weight after step:tensor([[ 0.5614, 0.1669],[-0.0383, 0.5213]]) weight in optimizer:1314236528344 weight in weight:1314236528344weight.grad is tensor([[1., 1.],[1., 1.]])after optimizer.zero_grad(), weight.grad is tensor([[0., 0.],[0., 0.]])(2) step():執(zhí)行一步優(yōu)化更新。
weight = torch.randn((2, 2), requires_grad=True) weight.grad = torch.ones((2, 2))optimizer = optim.SGD([weight], lr=0.1)print("weight before step:{}".format(weight.data)) optimizer.step() # 修改lr=1 0.1觀察結(jié)果 print("weight after step:{}".format(weight.data))結(jié)果如下:
weight before step:tensor([[0.6614, 0.2669],[0.0617, 0.6213]]) weight after step:tensor([[ 0.5614, 0.1669],[-0.0383, 0.5213]])(3) add_param_group():添加參數(shù)組。
weight = torch.randn((2, 2), requires_grad=True) weight.grad = torch.ones((2, 2))optimizer = optim.SGD([weight], lr=0.1)print("optimizer.param_groups is\n{}".format(optimizer.param_groups))w2 = torch.randn((3, 3), requires_grad=True)optimizer.add_param_group({"params": w2, 'lr': 0.0001})print("optimizer.param_groups is\n{}".format(optimizer.param_groups))結(jié)果如下:
optimizer.param_groups is [{'params': [tensor([[0.6614, 0.2669],[0.0617, 0.6213]], requires_grad=True)], 'lr': 0.1, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}] optimizer.param_groups is [{'params': [tensor([[0.6614, 0.2669],[0.0617, 0.6213]], requires_grad=True)], 'lr': 0.1, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}, {'params': [tensor([[-0.4519, -0.1661, -1.5228],[ 0.3817, -1.0276, -0.5631],[-0.8923, -0.0583, -0.1955]], requires_grad=True)], 'lr': 0.0001, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}](4)state_dict():獲取優(yōu)化器當(dāng)前狀態(tài)信息字典。
weight = torch.randn((2, 2), requires_grad=True) weight.grad = torch.ones((2, 2))optimizer = optim.SGD([weight], lr=0.1, momentum=0.9) opt_state_dict = optimizer.state_dict()print("state_dict before step:\n", opt_state_dict)for i in range(10):optimizer.step()print("state_dict after step:\n", optimizer.state_dict())torch.save(optimizer.state_dict(), os.path.join(BASE_DIR, "optimizer_state_dict.pkl"))結(jié)果如下:
state_dict before step:{'state': {}, 'param_groups': [{'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2872948098296]}]} state_dict after step:{'state': {2872948098296: {'momentum_buffer': tensor([[6.5132, 6.5132],[6.5132, 6.5132]])}}, 'param_groups': [{'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2872948098296]}]}獲取到了優(yōu)化器當(dāng)前狀態(tài)的信息字典,其中那個(gè)2872948098296是存放權(quán)重的地址,并將這些參數(shù)信息保存為一個(gè)pkl文件:
(5)load_state_dict():加載狀態(tài)信息字典。
optimizer = optim.SGD([weight], lr=0.1, momentum=0.9) state_dict = torch.load(os.path.join(BASE_DIR, "optimizer_state_dict.pkl"))print("state_dict before load state:\n", optimizer.state_dict()) optimizer.load_state_dict(state_dict) print("state_dict after load state:\n", optimizer.state_dict())從剛剛保存參數(shù)的pkl文件中讀取參數(shù)賦給一個(gè)新的空的優(yōu)化器,結(jié)果為:
state_dict before load state:{'state': {}, 'param_groups': [{'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [1838346925624]}]} state_dict after load state:{'state': {1838346925624: {'momentum_buffer': tensor([[6.5132, 6.5132],[6.5132, 6.5132]])}}, 'param_groups': [{'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [1838346925624]}]}注:state_dict()與load_state_dict()一般經(jīng)常用于模型訓(xùn)練中的保存和讀取模型參數(shù),防止斷電等突發(fā)情況導(dǎo)致模型訓(xùn)練強(qiáng)行中斷而前功盡棄。
三、學(xué)習(xí)率與動(dòng)量
1.學(xué)習(xí)率learning rate
梯度下降:
其中LR就是學(xué)習(xí)率,作用是控制更新的步伐,如果太大可能導(dǎo)致模型無法收斂或者是梯度爆炸,如果太小可能使得訓(xùn)練時(shí)間過長,需要調(diào)節(jié)。
2.動(dòng)量、沖量Momentum
結(jié)合當(dāng)前梯度與上一次更新信息,用于當(dāng)前更新。
PyTorch中梯度下降的更新公式為:
其中:
- Wi:第i次更新的參數(shù)。
- lr:學(xué)習(xí)率。
- Vi:更新量。
- m:momentum系數(shù)。
- g(Wi):Wi的梯度。
舉個(gè)例子:
100這個(gè)時(shí)刻的更新量不僅與當(dāng)前梯度有關(guān),還與之前的梯度有關(guān),只是越以前的對(duì)當(dāng)前時(shí)刻的影響就越小。
momentum的作用主要是可以加速收斂。
四、十種常見的優(yōu)化器(簡單羅列)
目前對(duì)優(yōu)化器的了解還不多,以后會(huì)繼續(xù)跟進(jìn),這里就簡單羅列一下:
總結(jié)
以上是生活随笔為你收集整理的PyTorch框架学习十三——优化器的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Python学习笔记(列表和元组的简单实
- 下一篇: (Matlab函数详解)机器学习中的4种