【深度强化学习】DQN训练超级玛丽闯关
上一期 MyEncyclopedia公眾號(hào)文章 通過代碼學(xué)Sutton強(qiáng)化學(xué)習(xí):從Q-Learning 演化到 DQN,我們從原理上講解了DQN算法,這一期,讓我們通過代碼來實(shí)現(xiàn)DQN 在任天堂經(jīng)典的超級(jí)瑪麗游戲中的自動(dòng)通關(guān)吧。本系列將延續(xù)通過代碼學(xué)Sutton 強(qiáng)化學(xué)習(xí)系列,逐步通過代碼實(shí)現(xiàn)經(jīng)典深度強(qiáng)化學(xué)習(xí)應(yīng)用在各種游戲環(huán)境中。本文所有代碼在?
https://github.com/MyEncyclopedia/reinforcement-learning-2nd/tree/master/super_mario
最終訓(xùn)練第一關(guān)結(jié)果動(dòng)畫
DQN 算法回顧
上期詳細(xì)講解了DQN中的兩個(gè)重要的技術(shù):Target Network 和 Experience Replay,正是有了它們才使得 Deep Q Network在實(shí)戰(zhàn)中容易收斂,以下是Deepmind 發(fā)表在Nature 的 Human-level control through deep reinforcement learning 的完整算法流程。
?超級(jí)瑪麗 NES OpenAI 環(huán)境
安裝基于OpenAI gym的超級(jí)瑪麗環(huán)境執(zhí)行下面的 pip 命令即可。
pip?install?gym-super-mario-bros我們先來看一下游戲環(huán)境的輸入和輸出。下面代碼采用隨機(jī)的action來和游戲交互。有了 組合游戲系列3: 井字棋、五子棋的OpenAI Gym GUI環(huán)境?關(guān)于OpenAI Gym 的介紹,現(xiàn)在對(duì)于其基本的交互步驟已經(jīng)不陌生了。
import?gym_super_mario_bros from?random?import?random,?randrange from?gym_super_mario_bros.actions?import?RIGHT_ONLY from?nes_py.wrappers?import?JoypadSpace from?gym?import?wrappersenv?=?gym_super_mario_bros.make('SuperMarioBros-v0') env?=?JoypadSpace(env,?RIGHT_ONLY)#?Play?randomly done?=?False env.reset()step?=?0 while?not?done:action?=?randrange(len(RIGHT_ONLY))state,?reward,?done,?info?=?env.step(action)print(done,?step,?info)env.render()step?+=?1env.close()隨機(jī)策略的效果如下
注意我們?cè)谟螒颦h(huán)境初始化的時(shí)候用了參數(shù) RIGHT_ONLY,它定義成五種動(dòng)作的list,表示僅使用右鍵的一些組合,適用于快速訓(xùn)練來完成Mario第一關(guān)。
RIGHT_ONLY?=?[['NOOP'],['right'],['right',?'A'],['right',?'B'],['right',?'A',?'B'], ]觀察一些 info 輸出內(nèi)容,coins表示金幣獲得數(shù)量,flag_get 表示是否取得最后的旗子,time 剩余時(shí)間,以及 Mario 大小狀態(tài)和所在的 x,y位置。
{"coins":0,"flag_get":False,"life":2,"score":0,"stage":1,"status":"small","time":381,"world":1,"x_pos":594,"y_pos":89 }游戲圖像處理
Deep Reinforcement Learning 一般是 end-to-end learning,意味著將游戲的 screen image,即 observed state 直接視為真實(shí)狀態(tài) state,喂給神經(jīng)網(wǎng)絡(luò)去訓(xùn)練。于此相反的另一種做法是,通過游戲環(huán)境拿到內(nèi)部狀態(tài),例如所有相關(guān)物品的位置和屬性作為模型輸入。這兩種方式的區(qū)別在我看來有兩點(diǎn)。第一點(diǎn),用觀察到的屏幕像素代替真正的狀態(tài) state,在partially observable 的環(huán)境時(shí)可能因?yàn)?non-stationarity 導(dǎo)致無法很好的工作,而拿內(nèi)部狀態(tài)利用了額外的作弊信息,在partially observable環(huán)境中也可以工作。第二點(diǎn),第一種方式屏幕像素維度比較高,輸入數(shù)據(jù)量大,需要神經(jīng)網(wǎng)絡(luò)的大量訓(xùn)練擬合,第二種方式,內(nèi)部真實(shí)狀態(tài)往往維度低得多,訓(xùn)練起來很快,但缺點(diǎn)是因?yàn)槌藘?nèi)部狀態(tài)往往還需要游戲相關(guān)規(guī)則作為輸入,因此generalization能力不如前者強(qiáng)。
?這里,我們當(dāng)然采樣屏幕像素的 end-to-end 方式了,自然首要任務(wù)是將游戲幀圖像有效處理。超級(jí)瑪麗游戲環(huán)境的屏幕輸出是 (240, 256, 3) shape的 numpy array,通過下面一系列的轉(zhuǎn)換,盡可能的在不影響訓(xùn)練效果的情況下減小采樣到的數(shù)據(jù)量。
MaxAndSkipFrameWrapper:每4個(gè)frame連在一起,采取同樣的動(dòng)作,降低frame數(shù)量
FrameDownsampleWrapper:將原始的 (240, 256, 3) down sample 到 (84, 84, 1)
ImageToPyTorchWrapper:轉(zhuǎn)換成適合 pytorch 的 shape (1, 84, 84)?
FrameBufferWrapper:保存最后4次屏幕采樣
NormalizeFloats:Normalize 成 [0., 1.0] 的浮點(diǎn)值
CNN 模型
模型比較簡(jiǎn)單,三個(gè)卷積層后做 softmax輸出,輸出維度數(shù)為離散動(dòng)作數(shù)。act() 采用了epsilon-greedy 模式,即在epsilon小概率時(shí)采取隨機(jī)動(dòng)作來 explore,大于epsilon時(shí)采取估計(jì)的最可能動(dòng)作來 exploit。
class?DQNModel(nn.Module):def?__init__(self,?input_shape,?num_actions):super(DQNModel,?self).__init__()self._input_shape?=?input_shapeself._num_actions?=?num_actionsself.features?=?nn.Sequential(nn.Conv2d(input_shape[0],?32,?kernel_size=8,?stride=4),nn.ReLU(),nn.Conv2d(32,?64,?kernel_size=4,?stride=2),nn.ReLU(),nn.Conv2d(64,?64,?kernel_size=3,?stride=1),nn.ReLU())self.fc?=?nn.Sequential(nn.Linear(self.feature_size,?512),nn.ReLU(),nn.Linear(512,?num_actions))def?forward(self,?x):x?=?self.features(x).view(x.size()[0],?-1)return?self.fc(x)def?act(self,?state,?epsilon,?device):if?random()?>?epsilon:state?=?torch.FloatTensor(np.float32(state)).unsqueeze(0).to(device)q_value?=?self.forward(state)action?=?q_value.max(1)[1].item()else:action?=?randrange(self._num_actions)return?actionExperience Replay 緩存
實(shí)現(xiàn)采用了 Pytorch CartPole DQN 的官方代碼,本質(zhì)是一個(gè)最大為 capacity 的 list 保存了采樣到的 (s, a, r, s', is_done) ?五元組。
Transition?=?namedtuple('Transition',?('state',?'action',?'reward',?'next_state',?'done'))class?ReplayMemory:def?__init__(self,?capacity):self.capacity?=?capacityself.memory?=?[]self.position?=?0def?push(self,?*args):if?len(self.memory)?<?self.capacity:self.memory.append(None)self.memory[self.position]?=?Transition(*args)self.position?=?(self.position?+?1)?%?self.capacitydef?sample(self,?batch_size):return?random.sample(self.memory,?batch_size)def?__len__(self):return?len(self.memory)DQNAgent
我們將 DQN 的邏輯封裝在 DQNAgent 類中。DQNAgent 成員變量包括兩個(gè) DQNModel,一個(gè)ReplayMemory。
train() 方法中會(huì)每隔一定時(shí)間將 Target Network 的參數(shù)同步成現(xiàn)行Network的參數(shù)。在td_loss_backprop()方法中采樣 ReplayMemory 中的五元組,通過minimize TD error方式來改進(jìn)現(xiàn)行 Network 參數(shù) 。Loss函數(shù)為:
class?DQNAgent():def?act(self,?state,?episode_idx):self.update_epsilon(episode_idx)action?=?self.model.act(state,?self.epsilon,?self.device)return?actiondef?process(self,?episode_idx,?state,?action,?reward,?next_state,?done):self.replay_mem.push(state,?action,?reward,?next_state,?done)self.train(episode_idx)def?train(self,?episode_idx):if?len(self.replay_mem)?>?self.initial_learning:if?episode_idx?%?self.target_update_frequency?==?0:self.target_model.load_state_dict(self.model.state_dict())self.optimizer.zero_grad()self.td_loss_backprop()self.optimizer.step()def?td_loss_backprop(self):transitions?=?self.replay_mem.sample(self.batch_size)batch?=?Transition(*zip(*transitions))state?=?Variable(FloatTensor(np.float32(batch.state))).to(self.device)action?=?Variable(LongTensor(batch.action)).to(self.device)reward?=?Variable(FloatTensor(batch.reward)).to(self.device)next_state?=?Variable(FloatTensor(np.float32(batch.next_state))).to(self.device)done?=?Variable(FloatTensor(batch.done)).to(self.device)q_values?=?self.model(state)next_q_values?=?self.target_net(next_state)q_value?=?q_values.gather(1,?action.unsqueeze(-1)).squeeze(-1)next_q_value?=?next_q_values.max(1)[0]expected_q_value?=?reward?+?self.gamma?*?next_q_value?*?(1?-?done)loss?=?(q_value?-?expected_q_value.detach()).pow(2)loss?=?loss.mean()loss.backward()外層控制代碼
最后是外層調(diào)用代碼,基本和以前文章一樣。
def?train(env,?args,?agent):for?episode_idx?in?range(args.num_episodes):episode_reward?=?0.0state?=?env.reset()while?True:action?=?agent.act(state,?episode_idx)if?args.render:env.render()next_state,?reward,?done,?stats?=?env.step(action)agent.process(episode_idx,?state,?action,?reward,?next_state,?done)state?=?next_stateepisode_reward?+=?rewardif?done:print(f'{episode_idx}:?{episode_reward}')break著作權(quán)歸作者所有。商業(yè)轉(zhuǎn)載請(qǐng)聯(lián)系作者獲得授權(quán),非商業(yè)轉(zhuǎn)載請(qǐng)注明出處。
往期精彩回顧適合初學(xué)者入門人工智能的路線及資料下載機(jī)器學(xué)習(xí)及深度學(xué)習(xí)筆記等資料打印機(jī)器學(xué)習(xí)在線手冊(cè)深度學(xué)習(xí)筆記專輯《統(tǒng)計(jì)學(xué)習(xí)方法》的代碼復(fù)現(xiàn)專輯 AI基礎(chǔ)下載機(jī)器學(xué)習(xí)的數(shù)學(xué)基礎(chǔ)專輯 獲取本站知識(shí)星球優(yōu)惠券,復(fù)制鏈接直接打開: https://t.zsxq.com/qFiUFMV 本站qq群704220115。加入微信群請(qǐng)掃碼:總結(jié)
以上是生活随笔為你收集整理的【深度强化学习】DQN训练超级玛丽闯关的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 【Python基础】13个知识点,系统整
- 下一篇: 我们的合作伙伴Datawhale两岁啦!