分布式训练中 PyTorch 的容错策略
一 检查点策略
- 全量检查点:定期持久化训练关键状态,至少包含模型参数、优化器状态、当前 epoch/step,必要时加入数据迭代器状态,用于从最近一致状态恢复训练。实践中常由rank=0写入共享存储,恢复时广播加载。示例:
- 保存:torch.save({'model': model.module.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}, path)
- 恢复:load 后 model.module.load_state_dict / optimizer.load_state_dict,并恢复 epoch/step
- 分层检查点:同时维护“恢复点(每 N 步,轻量)”与“周期检查点(每 epoch,完整)”。例如社区实践中使用参数如--recovery-interval=100、--checkpoint-hist=20,在故障后优先从最近的恢复点回滚,再加载最近周期检查点以继续训练,显著降低回退步数与开销。
- 增量与分布式优化:对超大模型可采用增量检查点(仅保存差异)降低 IO;在使用 FSDP 时可结合reshard优化保存/加载,缩短检查点时间并提升恢复效率。
二 弹性训练与自动恢复
- TorchElastic 核心机制:通过Rendezvous进行成员协商,配合Agent监控与重启异常 worker,在节点宕机、抢占等场景下实现不中断作业的容错与动态扩缩容。典型启动方式使用 rdzv_backend=etcd 与 rdzv_endpoint,训练脚本需实现save_checkpoint / load_checkpoint接口以在成员变更或恢复时持久化/还原训练进度。
- 云原生落地实践:在 Kubernetes 上结合 DLRover 与 TorchElastic,可实现故障机隔离与 Pod 替换、缩容继续训练、扩容自动加回以及FSDP 检查点优化,将故障恢复时间从小时级降至分钟级,显著提升有效训练时间与资源利用率。
三 通信与运行期稳定性
- 通信后端与重试:选用具备错误检测与恢复能力的后端(如 NCCL),结合超时重试与自动重连策略,降低集体通信中断的影响;必要时使用异步错误上报与rank=0 协调的恢复流程,确保参数版本一致性。
- 容错启动与编排:使用 torchrun / torchelastic.launch 等编排工具,自动注入 RANK/WORLD_SIZE/LOCAL_RANK 等环境变量,减少手工配置错误导致的训练中断;在云原生场景结合 Kubeflow 等框架实现作业级自动重启与弹性调度。
四 工程实践与监控建议
- 监控与告警:完善日志/指标与异常检测(如 GPU Xid/ECC/NVLINK/NCCL 错误),在异常趋势出现时提前触发检查点;训练过程中对 loss/学习率/GPU 利用率等关键指标进行实时上报与可视化,便于快速定位“卡住/发散”等问题。
- 可复现性与回滚:统一随机种子、算子确定性设置与数据分片/打乱一致性;在恢复后对比中断前后指标趋势,必要时回滚到更早检查点并降低学习率,确保收敛稳定。