Checkpointing a 70B model takes 20-30 minutes and consumes 280GB of disk — you can only checkpoint every few hours

devtools0 views
You are training a 70B parameter model across 32 GPUs. A full checkpoint includes: model weights (140GB in fp16), optimizer states (140GB — Adam stores 2 states per parameter), learning rate scheduler state, data loader position, and RNG state. Total checkpoint size: 280-300GB. Writing 280GB to network storage takes 15-30 minutes depending on I/O bandwidth. During checkpointing, all 32 GPUs idle — no training happens. If you checkpoint every 30 minutes, you lose 15-30 minutes per checkpoint = 33-50% of training time is wasted on saving. If you checkpoint every 3 hours, you lose 8-15% to checkpointing but risk losing up to 3 hours of compute if the training crashes. So what? Checkpointing frequency is a forced trade-off between safety (frequent checkpoints = less lost compute on crash) and efficiency (checkpoints waste training time). At $80/hour for 32 GPUs, every 20-minute checkpoint costs $27 in idle compute. Checkpointing every 30 minutes across a 7-day training run costs $5,400 in wasted compute — 15% of total training cost. But not checkpointing risks $80 × 3 hours = $240 per crash, and crashes happen 1-2 times per day at scale. Why does this persist? Asynchronous checkpointing (save to local SSD without pausing training, then copy to network storage in background) exists in some frameworks (DeepSpeed, Nebula) but is not the default in PyTorch. Incremental checkpointing (only save changed parameters) would reduce checkpoint size but optimizer states change entirely every step. Checkpoint compression exists but adds CPU overhead.

Evidence

Llama 2 70B checkpoint size: ~280GB (model + optimizer). NVMe SSD write speed: 3-7 GB/s (20-40 seconds for 280GB). Network storage write speed: 1-3 GB/s (1.5-5 minutes for 280GB). DeepSpeed ZeRO-3 supports async checkpointing but not enabled by default. Meta's Llama 3 training report mentioned significant checkpoint overhead. PyTorch native checkpointing (torch.save) is synchronous and blocking.

Comments