Why Distributed Training?
Two reasons to distribute training across multiple GPUs:
- Speed: A model that takes 2 weeks on one GPU takes days on 8
- Memory: A 70B model requires ~140GB in bf16 — doesn't fit on a single 80GB A100
These require different strategies. Speed → data parallelism. Memory → model or tensor parallelism.
Data Parallelism: The Simplest Case
Each GPU trains on a different batch. Gradients are averaged across GPUs and weights are updated identically on all GPUs.
GPU 0: batch_0 → forward → loss_0 → gradients_0 ─┐
GPU 1: batch_1 → forward → loss_1 → gradients_1 ─┼──> average gradients → update all weights
GPU 2: batch_2 → forward → loss_2 → gradients_2 ─┘
This works because all GPUs hold a full copy of the model. Effective batch size = single-GPU batch × number of GPUs.
PyTorch DistributedDataParallel (DDP)
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
def setup(rank, world_size):
dist.init_process_group(
backend="nccl", # NCCL for GPU-GPU communication
rank=rank,
world_size=world_size,
)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def train(rank, world_size, args):
setup(rank, world_size)
# Load model and move to this GPU
model = MyModel().to(rank)
# Wrap with DDP — handles gradient synchronization
model = DDP(model, device_ids=[rank])
# Distributed sampler: each GPU sees different data
sampler = DistributedSampler(
dataset,
num_replicas=world_size,
rank=rank,
shuffle=True,
)
loader = DataLoader(dataset, batch_size=32, sampler=sampler)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(args.epochs):
sampler.set_epoch(epoch) # Reshuffle each epoch
for batch in loader:
batch = {k: v.to(rank) for k, v in batch.items()}
optimizer.zero_grad()
loss = model(batch)
loss.backward()
# DDP automatically averages gradients across GPUs here
optimizer.step()
# Only log/save from rank 0
if rank == 0:
print(f"Epoch {epoch}: loss={loss.item():.4f}")
torch.save(model.module.state_dict(), "checkpoint.pt") # .module unwraps DDP
cleanup()
# Launch: torchrun --nproc_per_node=4 train.py
Launching with torchrun
# Single node, 4 GPUs
torchrun --nproc_per_node=4 train.py --epochs 10
# Multi-node: 2 nodes, 4 GPUs each (8 GPUs total)
# On node 0:
torchrun --nnodes=2 --nproc_per_node=4 --node_rank=0 --master_addr=192.168.1.1 train.py
# On node 1:
torchrun --nnodes=2 --nproc_per_node=4 --node_rank=1 --master_addr=192.168.1.1 train.py
Gradient Accumulation: Simulating Larger Batches
When you can't fit a large batch on one GPU:
accumulation_steps = 4 # Simulate 4x larger batch size
optimizer.zero_grad()
for i, batch in enumerate(loader):
loss = model(batch) / accumulation_steps # Scale loss
loss.backward() # Accumulate gradients
if (i + 1) % accumulation_steps == 0:
optimizer.step() # Update weights every 4 steps
optimizer.zero_grad()
This is equivalent to a batch size of batch_size × accumulation_steps without requiring the memory.
FSDP: When the Model Doesn't Fit on One GPU
DDP requires each GPU to hold a full model copy. For large models (7B+), this is often impossible.
Fully Sharded Data Parallel (FSDP) shards the model across GPUs:
DDP: GPU 0: full model + batch_0
GPU 1: full model + batch_1
FSDP: GPU 0: shard_0 of model + batch_0
GPU 1: shard_1 of model + batch_1
(parameters are all-gathered when needed, discarded after)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
# Auto-wrap policy: shard each transformer block separately
auto_wrap_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={LlamaDecoderLayer},
)
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
sharding_strategy=ShardingStrategy.FULL_SHARD, # Maximum memory savings
device_id=rank,
)
FSDP vs DDP: When to Use Which
| Strategy | When to use |
|---|---|
| Single GPU | Model fits on one GPU, speed is acceptable |
| DDP | Model fits on one GPU, need more speed |
| FSDP | Model doesn't fit on one GPU |
| Tensor Parallelism | Model so large that individual layers don't fit |
For models up to ~7B with 4× 80GB A100s, DDP works. For larger models, use FSDP.
Model Parallelism (Pipeline Parallelism)
Split the model's layers across GPUs. GPU 0 runs layers 1-12, GPU 1 runs layers 13-24:
# Naive model parallelism (pipeline bubbles are a problem)
class PipelineModel(nn.Module):
def __init__(self):
super().__init__()
# First half on GPU 0
self.embedding = nn.Embedding(vocab_size, d_model).to(0)
self.layers_0_11 = nn.ModuleList([
TransformerBlock(d_model) for _ in range(12)
]).to(0)
# Second half on GPU 1
self.layers_12_23 = nn.ModuleList([
TransformerBlock(d_model) for _ in range(12)
]).to(1)
self.lm_head = nn.Linear(d_model, vocab_size).to(1)
def forward(self, x):
x = self.embedding(x) # GPU 0
for layer in self.layers_0_11:
x = layer(x) # GPU 0
x = x.to(1) # Transfer to GPU 1
for layer in self.layers_12_23:
x = layer(x) # GPU 1
return self.lm_head(x) # GPU 1
Naive pipeline parallelism suffers from GPU idle time (pipeline bubbles). Production systems use microbatching to overlap computation across stages.
Practical Recommendations
For fine-tuning a 7B model:
# Option 1: QLoRA on single GPU (8B VRAM needed)
model = load_in_4bit(model_name)
model = apply_lora(model, r=16)
# Train normally
# Option 2: DDP with 4 GPUs (24GB each)
model = load_model(model_name, dtype=torch.bfloat16)
model = DDP(model, device_ids=[rank])
# 4x speedup
# Option 3: FSDP with 2 GPUs (40GB each)
model = load_model(model_name, dtype=torch.bfloat16)
model = FSDP(model, sharding_strategy=ShardingStrategy.FULL_SHARD)
# Fits on 2× A100 40GB
Debugging Distributed Training
# Check gradient synchronization is working
if rank == 0:
for name, param in model.named_parameters():
if param.grad is not None:
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
print(f"{name}: grad_norm={param.grad.norm():.4f}")
# Print from specific rank only
def print_rank(msg, rank=0, current_rank=None):
if current_rank == rank:
print(f"[Rank {current_rank}] {msg}")
# Check NCCL is working
dist.all_reduce(torch.tensor(1.0).to(rank)) # Should not hang
For serving distributed models in production, see our LLM inference and optimization guide.