GPU Memory Management
Understanding GPU Memory
Section titled “Understanding GPU Memory”GPU memory is one of the most precious resources in deep learning. Running out of memory (OOM errors) is frustrating, but understanding how memory is allocated helps you optimize usage.
Memory Breakdown
Section titled “Memory Breakdown”Your GPU memory is used by:
-
Model Parameters (10-40% typically)
- Weights and biases
- Relatively fixed size
-
Gradients (10-40%)
- Same size as parameters during training
- Freed after optimizer step
-
Activations (20-60%)
- Intermediate layer outputs
- Grows with batch size
- Biggest memory consumer
-
Optimizer States (varies)
- Adam: 2x parameter size
- SGD with momentum: 1x parameter size
- SGD: minimal overhead
-
CUDA Context (~500MB-2GB)
- PyTorch/framework overhead
- Largely fixed
Quick Diagnostics
Section titled “Quick Diagnostics”Check Current Memory Usage
Section titled “Check Current Memory Usage”import torch
def print_gpu_memory():
"""Print current GPU memory usage"""
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / 1024**3
reserved = torch.cuda.memory_reserved() / 1024**3
max_allocated = torch.cuda.max_memory_allocated() / 1024**3
print(f"Allocated: {allocated:.2f} GB")
print(f"Reserved: {reserved:.2f} GB")
print(f"Max Allocated: {max_allocated:.2f} GB")
# Get total GPU memory
total = torch.cuda.get_device_properties(0).total_memory / 1024**3
print(f"Total GPU Memory: {total:.2f} GB")
print(f"Utilization: {allocated/total*100:.1f}%")
print_gpu_memory()import tensorflow as tf
def print_gpu_memory():
"""Print current GPU memory usage"""
gpus = tf.config.list_physical_devices('GPU')
if gpus:
# Get memory info for first GPU
gpu = gpus[0]
memory_info = tf.config.experimental.get_memory_info(gpu.name.replace('/physical_device:', ''))
allocated = memory_info['current'] / 1024**3
peak = memory_info['peak'] / 1024**3
print(f"Current Allocated: {allocated:.2f} GB")
print(f"Peak Allocated: {peak:.2f} GB")
# TensorFlow dynamically allocates memory, so total may not be immediately available
print("Note: TensorFlow uses dynamic memory allocation")
print_gpu_memory()Monitor During Training
Section titled “Monitor During Training”# Add to your training loop
for epoch in range(epochs):
for batch_idx, (data, target) in enumerate(train_loader):
# Your training code here
loss.backward()
optimizer.step()
# Log memory every 100 batches
if batch_idx % 100 == 0:
print_gpu_memory()Out of Memory (OOM) Solutions
Section titled “Out of Memory (OOM) Solutions”1. Reduce Batch Size (Easiest)
Section titled “1. Reduce Batch Size (Easiest)”# Start large and reduce until it fits
batch_sizes = [128, 64, 32, 16, 8]
for bs in batch_sizes:
try:
train_loader = DataLoader(dataset, batch_size=bs)
model.train()
# Test one batch
data, target = next(iter(train_loader))
output = model(data.cuda())
loss = criterion(output, target.cuda())
loss.backward()
print(f"Batch size {bs} works!")
break
except RuntimeError as e:
if 'out of memory' in str(e):
torch.cuda.empty_cache()
print(f"Batch size {bs} - OOM")
continue
else:
raise e2. Gradient Accumulation
Section titled “2. Gradient Accumulation”Simulate larger batch sizes without using more memory:
# Effective batch size = batch_size * accumulation_steps
accumulation_steps = 4
optimizer.zero_grad()
for i, (data, target) in enumerate(train_loader):
data, target = data.cuda(), target.cuda()
# Forward pass
output = model(data)
loss = criterion(output, target)
# Normalize loss to account for accumulation
loss = loss / accumulation_steps
loss.backward()
# Update weights every accumulation_steps
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()3. Mixed Precision Training
Section titled “3. Mixed Precision Training”Reduces memory usage by ~40-50%:
from torch.cuda.amp import autocast, GradScaler
model = model.cuda()
scaler = GradScaler()
for epoch in range(epochs):
for data, target in train_loader:
data, target = data.cuda(), target.cuda()
optimizer.zero_grad()
# Run forward pass in mixed precision
with autocast():
output = model(data)
loss = criterion(output, target)
# Backward pass with scaling
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()import tensorflow as tf
# Enable mixed precision globally
policy = tf.keras.mixed_precision.Policy('mixed_float16')
tf.keras.mixed_precision.set_global_policy(policy)
# Build and compile model
model = create_model()
optimizer = tf.keras.optimizers.Adam()
# Wrap optimizer with loss scaling (automatic in TensorFlow)
model.compile(
optimizer=optimizer,
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
# Train normally - mixed precision handled automatically
for epoch in range(epochs):
model.fit(train_dataset, epochs=1)Memory savings:
- FP16 uses 2 bytes vs FP32’s 4 bytes
- Can often double batch size
- Minimal accuracy impact for most models
4. Gradient Checkpointing
Section titled “4. Gradient Checkpointing”Trade compute for memory by recomputing activations:
from torch.utils.checkpoint import checkpoint
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(1000, 1000)
self.layer2 = nn.Linear(1000, 1000)
self.layer3 = nn.Linear(1000, 10)
def forward(self, x):
# Use checkpointing for memory-intensive layers
x = checkpoint(self.layer1, x, use_reentrant=False)
x = checkpoint(self.layer2, x, use_reentrant=False)
x = self.layer3(x)
return ximport tensorflow as tf
# TensorFlow has gradient checkpointing via recompute_grad
from tensorflow.python.ops import gradient_checkpoint
class MyModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.layer1 = tf.keras.layers.Dense(1000)
self.layer2 = tf.keras.layers.Dense(1000)
self.layer3 = tf.keras.layers.Dense(10)
@tf.recompute_grad
def call(self, x, training=False):
# Decorator enables gradient checkpointing
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x
# Or use gradient_checkpointing in transformers
# from transformers import TFAutoModel
# model = TFAutoModel.from_pretrained("model_name", gradient_checkpointing=True)Trade-off:
- Saves ~40-50% memory on activations
- Increases training time by ~20-30%
- Best for very deep networks
5. Clear Unused Variables
Section titled “5. Clear Unused Variables”# Free memory explicitly
del large_tensor
torch.cuda.empty_cache()
# Don't keep unnecessary computation graphs
with torch.no_grad():
# Operations here won't build computation graph
validation_output = model(val_data)
# Detach tensors you don't need gradients for
prediction = output.detach()6. Optimize Model Size
Section titled “6. Optimize Model Size”# Use smaller models
model = torchvision.models.resnet18() # Instead of resnet152
# Reduce hidden dimensions
model = Transformer(
d_model=512, # Instead of 1024
nhead=8, # Instead of 16
num_layers=6 # Instead of 12
)
# Use depthwise separable convolutions
from torch.nn import Conv2d
# Standard conv
conv = Conv2d(256, 256, kernel_size=3, padding=1)
# Depthwise separable (fewer parameters)
depthwise = Conv2d(256, 256, kernel_size=3, padding=1, groups=256)
pointwise = Conv2d(256, 256, kernel_size=1)Memory Optimization Strategies
Section titled “Memory Optimization Strategies”Strategy 1: Multi-GPU Training
Section titled “Strategy 1: Multi-GPU Training”Distribute model across GPUs:
# DataParallel (simple but not optimal)
model = nn.DataParallel(model)
# DistributedDataParallel (recommended)
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
# Initialize process group
dist.init_process_group(backend='nccl')
# Wrap model
model = model.cuda()
model = DDP(model, device_ids=[local_rank])Strategy 2: CPU Offloading
Section titled “Strategy 2: CPU Offloading”Move some operations to CPU:
# Keep model on GPU, but compute some metrics on CPU
model.cuda()
for data, target in train_loader:
data, target = data.cuda(), target.cuda()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# Compute accuracy on CPU to save GPU memory
pred = output.argmax(dim=1).cpu()
target_cpu = target.cpu()
accuracy = (pred == target_cpu).float().mean()Strategy 3: In-Place Operations
Section titled “Strategy 3: In-Place Operations”Reduce memory by modifying tensors in-place:
# Instead of:
x = x + 1
x = torch.relu(x)
# Use in-place:
x += 1 # or x.add_(1)
x = torch.relu_(x) # or nn.ReLU(inplace=True)
# In models:
self.relu = nn.ReLU(inplace=True)Memory Profiling
Section titled “Memory Profiling”PyTorch Memory Profiler
Section titled “PyTorch Memory Profiler”from torch.profiler import profile, ProfilerActivity
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
profile_memory=True,
record_shapes=True
) as prof:
# Your training code
model(data)
loss.backward()
# Print memory stats
print(prof.key_averages().table(
sort_by="cuda_memory_usage",
row_limit=10
))
# Export for visualization
prof.export_chrome_trace("trace.json")Memory Snapshot
Section titled “Memory Snapshot”import torch
# Start recording memory history
torch.cuda.memory._record_memory_history()
try:
# Run your code
model(data)
loss.backward()
except RuntimeError as e:
if 'out of memory' in str(e):
# Dump memory snapshot for analysis
torch.cuda.memory._dump_snapshot("oom_snapshot.pickle")
finally:
torch.cuda.memory._record_memory_history(enabled=None)Best Practices Checklist
Section titled “Best Practices Checklist”-
Start with these settings:
torch.backends.cudnn.benchmark = True # Faster training torch.backends.cuda.matmul.allow_tf32 = True # Faster matmul -
Always use mixed precision (easy 2x improvement)
-
Monitor memory throughout training (catch leaks early)
-
Use gradient checkpointing for very deep models
-
Clear cache between experiments:
torch.cuda.empty_cache() -
Test batch size before long training runs
-
Use
torch.no_grad()during validation/inference:with torch.no_grad(): val_loss = validate(model, val_loader)
Common OOM Causes & Fixes
Section titled “Common OOM Causes & Fixes”| Problem | Solution |
|---|---|
| OOM during first forward pass | Reduce batch size or model size |
| OOM during backward pass | Use gradient checkpointing |
| OOM increases over time | Memory leak - check for tensors kept in lists |
| OOM only on large inputs | Implement dynamic batching by input size |
| OOM after many epochs | Clear unused cache, check for accumulating metrics |
Quick Reference: Memory Reduction Techniques
Section titled “Quick Reference: Memory Reduction Techniques”Ordered by effectiveness vs effort:
- Mixed Precision Training - 2x memory, 5 lines of code ⭐
- Reduce Batch Size - Variable memory, 1 line of code
- Gradient Accumulation - Simulate larger batches, 10 lines of code
- Gradient Checkpointing - 30-50% memory, 10-20% slower
- Model Optimization - Variable, requires architecture changes
- Multi-GPU Training - Linear scaling, requires multiple GPUs