多智能体连续行为空间问题求解——MADDPG
目錄
- 1. 問(wèn)題出現(xiàn):連續(xù)行為空間出現(xiàn)
- 2. DDPG 算法
- 2.1 DDPG 算法原理
- 2.2 DDPG 算法實(shí)現(xiàn)代碼
- 2.2.1 Actor & Critic
- 2.2.2 Target Network
- 2.2.3 Memory Pool
- 2.2.4 Update Parameters(evaluate network)
- 2.2.5 Update Parameters(target network)
- 3. MADDPG 算法
- 3.1 Actor 網(wǎng)絡(luò)定義
- 3.2 Critic 網(wǎng)絡(luò)定義
- 3.3 Update Parameters 過(guò)程
MADDPG 是一種針對(duì)多智能體、連續(xù)行為空間設(shè)計(jì)的算法。MADDPG 的前身是DDPG,DDPG 算法旨在解決連續(xù)性行為空間的強(qiáng)化學(xué)習(xí)問(wèn)題,而 MADDPG 是在 DDPG 的基礎(chǔ)上做了改進(jìn),使其能夠適用于多智能體之間的合作任務(wù)學(xué)習(xí)。本文先從 DDPG 引入,接著再介紹如何在 DDPG 算法上進(jìn)行修改使其變成 MADDPG 算法。
1. 問(wèn)題出現(xiàn):連續(xù)行為空間出現(xiàn)
Q-Learning 算法是強(qiáng)化學(xué)習(xí)中一種常用的方法,但傳統(tǒng)的 Q-Learning 需要枚舉所有的狀態(tài)空間并建立 Q-Table,為了解決龐大不可枚舉的狀態(tài)空間問(wèn)題,DQN 被人們?cè)O(shè)計(jì)出來(lái),利用神經(jīng)網(wǎng)絡(luò)近似擬合的方法來(lái)避免了窮舉所有可能的狀態(tài)空間。但 DQN 算法有一個(gè)問(wèn)題,那就是在計(jì)算當(dāng)前 Q 值的時(shí)候需要求出下一個(gè)狀態(tài)中每一個(gè)動(dòng)作的值函數(shù),選擇最大的動(dòng)作值函數(shù)值來(lái)進(jìn)行計(jì)算。
Qπ(st,at)=R(st,at)+γmaxaQπ(st+1,at+1)Q^{\pi}(s_t, a_t) = R(s_t, a_t) + \gamma max_aQ^{\pi}(s_{t+1}, a_{t+1}) Qπ(st?,at?)=R(st?,at?)+γmaxa?Qπ(st+1?,at+1?)
在 Actor-Critic 算法中同樣會(huì)面臨這個(gè)問(wèn)題,更新 critic 網(wǎng)絡(luò)時(shí)候需要計(jì)算下一個(gè)狀態(tài)下所有行為的Q值并取其平均值,計(jì)算公式如下:
Qπ(st,at)=R(st,at)+γEπ[Qπ(st+1,at+1)]Q^{\pi}(s_t, a_t) = R(s_t, a_t) + \gamma E_{\pi}[Q^{\pi}(s_{t+1}, a_{t+1})] Qπ(st?,at?)=R(st?,at?)+γEπ?[Qπ(st+1?,at+1?)]
其中 Eπ[Qπ(st+1,at+1)]E_{\pi}[Q^{\pi}(s_{t+1}, a_{t+1})]Eπ?[Qπ(st+1?,at+1?)] 是枚舉所有動(dòng)作的得分效用并乘上對(duì)應(yīng)動(dòng)作的選取概率(當(dāng)然在 AC 中可以直接通過(guò)擬合一個(gè) V(s)V(s)V(s) 來(lái)近似替代枚舉結(jié)果)。那么不管是 DQN 還是 AC 算法,都涉及到需要計(jì)算整個(gè)行為空間中所有行為的效用值,一旦行為空間演變?yōu)檫B續(xù)型的就無(wú)法使用以上算法,因?yàn)闊o(wú)法窮舉所有的行為并計(jì)算所有行為的值之和了。為此,在解決連續(xù)行為空間問(wèn)題的時(shí)候,我們需要一種新的算法,能夠不用窮舉所有行為的值就能完成算法更新,DDPG 的出現(xiàn)很好的解決了這個(gè)問(wèn)題。
2. DDPG 算法
2.1 DDPG 算法原理
DPG(Deterministic Policy Gradient)算法是一種 “確定性行為策略” 算法,我們之前問(wèn)題的難點(diǎn)在于對(duì)于連續(xù)的龐大行為空間,我們無(wú)法一一枚舉所有可能的行為。因此,DPG 認(rèn)為,在求取下一個(gè)狀態(tài)的狀態(tài)值時(shí),我們沒(méi)有必要去計(jì)算所有可能的行為值并跟據(jù)每個(gè)行為被采取的概率做加權(quán)平均,我們只需要認(rèn)為在一個(gè)狀態(tài)下只有可能采取某一個(gè)確定的行為 aaa,即該行為 aaa 被采取的概率為百分之百,這樣就行了,于是整個(gè) Q 值計(jì)算函數(shù)就變成了:
Qμ(st,at)=R(st,at)+γQμ(st+1,μ(st+1))]Q^{\mu}(s_t, a_t) = R(s_t, a_t) + \gamma Q^{\mu}(s_{t+1}, \mu{(s_{t+1})})] Qμ(st?,at?)=R(st?,at?)+γQμ(st+1?,μ(st+1?))]
即,原本的行為 aaa 是由隨機(jī)策略 π\(zhòng)piπ 進(jìn)行概率選擇,而現(xiàn)在這個(gè)行為由一個(gè)確定性策略 μ\muμ 來(lái)選擇,確定性策略是指只要輸入一個(gè)狀態(tài)就一定能得到唯一一個(gè)確定的輸出行為,而隨機(jī)性策略指的是輸入一個(gè)狀態(tài),輸出的是整個(gè)行為空間的所有行為概率分布。DDPG 是 DPG 算法上融合進(jìn)神經(jīng)網(wǎng)絡(luò)技術(shù),變成了 Deep Deterministic Policy Gradient,其整體思路和 DPG 是一致的。
2.2 DDPG 算法實(shí)現(xiàn)代碼
DDPG 沿用了 Actor-Critic 算法結(jié)構(gòu),在代碼中也存在一個(gè) Actor 和一個(gè) Critic,Actor 負(fù)責(zé)做行為決策,而 Critic 負(fù)責(zé)做行為效用評(píng)估,這里使用 DDPG 學(xué)習(xí)玩 gym 中一個(gè)倒立擺的游戲,游戲中的 action 為順時(shí)針或逆時(shí)針的旋轉(zhuǎn)力度,旋轉(zhuǎn)力度是一個(gè)連續(xù)行為,力的大小是一個(gè)連續(xù)的隨機(jī)變量,最終期望能夠通過(guò)不斷學(xué)習(xí)后算法能夠?qū)W會(huì)如何讓桿子倒立在上面靜止不動(dòng),如下圖所示:
2.2.1 Actor & Critic
我們先來(lái)看看在 DDPG 中 Actor 和 Critic 分別是怎么實(shí)現(xiàn)的的,Actor 和 Critic 的定義如下(代碼參考自這里):
class Actor(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(Actor, self).__init__()self.linear1 = nn.Linear(input_size, hidden_size)self.linear2 = nn.Linear(hidden_size, hidden_size)self.linear3 = nn.Linear(hidden_size, output_size)def forward(self, s):x = F.relu(self.linear1(s))x = F.relu(self.linear2(x))x = torch.tanh(self.linear3(x))return xclass Critic(nn.Module):def __init__(self, input_size, hidden_size, output_size):super().__init__()self.linear1 = nn.Linear(input_size, hidden_size)self.linear2 = nn.Linear(hidden_size, hidden_size)self.linear3 = nn.Linear(hidden_size, output_size)def forward(self, s, a):x = torch.cat([s, a], 1) # DDPG與普通AC算法的不同之處x = F.relu(self.linear1(x))x = F.relu(self.linear2(x))x = self.linear3(x)return xActor 的設(shè)計(jì)和以往相同,沒(méi)什么太大變化。
Critic 的實(shí)現(xiàn)有了一些改變,在 forward 函數(shù)中,原始的 critic 只用傳入狀態(tài)sss,輸出所有動(dòng)作的效用值,但由于這是連續(xù)動(dòng)作空間,無(wú)法輸出每一個(gè)行為的值,因此 critic 網(wǎng)絡(luò)改為接收一個(gè)狀態(tài) sss 和一個(gè)具體行為 aaa 作為輸入,輸出的是具體行為 aaa 在當(dāng)前狀態(tài) sss 下的效用值,即 critic 網(wǎng)絡(luò)輸出維度為1。
2.2.2 Target Network
除了在 critic 網(wǎng)絡(luò)上有了改變之外,DDPG 在整個(gè)算法層面上也做了修改。DDPG 參照了 DQN 的方式,為了算法添加了 target network,即固定住一個(gè) target 網(wǎng)絡(luò)產(chǎn)生樣本,另一個(gè) evaluate 網(wǎng)絡(luò)不斷更新迭代的思想,因此整個(gè)算法包含 4 個(gè)網(wǎng)絡(luò):
actor = Actor(s_dim, 256, a_dim) actor_target = Actor(s_dim, 256, a_dim) critic = Critic(s_dim+a_dim, 256, a_dim) # 輸入維度是 狀態(tài)空間 + 行為空間 critic_target = Critic(s_dim+a_dim, 256, a_dim)值得注意的是,在上述 critic 網(wǎng)絡(luò)中輸入的是 s_dim + a_dim,為什么是加 a_dim 呢?因?yàn)樵?DDPG 算法中,critic 網(wǎng)絡(luò)評(píng)判的是一組行為的效用值,即如果有(油門、方向盤)這兩個(gè)行為的話,那么傳入的應(yīng)該是(油門大小、方向盤轉(zhuǎn)動(dòng)度數(shù))這一組行為,critic 網(wǎng)絡(luò)對(duì)這一組動(dòng)作行為做一個(gè)效用評(píng)判。
2.2.3 Memory Pool
之前提到 DDPG 算法借用了 DQN 思想,除了加入了 Target 網(wǎng)絡(luò)之外還引入了 Memory Pool 機(jī)制,將收集到的歷史經(jīng)驗(yàn)存放到記憶庫(kù)中,在更新的時(shí)候取一個(gè) batch 的數(shù)據(jù)來(lái)計(jì)算均值,memory pool 代碼如下:
# 經(jīng)驗(yàn)池 buffer = []# 往經(jīng)驗(yàn)池存放經(jīng)驗(yàn)數(shù)據(jù) def put(self, *transition): if len(self.buffer)== self.capacity:self.buffer.pop(0)self.buffer.append(transition)2.2.4 Update Parameters(evaluate network)
在定義好了這些結(jié)構(gòu)之后,我們就開(kāi)始看看如何進(jìn)行梯度更新吧。所需要更新參數(shù)的網(wǎng)絡(luò)一共有 4 個(gè),2 個(gè) target network 和 2 個(gè) evaluate network,target network 的更新是在訓(xùn)練迭代了若干輪后將 evaluate network 當(dāng)前的參數(shù)值復(fù)制過(guò)去即可,只不過(guò)這里并不是直接復(fù)制,會(huì)做一些處理,這里我們先來(lái)看 evaluate network 是如何進(jìn)行參數(shù)更新的,actor 和 critic 的更新代碼如下 :
def critic_learn():a1 = self.actor_target(s1).detach()y_true = r1 + self.gamma * self.critic_target(s1, a1).detach() # 下一個(gè)狀態(tài)的目標(biāo)狀態(tài)值y_pred = self.critic(s0, a0) # 下一個(gè)狀態(tài)的預(yù)測(cè)狀態(tài)值loss_fn = nn.MSELoss()loss = loss_fn(y_pred, y_true)self.critic_optim.zero_grad()loss.backward()self.critic_optim.step()def actor_learn():loss = -torch.mean( self.critic(s0, self.actor(s0)) )self.actor_optim.zero_grad()loss.backward()self.actor_optim.step()我們先來(lái)看 critic 的 learn 函數(shù),loss 函數(shù)比較的是 用當(dāng)前網(wǎng)絡(luò)預(yù)測(cè)當(dāng)前狀態(tài)的Q值 和 利用回報(bào)R與下一狀態(tài)的狀態(tài)值之和 之間的 error 值,現(xiàn)在問(wèn)題在于下一個(gè)狀態(tài)的狀態(tài)值如何計(jì)算,在 DDPG 算法中由于確定了在一種狀態(tài)下只會(huì)以100%的概率去選擇一個(gè)確定的動(dòng)作,因此在計(jì)算下一個(gè)狀態(tài)的狀態(tài)值的時(shí)候,直接根據(jù) actor 網(wǎng)絡(luò)輸出一個(gè)在下一個(gè)狀態(tài)會(huì)采取的行為,把這個(gè)行為當(dāng)作100%概率的確定行為,并根據(jù)這個(gè)行為和下一刻的狀態(tài)輸入 critic 網(wǎng)絡(luò)得到下一個(gè)狀態(tài)的狀態(tài)值,最后通過(guò)計(jì)算這兩個(gè)值的差來(lái)進(jìn)行反向梯度更新(TD-ERROR)。
再來(lái)看看 actor 的 learn 函數(shù),actor 還是普通的更新思路 —— actor 選擇一個(gè)可能的行為,通過(guò) reward 來(lái)決定增加選取這個(gè) action 的概率還是降低選擇這個(gè) action 的概率。而增加/減少概率的多少由 critic 網(wǎng)絡(luò)來(lái)決定,若 critic 網(wǎng)絡(luò)評(píng)判出來(lái)當(dāng)前狀態(tài)下采取當(dāng)前行為會(huì)得到一個(gè)非常高的正效用值,那么梯度更新后 actor 下次采取這個(gè)行為的概率就會(huì)大幅度增加。而傳統(tǒng)的 actor 在進(jìn)行行為選擇時(shí)神經(jīng)網(wǎng)絡(luò)會(huì)輸出每一個(gè)行為的被采取概率,按照這些概率來(lái)隨機(jī)選擇一個(gè)行為,但在 DDPG 算法中,所有行為都是被確定性選擇的,不會(huì)存在隨機(jī)性,因此在代碼中傳入的是經(jīng)過(guò) actor 后得到的輸出行為,認(rèn)為該行為就是100%被確定性選擇的,沒(méi)有之前的按概率選擇行為這一個(gè)環(huán)節(jié)了。 選好行為后和當(dāng)前狀態(tài)一起傳給 critic 網(wǎng)絡(luò)做效用值評(píng)估。
2.2.5 Update Parameters(target network)
Target Network 在 DDPG 算法中沿用了 DQN 的思路,在迭代一定的輪數(shù)后,會(huì)從 evaluate network 中 copy 參數(shù)到自身網(wǎng)絡(luò)中去。但是不同的是,DDPG 在進(jìn)行參數(shù)復(fù)制的時(shí)候選擇的是 soft update 的方式,即,在進(jìn)行參數(shù)復(fù)制的時(shí)候不是進(jìn)行直接復(fù)制值,而是將 target net 和 evaluate net 的參數(shù)值以一定的權(quán)重值加起來(lái),融合成新的網(wǎng)絡(luò)參數(shù),代碼如下:
def soft_update(net_target, net, tau):for target_param, param in zip(net_target.parameters(), net.parameters()):target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)參數(shù) tau 是保留程度參數(shù),tau 值越大則保留的原網(wǎng)絡(luò)的參數(shù)的程度越大。
3. MADDPG 算法
在理解了 DDPG 算法后,理解 MADDPG 就比較容易了。MADDPG 是 Multi-Agent 下的 DDPG 算法,主要針對(duì)于多智能體之間連續(xù)行為進(jìn)行求解。MADDPG 同樣沿用了 AC 算法的架構(gòu),和 DDPG 相比只是在 Critic 網(wǎng)絡(luò)上的輸入做了一些額外信息的添加,下面結(jié)合實(shí)際代碼來(lái)分析:
3.1 Actor 網(wǎng)絡(luò)定義
class Actor(nn.Module):def __init__(self, args, agent_id):""" 網(wǎng)絡(luò)層定義部分 """super(Actor, self).__init__()self.fc1 = nn.Linear(args.obs_shape[agent_id], 64) # 定義輸入維度self.fc2 = nn.Linear(64, 64)self.fc3 = nn.Linear(64, 64)self.action_out = nn.Linear(64, args.action_shape[agent_id]) # 定義輸出維度def forward(self, x):""" 網(wǎng)絡(luò)前向傳播過(guò)程定義 """x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = F.relu(self.fc3(x))actions = torch.tanh(self.action_out(x))return actions上面是 MADDPG 中 actor 網(wǎng)絡(luò)的定義代碼,由于一個(gè)場(chǎng)景中可能存在多種不同的智能體,其觀測(cè)空間維度與行為空間維度都不盡相同,因此在進(jìn)行 actor 定義時(shí)需傳入每個(gè)智能體自身所符合的維度信息,如上述代碼一樣,通過(guò) agent_id 來(lái)獲取具體的智能體信息,前向傳播過(guò)程與 DDPG 相同,沒(méi)有什么特殊之處。
3.2 Critic 網(wǎng)絡(luò)定義
class Critic(nn.Module):def __init__(self, args):super(Critic, self).__init__()self.max_action = args.high_actionself.fc1 = nn.Linear(sum(args.obs_shape) + sum(args.action_shape), 64) # 定義輸入層維度(聯(lián)合觀測(cè)+聯(lián)合行為)self.fc2 = nn.Linear(64, 64)self.fc3 = nn.Linear(64, 64)self.q_out = nn.Linear(64, 1)def forward(self, state, action):state = torch.cat(state, dim=1) # 聯(lián)合觀測(cè)action = torch.cat(action, dim=1) # 聯(lián)合行為x = torch.cat([state, action], dim=1) # 聯(lián)合觀測(cè) + 聯(lián)合行為x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = F.relu(self.fc3(x))q_value = self.q_out(x)return q_valueCritic 的代碼如上,可見(jiàn) MADDPG 中的 Critic 是一個(gè)中心化網(wǎng)絡(luò),即傳入的不只是當(dāng)前 Agent 的(s,a)信息,還加入了其他 Agent 的(s,a)信息。這種做法在多智能體算法中不算新奇了,在訓(xùn)練學(xué)習(xí)階段利用中心化的評(píng)價(jià)網(wǎng)絡(luò)來(lái)指導(dǎo) Actor 的更新在許多多智能體算法當(dāng)中都用到了這個(gè)技巧。值得一提的是,由于 Critic 需要指導(dǎo) Actor 的更新,所以理論上需要讓 Critic 比 Actor 更快的收斂,因此通常 Critic 的 learning rate 需要設(shè)置的比 Actor 要稍大些。
3.3 Update Parameters 過(guò)程
下面我們來(lái)看看 Actor 和 Critic 的更新過(guò)程:
- Critic 更新
上面是 Critic 的更新過(guò)程,Critic 的更新很好理解,利用聯(lián)合觀測(cè)來(lái)確定聯(lián)合行為(DPG中一個(gè)觀測(cè)就對(duì)應(yīng)一個(gè)具體的行為),輸入到 Critic 網(wǎng)絡(luò)中進(jìn)行計(jì)算,最后利用 TD-Error 進(jìn)行梯度更新。
- Actor 更新
Actor 在進(jìn)行更新的時(shí)候,首先把當(dāng)前 Agent 的當(dāng)前行為替換成了另外一個(gè)行為,再用新的聯(lián)合行為去預(yù)估 Critic 的值,新的聯(lián)合行為中其他 Agent 的行為是保持不變的。那么這里為什么要單獨(dú)改變自身 Agent 的行為呢?這是因?yàn)?MADDPG 是一種 off-policy 的算法,我們所取的更新樣本是來(lái)自 Memory Pool 中的,是以往的歷史經(jīng)驗(yàn),但我們現(xiàn)在自身的 Policy 已經(jīng)和之前的不一樣了(已經(jīng)進(jìn)化過(guò)了),因此需要按照現(xiàn)在的 Policy 重新選擇一個(gè)行為進(jìn)行計(jì)算。這和 PPO 算法中的 Importance Sampling 的思想一樣,PPO 是采用概率修正的方式來(lái)解決行為不一致問(wèn)題,而 MADDPG 中干脆直接就舍棄歷史舊行為,按照當(dāng)前策略重采樣一次行為來(lái)進(jìn)行計(jì)算。
- Target 網(wǎng)絡(luò)更新
和 DDPG 一樣,MADDPG 中針對(duì) Actor 和 Critic 的 target 網(wǎng)絡(luò)也是采用 soft update 的,具體內(nèi)容參見(jiàn) 2.2.5 小節(jié)。
以上就是 MADDPG 的全部?jī)?nèi)容。
總結(jié)
以上是生活随笔為你收集整理的多智能体连续行为空间问题求解——MADDPG的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: Sigmoid函数与逻辑回归
- 下一篇: MultiProcessing中主进程与