Pytorch 多 GPU 并行处理机制
生活随笔
收集整理的這篇文章主要介紹了
Pytorch 多 GPU 并行处理机制
小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.
Pytorch 的多 GPU 處理接口是 torch.nn.DataParallel(module, device_ids),其中 module 參數(shù)是所要執(zhí)行的模型,而 device_ids 則是指定并行的 GPU id 列表。
而其并行處理機(jī)制是,首先將模型加載到主 GPU 上,然后再將模型復(fù)制到各個指定的從 GPU 中,然后將輸入數(shù)據(jù)按 batch 維度進(jìn)行劃分,具體來說就是每個 GPU 分配到的數(shù)據(jù) batch 數(shù)量是總輸入數(shù)據(jù)的 batch 除以指定 GPU 個數(shù)。每個 GPU 將針對各自的輸入數(shù)據(jù)獨(dú)立進(jìn)行 forward 計算,最后將各個 GPU 的 loss 進(jìn)行求和,再用反向傳播更新單個 GPU 上的模型參數(shù),再將更新后的模型參數(shù)復(fù)制到剩余指定的 GPU 中,這樣就完成了一次迭代計算。所以該接口還要求輸入數(shù)據(jù)的 batch 數(shù)量要不小于所指定的 GPU 數(shù)量。
這里有兩點(diǎn)需要注意:
- 主 GPU 默認(rèn)情況下是 0 號 GPU,也可以通過
torch.cuda.set_device(id)來手動更改默認(rèn) GPU。 - 提供的多 GPU 并行列表中需要包含有主 GPU。
作者:葉俊賢
鏈接:https://www.jianshu.com/p/9e36e5e36638
來源:簡書
簡書著作權(quán)歸作者所有,任何形式的轉(zhuǎn)載都請聯(lián)系作者獲得授權(quán)并注明出處。
總結(jié)
以上是生活随笔為你收集整理的Pytorch 多 GPU 并行处理机制的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: PyTorch之前向传播函数自动调用fo
- 下一篇: python any()和all()用法