TensorFlow 2.0深度强化学习指南
在本教程中,我將通過實施Advantage Actor-Critic(演員-評論家,A2C)代理來解決經典的CartPole-v0環境,通過深度強化學習(DRL)展示即將推出的TensorFlow2.0特性。雖然我們的目標是展示TensorFlow2.0,但我將盡最大努力讓DRL的講解更加平易近人,包括對該領域的簡要概述。
事實上,由于2.0版本的焦點是讓開發人員的生活變得更輕松,所以我認為現在是使用TensorFlow進入DRL的好時機,本文用到的例子的源代碼不到150行!代碼可以在這里或者這里獲取。
建立
由于TensorFlow2.0仍處于試驗階段,我建議將其安裝在獨立的虛擬環境中。我個人比較喜歡Anaconda,所以我將用它來演示安裝過程:
<span style="color:#f8f8f2"><code class="language-python"><span style="color:#f8f8f2">></span> conda create <span style="color:#f8f8f2">-</span>n tf2 python<span style="color:#f8f8f2">=</span><span style="color:#ae81ff"><span style="color:#ae81ff">3.6</span></span> <span style="color:#f8f8f2">></span> source activate tf2 <span style="color:#f8f8f2">></span> pip install tf<span style="color:#f8f8f2">-</span>nightly<span style="color:#ae81ff"><span style="color:#ae81ff">-2.0</span></span><span style="color:#f8f8f2">-</span>preview <span style="color:slategray"><span style="color:#75715e"># tf-nightly-gpu-2.0-preview for GPU version</span></span></code></span>讓我們快速驗證一切是否按能夠正常工作:
<span style="color:#f8f8f2"><code class="language-python"><span style="color:#f8f8f2"><span style="color:#75715e">>></span></span><span style="color:#f8f8f2"><span style="color:#75715e">></span></span> <span style="color:#66d9ef"><span style="color:#f92672">import</span></span> tensorflow <span style="color:#66d9ef"><span style="color:#f92672">as</span></span> tf <span style="color:#f8f8f2"><span style="color:#75715e">>></span></span><span style="color:#f8f8f2"><span style="color:#75715e">></span></span> <span style="color:#66d9ef">print</span><span style="color:#f8f8f2">(</span>tf<span style="color:#f8f8f2">.</span>__version__<span style="color:#f8f8f2">)</span> <span style="color:#ae81ff"><span style="color:#ae81ff">1.13</span></span><span style="color:#f8f8f2"><span style="color:#ae81ff">.</span></span><span style="color:#ae81ff"><span style="color:#ae81ff">0</span></span><span style="color:#f8f8f2">-</span>dev20190117 <span style="color:#f8f8f2"><span style="color:#75715e">>></span></span><span style="color:#f8f8f2"><span style="color:#75715e">></span></span> <span style="color:#66d9ef">print</span><span style="color:#f8f8f2">(</span>tf<span style="color:#f8f8f2">.</span>executing_eagerly<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">)</span> <span style="color:#ae81ff"><span style="color:#f92672">True</span></span></code></span>不要擔心1.13.x版本,這只是意味著它是早期預覽。這里要注意的是我們默認處于eager模式!
<span style="color:#f8f8f2"><code class="language-none">>>> print(tf.reduce_sum([1, 2, 3, 4, 5])) tf.Tensor(15, shape=(), dtype=int32)</code></span>如果你還不熟悉eager模式,那么實質上意味著計算是在運行時被執行的,而不是通過預編譯的圖(曲線圖)來執行。你可以在TensorFlow文檔中找到一個很好的概述。
深度強化學習
一般而言,強化學習是解決連續決策問題的高級框架。RL通過基于某些agent進行導航觀察環境,并且獲得獎勵。大多數RL算法通過最大化代理在一輪游戲期間收集的獎勵總和來工作。
基于RL的算法的輸出通常是policy(策略)-將狀態映射到函數有效的策略中,有效的策略可以像硬編碼的無操作動作一樣簡單。在某些狀態下,隨機策略表示為行動的條件概率分布。
演員,評論家方法(Actor-Critic Methods)
RL算法通常基于它們優化的目標函數進行分組。Value-based諸如DQN之類的方法通過減少預期的狀態-動作值的誤差來工作。
策略梯度(Policy Gradients)方法通過調整其參數直接優化策略本身,通常通過梯度下降完成的。完全計算梯度通常是難以處理的,因此通常要通過蒙特卡羅方法估算它們。
最流行的方法是兩者的混合:actor-critic方法,其中代理策略通過策略梯度進行優化,而基于值的方法用作預期值估計的引導。
深度演員-批評方法
雖然很多基礎的RL理論是在表格案例中開發的,但現代RL幾乎完全是用函數逼近器完成的,例如人工神經網絡。具體而言,如果策略和值函數用深度神經網絡近似,則RL算法被認為是“深度”。
異步優勢演員-評論家(actor-critical)
多年來,為了提高學習過程的樣本效率和穩定性,技術發明者已經進行了一些改進。
首先,梯度加權回報:折現的未來獎勵,這在一定程度上緩解了信用分配問題,并以無限的時間步長解決了理論問題。
其次,使用優勢函數代替原始回報。優勢在收益與某些基線之間的差異之間形成,并且可以被視為衡量給定值與某些平均值相比有多好的指標。
第三,在目標函數中使用額外的熵最大化項以確保代理充分探索各種策略。本質上,熵以均勻分布最大化來測量概率分布的隨機性。
最后,并行使用多個工人加速樣品采集,同時在訓練期間幫助它們去相關。
將所有這些變化與深度神經網絡相結合,我們得出了兩種最流行的現代算法:異步優勢演員評論家(actor-critical)算法,簡稱A3C或者A2C。兩者之間的區別在于技術性而非理論性:顧名思義,它歸結為并行工人如何估計其梯度并將其傳播到模型中。
有了這個,我將結束我們的DRL方法之旅,因為博客文章的重點更多是關于TensorFlow2.0的功能。如果你仍然不了解該主題,請不要擔心,代碼示例應該更清楚。如果你想了解更多,那么一個好的資源就可以開始在Deep RL中進行Spinning Up了。
使用TensorFlow 2.0的優勢演員-評論家
讓我們看看實現現代DRL算法的基礎是什么:演員評論家代理(actor-critic agent)。如前一節所述,為簡單起見,我們不會實現并行工作程序,盡管大多數代碼都會支持它,感興趣的讀者可以將其用作鍛煉機會。
作為測試平臺,我們將使用CartPole-v0環境。雖然它有點簡單,但它仍然是一個很好的選擇開始。在實現RL算法時,我總是依賴它作為一種健全性檢查。
通過Keras Model API實現的策略和價值
首先,讓我們在單個模型類下創建策略和價值估計NN:
<span style="color:#f8f8f2"><code class="language-python"><span style="color:#66d9ef"><span style="color:#f92672">import</span></span> numpy <span style="color:#66d9ef"><span style="color:#f92672">as</span></span> np <span style="color:#66d9ef"><span style="color:#f92672">import</span></span> tensorflow <span style="color:#66d9ef"><span style="color:#f92672">as</span></span> tf <span style="color:#66d9ef"><span style="color:#f92672">import</span></span> tensorflow<span style="color:#f8f8f2">.</span>keras<span style="color:#f8f8f2">.</span>layers <span style="color:#66d9ef"><span style="color:#f92672">as</span></span> kl<span style="color:#66d9ef"><span style="color:#f92672">class</span></span> <span style="color:#f8f8f2">ProbabilityDistribution</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">tf</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">.</span></span><span style="color:#f8f8f2">keras</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">.</span></span><span style="color:#f8f8f2">Model</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">call</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> logits</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:slategray"><span style="color:#75715e"># sample a random categorical action from given logits</span></span><span style="color:#66d9ef"><span style="color:#f92672">return</span></span> tf<span style="color:#f8f8f2">.</span>squeeze<span style="color:#f8f8f2">(</span>tf<span style="color:#f8f8f2">.</span>random<span style="color:#f8f8f2">.</span>categorical<span style="color:#f8f8f2">(</span>logits<span style="color:#f8f8f2">,</span> <span style="color:#ae81ff"><span style="color:#ae81ff">1</span></span><span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">,</span> axis<span style="color:#f8f8f2">=</span><span style="color:#f8f8f2"><span style="color:#ae81ff">-</span></span><span style="color:#ae81ff"><span style="color:#ae81ff">1</span></span><span style="color:#f8f8f2">)</span><span style="color:#66d9ef"><span style="color:#f92672">class</span></span> <span style="color:#f8f8f2">Model</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">tf</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">.</span></span><span style="color:#f8f8f2">keras</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">.</span></span><span style="color:#f8f8f2">Model</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">__init__</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> num_actions</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span>super<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">.</span>__init__<span style="color:#f8f8f2">(</span><span style="color:#a6e22e"><span style="color:#e6db74">'mlp_policy'</span></span><span style="color:#f8f8f2">)</span><span style="color:slategray"><span style="color:#75715e"># no tf.get_variable(), just simple Keras API</span></span>self<span style="color:#f8f8f2">.</span>hidden1 <span style="color:#f8f8f2">=</span> kl<span style="color:#f8f8f2">.</span>Dense<span style="color:#f8f8f2">(</span><span style="color:#ae81ff"><span style="color:#ae81ff">128</span></span><span style="color:#f8f8f2">,</span> activation<span style="color:#f8f8f2">=</span><span style="color:#a6e22e"><span style="color:#e6db74">'relu'</span></span><span style="color:#f8f8f2">)</span>self<span style="color:#f8f8f2">.</span>hidden2 <span style="color:#f8f8f2">=</span> kl<span style="color:#f8f8f2">.</span>Dense<span style="color:#f8f8f2">(</span><span style="color:#ae81ff"><span style="color:#ae81ff">128</span></span><span style="color:#f8f8f2">,</span> activation<span style="color:#f8f8f2">=</span><span style="color:#a6e22e"><span style="color:#e6db74">'relu'</span></span><span style="color:#f8f8f2">)</span>self<span style="color:#f8f8f2">.</span>value <span style="color:#f8f8f2">=</span> kl<span style="color:#f8f8f2">.</span>Dense<span style="color:#f8f8f2">(</span><span style="color:#ae81ff"><span style="color:#ae81ff">1</span></span><span style="color:#f8f8f2">,</span> name<span style="color:#f8f8f2">=</span><span style="color:#a6e22e"><span style="color:#e6db74">'value'</span></span><span style="color:#f8f8f2">)</span><span style="color:slategray"><span style="color:#75715e"># logits are unnormalized log probabilities</span></span>self<span style="color:#f8f8f2">.</span>logits <span style="color:#f8f8f2">=</span> kl<span style="color:#f8f8f2">.</span>Dense<span style="color:#f8f8f2">(</span>num_actions<span style="color:#f8f8f2">,</span> name<span style="color:#f8f8f2">=</span><span style="color:#a6e22e"><span style="color:#e6db74">'policy_logits'</span></span><span style="color:#f8f8f2">)</span>self<span style="color:#f8f8f2">.</span>dist <span style="color:#f8f8f2">=</span> ProbabilityDistribution<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">)</span><span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">call</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> inputs</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:slategray"><span style="color:#75715e"># inputs is a numpy array, convert to Tensor</span></span>x <span style="color:#f8f8f2">=</span> tf<span style="color:#f8f8f2">.</span>convert_to_tensor<span style="color:#f8f8f2">(</span>inputs<span style="color:#f8f8f2">,</span> dtype<span style="color:#f8f8f2">=</span>tf<span style="color:#f8f8f2">.</span>float32<span style="color:#f8f8f2">)</span><span style="color:slategray"><span style="color:#75715e"># separate hidden layers from the same input tensor</span></span>hidden_logs <span style="color:#f8f8f2">=</span> self<span style="color:#f8f8f2">.</span>hidden1<span style="color:#f8f8f2">(</span>x<span style="color:#f8f8f2">)</span>hidden_vals <span style="color:#f8f8f2">=</span> self<span style="color:#f8f8f2">.</span>hidden2<span style="color:#f8f8f2">(</span>x<span style="color:#f8f8f2">)</span><span style="color:#66d9ef"><span style="color:#f92672">return</span></span> self<span style="color:#f8f8f2">.</span>logits<span style="color:#f8f8f2">(</span>hidden_logs<span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">,</span> self<span style="color:#f8f8f2">.</span>value<span style="color:#f8f8f2">(</span>hidden_vals<span style="color:#f8f8f2">)</span><span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">action_value</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> obs</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:slategray"><span style="color:#75715e"># executes call() under the hood</span></span>logits<span style="color:#f8f8f2">,</span> value <span style="color:#f8f8f2">=</span> self<span style="color:#f8f8f2">.</span>predict<span style="color:#f8f8f2">(</span>obs<span style="color:#f8f8f2">)</span>action <span style="color:#f8f8f2">=</span> self<span style="color:#f8f8f2">.</span>dist<span style="color:#f8f8f2">.</span>predict<span style="color:#f8f8f2">(</span>logits<span style="color:#f8f8f2">)</span><span style="color:slategray"><span style="color:#75715e"># a simpler option, will become clear later why we don't use it</span></span><span style="color:slategray"><span style="color:#75715e"># action = tf.random.categorical(logits, 1)</span></span><span style="color:#66d9ef"><span style="color:#f92672">return</span></span> np<span style="color:#f8f8f2">.</span>squeeze<span style="color:#f8f8f2">(</span>action<span style="color:#f8f8f2">,</span> axis<span style="color:#f8f8f2">=</span><span style="color:#f8f8f2"><span style="color:#ae81ff">-</span></span><span style="color:#ae81ff"><span style="color:#ae81ff">1</span></span><span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">,</span> np<span style="color:#f8f8f2">.</span>squeeze<span style="color:#f8f8f2">(</span>value<span style="color:#f8f8f2">,</span> axis<span style="color:#f8f8f2">=</span><span style="color:#f8f8f2"><span style="color:#ae81ff">-</span></span><span style="color:#ae81ff"><span style="color:#ae81ff">1</span></span><span style="color:#f8f8f2">)</span></code></span>驗證我們驗證模型是否按預期工作:
<span style="color:#f8f8f2"><code class="language-python"><span style="color:#66d9ef"><span style="color:#f92672">import</span></span> gym env <span style="color:#f8f8f2">=</span> gym<span style="color:#f8f8f2">.</span>make<span style="color:#f8f8f2">(</span><span style="color:#a6e22e"><span style="color:#e6db74">'CartPole-v0'</span></span><span style="color:#f8f8f2">)</span> model <span style="color:#f8f8f2">=</span> Model<span style="color:#f8f8f2">(</span>num_actions<span style="color:#f8f8f2">=</span>env<span style="color:#f8f8f2">.</span>action_space<span style="color:#f8f8f2">.</span>n<span style="color:#f8f8f2">)</span> obs <span style="color:#f8f8f2">=</span> env<span style="color:#f8f8f2">.</span>reset<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">)</span> <span style="color:slategray"><span style="color:#75715e"># no feed_dict or tf.Session() needed at all</span></span> action<span style="color:#f8f8f2">,</span> value <span style="color:#f8f8f2">=</span> model<span style="color:#f8f8f2">.</span>action_value<span style="color:#f8f8f2">(</span>obs<span style="color:#f8f8f2">[</span><span style="color:#f92672">None</span><span style="color:#f8f8f2">,</span> <span style="color:#f8f8f2">:</span><span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">)</span> <span style="color:#66d9ef">print</span><span style="color:#f8f8f2">(</span>action<span style="color:#f8f8f2">,</span> value<span style="color:#f8f8f2">)</span> <span style="color:slategray"><span style="color:#75715e"># [1] [-0.00145713]</span></span></code></span>這里要注意的事項:
- 模型層和執行路徑是分開定義的;
- 沒有“輸入”圖層,模型將接受原始numpy數組;
- 可以通過函數API在一個模型中定義兩個計算路徑;
- 模型可以包含一些輔助方法,例如動作采樣;
- 在eager的模式下,一切都可以從原始的numpy數組中運行;
隨機代理
現在我們可以繼續學習一些有趣的東西A2CAgent類。首先,讓我們添加一個貫穿整集的test方法并返回獎勵總和。
<span style="color:#f8f8f2"><code class="language-python"><span style="color:#66d9ef"><span style="color:#f92672">class</span></span> <span style="color:#f8f8f2">A2CAgent</span><span style="color:#f8f8f2">:</span><span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">__init__</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> model</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span>self<span style="color:#f8f8f2">.</span>model <span style="color:#f8f8f2">=</span> model<span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">test</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> env</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> render</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">=</span></span><span style="color:#ae81ff"><span style="color:#f8f8f2">True</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span>obs<span style="color:#f8f8f2">,</span> done<span style="color:#f8f8f2">,</span> ep_reward <span style="color:#f8f8f2">=</span> env<span style="color:#f8f8f2">.</span>reset<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">,</span> <span style="color:#ae81ff"><span style="color:#f92672">False</span></span><span style="color:#f8f8f2">,</span> <span style="color:#ae81ff"><span style="color:#ae81ff">0</span></span><span style="color:#66d9ef"><span style="color:#f92672">while</span></span> <span style="color:#f8f8f2"><span style="color:#f92672">not</span></span> done<span style="color:#f8f8f2">:</span>action<span style="color:#f8f8f2">,</span> _ <span style="color:#f8f8f2">=</span> self<span style="color:#f8f8f2">.</span>model<span style="color:#f8f8f2">.</span>action_value<span style="color:#f8f8f2">(</span>obs<span style="color:#f8f8f2">[</span><span style="color:#f92672">None</span><span style="color:#f8f8f2">,</span> <span style="color:#f8f8f2">:</span><span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">)</span>obs<span style="color:#f8f8f2">,</span> reward<span style="color:#f8f8f2">,</span> done<span style="color:#f8f8f2">,</span> _ <span style="color:#f8f8f2">=</span> env<span style="color:#f8f8f2">.</span>step<span style="color:#f8f8f2">(</span>action<span style="color:#f8f8f2">)</span>ep_reward <span style="color:#f8f8f2">+=</span> reward<span style="color:#66d9ef"><span style="color:#f92672">if</span></span> render<span style="color:#f8f8f2">:</span>env<span style="color:#f8f8f2">.</span>render<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">)</span><span style="color:#66d9ef"><span style="color:#f92672">return</span></span> ep_reward</code></span>讓我們看看我們的模型在隨機初始化權重下得分多少:
<span style="color:#f8f8f2"><code class="language-python">agent <span style="color:#f8f8f2">=</span> A2CAgent<span style="color:#f8f8f2">(</span>model<span style="color:#f8f8f2">)</span> rewards_sum <span style="color:#f8f8f2">=</span> agent<span style="color:#f8f8f2">.</span>test<span style="color:#f8f8f2">(</span>env<span style="color:#f8f8f2">)</span> <span style="color:#66d9ef">print</span><span style="color:#f8f8f2">(</span><span style="color:#a6e22e"><span style="color:#e6db74">"%d out of 200"</span></span> <span style="color:#f8f8f2">%</span> rewards_sum<span style="color:#f8f8f2">)</span> <span style="color:slategray"><span style="color:#75715e"># 18 out of 200</span></span></code></span>離最佳轉臺還有很遠,接下來是訓練部分!
損失/目標函數
正如我在DRL概述部分所描述的那樣,代理通過基于某些損失(目標)函數的梯度下降來改進其策略。在演員評論家中,我們訓練了三個目標:用優勢加權梯度加上熵最大化來改進策略,并最小化價值估計誤差。
<span style="color:#f8f8f2"><code class="language-python"><span style="color:#66d9ef"><span style="color:#f92672">import</span></span> tensorflow<span style="color:#f8f8f2">.</span>keras<span style="color:#f8f8f2">.</span>losses <span style="color:#66d9ef"><span style="color:#f92672">as</span></span> kls <span style="color:#66d9ef"><span style="color:#f92672">import</span></span> tensorflow<span style="color:#f8f8f2">.</span>keras<span style="color:#f8f8f2">.</span>optimizers <span style="color:#66d9ef"><span style="color:#f92672">as</span></span> ko <span style="color:#66d9ef"><span style="color:#f92672">class</span></span> <span style="color:#f8f8f2">A2CAgent</span><span style="color:#f8f8f2">:</span><span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">__init__</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> model</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:slategray"><span style="color:#75715e"># hyperparameters for loss terms</span></span>self<span style="color:#f8f8f2">.</span>params <span style="color:#f8f8f2">=</span> <span style="color:#f8f8f2">{</span><span style="color:#a6e22e"><span style="color:#e6db74">'value'</span></span><span style="color:#f8f8f2">:</span> <span style="color:#ae81ff"><span style="color:#ae81ff">0.5</span></span><span style="color:#f8f8f2">,</span> <span style="color:#a6e22e"><span style="color:#e6db74">'entropy'</span></span><span style="color:#f8f8f2">:</span> <span style="color:#ae81ff"><span style="color:#ae81ff">0.0001</span></span><span style="color:#f8f8f2">}</span>self<span style="color:#f8f8f2">.</span>model <span style="color:#f8f8f2">=</span> modelself<span style="color:#f8f8f2">.</span>model<span style="color:#f8f8f2">.</span>compile<span style="color:#f8f8f2">(</span>optimizer<span style="color:#f8f8f2">=</span>ko<span style="color:#f8f8f2">.</span>RMSprop<span style="color:#f8f8f2">(</span>lr<span style="color:#f8f8f2">=</span><span style="color:#ae81ff"><span style="color:#ae81ff">0.0007</span></span><span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">,</span><span style="color:slategray"><span style="color:#75715e"># define separate losses for policy logits and value estimate</span></span>loss<span style="color:#f8f8f2">=</span><span style="color:#f8f8f2">[</span>self<span style="color:#f8f8f2">.</span>_logits_loss<span style="color:#f8f8f2">,</span> self<span style="color:#f8f8f2">.</span>_value_loss<span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">)</span><span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">test</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> env</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> render</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">=</span></span><span style="color:#ae81ff"><span style="color:#f8f8f2">True</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:slategray"><span style="color:#75715e"># unchanged from previous section</span></span><span style="color:#f8f8f2">.</span><span style="color:#f8f8f2">.</span><span style="color:#f8f8f2">.</span><span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">_value_loss</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> returns</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> value</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:slategray"><span style="color:#75715e"># value loss is typically MSE between value estimates and returns</span></span><span style="color:#66d9ef"><span style="color:#f92672">return</span></span> self<span style="color:#f8f8f2">.</span>params<span style="color:#f8f8f2">[</span><span style="color:#a6e22e"><span style="color:#e6db74">'value'</span></span><span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">*</span>kls<span style="color:#f8f8f2">.</span>mean_squared_error<span style="color:#f8f8f2">(</span>returns<span style="color:#f8f8f2">,</span> value<span style="color:#f8f8f2">)</span><span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">_logits_loss</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> acts_and_advs</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> logits</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:slategray"><span style="color:#75715e"># a trick to input actions and advantages through same API</span></span>actions<span style="color:#f8f8f2">,</span> advantages <span style="color:#f8f8f2">=</span> tf<span style="color:#f8f8f2">.</span>split<span style="color:#f8f8f2">(</span>acts_and_advs<span style="color:#f8f8f2">,</span> <span style="color:#ae81ff"><span style="color:#ae81ff">2</span></span><span style="color:#f8f8f2">,</span> axis<span style="color:#f8f8f2">=</span><span style="color:#f8f8f2"><span style="color:#ae81ff">-</span></span><span style="color:#ae81ff"><span style="color:#ae81ff">1</span></span><span style="color:#f8f8f2">)</span><span style="color:slategray"><span style="color:#75715e"># polymorphic CE loss function that supports sparse and weighted options</span></span><span style="color:slategray"><span style="color:#75715e"># from_logits argument ensures transformation into normalized probabilities</span></span>cross_entropy <span style="color:#f8f8f2">=</span> kls<span style="color:#f8f8f2">.</span>CategoricalCrossentropy<span style="color:#f8f8f2">(</span>from_logits<span style="color:#f8f8f2">=</span><span style="color:#ae81ff"><span style="color:#f92672">True</span></span><span style="color:#f8f8f2">)</span><span style="color:slategray"><span style="color:#75715e"># policy loss is defined by policy gradients, weighted by advantages</span></span><span style="color:slategray"><span style="color:#75715e"># note: we only calculate the loss on the actions we've actually taken</span></span><span style="color:slategray"><span style="color:#75715e"># thus under the hood a sparse version of CE loss will be executed</span></span>actions <span style="color:#f8f8f2">=</span> tf<span style="color:#f8f8f2">.</span>cast<span style="color:#f8f8f2">(</span>actions<span style="color:#f8f8f2">,</span> tf<span style="color:#f8f8f2">.</span>int32<span style="color:#f8f8f2">)</span>policy_loss <span style="color:#f8f8f2">=</span> cross_entropy<span style="color:#f8f8f2">(</span>actions<span style="color:#f8f8f2">,</span> logits<span style="color:#f8f8f2">,</span> sample_weight<span style="color:#f8f8f2">=</span>advantages<span style="color:#f8f8f2">)</span><span style="color:slategray"><span style="color:#75715e"># entropy loss can be calculated via CE over itself</span></span>entropy_loss <span style="color:#f8f8f2">=</span> cross_entropy<span style="color:#f8f8f2">(</span>logits<span style="color:#f8f8f2">,</span> logits<span style="color:#f8f8f2">)</span><span style="color:slategray"><span style="color:#75715e"># here signs are flipped because optimizer minimizes</span></span><span style="color:#66d9ef"><span style="color:#f92672">return</span></span> policy_loss <span style="color:#f8f8f2">-</span> self<span style="color:#f8f8f2">.</span>params<span style="color:#f8f8f2">[</span><span style="color:#a6e22e"><span style="color:#e6db74">'entropy'</span></span><span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">*</span>entropy_loss</code></span>我們完成了目標函數!請注意代碼的緊湊程度:注釋行幾乎比代碼本身多。
代理訓練循環
最后,還有訓練回路本身,它相對較長,但相當簡單:收集樣本,計算回報和優勢,并在其上訓練模型。
<span style="color:#f8f8f2"><code class="language-python"><span style="color:#66d9ef"><span style="color:#f92672">class</span></span> <span style="color:#f8f8f2">A2CAgent</span><span style="color:#f8f8f2">:</span><span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">__init__</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> model</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:slategray"><span style="color:#75715e"># hyperparameters for loss terms</span></span>self<span style="color:#f8f8f2">.</span>params <span style="color:#f8f8f2">=</span> <span style="color:#f8f8f2">{</span><span style="color:#a6e22e"><span style="color:#e6db74">'value'</span></span><span style="color:#f8f8f2">:</span> <span style="color:#ae81ff"><span style="color:#ae81ff">0.5</span></span><span style="color:#f8f8f2">,</span> <span style="color:#a6e22e"><span style="color:#e6db74">'entropy'</span></span><span style="color:#f8f8f2">:</span> <span style="color:#ae81ff"><span style="color:#ae81ff">0.0001</span></span><span style="color:#f8f8f2">,</span> <span style="color:#a6e22e"><span style="color:#e6db74">'gamma'</span></span><span style="color:#f8f8f2">:</span> <span style="color:#ae81ff"><span style="color:#ae81ff">0.99</span></span><span style="color:#f8f8f2">}</span><span style="color:slategray"><span style="color:#75715e"># unchanged from previous section</span></span><span style="color:#f8f8f2">.</span><span style="color:#f8f8f2">.</span><span style="color:#f8f8f2">.</span><span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">train</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> env</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> batch_sz</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">=</span></span><span style="color:#ae81ff"><span style="color:#f8f8f2"><span style="color:#ae81ff">32</span></span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> updates</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">=</span></span><span style="color:#ae81ff"><span style="color:#f8f8f2"><span style="color:#ae81ff">1000</span></span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:slategray"><span style="color:#75715e"># storage helpers for a single batch of data</span></span>actions <span style="color:#f8f8f2">=</span> np<span style="color:#f8f8f2">.</span>empty<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">(</span>batch_sz<span style="color:#f8f8f2">,</span><span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">,</span> dtype<span style="color:#f8f8f2">=</span>np<span style="color:#f8f8f2">.</span>int32<span style="color:#f8f8f2">)</span>rewards<span style="color:#f8f8f2">,</span> dones<span style="color:#f8f8f2">,</span> values <span style="color:#f8f8f2">=</span> np<span style="color:#f8f8f2">.</span>empty<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">(</span><span style="color:#ae81ff"><span style="color:#ae81ff">3</span></span><span style="color:#f8f8f2">,</span> batch_sz<span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">)</span>observations <span style="color:#f8f8f2">=</span> np<span style="color:#f8f8f2">.</span>empty<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">(</span>batch_sz<span style="color:#f8f8f2">,</span><span style="color:#f8f8f2">)</span> <span style="color:#f8f8f2">+</span> env<span style="color:#f8f8f2">.</span>observation_space<span style="color:#f8f8f2">.</span>shape<span style="color:#f8f8f2">)</span><span style="color:slategray"><span style="color:#75715e"># training loop: collect samples, send to optimizer, repeat updates times</span></span>ep_rews <span style="color:#f8f8f2">=</span> <span style="color:#f8f8f2">[</span><span style="color:#ae81ff"><span style="color:#ae81ff">0.0</span></span><span style="color:#f8f8f2">]</span>next_obs <span style="color:#f8f8f2">=</span> env<span style="color:#f8f8f2">.</span>reset<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">)</span><span style="color:#66d9ef"><span style="color:#f92672">for</span></span> update <span style="color:#66d9ef"><span style="color:#f92672">in</span></span> range<span style="color:#f8f8f2">(</span>updates<span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">:</span><span style="color:#66d9ef"><span style="color:#f92672">for</span></span> step <span style="color:#66d9ef"><span style="color:#f92672">in</span></span> range<span style="color:#f8f8f2">(</span>batch_sz<span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">:</span>observations<span style="color:#f8f8f2">[</span>step<span style="color:#f8f8f2">]</span> <span style="color:#f8f8f2">=</span> next_obs<span style="color:#f8f8f2">.</span>copy<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">)</span>actions<span style="color:#f8f8f2">[</span>step<span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">,</span> values<span style="color:#f8f8f2">[</span>step<span style="color:#f8f8f2">]</span> <span style="color:#f8f8f2">=</span> self<span style="color:#f8f8f2">.</span>model<span style="color:#f8f8f2">.</span>action_value<span style="color:#f8f8f2">(</span>next_obs<span style="color:#f8f8f2">[</span><span style="color:#f92672">None</span><span style="color:#f8f8f2">,</span> <span style="color:#f8f8f2">:</span><span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">)</span>next_obs<span style="color:#f8f8f2">,</span> rewards<span style="color:#f8f8f2">[</span>step<span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">,</span> dones<span style="color:#f8f8f2">[</span>step<span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">,</span> _ <span style="color:#f8f8f2">=</span> env<span style="color:#f8f8f2">.</span>step<span style="color:#f8f8f2">(</span>actions<span style="color:#f8f8f2">[</span>step<span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">)</span>ep_rews<span style="color:#f8f8f2">[</span><span style="color:#f8f8f2"><span style="color:#ae81ff">-</span></span><span style="color:#ae81ff"><span style="color:#ae81ff">1</span></span><span style="color:#f8f8f2">]</span> <span style="color:#f8f8f2">+=</span> rewards<span style="color:#f8f8f2">[</span>step<span style="color:#f8f8f2">]</span><span style="color:#66d9ef"><span style="color:#f92672">if</span></span> dones<span style="color:#f8f8f2">[</span>step<span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">:</span>ep_rews<span style="color:#f8f8f2">.</span>append<span style="color:#f8f8f2">(</span><span style="color:#ae81ff"><span style="color:#ae81ff">0.0</span></span><span style="color:#f8f8f2">)</span>next_obs <span style="color:#f8f8f2">=</span> env<span style="color:#f8f8f2">.</span>reset<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">)</span>_<span style="color:#f8f8f2">,</span> next_value <span style="color:#f8f8f2">=</span> self<span style="color:#f8f8f2">.</span>model<span style="color:#f8f8f2">.</span>action_value<span style="color:#f8f8f2">(</span>next_obs<span style="color:#f8f8f2">[</span><span style="color:#f92672">None</span><span style="color:#f8f8f2">,</span> <span style="color:#f8f8f2">:</span><span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">)</span>returns<span style="color:#f8f8f2">,</span> advs <span style="color:#f8f8f2">=</span> self<span style="color:#f8f8f2">.</span>_returns_advantages<span style="color:#f8f8f2">(</span>rewards<span style="color:#f8f8f2">,</span> dones<span style="color:#f8f8f2">,</span> values<span style="color:#f8f8f2">,</span> next_value<span style="color:#f8f8f2">)</span><span style="color:slategray"><span style="color:#75715e"># a trick to input actions and advantages through same API</span></span>acts_and_advs <span style="color:#f8f8f2">=</span> np<span style="color:#f8f8f2">.</span>concatenate<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">[</span>actions<span style="color:#f8f8f2">[</span><span style="color:#f8f8f2">:</span><span style="color:#f8f8f2">,</span> <span style="color:#f92672">None</span><span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">,</span> advs<span style="color:#f8f8f2">[</span><span style="color:#f8f8f2">:</span><span style="color:#f8f8f2">,</span> <span style="color:#f92672">None</span><span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">,</span> axis<span style="color:#f8f8f2">=</span><span style="color:#f8f8f2"><span style="color:#ae81ff">-</span></span><span style="color:#ae81ff"><span style="color:#ae81ff">1</span></span><span style="color:#f8f8f2">)</span><span style="color:slategray"><span style="color:#75715e"># performs a full training step on the collected batch</span></span><span style="color:slategray"><span style="color:#75715e"># note: no need to mess around with gradients, Keras API handles it</span></span>losses <span style="color:#f8f8f2">=</span> self<span style="color:#f8f8f2">.</span>model<span style="color:#f8f8f2">.</span>train_on_batch<span style="color:#f8f8f2">(</span>observations<span style="color:#f8f8f2">,</span> <span style="color:#f8f8f2">[</span>acts_and_advs<span style="color:#f8f8f2">,</span> returns<span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">)</span><span style="color:#66d9ef"><span style="color:#f92672">return</span></span> ep_rews<span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">_returns_advantages</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> rewards</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> dones</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> values</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> next_value</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:slategray"><span style="color:#75715e"># next_value is the bootstrap value estimate of a future state (the critic)</span></span>returns <span style="color:#f8f8f2">=</span> np<span style="color:#f8f8f2">.</span>append<span style="color:#f8f8f2">(</span>np<span style="color:#f8f8f2">.</span>zeros_like<span style="color:#f8f8f2">(</span>rewards<span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">,</span> next_value<span style="color:#f8f8f2">,</span> axis<span style="color:#f8f8f2">=</span><span style="color:#f8f8f2"><span style="color:#ae81ff">-</span></span><span style="color:#ae81ff"><span style="color:#ae81ff">1</span></span><span style="color:#f8f8f2">)</span><span style="color:slategray"><span style="color:#75715e"># returns are calculated as discounted sum of future rewards</span></span><span style="color:#66d9ef"><span style="color:#f92672">for</span></span> t <span style="color:#66d9ef"><span style="color:#f92672">in</span></span> reversed<span style="color:#f8f8f2">(</span>range<span style="color:#f8f8f2">(</span>rewards<span style="color:#f8f8f2">.</span>shape<span style="color:#f8f8f2">[</span><span style="color:#ae81ff"><span style="color:#ae81ff">0</span></span><span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">:</span>returns<span style="color:#f8f8f2">[</span>t<span style="color:#f8f8f2">]</span> <span style="color:#f8f8f2">=</span> rewards<span style="color:#f8f8f2">[</span>t<span style="color:#f8f8f2">]</span> <span style="color:#f8f8f2">+</span> self<span style="color:#f8f8f2">.</span>params<span style="color:#f8f8f2">[</span><span style="color:#a6e22e"><span style="color:#e6db74">'gamma'</span></span><span style="color:#f8f8f2">]</span> <span style="color:#f8f8f2">*</span> returns<span style="color:#f8f8f2">[</span>t<span style="color:#f8f8f2">+</span><span style="color:#ae81ff"><span style="color:#ae81ff">1</span></span><span style="color:#f8f8f2">]</span> <span style="color:#f8f8f2">*</span> <span style="color:#f8f8f2">(</span><span style="color:#ae81ff"><span style="color:#ae81ff">1</span></span><span style="color:#f8f8f2">-</span>dones<span style="color:#f8f8f2">[</span>t<span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">)</span>returns <span style="color:#f8f8f2">=</span> returns<span style="color:#f8f8f2">[</span><span style="color:#f8f8f2">:</span><span style="color:#f8f8f2"><span style="color:#ae81ff">-</span></span><span style="color:#ae81ff"><span style="color:#ae81ff">1</span></span><span style="color:#f8f8f2">]</span><span style="color:slategray"><span style="color:#75715e"># advantages are returns - baseline, value estimates in our case</span></span>advantages <span style="color:#f8f8f2">=</span> returns <span style="color:#f8f8f2">-</span> values<span style="color:#66d9ef"><span style="color:#f92672">return</span></span> returns<span style="color:#f8f8f2">,</span> advantages<span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">test</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> env</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> render</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">=</span></span><span style="color:#ae81ff"><span style="color:#f8f8f2">True</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:slategray"><span style="color:#75715e"># unchanged from previous section</span></span><span style="color:#f8f8f2">.</span><span style="color:#f8f8f2">.</span><span style="color:#f8f8f2">.</span><span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">_value_loss</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> returns</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> value</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:slategray"><span style="color:#75715e"># unchanged from previous section</span></span><span style="color:#f8f8f2">.</span><span style="color:#f8f8f2">.</span><span style="color:#f8f8f2">.</span><span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">_logits_loss</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> acts_and_advs</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> logits</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:slategray"><span style="color:#75715e"># unchanged from previous section</span></span><span style="color:#f8f8f2">.</span><span style="color:#f8f8f2">.</span><span style="color:#f8f8f2">.</span></code></span>訓練和結果
我們現在已經準備好在CartPole-v0上訓練我們的單工A2C代理了!訓練過程不應超過幾分鐘,訓練完成后,你應該看到代理成功達到200分中的目標。
<span style="color:#f8f8f2"><code class="language-python">rewards_history <span style="color:#f8f8f2">=</span> agent<span style="color:#f8f8f2">.</span>train<span style="color:#f8f8f2">(</span>env<span style="color:#f8f8f2">)</span> <span style="color:#66d9ef">print</span><span style="color:#f8f8f2">(</span><span style="color:#a6e22e"><span style="color:#e6db74">"Finished training, testing..."</span></span><span style="color:#f8f8f2">)</span> <span style="color:#66d9ef">print</span><span style="color:#f8f8f2">(</span><span style="color:#a6e22e"><span style="color:#e6db74">"%d out of 200"</span></span> <span style="color:#f8f8f2">%</span> agent<span style="color:#f8f8f2">.</span>test<span style="color:#f8f8f2">(</span>env<span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">)</span> <span style="color:slategray"><span style="color:#75715e"># 200 out of 200</span></span></code></span>?
在源代碼中,我包含了一些額外的幫助程序,可以打印出運行的獎勵和損失,以及rewards_history的基本繪圖儀。
靜態計算圖
有了所有這種渴望模式的成功的喜悅,你可能想知道靜態圖形執行是否可以。當然!此外,我們還需要多一行代碼來啟用它!
<span style="color:#f8f8f2"><code class="language-python"><span style="color:#66d9ef"><span style="color:#f92672">with</span></span> tf<span style="color:#f8f8f2">.</span>Graph<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">.</span>as_default<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">:</span><span style="color:#66d9ef">print</span><span style="color:#f8f8f2">(</span>tf<span style="color:#f8f8f2">.</span>executing_eagerly<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">)</span> <span style="color:slategray"><span style="color:#75715e"># False</span></span>model <span style="color:#f8f8f2">=</span> Model<span style="color:#f8f8f2">(</span>num_actions<span style="color:#f8f8f2">=</span>env<span style="color:#f8f8f2">.</span>action_space<span style="color:#f8f8f2">.</span>n<span style="color:#f8f8f2">)</span>agent <span style="color:#f8f8f2">=</span> A2CAgent<span style="color:#f8f8f2">(</span>model<span style="color:#f8f8f2">)</span>rewards_history <span style="color:#f8f8f2">=</span> agent<span style="color:#f8f8f2">.</span>train<span style="color:#f8f8f2">(</span>env<span style="color:#f8f8f2">)</span><span style="color:#66d9ef">print</span><span style="color:#f8f8f2">(</span><span style="color:#a6e22e"><span style="color:#e6db74">"Finished training, testing..."</span></span><span style="color:#f8f8f2">)</span><span style="color:#66d9ef">print</span><span style="color:#f8f8f2">(</span><span style="color:#a6e22e"><span style="color:#e6db74">"%d out of 200"</span></span> <span style="color:#f8f8f2">%</span> agent<span style="color:#f8f8f2">.</span>test<span style="color:#f8f8f2">(</span>env<span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">)</span> <span style="color:slategray"><span style="color:#75715e"># 200 out of 200</span></span></code></span>有一點需要注意,在靜態圖形執行期間,我們不能只有Tensors,這就是為什么我們在模型定義期間需要使用CategoricalDistribution的技巧。事實上,當我在尋找一種在靜態模式下執行的方法時,我發現了一個關于通過Keras API構建的模型的一個有趣的低級細節。
還有一件事…
還記得我說過TensorFlow默認是運行在eager模式下吧,甚至用代碼片段證明它嗎?好吧,我錯了。
如果你使用Keras API來構建和管理模型,那么它將嘗試將它們編譯為靜態圖形。所以你最終得到的是靜態計算圖的性能,具有渴望執行的靈活性。
你可以通過model.run_eagerly標志檢查模型的狀態,你也可以通過設置此標志來強制執行eager模式變成True,盡管大多數情況下你可能不需要這樣做。但如果Keras檢測到沒有辦法繞過eager模式,它將自動退出。
為了說明它確實是作為靜態圖運行,這里是一個簡單的基準測試:
<span style="color:#f8f8f2"><code class="language-python"><span style="color:slategray"><span style="color:#75715e"># create a 100000 samples batch</span></span> env <span style="color:#f8f8f2">=</span> gym<span style="color:#f8f8f2">.</span>make<span style="color:#f8f8f2">(</span><span style="color:#a6e22e"><span style="color:#e6db74">'CartPole-v0'</span></span><span style="color:#f8f8f2">)</span> obs <span style="color:#f8f8f2">=</span> np<span style="color:#f8f8f2">.</span>repeat<span style="color:#f8f8f2">(</span>env<span style="color:#f8f8f2">.</span>reset<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">[</span><span style="color:#f92672">None</span><span style="color:#f8f8f2">,</span> <span style="color:#f8f8f2">:</span><span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">,</span> <span style="color:#ae81ff"><span style="color:#ae81ff">100000</span></span><span style="color:#f8f8f2">,</span> axis<span style="color:#f8f8f2">=</span><span style="color:#ae81ff"><span style="color:#ae81ff">0</span></span><span style="color:#f8f8f2">)</span></code></span>Eager基準
<span style="color:#f8f8f2"><code class="language-python"><span style="color:#f8f8f2">%</span><span style="color:#f8f8f2">%</span>time model <span style="color:#f8f8f2">=</span> Model<span style="color:#f8f8f2">(</span>env<span style="color:#f8f8f2">.</span>action_space<span style="color:#f8f8f2">.</span>n<span style="color:#f8f8f2">)</span> model<span style="color:#f8f8f2">.</span>run_eagerly <span style="color:#f8f8f2">=</span> <span style="color:#ae81ff"><span style="color:#f92672">True</span></span> <span style="color:#66d9ef">print</span><span style="color:#f8f8f2">(</span><span style="color:#a6e22e"><span style="color:#e6db74">"Eager Execution: "</span></span><span style="color:#f8f8f2">,</span> tf<span style="color:#f8f8f2">.</span>executing_eagerly<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">)</span> <span style="color:#66d9ef">print</span><span style="color:#f8f8f2">(</span><span style="color:#a6e22e"><span style="color:#e6db74">"Eager Keras Model:"</span></span><span style="color:#f8f8f2">,</span> model<span style="color:#f8f8f2">.</span>run_eagerly<span style="color:#f8f8f2">)</span> _ <span style="color:#f8f8f2">=</span> model<span style="color:#f8f8f2">(</span>obs<span style="color:#f8f8f2">)</span> <span style="color:slategray"><span style="color:#75715e">######## Results #######</span></span> Eager Execution<span style="color:#f8f8f2">:</span> <span style="color:#ae81ff"><span style="color:#f92672">True</span></span> Eager Keras Model<span style="color:#f8f8f2">:</span> <span style="color:#ae81ff"><span style="color:#f92672">True</span></span> CPU times<span style="color:#f8f8f2">:</span> user <span style="color:#ae81ff"><span style="color:#ae81ff">639</span></span> ms<span style="color:#f8f8f2">,</span> sys<span style="color:#f8f8f2">:</span> <span style="color:#ae81ff"><span style="color:#ae81ff">736</span></span> ms<span style="color:#f8f8f2">,</span> total<span style="color:#f8f8f2">:</span> <span style="color:#ae81ff"><span style="color:#ae81ff">1.38</span></span> s</code></span>靜態基準
<span style="color:#f8f8f2"><code class="language-none">%%time with tf.Graph().as_default():model = Model(env.action_space.n)print("Eager Execution: ", tf.executing_eagerly())print("Eager Keras Model:", model.run_eagerly)_ = model.predict(obs) ######## Results ####### Eager Execution: False Eager Keras Model: False CPU times: user 793 ms, sys: 79.7 ms, total: 873 ms</code></span>默認基準
<span style="color:#333333"><span style="color:#f8f8f2"><code class="language-none">%%time model = Model(env.action_space.n) print("Eager Execution: ", tf.executing_eagerly()) print("Eager Keras Model:", model.run_eagerly) _ = model.predict(obs) ######## Results ####### Eager Execution: True Eager Keras Model: False CPU times: user 994 ms, sys: 23.1 ms, total: 1.02 s</code></span></span>正如你所看到的,eager模式是靜態模式的背后,默認情況下,我們的模型確實是靜態執行的。
結論
希望本文能夠幫助你理解DRL和TensorFlow2.0。請注意,TensorFlow2.0仍然只是預覽版本,甚至不是候選版本,一切都可能發生變化。如果TensorFlow有什么東西你特別不喜歡,讓它的開發者知道!
人們可能會有一個揮之不去的問題:TensorFlow比PyTorch好嗎?也許,也許不是。它們兩個都是偉大的庫,所以很難說這樣誰好,誰不好。如果你熟悉PyTorch,你可能已經注意到TensorFlow 2.0不僅趕上了它,而且還避免了一些PyTorch API的缺陷。
在任何一種情況下,對于開發者來說,這場競爭都已經為雙方帶來了積極的結果,我很期待看到未來的框架將會變成什么樣。
?
原文鏈接
本文為云棲社區原創內容,未經允許不得轉載。
總結
以上是生活随笔為你收集整理的TensorFlow 2.0深度强化学习指南的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 80后阿里P10,“关老板”如何带着Ma
- 下一篇: 五年,时间告诉我只有自己强大才是真的强大