• 首页 > 
  • AI技术 > 
  • 如何在PyTorch中实现多GPU协同训练

如何在PyTorch中实现多GPU协同训练

GPU
小华
2025-10-01

在PyTorch中实现多GPU协同训练,通常需要使用torch.nn.DataParalleltorch.nn.parallel.DistributedDataParallel。以下是这两种方法的简要说明和示例代码:

方法一:使用 torch.nn.DataParallel

DataParallel 是一个简单的并行化方法,它将模型复制到每个GPU上,并将输入数据分割到各个GPU上进行前向传播,然后将结果收集并合并。

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
# 假设我们有一个模型和一个数据集
model = YourModel()
dataset = YourDataset()
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 将模型移动到GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# 使用DataParallel包装模型
if torch.cuda.device_count() > 1:
print(f"Let's use {torch.cuda.device_count()} GPUs!")
model = nn.DataParallel(model)
# 训练模型
for epoch in range(num_epochs):
for data, target in dataloader:
data, target = data.to(device), target.to(device)
# 前向传播
output = model(data)
# 计算损失
loss = nn.CrossEntropyLoss()(output, target)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()

方法二:使用 torch.nn.parallel.DistributedDataParallel

DistributedDataParallel 是一个更高级的并行化方法,它支持多进程并行训练,通常比 DataParallel 更高效。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, DistributedSampler
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def main(rank, world_size):
# 初始化分布式环境
dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
# 设置设备
device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
# 假设我们有一个模型和一个数据集
model = YourModel().to(device)
dataset = YourDataset()
sampler = DistributedSampler(dataset)
dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
# 使用DistributedDataParallel包装模型
model = DDP(model, device_ids=[rank])
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 训练模型
for epoch in range(num_epochs):
sampler.set_epoch(epoch)
for data, target in dataloader:
data, target = data.to(device), target.to(device)
# 前向传播
output = model(data)
# 计算损失
loss = criterion(output, target)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
mp.spawn(main, args=(world_size,), nprocs=world_size, join=True)

注意事项

  1. 数据并行 vs 分布式并行
  • DataParallel 适用于单机多卡的情况,简单易用,但在大规模分布式训练中效率较低。
  • DistributedDataParallel 适用于多机多卡的情况,支持更高效的并行训练。
  1. 环境设置
  • 使用 DistributedDataParallel 时,需要正确设置分布式环境变量,例如 WORLD_SIZERANK 等。
  • 可以使用 torch.distributed.launchaccelerate 库来简化分布式训练的启动过程。
  1. 性能优化
  • 在使用 DistributedDataParallel 时,可以通过设置 find_unused_parameters=False 来减少不必要的参数检查,提高性能。
  • 确保数据加载器使用 DistributedSampler,以避免数据重复或遗漏。

通过以上方法,你可以在PyTorch中实现多GPU协同训练。选择合适的方法取决于你的具体需求和硬件资源。

亿速云提供售前/售后服务

售前业务咨询

售后技术保障

400-100-2938

7*24小时售后电话

官方微信小程序