Distributed Training with PyTorch DDP | Generated by AI

Home 2025.11

Overview of Karpathy’s Explanation

In his YouTube video “Let’s build GPT: from scratch, in code, spelled out” (starting around the 2-hour 58-minute mark), Andrej Karpathy dives into distributed training for large language models using PyTorch. He focuses on scaling a single-GPU training loop across multiple GPUs (e.g., 8) without major code changes. The core idea is to run identical processes on each GPU, each handling a non-overlapping shard of the data, and synchronize gradients via averaging. This enables effective batch sizes that are multiples of the per-GPU batch size. He emphasizes simplicity, using PyTorch’s DistributedDataParallel (DDP) module, and walks through code examples for a GPT-2-like model.

Karpathy uses an analogy: Imagine a team of researchers (GPUs) independently analyzing different parts of a dataset but periodically “comparing notes” (gradient averaging) to align on the solution. He launches training with torchrun --standalone --nproc_per_node=8 train_gpt2.py, which automatically sets environment variables like RANK, WORLD_SIZE, and LOCAL_RANK.

Distributed Training

Karpathy explains distributed training as a way to parallelize across GPUs while keeping the core forward/backward pass mostly unchanged:

He notes: “The forward is unchanged and backward is mostly unchanged and we’re tacking on this average.”

Distributed Data Parallel (DDP)

DDP is Karpathy’s go-to for multi-GPU training, preferred over the older DataParallel due to better handling of gradient sync and multi-node setups. Wrap the model like this: model = DDP(model, device_ids=[local_rank]).

Quote: “What DDP does for you is… once the backward pass is over, it will call what’s called all-reduce and it basically does an average across all the ranks of their gradients and then it will deposit that average on every single rank.”

For loss logging (since gradients are averaged), reduce scalar losses: Sum them across ranks with all_reduce (op=ReduceOp.SUM), then divide by world_size on rank 0.

The Concept of Rank

Rank is the unique integer ID for each process (0 to world_size - 1, e.g., 0-7 for 8 GPUs), set via os.getenv('RANK'). It determines:

In evaluation (e.g., on H-SWAG dataset), each rank computes local counts (correct predictions, total examples), all-reduces the sums, and rank 0 computes accuracy.

Karpathy stresses: Deterministic seeding ensures all ranks see the same shuffled order, but ranks filter their slices for non-overlap.

Additional Notes on Evaluation and Implementation

Karpathy demonstrates distributed eval on a multiple-choice dataset (H-SWAG): Shard examples, compute local accuracy, all-reduce counts, and print on rank 0. For predictions, select the option with the lowest average loss (or highest prob) per example, using padding/masking.

Practical tips:

This section builds directly on prior single-GPU code, showing minimal changes for scaling.

Let’s build GPT: from scratch, in code, spelled out


Back

x-ai/grok-4.1-fast

Donate