Skip to content

Multi-GPU Training Setup

Training on multiple GPUs can dramatically reduce training time, but requires proper setup and understanding of different parallelism strategies.

Most common approach

  • Same model replicated on each GPU
  • Different data batches per GPU
  • Gradients synchronized after backward pass
  • Linear scaling (ideally)

Best for: Most training scenarios, models that fit on single GPU

Recommended over DP

  • More efficient than Data Parallel
  • Better multi-node support
  • Process per GPU
  • NCCL backend for communication

Best for: Any multi-GPU training

For very large models

  • Model split across GPUs
  • Different layers on different GPUs
  • Sequential execution
  • Communication overhead

Best for: Models too large for single GPU

Advanced

  • Model split into stages
  • Micro-batching
  • Overlapped computation
  • Complex to implement

Best for: Very large models with many layers

Critical: Use identical GPUs

✓ Good: 4x RTX 4090
✓ Good: 2x A100 80GB
✗ Bad: 2x RTX 4090 + 1x RTX 4080
✗ Bad: 1x A100 + 1x A6000

Why identical GPUs?

  • Mixed GPUs limited by slowest card
  • Memory differences cause imbalance
  • Driver compatibility issues

Check PCIe lanes:

# See PCIe generation and lanes
lspci -vv | grep -i "lnkcap\|lnksta"

# Ideal for multi-GPU:
# 2 GPUs: x16/x16 or x16/x8
# 3 GPUs: x16/x8/x8
# 4 GPUs: x8/x8/x8/x8

PCIe Gen 4 x8 vs Gen 3 x16:

  • Gen 4 x8: ~16 GB/s
  • Gen 3 x16: ~16 GB/s
  • Both adequate for most training

If available (A100, H100, etc.):

# Check NVLink status
nvidia-smi nvlink --status

# Should show connected topology

NVLink benefits:

  • 10x faster GPU-to-GPU vs PCIe
  • Enables larger effective batch sizes
  • Better scaling efficiency

Rule of thumb: 16GB RAM per GPU

2 GPUs: 32GB minimum
4 GPUs: 64GB minimum
8 GPUs: 128GB minimum

Step 1: Basic Training Script

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler

def setup(rank, world_size):
    """Initialize distributed training"""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # Initialize process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    """Clean up distributed training"""
    dist.destroy_process_group()

def train(rank, world_size):
    """Training function for each GPU"""
    # Setup
    setup(rank, world_size)

    # Create model and move to GPU
    model = YourModel().to(rank)
    model = DDP(model, device_ids=[rank])

    # Create distributed sampler
    train_sampler = DistributedSampler(
        train_dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=train_sampler,
        num_workers=4,
        pin_memory=True
    )

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

    # Training loop
    for epoch in range(num_epochs):
        train_sampler.set_epoch(epoch)  # Important for shuffling

        for batch in train_loader:
            data, target = batch
            data, target = data.to(rank), target.to(rank)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

    cleanup()

def main():
    world_size = torch.cuda.device_count()
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)

if __name__ == "__main__":
    main()

Step 2: Launch Training

# Automatic (recommended)
python train_script.py

# Manual with torchrun
torchrun --nproc_per_node=4 train_script.py

# Multi-node
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 \
  --master_addr="192.168.1.1" --master_port=12355 train_script.py

Linear scaling rule:

# Single GPU
batch_size_per_gpu = 32

# 4 GPUs
total_batch_size = 32 * 4  # = 128

Adjust learning rate:

# Linear scaling (most common)
lr_single_gpu = 1e-3
lr_multi_gpu = 1e-3 * num_gpus

# Square root scaling (sometimes better)
lr_multi_gpu = 1e-3 * sqrt(num_gpus)

For even larger effective batch sizes:

accumulation_steps = 4  # Accumulate 4 batches
effective_batch = batch_size * num_gpus * accumulation_steps

for i, batch in enumerate(train_loader):
    output = model(data)
    loss = criterion(output, target) / accumulation_steps
    loss.backward()

    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

Gradient bucketing:

# PyTorch DDP automatically buckets gradients
model = DDP(
    model,
    device_ids=[rank],
    bucket_cap_mb=25,  # Default, tune if needed
    find_unused_parameters=False  # Set True only if necessary
)

NCCL tuning:

# Environment variables for better performance
export NCCL_DEBUG=INFO  # For debugging
export NCCL_IB_DISABLE=0  # Enable InfiniBand if available
export NCCL_SOCKET_IFNAME=eth0  # Specify network interface
# Watch all GPUs
watch -n 1 nvidia-smi

# Or use nvtop for better visualization
nvtop

What to look for:

  • All GPUs at similar utilization (80-100%)
  • Balanced memory usage across GPUs
  • Minimal GPU-Util fluctuation

Ideal scaling:

1 GPU:  1x speed (baseline)
2 GPUs: 2x speed (100% efficiency)
4 GPUs: 4x speed (100% efficiency)

Real-world scaling:

1 GPU:  1x speed
2 GPUs: 1.9x speed (95% efficiency) ✓ Good
4 GPUs: 3.5x speed (88% efficiency) ✓ Acceptable
8 GPUs: 6.5x speed (81% efficiency) ✓ Acceptable

Measure it:

import time

# Single GPU
start = time.time()
train_one_epoch()
single_gpu_time = time.time() - start

# Multi GPU
start = time.time()
train_one_epoch()  # With DDP
multi_gpu_time = time.time() - start

speedup = single_gpu_time / multi_gpu_time
efficiency = speedup / num_gpus * 100

print(f"Speedup: {speedup:.2f}x")
print(f"Efficiency: {efficiency:.1f}%")
# Check all GPUs visible
python -c "import torch; print(torch.cuda.device_count())"

# Should match:
nvidia-smi -L

If mismatch:

# Check CUDA_VISIBLE_DEVICES
echo $CUDA_VISIBLE_DEVICES

# Clear it if set incorrectly
unset CUDA_VISIBLE_DEVICES

Symptoms:

  • GPU 0 at 100%, others at 60%
  • Different memory usage across GPUs

Causes:

  1. Data loading bottleneck
  2. Uneven data distribution
  3. Model on GPU 0 during evaluation

Solutions:

# Use distributed sampler
train_sampler = DistributedSampler(dataset, shuffle=True)

# Increase num_workers
DataLoader(dataset, num_workers=8)  # Higher

# Move evaluation model correctly
model.eval()
with torch.no_grad():
    # Ensure data goes to correct GPU
    data = data.to(rank)

Surprising but common:

Cause: Larger effective batch size

Solution:

# Reduce per-GPU batch size
batch_size = 16  # Instead of 32

# Or use gradient accumulation
# (See above)

Check:

  1. PCIe bandwidth: nvidia-smi topo -m
  2. CPU bottleneck: htop during training
  3. Data loading: Profile with PyTorch Profiler
  4. Network (multi-node): Check with iperf3

For models too large for single GPU:

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

# Wrap model with FSDP
model = FSDP(
    model,
    device_id=rank,
    auto_wrap_policy=default_auto_wrap_policy,
    mixed_precision=mp_policy,
)

# Train normally
# Model automatically sharded across GPUs

When you need it:

  • More than 8 GPUs
  • Model doesn’t fit on single node
  • Distributed across cluster

See: HPC Integration for SLURM-based multi-node training

import torch
import time
from torch.utils.data import DataLoader, TensorDataset

def benchmark_multi_gpu(model, batch_size, num_iterations=100):
    """Benchmark multi-GPU training speed"""

    # Dummy data
    dummy_data = torch.randn(1000, 3, 224, 224)
    dummy_labels = torch.randint(0, 1000, (1000,))
    dataset = TensorDataset(dummy_data, dummy_labels)

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model.train()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    criterion = torch.nn.CrossEntropyLoss()

    # Warmup
    for i, (data, target) in enumerate(loader):
        if i >= 10:
            break
        data, target = data.cuda(), target.cuda()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

    # Benchmark
    torch.cuda.synchronize()
    start = time.time()

    for i, (data, target) in enumerate(loader):
        if i >= num_iterations:
            break
        data, target = data.cuda(), target.cuda()

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

    torch.cuda.synchronize()
    elapsed = time.time() - start

    throughput = (num_iterations * batch_size) / elapsed
    print(f"Throughput: {throughput:.2f} images/sec")
    print(f"Time per iteration: {elapsed/num_iterations*1000:.2f} ms")

# Usage
# benchmark_multi_gpu(model, batch_size=32)