计图MPI分布式多卡
計圖MPI分布式多卡
計圖分布式基于MPI(Message Passing Interface),主要闡述使用計圖MPI,進行多卡和分布式訓練。目前計圖分布式處于測試階段。
計圖MPI安裝
計圖依賴OpenMPI,用戶可以使用如下命令安裝OpenMPI:
sudo apt install openmpi-bin openmpi-common libopenmpi-dev
計圖會自動檢測環境變量中是否包含mpicc,如果計圖成功的檢測到了mpicc,輸出如下信息:
[i 0502 14:09:55.758481 24 init.py:203] Found mpicc(1.10.2) at /usr/bin/mpicc
如果計圖沒有在環境變量中找到mpi,用戶也可以手動指定mpicc的路徑告訴計圖,添加環境變量即可:export mpicc_path=/you/mpicc/path
OpenMPI安裝完成以后,用戶無需修改代碼,需要做的僅僅是修改啟動命令行,計圖就會用數據并行的方式,自動完成并行操作。
單卡訓練代碼
python3.7 -m jittor.test.test_resnet
分布式多卡訓練代碼
mpirun -np 4 python3.7 -m jittor.test.test_resnet
指定特定顯卡的多卡訓練代碼
CUDA_VISIBLE_DEVICES=“2,3” mpirun -np 2 python3.7 -m jittor.test.test_resnet
便捷性的背后,計圖的分布式算子的支撐,計圖支持的mpi算子后端會使用nccl進行進一步的加速。計圖所有分布式算法的開發,均在Python前端完成,讓分布式算法的靈活度增強,開發分布式算法的難度也大大降低。
基于這些mpi算子接口,研發團隊已經集成了如下三種分布式相關的算法:
? 分布式數據并行加載
? 分布式優化器
? 分布式同步批歸一化層
用戶在使用MPI進行分布式訓練時,計圖內部的Dataset類會自動并行分發數據,需要注意的是Dataset類中設置的Batch size是所有節點的batch size之和,也就是總batch size,不是單個節點接收到的batch size。
MPI接口
目前MPI開放接口如下:
? jt.mpi: 計圖的MPI模塊,當計圖不在MPI環境下時,jt.mpi == None, 用戶可以用這個判斷是否在mpi環境下。
? jt.Module.mpi_param_broadcast(root=0): 將模塊的參數從root節點廣播給其他節點。
? jt.mpi.mpi_reduce(x, op=‘add’, root=0): 將所有節點的變量x使用算子op,reduce到root節點。如果op是’add’或者’sum’,該接口會把所有變量求和,如果op是’mean’,該接口會取均值。
? jt.mpi.mpi_broadcast(x, root=0): 將變量x從root節點廣播到所有節點。
? jt.mpi.mpi_all_reduce(x, op=‘add’): 將所有節點的變量x使用一起reduce,并且吧reduce的結果再次廣播到所有節點。如果op是’add’或者’sum’,該接口會把所有變量求和,如果op是’mean’,該接口會取均值。
實例:MPI實現分布式同步批歸一化層
下面的代碼是使用計圖實現分布式同步批,歸一化層的實例代碼,在原來批歸一化層的基礎上,只需增加三行代碼,就可以實現分布式的batch norm,添加的代碼如下:
將均值和方差,通過all reduce同步到所有節點
if self.sync and jt.mpi:
xmean = xmean.mpi_all_reduce(“mean”)
x2mean = x2mean.mpi_all_reduce(“mean”)
注:計圖內部已經實現了同步的批歸一化層,用戶不需要自己實現
分布式同步批歸一化層的完整代碼:
class BatchNorm(Module):
def init(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True, sync=True):
assert affine == None
self.sync = syncself.num_features = num_featuresself.is_train = is_trainself.eps = epsself.momentum = momentumself.weight = init.constant((num_features,), "float32", 1.0)self.bias = init.constant((num_features,), "float32", 0.0)self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()def execute(self, x):if self.is_train:xmean = jt.mean(x, dims=[0,2,3], keepdims=1)x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)# 將均值和方差,通過all reduce同步到所有節點if self.sync and jt.mpi:xmean = xmean.mpi_all_reduce("mean")x2mean = x2mean.mpi_all_reduce("mean")xvar = x2mean-xmean*xmeannorm_x = (x-xmean)/jt.sqrt(xvar+self.eps)self.running_mean += (xmean.sum([0,2,3])-self.running_mean)*self.momentumself.running_var += (xvar.sum([0,2,3])-self.running_var)*self.momentumelse:running_mean = self.running_mean.broadcast(x, [0,2,3])running_var = self.running_var.broadcast(x, [0,2,3])norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)w = self.weight.broadcast(x, [0,2,3])b = self.bias.broadcast(x, [0,2,3])return norm_x * w + b
總結
以上是生活随笔為你收集整理的计图MPI分布式多卡的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 自定义算子高性能开发
- 下一篇: 计图点云库