作者丨颜挺帅@知乎(已授权)

来源丨https://zhuanlan.zhihu.com/p/489892744

编辑丨极市平台

由于工作需要,最近在补充分布式训练方面的知识。经过一番理论学习后仍觉得意犹未尽,很多知识点无法准确get到(例如:分布式原语scatter、all reduce等代码层面应该是什么样的,ring all reduce 算法在梯度同步时是怎么使用的,parameter server参数是如何部分更新的)。

著名物理学家,诺贝尔奖得主Richard Feynman办公室的黑板上写了:"What I cannot create, I do not understand."。在程序员界也经常有"show me the code"的口号。 因此,我打算写一系列的分布式训练的文章,将以往抽象的分布式训练的概念以代码的形式展现出来,并保证每个代码可执行、可验证、可复现,并贡献出来源码让大家相互交流。

经过调研发现Pytorch对于分布式训练做好很好的抽象且接口完善,因此本系列文章将以pytorch为主要框架进行,文章中的例子很多都来自pytorch的文档,并在此基础上进行了调试和扩充。

最后,由于分布式训练的理论介绍网络上已经很多了,理论部分的介绍不会是本系列文章的重点,我会将重点放在代码层面的介绍上面。

Pytorch - 分布式训练极简体验:https://zhuanlan.zhihu.com/p/477073906

Pytorch - 分布式通信原语(附源码):https://zhuanlan.zhihu.com/p/478953028

Pytorch - 手写allreduce分布式训练(附源码):https://zhuanlan.zhihu.com/p/482557067

Pytorch - 算子间并行极简实现(附源码):https://zhuanlan.zhihu.com/p/483640235

Pytorch - 多机多卡极简实现(附源码):https://zhuanlan.zhihu.com/p/486130584

1. 介绍

Pytorch在1.9.0引入了torchrun,用其替代1.9.0以前版本的torch.distributed.launch。torchrun在torch.distributed.launch 功能的基础上主要新增了两个功能:

本例中会先在node0上启动4 GPU的worker group ,等其训练一段时间后,会在Node1上再启动4 GPU的workers,并与Node1上的workers构成一个新的worker group,最终构成一个2机8卡的分布式训练。

pytorch框架深度学习实战教程(Pytorch-弹性训练极简实现)(1)

2. 模型构建

一个简单的全连接模型神经网络模型

class ToyModel(nn.Module): def __init__(self): super(ToyModel, self).__init__() self.net1 = nn.Linear(10, 10) self.relu = nn.ReLU() self.net2 = nn.Linear(10, 5) def forward(self, x): return self.net2(self.relu(self.net1(x)))

3. checkpoint 处理

由于再每次增加或删除node时,会将所有worker kill掉,然后再重新启动所有worker进行训练。因此,在训练代码中要对训练的状态进行保存,以保证重启后能接着上次的状态继续训练。

需要保存的信息一般有如下内容:

save和load的代码如下所示

def save_checkpoint(epoch, model, optimizer, path): torch.save({ "epoch": epoch, "model_state_dict": model.state_dict(), "optimize_state_dict": optimizer.state_dict(), }, path) def load_checkpoint(path): checkpoint = torch.load(path) return checkpoint

4. 训练代码

初始化逻辑如下:

local_rank = int(os.environ["LOCAL_RANK"]) rank = int(os.environ["RANK"]) print(f"[{os.getpid()}] (rank = {rank}, local_rank = {local_rank}) train worker starting...") model = ToyModel().cuda(local_rank) ddp_model = DDP(model, [local_rank]) loss_fn = nn.MSELoss() optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) optimizer.zero_grad() max_epoch = 100 first_epoch = 0 ckp_path = "checkpoint.pt" if os.path.exists(ckp_path): print(f"load checkpoint from {ckp_path}") checkpoint = load_checkpoint(ckp_path) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimize_state_dict"]) first_epoch = checkpoint["epoch"]

训练逻辑:

for i in range(first_epoch, max_epoch): time.sleep(1) # 为了展示动态添加node效果,这里添加sleep函数来降低训练的速度 outputs = ddp_model(torch.randn(20, 10).to(local_rank)) labels = torch.randn(20, 5).to(local_rank) loss = loss_fn(outputs, labels) loss.backward() print(f"[{os.getpid()}] epoch {i} (rank = {rank}, local_rank = {local_rank}) loss = {loss.item()}\n") optimizer.step() save_checkpoint(i, model, optimizer, ckp_path)

5. 启动方式

由于我们使用torchrun来启动多机多卡任务,无需使用spawn接口来启动多个进程(torchrun会负责将我们的python script启动为一个process),因此直接调用上文编写的train函数,并在前后分别添加DistributedDataParallel的初始化和效果函数即可。

下面代码描述了上文train接口的调用。

def run(): env_dict = { key: os.environ[key] for key in ("MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "LOCAL_WORLD_SIZE") } print(f"[{os.getpid()}] Initializing process group with: {env_dict}") dist.init_process_group(backend="nccl") train() dist.destroy_process_group() if __name__ == "__main__": run()

本例中使用torchrun来执行多机多卡的分布式训练任务(注:torch.distributed.launch 已经被pytorch淘汰了,尽量不要再使用)。启动脚本描述如下(注:node0和node1均通过该脚本进行启动)

torchrun \ --nnodes=1:3\ --nproc_per_node=4\ --max_restarts=3\ --rdzv_id=1\ --rdzv_backend=c10d\ --rdzv_endpoint="192.0.0.1:1234"\ train_elastic.py

6. 结果分析

代码:BetterDL - train_elastic.py

运行环境: 2台4卡 v100机器

image: pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime gpu: v100

先在node0上执行执行启动脚本

torchrun \ --nnodes=1:3\ --nproc_per_node=4\ --max_restarts=3\ --rdzv_id=1\ --rdzv_backend=c10d\ --rdzv_endpoint="192.0.0.1:1234"\ train_elastic.py

得到如下结果

r/workspace/DDP# sh run_elastic.sh [4031] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '44901', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'} [4029] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '44901', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'} [4030] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '44901', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'} [4032] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '44901', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'} [4029] (rank = 0, local_rank = 0) train worker starting... [4030] (rank = 1, local_rank = 1) train worker starting... [4032] (rank = 3, local_rank = 3) train worker starting... [4031] (rank = 2, local_rank = 2) train worker starting... [4101] epoch 0 (rank = 1, local_rank = 1) loss = 0.9288564920425415 [4103] epoch 0 (rank = 3, local_rank = 3) loss = 0.9711472988128662 [4102] epoch 0 (rank = 2, local_rank = 2) loss = 1.0727070569992065 [4100] epoch 0 (rank = 0, local_rank = 0) loss = 0.9402943253517151 [4100] epoch 1 (rank = 0, local_rank = 0) loss = 1.0327017307281494 [4101] epoch 1 (rank = 1, local_rank = 1) loss = 1.4485043287277222 [4103] epoch 1 (rank = 3, local_rank = 3) loss = 1.0959293842315674 [4102] epoch 1 (rank = 2, local_rank = 2) loss = 1.0669530630111694 ...

在node1上执行与上面相同的脚本

torchrun \ --nnodes=1:3\ --nproc_per_node=4\ --max_restarts=3\ --rdzv_id=1\ --rdzv_backend=c10d\ --rdzv_endpoint="192.0.0.1:1234"\ train_elastic.py

node1上结果如下:

/workspace/DDP# sh run_elastic.sh [696] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'} [697] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'} [695] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'} [694] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'} [697] (rank = 7, local_rank = 3) train worker starting... [695] (rank = 5, local_rank = 1) train worker starting... [694] (rank = 4, local_rank = 0) train worker starting... [696] (rank = 6, local_rank = 2) train worker starting... load checkpoint from checkpoint.ptload checkpoint from checkpoint.pt load checkpoint from checkpoint.pt load checkpoint from checkpoint.pt [697] epoch 35 (rank = 7, local_rank = 3) loss = 1.1888569593429565 [694] epoch 35 (rank = 4, local_rank = 0) loss = 0.8916441202163696 [695] epoch 35 (rank = 5, local_rank = 1) loss = 1.5685604810714722 [696] epoch 35 (rank = 6, local_rank = 2) loss = 1.11683189868927 [696] epoch 36 (rank = 6, local_rank = 2) loss = 1.3724170923233032 [694] epoch 36 (rank = 4, local_rank = 0) loss = 1.061527967453003 [695] epoch 36 (rank = 5, local_rank = 1) loss = 0.96876460313797 [697] epoch 36 (rank = 7, local_rank = 3) loss = 0.8060566782951355 ...

node0上结果如下:

... [4100] epoch 35 (rank = 0, local_rank = 0) loss = 1.0746158361434937 [4101] epoch 35 (rank = 1, local_rank = 1) loss = 1.1712706089019775 [4103] epoch 35 (rank = 3, local_rank = 3) loss = 1.1774182319641113 [4102] epoch 35 (rank = 2, local_rank = 2) loss = 1.0898035764694214 WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4100 closing signal SIGTERM WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4101 closing signal SIGTERM WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4102 closing signal SIGTERM WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4103 closing signal SIGTERM [4164] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'} [4165] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'} [4162] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'} [4163] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'} [4162] (rank = 0, local_rank = 0) train worker starting... [4163] (rank = 1, local_rank = 1) train worker starting... [4164] (rank = 2, local_rank = 2) train worker starting... [4165] (rank = 3, local_rank = 3) train worker starting... load checkpoint from checkpoint.pt load checkpoint from checkpoint.pt load checkpoint from checkpoint.pt load checkpoint from checkpoint.pt [4165] epoch 35 (rank = 3, local_rank = 3) loss = 1.3437936305999756 [4162] epoch 35 (rank = 0, local_rank = 0) loss = 1.5693414211273193 [4163] epoch 35 (rank = 1, local_rank = 1) loss = 1.199862003326416 [4164] epoch 35 (rank = 2, local_rank = 2) loss = 1.0465545654296875 [4163] epoch 36 (rank = 1, local_rank = 1) loss = 0.9741991758346558 [4162] epoch 36 (rank = 0, local_rank = 0) loss = 1.3609280586242676 [4164] epoch 36 (rank = 2, local_rank = 2) loss = 0.9585908055305481 [4165] epoch 36 (rank = 3, local_rank = 3) loss = 0.9169824123382568 ...

,