日韩性视频-久久久蜜桃-www中文字幕-在线中文字幕av-亚洲欧美一区二区三区四区-撸久久-香蕉视频一区-久久无码精品丰满人妻-国产高潮av-激情福利社-日韩av网址大全-国产精品久久999-日本五十路在线-性欧美在线-久久99精品波多结衣一区-男女午夜免费视频-黑人极品ⅴideos精品欧美棵-人人妻人人澡人人爽精品欧美一区-日韩一区在线看-欧美a级在线免费观看

歡迎訪問 生活随笔!

生活随笔

當前位置: 首頁 > 人工智能 > pytorch >内容正文

pytorch

深度学习总结:DQN原理,算法及pytorch方式实现

發布時間:2024/9/15 pytorch 45 豆豆
生活随笔 收集整理的這篇文章主要介紹了 深度学习总结:DQN原理,算法及pytorch方式实现 小編覺得挺不錯的,現在分享給大家,幫大家做個參考.

文章目錄

  • Q-learning原理圖
  • Q-learning算法描述:
  • pytorch實現:
    • Q-network實現:
    • DQN實現:
      • 2個Q-network,其中一個為target Q-network;
      • take action獲取下一步的動作,這個部分就是和環境互動的部分,選取動作是基于e-greedy來的;
      • store transmitions就是保存數據,用于experience replay;
      • 最重要的是學習過程:就是算法描述的核心部分, 需要針對minibatach的處理,需要做regression更新Q-network,還需要定期更新target Q-network。
    • 訓練實現:優化游戲環境的reward, 實現算法描述的for each episode(通過for range控制) for each time step(通過游戲返回的done終止)

Q-learning原理圖

Q-learning算法描述:

pytorch實現:

Q-network實現:

輸入s,輸出是Q(s,a_i)即所有action在s下對應的Q值。

class Net(nn.Module):def __init__(self, ):super(Net, self).__init__()self.fc1 = nn.Linear(N_STATES, 50)self.fc1.weight.data.normal_(0, 0.1) # initializationself.out = nn.Linear(50, N_ACTIONS)self.out.weight.data.normal_(0, 0.1) # initializationdef forward(self, x):x = self.fc1(x)x = F.relu(x)actions_value = self.out(x)return actions_value

DQN實現:

DQN包含:

2個Q-network,其中一個為target Q-network;

class DQN(object):def __init__(self):self.eval_net, self.target_net = Net(), Net()self.learn_step_counter = 0 # for target updatingself.memory_counter = 0 # for storing memoryself.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2)) # initialize memoryself.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR)self.loss_func = nn.MSELoss()

take action獲取下一步的動作,這個部分就是和環境互動的部分,選取動作是基于e-greedy來的;

def choose_action(self, x):x = torch.unsqueeze(torch.FloatTensor(x), 0)# input only one sampleif np.random.uniform() < EPSILON: # greedyactions_value = self.eval_net.forward(x)action = torch.max(actions_value, 1)[1].data.numpy()action = action[0] if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE) # return the argmax indexelse: # randomaction = np.random.randint(0, N_ACTIONS)action = action if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE)return action

store transmitions就是保存數據,用于experience replay;

def store_transition(self, s, a, r, s_):transition = np.hstack((s, [a, r], s_))# replace the old memory with new memoryindex = self.memory_counter % MEMORY_CAPACITYself.memory[index, :] = transitionself.memory_counter += 1

最重要的是學習過程:就是算法描述的核心部分, 需要針對minibatach的處理,需要做regression更新Q-network,還需要定期更新target Q-network。

def learn(self):# target parameter updateif self.learn_step_counter % TARGET_REPLACE_ITER == 0:self.target_net.load_state_dict(self.eval_net.state_dict())self.learn_step_counter += 1# sample batch transitionssample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)b_memory = self.memory[sample_index, :]b_s = torch.FloatTensor(b_memory[:, :N_STATES])b_a = torch.LongTensor(b_memory[:, N_STATES:N_STATES+1].astype(int))b_r = torch.FloatTensor(b_memory[:, N_STATES+1:N_STATES+2])b_s_ = torch.FloatTensor(b_memory[:, -N_STATES:])# q_eval w.r.t the action in experienceq_eval = self.eval_net(b_s).gather(1, b_a) # shape (batch, 1)q_next = self.target_net(b_s_).detach() # detach from graph, don't backpropagateq_target = b_r + GAMMA * q_next.max(1)[0].view(BATCH_SIZE, 1) # shape (batch, 1)loss = self.loss_func(q_eval, q_target)self.optimizer.zero_grad()loss.backward()self.optimizer.step()

訓練實現:優化游戲環境的reward, 實現算法描述的for each episode(通過for range控制) for each time step(通過游戲返回的done終止)

dqn = DQN()print('\nCollecting experience...') for i_episode in range(400):s = env.reset()ep_r = 0while True:env.render()a = dqn.choose_action(s)# take actions_, r, done, info = env.step(a)# modify the rewardx, x_dot, theta, theta_dot = s_r1 = (env.x_threshold - abs(x)) / env.x_threshold - 0.8r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5r = r1 + r2dqn.store_transition(s, a, r, s_)ep_r += rif dqn.memory_counter > MEMORY_CAPACITY:dqn.learn()if done:print('Ep: ', i_episode,'| Ep_r: ', round(ep_r, 2))if done:breaks = s_

總結

以上是生活随笔為你收集整理的深度学习总结:DQN原理,算法及pytorch方式实现的全部內容,希望文章能夠幫你解決所遇到的問題。

如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。