PaddlePaddle版Flappy-Bird—使用DQN算法实现游戏智能
剛剛舉行的 WAVE SUMMIT 2019 深度學(xué)習(xí)開發(fā)者峰會(huì)上,PaddlePaddle 發(fā)布了 PARL 1.1 版本,這一版新增了 IMPALA、A3C、A2C 等一系列并行算法。作者重新測試了一遍內(nèi)置 example,發(fā)現(xiàn)卷積速度也明顯加快,從 1.0 版本的訓(xùn)練一幀需大約 1 秒優(yōu)化到了 0.15 秒(配置:win8,i5-6200U,GeForce-940M,batch-size=32)。
嘿嘿,超級(jí)本實(shí)現(xiàn)游戲智能的時(shí)代終于來臨!廢話不多說,我們趕緊試試 PARL 的官方 DQN 算法,玩一玩 Flappy-Bird。
關(guān)于作者:曹天明(kosora),2011 年畢業(yè)于天津科技大學(xué),7 年的 PHP+Java 經(jīng)驗(yàn)。個(gè)人研究方向——融合 CLRS 與 DRL 兩大技術(shù)體系,并行刷題和模型訓(xùn)練。專注于游戲智能、少兒趣味編程兩大領(lǐng)域。
?
模擬環(huán)境
相信大家對(duì)于這個(gè)游戲并不陌生,我們需要控制一只小鳥向前飛行,只有飛翔、下落兩種操作,小鳥每穿過一根柱子,總分就會(huì)增加。由于柱子是高低不平的,所以需要想盡辦法躲避它們。一旦碰到了柱子,或者碰到了上、下邊緣,都會(huì)導(dǎo)致 game-over。下圖展示了未經(jīng)訓(xùn)練的小笨鳥,可以看到,他處于人工智障的狀態(tài),經(jīng)常撞柱子或者撞草地:
▲?未經(jīng)訓(xùn)練的小笨鳥
先簡要分析一下環(huán)境 Environment 的主要代碼。
BirdEnv.py 繼承自 gym.Env,實(shí)現(xiàn)了 init、reset、reward、render 等標(biāo)準(zhǔn)接口。init 函數(shù),用于加載圖片、聲音等外部文件,并初始化得分、小鳥位置、上下邊緣、水管位置等環(huán)境信息:
????if?not?hasattr(self,'IMAGES'):
????????print('InitGame!')
????????self.beforeInit()
????self.score?=?self.playerIndex?=?self.loopIter?=?0
????self.playerx?=?int(SCREENWIDTH?*?0.3)
????self.playery?=?int((SCREENHEIGHT?-?self.PLAYER_HEIGHT)?/?2.25)
self.baseShift?=?self.IMAGES['base'].get_width()?-?self.BACKGROUND_WIDTH
????newPipe1?=?getRandomPipe(self.PIPE_HEIGHT)
????newPipe2?=?getRandomPipe(self.PIPE_HEIGHT)
????#...other?code
step 函數(shù),執(zhí)行兩個(gè)動(dòng)作,0 表示不采取行動(dòng)(小鳥會(huì)自動(dòng)下落),1 表示飛翔;step 函數(shù)有四個(gè)返回值,image_data 表示當(dāng)前狀態(tài),也就是游戲畫面,reward 表示本次 step 的即時(shí)獎(jiǎng)勵(lì),terminal 表示是否是吸收狀態(tài),{} 表示其他信息:
????pygame.event.pump()
????reward?=?0.1
????terminal?=?False
????if?input_action?==?1:
????????if?self.playery?>?-2?*?self.PLAYER_HEIGHT:
????????????self.playerVelY?=?self.playerFlapAcc
????????????self.playerFlapped?=?True
???#...other?code
???image_data=self.render()
???return?image_data,?reward,?terminal,{}
獎(jiǎng)勵(lì) reward;初始獎(jiǎng)勵(lì)是 +0.1,表示小鳥向前飛行一小段距離;穿過柱子,獎(jiǎng)勵(lì) +1;撞到柱子,獎(jiǎng)勵(lì)為 -1,并且到達(dá) terminal 狀態(tài):
reward?=?0.1
#...other?code
playerMidPos?=?self.playerx?+?self.PLAYER_WIDTH?/?2
for?pipe?in?self.upperPipes:
????pipeMidPos?=?pipe['x']?+?self.PIPE_WIDTH?/?2
????#穿過一個(gè)柱子獎(jiǎng)勵(lì)加1
????if?pipeMidPos?<=?playerMidPos?<?pipeMidPos?+?4:?????????????
????????self.score?+=?1
????????reward?=?self.reward(1)
#...other?code
if?isCrash:
????#撞到邊緣或者撞到柱子,結(jié)束,并且獎(jiǎng)勵(lì)為-1
????terminal?=?True
????reward?=?self.reward(-1)
reward 函數(shù),返回即時(shí)獎(jiǎng)勵(lì) r:
????return?r
reset 函數(shù),調(diào)用 init,并執(zhí)行一次飛翔操作,返回 observation,reward,isOver:
????self.__init__()
????self.mode=mode
????action0?=?1
????observation,?reward,?isOver,_?=?self.step(action0)
????return?observation,reward,isOver
render 函數(shù),渲染游戲界面,并返回當(dāng)前畫面:
????image_data?=?pygame.surfarray.array3d(pygame.display.get_surface())
????pygame.display.update()
????self.FPSCLOCK.tick(FPS)
????return?image_data
至此,強(qiáng)化學(xué)習(xí)所需的狀態(tài)、動(dòng)作、獎(jiǎng)勵(lì)等功能均定義完畢。接下來簡單推導(dǎo)一下 DQN (Deep-Q-Network) 算法的原理。
DQN的發(fā)展過程
DQN 的進(jìn)化歷史可謂源遠(yuǎn)流長,從最開始 Bellman 在 1956 年提出的動(dòng)態(tài)規(guī)劃,到后來 Watkins 在 1989 年提出的的 Q-learning,再到 DeepMind 的 Nature-2015 穩(wěn)定版,最后到 Dueling DQN、Priority Replay Memory、Parameter Noise 等優(yōu)化算法,橫跨整整一個(gè)甲子,凝聚了無數(shù)專家、教授們的心血。如今的我們站在先賢們的肩膀上,從以下角度逐步分析:
貝爾曼(最優(yōu))方程與 VQ 樹
Q-learning
參數(shù)化逼近
DQN 算法框架
貝爾曼 (最優(yōu)) 方程與VQ樹
我們從經(jīng)典的表格型強(qiáng)化學(xué)習(xí)(Tabular Reinforcement Learning)開始,回憶一下馬爾可夫決策(MDP)過程,MDP 可由五元組 (S,A,P,R,γ) 表示,其中:
S 狀態(tài)集合,維度為 1×|S|?
A 動(dòng)作集合,維度為 1×|A|?
P 狀態(tài)轉(zhuǎn)移概率矩陣,經(jīng)常寫成,其維度為 |S|×|A|×|S|?
R 回報(bào)函數(shù),如果依賴于狀態(tài)值函數(shù) V,維度為 1×|S|,如果依賴于狀態(tài)-動(dòng)作值函數(shù) Q,則維度為 |S|×|A|?
γ 折扣因子,用來計(jì)算帶折扣的累計(jì)回報(bào) G(t),維度為 1?
S、A、R、γ 均不難理解,可能部分同學(xué)對(duì)有疑問——既然 S 和 A 確定了,下一個(gè)狀態(tài) S' 不是也確定了嗎?為什么會(huì)有概率轉(zhuǎn)移矩陣呢?
其實(shí)我初學(xué)的時(shí)候也曾經(jīng)被這個(gè)問題困擾過,不妨通過如下兩個(gè)例子以示區(qū)別:
1. 恒等于 1.0 的情況。如圖 1 所示,也就是上一次我們?cè)诓呗蕴荻人惴ㄖ兴褂玫拿詫m,假設(shè)機(jī)器人處于左上角,這時(shí)候你命令機(jī)器人向右走,那么他轉(zhuǎn)移到紅框所示位置的概率就是 1.0,不會(huì)有任何異議:
▲?圖1.?迷宮尋寶
2. 不等于 1.0 的情況。假設(shè)現(xiàn)在我們下一個(gè)飛行棋,如圖 2 所示。有兩種骰子,第一種是普通的正方體骰子,可以投出 1~6,第二種是正四面體的骰子,可以投出 1~4。現(xiàn)在飛機(jī)處于紅框所示的位置,現(xiàn)在我們選擇投擲第二種骰子這個(gè)動(dòng)作,由于骰子本身具有均勻隨機(jī)性,所以飛機(jī)轉(zhuǎn)移到終點(diǎn)的概率僅僅是 0.25。這就說明,在某些環(huán)境中,給定 S、A 的情況下,轉(zhuǎn)移到具體哪一個(gè) S' 其實(shí)是不確定的:
▲?圖2.?飛行棋
除了經(jīng)典的五元組外,為了研究長期回報(bào),還經(jīng)常加入三個(gè)重要的元素,分別是:
策略 π(a∣s),維度為 |S|×|A|
狀態(tài)值函數(shù),維度為 1×|S|,表示當(dāng)智能體采用策略 π 時(shí),累積回報(bào)在狀態(tài) s 處的期望值:
▲?圖3.?狀態(tài)值函數(shù)
狀態(tài)-行為值函數(shù),也叫狀態(tài)-動(dòng)作值函數(shù),維度為 |S|×|A|,表示當(dāng)智能體采取策略 π 時(shí),累計(jì)回報(bào)在狀態(tài) s 處并執(zhí)行動(dòng)作 a 時(shí)的期望值:
▲?圖4.?狀態(tài)-行為值函數(shù)
知道了 π、v、q 的具體含義后,我們來看一個(gè)重要的概念,也就是 V、Q 的遞歸展開式。
學(xué)過動(dòng)態(tài)規(guī)劃的同學(xué)都知道,動(dòng)態(tài)規(guī)劃本質(zhì)上是一個(gè) bootstrap(自舉)問題,它包含最優(yōu)子結(jié)構(gòu)與重疊子問題兩個(gè)性質(zhì),也就是說,通常有兩種方法解決動(dòng)態(tài)規(guī)劃:
將總問題劃分為 k 個(gè)子問題,遞歸求解這些子問題,然后將子問題進(jìn)行合并,得到總問題的最優(yōu)解;對(duì)于重復(fù)的子問題,我們可以將他們進(jìn)行緩存(記憶搜索 MemorySearch,請(qǐng)回憶 f(n)=f(n-1)+f(n-2) 這個(gè)遞歸程序);
計(jì)算最小的子問題,合并這些子問題產(chǎn)生一個(gè)更大的子問題,不斷的自底向上計(jì)算,隨著子問題的規(guī)模越來越大,我們會(huì)得到最終的總問題的最優(yōu)解(打表 DP,請(qǐng)回憶楊輝三角中的 dp[i-1,j-1]+dp[i-1,j]=dp[i,j])。
這兩種切題技巧,對(duì)于有過 ACM 或者 LeetCode 刷題經(jīng)驗(yàn)的同學(xué),可以說是老朋友了,那么能否把以上思想遷移到強(qiáng)化學(xué)習(xí)呢?答案是肯定的!
分別考慮 v、q 的展開式:
處在狀態(tài) s 時(shí),由于有策略 π 的存在,故可以把狀態(tài)值函數(shù) v 展開成以下形式:
▲?圖5.?v展開成q
這個(gè)公式表示:在狀態(tài) s 處的值函數(shù),等于采取策略 π 時(shí),所有狀態(tài)-行為值函數(shù)的總和。
處在狀態(tài) s、并執(zhí)行動(dòng)作 a,可以把狀態(tài)-行為值函數(shù) q 展開成以下形式:
▲?圖6.?q展開成v
這個(gè)公式表示:在狀態(tài) s 采用動(dòng)作 a 的狀態(tài)行為值函數(shù),等于回報(bào)加上后序可能產(chǎn)生的的狀態(tài)值函數(shù)的總和。
我們可以看到:v 可以展開成 q,同時(shí) q 也可以展開成 v。
所以可以用以下 v、q 節(jié)點(diǎn)相隔的樹來表示以上兩個(gè)公式,這顆樹比純粹的公式更容易理解,我習(xí)慣上把它叫做 V-Q 樹,它顯然是一個(gè)遞歸的結(jié)構(gòu):
▲?圖7.?V-Q樹
注意畫紅圈中的兩個(gè)節(jié)點(diǎn),體現(xiàn)了重疊子問題特性。如何理解這個(gè)性質(zhì)呢?不妨回憶一下上文提到的飛行棋,假設(shè)飛機(jī)處在起點(diǎn)位置 1,那么無論投擲 1 號(hào)骰子還是 2 號(hào)骰子,都是有機(jī)會(huì)可以到達(dá)位置 3 的,這就是重疊子問題的一個(gè)例子。
有了這棵遞歸樹之后,就不難推導(dǎo)出 v 和 v',以及 q 和 q' 自身的遞歸展開式:
▲?圖8.?狀態(tài)值函數(shù)v自身的遞歸展開式
▲?圖9.?狀態(tài)-行為值函數(shù)q自身的遞歸展開式
其實(shí)無論是 v 還是 q,都擁有最優(yōu)子結(jié)構(gòu)特性。不妨利用反證法加以證明:
假設(shè)要求總問題 V(s) 的最優(yōu)解,那么它包含的每個(gè)子問題 V(s') 也必須是最優(yōu)解;否則,如果某個(gè)子問題 V(s') 不是最優(yōu),那么必然有一個(gè)更優(yōu)的子問題 V'(s') 存在,使得總問題 V'(s) 比原來的總問題 V(s) 更優(yōu),與我們的假設(shè)相矛盾,故最優(yōu)子結(jié)構(gòu)性質(zhì)得證,q(s) 的最優(yōu)子結(jié)構(gòu)性質(zhì)同理。
計(jì)算值函數(shù)的目的是為了構(gòu)建學(xué)習(xí)算法得到最優(yōu)策略,每個(gè)策略對(duì)應(yīng)著一個(gè)狀態(tài)值函數(shù),最優(yōu)策略自然也對(duì)應(yīng)著最優(yōu)狀態(tài)值函數(shù),故而定義如下兩個(gè)函數(shù):
最優(yōu)狀態(tài)值函數(shù),表示在所有策略中最大的值函數(shù),即:
▲?圖10.?最優(yōu)狀態(tài)值函數(shù)
最優(yōu)狀態(tài)-行為值函數(shù),表示在所有策略中最大的狀態(tài)-行為值函數(shù):
▲?圖11.?最優(yōu)狀態(tài)-行為值函數(shù)
結(jié)合上文的遞歸展開式和最優(yōu)子結(jié)構(gòu)性質(zhì),可以得到 v 與 q 的貝爾曼最優(yōu)方程:
▲?圖12.?v的貝爾曼最優(yōu)方程
▲?圖13.?q的貝爾曼最優(yōu)方程
重點(diǎn)理解第二個(gè)公式,也就是關(guān)于 q 的貝爾曼最優(yōu)方程,它是今天的主角 Q-learning 以及 DQN 的理論基礎(chǔ)。
有了貝爾曼最優(yōu)方程,我們就可以通過純粹貪心的策略來確定 π,即:僅僅把最優(yōu)動(dòng)作的概率設(shè)置為 1,其他所有非最優(yōu)動(dòng)作的概率都設(shè)置為 0。這樣做的好處是:當(dāng)算法收斂的時(shí)候,策略 π(a|s) 必然是一個(gè) one-hot 型的矩陣。用數(shù)學(xué)公式表達(dá)如下:
▲?圖14.?算法收斂時(shí)候的策略π
強(qiáng)化學(xué)習(xí)中的動(dòng)態(tài)規(guī)劃方法實(shí)質(zhì)上是一種 model-based(模型已知)方法,因?yàn)?MDP 五元組是已知的,特別是狀態(tài)轉(zhuǎn)移概率矩陣是已知的。
也就是說,所有的環(huán)境信息對(duì)于我們來說是 100% 完備的,故而可以對(duì)整個(gè)解空間樹進(jìn)行全局搜索,下圖展示了動(dòng)態(tài)規(guī)劃方法的示意圖,在確定根節(jié)點(diǎn)狀態(tài) S(t) 的最優(yōu)值的時(shí)候,必須遍歷他所有的 S(t+1) 子節(jié)點(diǎn)并選出最優(yōu)解:
▲?圖15.?動(dòng)態(tài)規(guī)劃方法的解空間搜索過程
不過,和傳統(tǒng)的刷題動(dòng)態(tài)規(guī)劃略有不同,強(qiáng)化學(xué)習(xí)往往是利用值迭代(Value Iteration)、策略迭代(Policy Iteration)、策略改善(Policy Improve)等方式使 v、q、π 等元素達(dá)到收斂狀態(tài),當(dāng)然也有直接利用矩陣求逆計(jì)算解析解的方法,有興趣的同學(xué)可以參考相關(guān)文獻(xiàn),這里不再贅述。
Q-learning
上文提到的動(dòng)態(tài)規(guī)劃方法是一種 model-based 方法,僅僅適用于已知的情況。若狀態(tài)轉(zhuǎn)移概率矩陣未知,model-free(無模型)方法就派上用場了,上一期的 MCPG 算法就是一種典型的 model-free 方法。它搜索解空間的方式更像是 DFS(深度優(yōu)先搜索),而且一條道走到黑,沒有指針回溯的操作,下圖展示了蒙特卡洛算法的求解示意圖:
▲?圖16.?MC系列方法的解空間搜索過程
雖然每次只能走一條分支,但隨機(jī)數(shù)發(fā)生器會(huì)幫助算法遍歷整個(gè)解空間,再通過大量的迭代,所有節(jié)點(diǎn)也會(huì)收斂到最優(yōu)解。
不過,MC 類方法有兩個(gè)小缺點(diǎn):
1. 使用作為訓(xùn)練標(biāo)簽,其本身就是值函數(shù)準(zhǔn)確的無偏估計(jì)。但是,這也正是它的缺點(diǎn),因?yàn)?MC 方法會(huì)經(jīng)歷很多隨機(jī)的狀態(tài)和動(dòng)作,使得每次得到的 G(t) 隨機(jī)性很大,具有很高的方差。
2. 由于采用的是一條道走到黑的方式從根節(jié)點(diǎn)遍歷到葉子節(jié)點(diǎn),所以必須要等到 episode 結(jié)束才能進(jìn)行訓(xùn)練,而且每輪 episode 產(chǎn)生的數(shù)據(jù)只訓(xùn)練一次,每輪 episode 產(chǎn)生數(shù)據(jù)的 batch-size 還不一定相同,所以在訓(xùn)練過程中,MC 方法的 loss 函數(shù)(或者 TD-Error)的波動(dòng)幅度較大,而數(shù)據(jù)利用效率不高。
那么,能否邊產(chǎn)生數(shù)據(jù)邊訓(xùn)練呢?可以!時(shí)序差分(Temporal-Difference-Learning,簡稱 TD)算法應(yīng)運(yùn)而生了。
時(shí)序差分學(xué)習(xí)是模擬(或者經(jīng)歷)一段序列,每行動(dòng)一步(或者幾步)就根據(jù)新狀態(tài)的價(jià)值估計(jì)當(dāng)前執(zhí)行的狀態(tài)價(jià)值。大致可以分為兩個(gè)小類:
1. TD(0) 算法,只向后估計(jì)一個(gè) step。其值函數(shù)更新公式為:
▲?圖17.?TD(0)算法的更新公式
其中,α 為學(xué)習(xí)率,稱為 TD 目標(biāo),MC 方法中的 G(t) 也可以叫做 TD 目標(biāo),稱為 TD-Error,當(dāng)模型收斂時(shí),TD-Error 會(huì)無限接近于 0。
2. Sarsa(λ) 算法,向后估計(jì) n 步,n 為有限值,還有一個(gè)衰減因子 λ。其值函數(shù)的更新公式為:
▲?圖18.?Sarsa(λ)算法的更新公式
▲?圖19.?的計(jì)算方法
與 MC 方法相比,TD 方法只用到了一步或者有限步隨機(jī)狀態(tài)和動(dòng)作,因此它是一個(gè)有偏估計(jì)。不過,由于 TD 目標(biāo)的隨機(jī)性比 MC 方法的 G(t) 要小,所以方差也比 MC 方法小的多,值函數(shù)的波動(dòng)幅度較小,訓(xùn)練比較穩(wěn)定。
看一下 TD 方法的解空間搜索示意圖,紅框表示 TD(0),藍(lán)框表示 Sarsa(λ)。雖然每次估計(jì)都有一定的偏差,但隨著算法的不斷迭代,所有的節(jié)點(diǎn)也會(huì)收斂到最優(yōu)解:
▲?圖20.?TD方法的解空間搜索過程
有了 TD 的框架,既然我們要求狀態(tài)值函數(shù) v、狀態(tài)-行為值函數(shù) q 的最優(yōu)解,那么是否能直接選擇最優(yōu)的 TD 目標(biāo)作為 Target 呢?答案是肯定的,這也是 Q-Learning 算法的基本思想,其公式如下所示:
▲?圖21.?Q-learning算法的學(xué)習(xí)公式
其中,動(dòng)作 a 由 ε-greedy 策略選出,從而在狀態(tài) s 處執(zhí)行 a 之后產(chǎn)生了另一個(gè)狀態(tài) s',接下來選出狀態(tài) s' 處最大的狀態(tài)-行為值函數(shù) q(s',a'),這樣,TD 目標(biāo)就可以確定為 R+γmax[a′]Q(s′,a′)。這種思想很像貪心算法中的總是選擇在當(dāng)前看來最優(yōu)的決策,它一開始可能會(huì)得到一個(gè)局部最優(yōu)解,不過沒關(guān)系,隨著算法的不斷迭代,整個(gè)解空間樹也會(huì)收斂到全局最優(yōu)解。
以下是 Q-learning 算法的偽代碼,和 on-policy 的 MC 方法對(duì)應(yīng),它是一種 off-policy(異策略)方法:
#define?maxStep=1024?//定義每一輪最多走多少步
initialize?Q_table[|S|,|A|]?//初始化Q矩陣
for?i?in?range(0,maxEpisode):
????s=env.reset()??//初始化狀態(tài)s
????for?j?in?range(0,maxStep):
????????//用ε-greedy策略在s行選一個(gè)動(dòng)作a
????????choose?action?a?using?ε-greedy?from?Q_table[s]?
????????s',R,terminal,_=env.step(a)?//執(zhí)行動(dòng)作a,得到下一個(gè)狀態(tài)s',獎(jiǎng)勵(lì)R,是否結(jié)束terminal
????????max_s_prime_action=np.max(Q_table[s',:])?//選s'對(duì)應(yīng)的最大行為值函數(shù)
????????td=R+γ*max_s_prime_action?//計(jì)算TD目標(biāo)
????????Q_table[s,a]=?Q_table[s,a]+α*(td-Q_table[s,a])?//學(xué)習(xí)Q(s,a)的值
????????s=s'?//更新s,注意,和sarsa算法不同,這里的a不用更新
????????if?terminal:
????????????break
Q-learning 是一種優(yōu)秀的算法,不僅簡單直觀,而且平均速度比 MC 快。在 DRL 未出現(xiàn)之前,它在強(qiáng)化學(xué)習(xí)中的地位,差不多可以媲美 SVM 在機(jī)器學(xué)習(xí)中的地位。
參數(shù)化逼近
有了 Q-learning 算法,是否就能一招吃遍天下鮮了呢?答案是否定的,我們看一下它存在的問題。
上文所提到的,無論是 DP、MC 還是 TD,都是基于表格(tabular)的方法,當(dāng)狀態(tài)空間比較小的時(shí)候,計(jì)算機(jī)內(nèi)存完全可以裝下,表格式型強(qiáng)化學(xué)習(xí)是完全適用的。但遇到高階魔方(三階魔方的總變化數(shù)是)、圍棋()這類問題時(shí),S、V、Q、P 等表格均會(huì)出現(xiàn)維度災(zāi)難,早就超出了計(jì)算機(jī)內(nèi)存甚至硬盤容量。這時(shí)候,參數(shù)化逼近方法就派上用場了。
所謂參數(shù)化逼近,是指值函數(shù)可以由一組參數(shù) θ 來近似,如 Q-learning 中的 Q(s,a) 可以寫成 Q(s,a|θ) 的形式。這樣,不但降低了存儲(chǔ)維度,還便于做一些額外的特征工程,而且 θ 更新的同時(shí),Q(s,a|θ) 會(huì)進(jìn)行整體更新,不僅避免了過擬合情況,還使得模型的泛化能力更強(qiáng)。
既然有了可訓(xùn)練參數(shù),我們就要研究損失函數(shù)了,Q-Learning 的損失函數(shù)是什么呢?
先看一下 Q-Learning 的優(yōu)化目標(biāo)——使得 TD-Error 最小:
▲?圖22.?Q-Learning的優(yōu)化目標(biāo)
加入?yún)?shù) θ 之后,若將 TD 目標(biāo)作為標(biāo)簽 target,將 Q(s,a) 作為模型的輸出 y,則問題轉(zhuǎn)化為:
▲?圖23.?帶參數(shù)的優(yōu)化目標(biāo)
這是我們所熟悉的監(jiān)督學(xué)習(xí)中的回歸問題,顯然 loss 函數(shù)就是 mse,故而可以用梯度下降算法最小化 loss,從而更新參數(shù) θ:
▲?圖24.?loss函數(shù)的梯度下降公式
注意到,TD 目標(biāo)是標(biāo)簽,所以 Q(s',a'|θ) 中的 θ 是不能更新的,這種方法并非完全的梯度法,只有部分梯度,稱為半梯度法,這是 NIPS-2013 的雛形。
后來,DeepMind 在 Nature-2015 版本中將 TD 網(wǎng)絡(luò)單獨(dú)分開,其參數(shù)為 θ',它本身并不參與訓(xùn)練,而是每隔固定步數(shù)將值函數(shù)逼近的網(wǎng)絡(luò)參數(shù) θ 拷貝給 θ',這樣保證了 DQN 的訓(xùn)練更加穩(wěn)定:
▲?圖25.?含有目標(biāo)網(wǎng)絡(luò)參數(shù)θ'的梯度下降公式
?
至此,DQN 的 Loss 函數(shù)、梯度下降公式推導(dǎo)完畢。
DQN算法框架
接下來,還要解決兩個(gè)問題——數(shù)據(jù)從哪里來?如何采集?
針對(duì)以上兩個(gè)問題,DeepMind 團(tuán)隊(duì)提出了深度強(qiáng)化學(xué)習(xí)的全新訓(xùn)練方法:經(jīng)驗(yàn)回放(experience replay)。
在強(qiáng)化學(xué)習(xí)過程中,智能體將數(shù)據(jù)存儲(chǔ)到一個(gè) ReplayBuffer 中(任何一種集合,可以是哈希表、數(shù)組、隊(duì)列,也可以是數(shù)據(jù)庫),然后利用均勻隨機(jī)采樣的方法從 ReplayBuffer 中抽取數(shù)據(jù),這些數(shù)據(jù)就可以進(jìn)行 Mini-Batch-SGD,這樣就打破了數(shù)據(jù)之間的相關(guān)性,使得數(shù)據(jù)之間盡量符合獨(dú)立同分布原則。
DQN 的基本網(wǎng)絡(luò)結(jié)構(gòu)如下:
▲?圖26.?DQN的基本網(wǎng)絡(luò)結(jié)構(gòu)
要特別注意:
1. 與參數(shù) θ 做線性運(yùn)算 (wx+b) 的僅僅是輸入狀態(tài) s,這一步?jīng)]有動(dòng)作 a 的參與;
2. output_1 的維度為 |A|,表示神經(jīng)網(wǎng)絡(luò) Q(s,θ) 的輸出;
3. 輸入動(dòng)作 a 是 one-hot,與 output_1 作哈達(dá)馬積后產(chǎn)生的 output_2 是一個(gè)數(shù)字,作為損失函數(shù)中的 Q(s,a|θ),也就是 y。
以下是 DQN 算法的偽代碼:
#定義為一個(gè)雙端隊(duì)列D,作為經(jīng)驗(yàn)回放區(qū)域,最大長度為max_size
Initialize?replay_memory?D?as?a?deque,mas_size=50000
#初始化狀態(tài)-行為值函數(shù)Q的神經(jīng)網(wǎng)絡(luò),權(quán)值隨機(jī)
Initialize?action-value?function?Q(s,a|θ)?as?Neural?Network?with?random-weights-initializer
#初始化TD目標(biāo)網(wǎng)絡(luò),初始權(quán)值和θ相等
Initialize?target?action-value?function?Q(s,a|θ)?with?weights?θ'=θ
#迭代max_episode個(gè)輪次
for?episode?in?range(0,max_episode=65535):
????#重置環(huán)境env,得到初始狀態(tài)s
????s=env.reset()
????#循環(huán)事件的每一步,最多迭代max_step_limit個(gè)step
????for?step?in?range(0,max_step_limit=1024):
????????#通過ε-greedy的方式選出一個(gè)動(dòng)作action
????????With?probability?ε?select?a?random?action?a?or?select?a=argmax(Q(s,θ))
????????#在env中執(zhí)行動(dòng)作a,得到下一個(gè)狀態(tài)s',獎(jiǎng)勵(lì)R,是否終止terminal
????????s',R,terminal,_=env.step(a)
????????#將五元組(s,a,s',R,terminal)壓進(jìn)隊(duì)尾
????????D.addLast(s,a,s',R,terminal)
????????#如果隊(duì)列滿,彈出隊(duì)頭元素
????????if?D.isFull():
????????????D.removeFirst()
????????#更新狀態(tài)s
????????s=s'
????????#從隊(duì)列中進(jìn)行隨機(jī)采樣
????????batch_experience[s,a,s',R,terminal]=random_select(D,batch_size=32)
????????#計(jì)算TD目標(biāo)
????????target?=?R?+?γ*(1-?terminal)?*?np.max(Q(s',θ'))
????????#對(duì)loss函數(shù)執(zhí)行Gradient-decent,訓(xùn)練參數(shù)θ
????????θ=θ+α*(target-Q(s,a|θ))▽Q(s,a|θ)
????????#每隔C步,同步θ與θ'的權(quán)值
????????Every?C?steps?set?θ'=θ
????????#是否結(jié)束
if terminal:
break?
我們玩的游戲 Flappy-Bird,它的輸入是一幀一幀的圖片,所以,經(jīng)典的 Atari-CNN 模型就可以派上用場了:
▲?圖27.?Atari游戲的CNN網(wǎng)絡(luò)結(jié)構(gòu)
網(wǎng)絡(luò)的輸入是被處理成灰度圖的最近 4 幀 84*84 圖像(4 是經(jīng)驗(yàn)值),經(jīng)過若干 CNN 和 FullyConnect 后,輸出各個(gè)動(dòng)作所對(duì)應(yīng)的狀態(tài)-行為值函數(shù) Q。以下是每一層的具體參數(shù),由于 atari 游戲最多有 18 個(gè)動(dòng)作,所以最后一層的維度是 18:
▲?圖28.?神經(jīng)網(wǎng)絡(luò)的具體參數(shù)
至此,理論部分推導(dǎo)完畢。下面,我們分析一下 PARL 中的 DQN 部分的源碼,并實(shí)現(xiàn) Flappy-Bird 的游戲智能。
代碼實(shí)現(xiàn)
依次分析 env、model、algorithm、agent、replay_memory、train 等模塊。
1. BirdEnv.py,環(huán)境;上文已經(jīng)分析過了。
2. BirdModel.py,神經(jīng)網(wǎng)絡(luò)模型;使用三層 CNN+兩層 FC,CNN 的 padding 方式都是 valid,最后輸出狀態(tài)-行為值函數(shù) Q,維度為 |A|。注意輸入圖片歸一化,并按照官方模板填入代碼:
????def?__init__(self,?act_dim):
????????self.act_dim?=?act_dim
????????#padding方式為valid
????????p_valid=0
????????self.conv1?=?layers.conv2d(
????????????num_filters=32,?filter_size=8,?stride=4,?padding=p_valid,?act='relu')
????????self.conv2?=?layers.conv2d(
????????????num_filters=64,?filter_size=4,?stride=2,?padding=p_valid,?act='relu')
????????self.conv3?=?layers.conv2d(
????????????num_filters=64,?filter_size=3,?stride=1,?padding=p_valid,?act='relu')
????????self.fc0=layers.fc(size=512)
????????self.fc1?=?layers.fc(size=act_dim)
????def?value(self,?obs):
????????#輸入歸一化
????????obs?=?obs?/?255.0
????????out?=?self.conv1(obs)
????????out?=?self.conv2(out)
????????out?=?self.conv3(out)
????????out?=?layers.flatten(out,?axis=1)
????????out?=?self.fc0(out)
????????out?=?self.fc1(out)
????????return?out
3. dqn.py,算法層;官方倉庫已經(jīng)提供好了,我們無需自己再寫,直接復(fù)用算法庫(parl.algorithms)里邊的 DQN 算法即可。?
簡單分析一下 DQN 的源碼實(shí)現(xiàn)。
define_learn 函數(shù),用于神經(jīng)網(wǎng)絡(luò)的學(xué)習(xí)。接收 [狀態(tài) obs, 動(dòng)作 action, 即時(shí)獎(jiǎng)勵(lì) reward, 下一個(gè)狀態(tài) next_obs, 是否終止 terminal] 這樣一個(gè)五元組,代碼實(shí)現(xiàn)如下:
pred_value?=?self.model.value(obs)
#根據(jù)next_obs以及參數(shù)θ'計(jì)算目標(biāo)網(wǎng)絡(luò)的狀態(tài)-行為值函數(shù)next_pred_value,對(duì)應(yīng)偽代碼中的Q(s',θ')
next_pred_value?=?self.target_model.value(next_obs)
#選出next_pred_value的最大值best_v,對(duì)應(yīng)偽代碼中的np.max(Q(s',θ'));注意θ'不參與訓(xùn)練,所以要stop_gradient
best_v?=?layers.reduce_max(next_pred_value,?dim=1)
best_v.stop_gradient?=?True
#計(jì)算TD目標(biāo)
target?=?reward?+?(1.0?-?layers.cast(terminal,?dtype='float32'))?*?self.gamma?*?best_v
#輸入的動(dòng)作action與pred_value作哈達(dá)瑪積,選出要評(píng)估的狀態(tài)-行為值函數(shù)pred_action_value,對(duì)應(yīng)偽代碼中的?Q(s,a|θ)
action_onehot?=?layers.one_hot(action,?self.action_dim)
action_onehot?=?layers.cast(action_onehot,?dtype='float32')
pred_action_value?=?layers.reduce_sum(layers.elementwise_mul(action_onehot,?pred_value),?dim=1)
#mse以及梯度下降,對(duì)應(yīng)偽代碼中的θ=θ+α*(target-Q(s,a|θ))▽Q(s,a|θ)
cost?=?layers.square_error_cost(pred_action_value,?target)
cost?=?layers.reduce_mean(cost)
optimizer?=?fluid.optimizer.Adam(self.lr,?epsilon=1e-3)
optimizer.minimize(cost)
sync_target 函數(shù)用于同步網(wǎng)絡(luò)參數(shù):
????"""?sync?parameters?of?self.target_model?with?self.model
????"""
????self.model.sync_params_to(self.target_model,?gpu_id=gpu_id)
4. BirdAgent.py,智能體。其中,build_program 函數(shù)封裝了 algorithm 中的 define_predict 和 define_learn,sample 函數(shù)以 ε-greedy 策略選擇動(dòng)作,predict 函數(shù)以 100% 貪心的策略選擇 argmax 動(dòng)作,learn 函數(shù)接收五元組 (obs, act, reward, next_obs, terminal) 完成學(xué)習(xí)功能,這些函數(shù)和 Policy-Gradient 的寫法類似。
除了這些常用功能之外,由于游戲的訓(xùn)練時(shí)間比較長,所以附加了兩個(gè)函數(shù),save_params 用于保存模型,load_params 用于加載模型:
def?save_params(self,?learnDir,predictDir):
????fluid.io.save_params(
????????executor=self.fluid_executor,
????????dirname=learnDir,
????????main_program=self.learn_programs[0])???
????fluid.io.save_params(
????????executor=self.fluid_executor,
????????dirname=predictDir,
????????main_program=self.predict_programs[0])?????
#加載模型
def?load_params(self,?learnDir,predictDir):?
????fluid.io.load_params(
????????executor=self.fluid_executor,
????????dirname=learnDir,
????????main_program=self.learn_programs[0])??
????fluid.io.load_params(
????????executor=self.fluid_executor,
????????dirname=predictDir,
????????main_program=self.predict_programs[0])?
另外,還有四個(gè)超參數(shù),可以進(jìn)行微調(diào):
self.update_target_steps?=?5000
#初始探索概率ε,超參數(shù)可微調(diào)
self.exploration?=?0.8
#每步探索的衰減程度,超參數(shù)可微調(diào)
self.exploration_dacay=1e-6
#最小探索概率,超參數(shù)可微調(diào)
self.min_exploration=0.05
5. replay_memory.py,經(jīng)驗(yàn)回放單元。雙端隊(duì)列 _context 是一個(gè)滑動(dòng)窗口,用來記錄最近 3 幀(再加上新產(chǎn)生的 1 幀就是 4 幀);state、action、reward 等用 numpy 數(shù)組存儲(chǔ),因?yàn)?numpy 的功能比雙端隊(duì)列更豐富,max_size 表示 replay_memory 的最大容量:
self.action?=?np.zeros((self.max_size,?),?dtype='int32')
self.reward?=?np.zeros((self.max_size,?),?dtype='float32')
self.isOver?=?np.zeros((self.max_size,?),?dtype='bool')
#_context是一個(gè)滑動(dòng)窗口,長度永遠(yuǎn)保持3
self._context?=?deque(maxlen=context_len?-?1)
其他的 append、recent_state、sample_batch 等函數(shù)并不難理解,都是基于 numpy 數(shù)組的進(jìn)一步封裝,略過一遍即可看懂。
6. Train_Test_Working_Flow.py,訓(xùn)練與測試,讓環(huán)境 evn 和智能體 agent 進(jìn)行交互。最重要的就是 run_train_episode 函數(shù),體現(xiàn)了 DQN 的主要邏輯,重點(diǎn)分析注釋部分與 DQN 偽代碼的對(duì)應(yīng)關(guān)系,其他都是編程細(xì)節(jié):
def?run_train_episode(env,?agent,?rpm):
????global?trainEpisode
????global?meanReward
????total_reward?=?0
????all_cost?=?[]
????#重置環(huán)境
????state,_,?__?=?env.reset()
????step?=?0
????#循環(huán)每一步
????while?True:
????????context?=?rpm.recent_state()
????????context.append(resizeBirdrToAtari(state))
????????context?=?np.stack(context,?axis=0)
????????#用ε-greedy的方式選一個(gè)動(dòng)作
????????action?=?agent.sample(context)
????????#執(zhí)行動(dòng)作
????????next_state,?reward,?isOver,_?=?env.step(action)
????????step?+=?1
????????#存入replay_buffer
????????rpm.append(Experience(resizeBirdrToAtari(state),?action,?reward,?isOver))
????????if?rpm.size()?>?MEMORY_WARMUP_SIZE:
????????????if?step?%?UPDATE_FREQ?==?0:
????????????????#從replay_buffer中隨機(jī)采樣
????????????????batch_all_state,?batch_action,?batch_reward,?batch_isOver?=?rpm.sample_batch(batchSize)
????????????????batch_state?=?batch_all_state[:,?:CONTEXT_LEN,?:,?:]
????????????????batch_next_state?=?batch_all_state[:,?1:,?:,?:]
????????????????#執(zhí)行SGD,訓(xùn)練參數(shù)θ
????????????????cost=agent.learn(batch_state,batch_action,?batch_reward,batch_next_state,?batch_isOver)
????????????????all_cost.append(float(cost))
????????total_reward?+=?reward
????????state?=?next_state
????????if?isOver?or?step>=MAX_Step_Limit:
????????????break
????if?all_cost:
????????trainEpisode+=1
????????#以滑動(dòng)平均的方式打印平均獎(jiǎng)勵(lì)
????????meanReward=meanReward+(total_reward-meanReward)/trainEpisode
????????print('\n?trainEpisode:{},total_reward:{:.2f},?meanReward:{:.2f}?mean_cost:{:.3f}'\
??????????????.format(trainEpisode,total_reward,?meanReward,np.mean(all_cost)))
????return?total_reward,?step
除了主要邏輯外,還有一些常見的優(yōu)化手段,防止訓(xùn)練過程中出現(xiàn) trick:
MEMORY_WARMUP_SIZE?=?MEMORY_SIZE//20
##一輪episode最多執(zhí)行多少次step,不然小鳥會(huì)無限制的飛下去,相當(dāng)于gym.env中的_max_episode_steps屬性
MAX_Step_Limit=int(1<<12)
#用一個(gè)雙端隊(duì)列記錄最近16次episode的平均獎(jiǎng)勵(lì)
avgQueue=deque(maxlen=16)
另外,還有其他一些超參數(shù),比如學(xué)習(xí)率 LEARNING_RATE、衰減因子 GAMMA、記錄日志的頻率 log_freq 等等,都可以進(jìn)行微調(diào):
GAMMA?=?0.99
#學(xué)習(xí)率
LEARNING_RATE?=?1e-3?*?0.5
#記錄日志的頻率
log_freq=10
main 函數(shù)在這里,輸入 train 訓(xùn)練網(wǎng)絡(luò),輸入 test 進(jìn)行測試:
????print("train?or?test??")
????mode=input()
????print(mode)
????if?mode=='train':
????????train()
????elif?mode=='test':
????????test()
????else:
????????print('Invalid?input!')
這是模型在我本機(jī)訓(xùn)練的輸出日志,大概 3300 個(gè) episode、50 萬步之后,模型就收斂了:
▲?圖29.?模型訓(xùn)練的輸出日志
平均獎(jiǎng)勵(lì):
▲?圖30.?最近16次平均獎(jiǎng)勵(lì)變化曲線
各位同學(xué)可以試著調(diào)節(jié)超參數(shù),或者修改網(wǎng)絡(luò)模型,看看能不能遇到一些坑?哪些因素會(huì)影響訓(xùn)練效率?如何提升收斂速度?
接下來就是見證奇跡的時(shí)刻,當(dāng)初懵懂的小笨鳥,如今已修煉成精了!
▲?訓(xùn)練完的FlappyBird
觀看 4 分鐘完整版:
https://www.bilibili.com/video/av49282860/
Github源碼:
https://github.com/kosoraYintai/PARL-Sample/tree/master/flappy_bird
參考文獻(xiàn)
[1] Bellman, R.E. & Dreyfus, S.E. (1962). Applied dynamic programming. RAND Corporation.?
[2] Sutton, R.S. (1988). Learning to predict by the methods of temporal difference.Machine Learning, 3, pp. 9–44.
[3] V. Mnih, K. Kavukcuoglu, D. Silver, A. A. Rusu, et al., "Human-level control through deep reinforcement learning," Nature, vol. 518(7540), pp. 529-533, 2015.
[4] https://leetcode.com/problems/climbing-stairs/?
[5]?https://leetcode.com/problems/pascals-triangle-ii/?
[6]?https://github.com/yenchenlin/DeepLearningFlappyBird?
[7]?https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow
點(diǎn)擊以下標(biāo)題查看更多往期內(nèi)容:?
目標(biāo)檢測小tricks之樣本不均衡處理
圖神經(jīng)網(wǎng)絡(luò)綜述:模型與應(yīng)用
DRr-Net:基于動(dòng)態(tài)重讀機(jī)制的句子語義匹配方法
小樣本學(xué)習(xí)(Few-shot Learning)綜述
萬字綜述之生成對(duì)抗網(wǎng)絡(luò)(GAN)
可逆ResNet:極致的暴力美學(xué)
基于多任務(wù)學(xué)習(xí)的可解釋推薦系統(tǒng)
AAAI 2019 | 基于分層強(qiáng)化學(xué)習(xí)的關(guān)系抽取
#投 稿 通 道#
?讓你的論文被更多人看到?
如何才能讓更多的優(yōu)質(zhì)內(nèi)容以更短路徑到達(dá)讀者群體,縮短讀者尋找優(yōu)質(zhì)內(nèi)容的成本呢?答案就是:你不認(rèn)識(shí)的人。
總有一些你不認(rèn)識(shí)的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學(xué)者和學(xué)術(shù)靈感相互碰撞,迸發(fā)出更多的可能性。
PaperWeekly 鼓勵(lì)高校實(shí)驗(yàn)室或個(gè)人,在我們的平臺(tái)上分享各類優(yōu)質(zhì)內(nèi)容,可以是最新論文解讀,也可以是學(xué)習(xí)心得或技術(shù)干貨。我們的目的只有一個(gè),讓知識(shí)真正流動(dòng)起來。
??來稿標(biāo)準(zhǔn):
? 稿件確系個(gè)人原創(chuàng)作品,來稿需注明作者個(gè)人信息(姓名+學(xué)校/工作單位+學(xué)歷/職位+研究方向)?
? 如果文章并非首發(fā),請(qǐng)?jiān)谕陡鍟r(shí)提醒并附上所有已發(fā)布鏈接?
? PaperWeekly 默認(rèn)每篇文章都是首發(fā),均會(huì)添加“原創(chuàng)”標(biāo)志
? 投稿郵箱:
? 投稿郵箱:hr@paperweekly.site?
? 所有文章配圖,請(qǐng)單獨(dú)在附件中發(fā)送?
? 請(qǐng)留下即時(shí)聯(lián)系方式(微信或手機(jī)),以便我們?cè)诰庉嫲l(fā)布時(shí)和作者溝通
?
現(xiàn)在,在「知乎」也能找到我們了
進(jìn)入知乎首頁搜索「PaperWeekly」
點(diǎn)擊「關(guān)注」訂閱我們的專欄吧
關(guān)于PaperWeekly
PaperWeekly 是一個(gè)推薦、解讀、討論、報(bào)道人工智能前沿論文成果的學(xué)術(shù)平臺(tái)。如果你研究或從事 AI 領(lǐng)域,歡迎在公眾號(hào)后臺(tái)點(diǎn)擊「交流群」,小助手將把你帶入 PaperWeekly 的交流群里。
▽ 點(diǎn)擊 |?閱讀原文?| 訪問PARL官網(wǎng)
總結(jié)
以上是生活随笔為你收集整理的PaddlePaddle版Flappy-Bird—使用DQN算法实现游戏智能的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 报名 | “智见AI”SpringCam
- 下一篇: 2019年“计算法学”夏令营即日起接收报