Skip to content

Training Utilities

Essential utilities for managing deep learning training:

  • Kill stuck or runaway processes
  • Resume interrupted training
  • Monitor training progress
  • Manage distributed training
  • Clean up GPU memory

kill_gpu_process.sh - Kill processes on specific GPU
#!/bin/bash
# Kill all processes running on GPU 0

GPU_ID=0

# Get PIDs of processes on GPU
PIDS=$(nvidia-smi --query-compute-apps=pid --format=csv,noheader -i $GPU_ID)

if [ -z "$PIDS" ]; then
    echo "No processes running on GPU $GPU_ID"
else
    echo "Killing processes on GPU $GPU_ID: $PIDS"
    for PID in $PIDS; do
        kill -9 $PID
    done
    echo "Done"
fi
cleanup_zombies.sh
#!/bin/bash
# Find and kill zombie Python processes

# Find zombie processes
ps aux | grep -E 'python.*<defunct>' | awk '{print $2}' | while read pid; do
    echo "Killing zombie process: $pid"
    kill -9 $pid
done

# Force clear GPU memory
nvidia-smi --gpu-reset

resume_training.py - Resume from checkpoint
#!/usr/bin/env python3
import torch
import os
from pathlib import Path

def find_latest_checkpoint(checkpoint_dir):
    """Find the most recent checkpoint"""
    checkpoints = list(Path(checkpoint_dir).glob('checkpoint_epoch_*.pt'))
    if not checkpoints:
        return None
    return max(checkpoints, key=os.path.getctime)

def resume_training(checkpoint_path, model, optimizer, scheduler=None):
    """Resume training from checkpoint"""
    print(f"Resuming from {checkpoint_path}")

    checkpoint = torch.load(checkpoint_path)

    # Load model state
    model.load_state_dict(checkpoint['model_state_dict'])

    # Load optimizer state
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    # Load scheduler if available
    if scheduler and 'scheduler_state_dict' in checkpoint:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

    # Get training state
    start_epoch = checkpoint['epoch'] + 1
    best_loss = checkpoint.get('best_loss', float('inf'))

    print(f"Resuming from epoch {start_epoch}, best loss: {best_loss:.4f}")

    return start_epoch, best_loss

# Usage in training script
checkpoint_dir = './checkpoints'
latest_ckpt = find_latest_checkpoint(checkpoint_dir)

if latest_ckpt:
    start_epoch, best_loss = resume_training(
        latest_ckpt, model, optimizer, scheduler
    )
else:
    start_epoch = 0
    best_loss = float('inf')

# Continue training from start_epoch
for epoch in range(start_epoch, num_epochs):
    # Training loop
    pass
auto_resume_training.sh - Automatically resume if interrupted
#!/bin/bash
# Automatically resume training if it crashes or is interrupted

SCRIPT="train.py"
CHECKPOINT_DIR="./checkpoints"
MAX_RETRIES=3
RETRY_COUNT=0

while [ $RETRY_COUNT -lt $MAX_RETRIES ]; do
    echo "Starting training (attempt $((RETRY_COUNT + 1))/$MAX_RETRIES)..."

    # Run training script
    python $SCRIPT --checkpoint_dir "$CHECKPOINT_DIR" --resume

    EXIT_CODE=$?

    if [ $EXIT_CODE -eq 0 ]; then
        echo "Training completed successfully!"
        exit 0
    else
        echo "Training failed with exit code $EXIT_CODE"
        RETRY_COUNT=$((RETRY_COUNT + 1))

        if [ $RETRY_COUNT -lt $MAX_RETRIES ]; then
            echo "Waiting 10 seconds before retry..."
            sleep 10
        fi
    fi
done

echo "Training failed after $MAX_RETRIES attempts"
exit 1

watch_logs.sh - Monitor training progress
#!/bin/bash
# Watch training logs with live updates

LOG_FILE="training.log"

# Follow log file
tail -f "$LOG_FILE" | grep --line-buffered -E "Epoch|Loss|Accuracy"

# Or with color highlighting
tail -f "$LOG_FILE" | \
    grep --line-buffered --color=always -E "Epoch.*|Loss.*|$"
plot_training.py - Plot loss from logs
#!/usr/bin/env python3
import re
import matplotlib.pyplot as plt

def parse_training_log(log_file):
    """Extract epoch, loss, accuracy from log file"""
    epochs = []
    train_losses = []
    val_losses = []

    with open(log_file) as f:
        for line in f:
            # Example: Epoch 10 - Train Loss: 0.234, Val Loss: 0.189
            match = re.search(r'Epoch (\d+).*Train Loss: ([\d.]+).*Val Loss: ([\d.]+)', line)
            if match:
                epochs.append(int(match.group(1)))
                train_losses.append(float(match.group(2)))
                val_losses.append(float(match.group(3)))

    return epochs, train_losses, val_losses

def plot_training_curves(log_file, output='training_curves.png'):
    """Plot training curves from log file"""
    epochs, train_losses, val_losses = parse_training_log(log_file)

    plt.figure(figsize=(10, 6))
    plt.plot(epochs, train_losses, label='Train Loss', linewidth=2)
    plt.plot(epochs, val_losses, label='Val Loss', linewidth=2)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Progress')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig(output, dpi=150, bbox_inches='tight')
    print(f"Saved plot to {output}")

if __name__ == "__main__":
    plot_training_curves('training.log')

launch_distributed.sh - PyTorch DistributedDataParallel
#!/bin/bash
# Launch distributed training on multiple GPUs

# Single node, multiple GPUs
python -m torch.distributed.launch \
    --nproc_per_node=4 \
    --master_port=29500 \
    train.py \
    --distributed

# Or using torchrun (recommended for PyTorch 1.10+)
torchrun \
    --standalone \
    --nnodes=1 \
    --nproc_per_node=4 \
    train.py
check_distributed.sh
#!/bin/bash
# Check status of distributed training

echo "=== GPU Processes Per Node ==="
nvidia-smi --query-compute-apps=pid,process_name,used_memory --format=csv

echo ""
echo "=== Network Connections (DDP) ==="
netstat -tuln | grep 29500

echo ""
echo "=== Process Tree ==="
ps auxf | grep -E 'python|train'

Limit memory per process
import torch

# Limit to 8GB
torch.cuda.set_per_process_memory_fraction(0.5, device=0)  # 50% of GPU 0

# Or set max memory
torch.cuda.set_per_process_memory_fraction(
    8 * 1024**3 / torch.cuda.get_device_properties(0).total_memory,
    device=0
)
job_queue.sh - Run multiple training jobs sequentially
#!/bin/bash
# Queue training jobs to run one after another

JOBS=(
    "python train.py --model resnet50 --epochs 100"
    "python train.py --model efficientnet --epochs 100"
    "python train.py --model vit --epochs 100"
)

for JOB in "${JOBS[@]}"; do
    echo "Running: $JOB"
    eval $JOB

    if [ $? -ne 0 ]; then
        echo "Job failed: $JOB"
        exit 1
    fi

    # Wait 10 seconds between jobs
    sleep 10
done

echo "All jobs completed!"

detect_deadlock.py - Find training deadlocks
#!/usr/bin/env python3
import torch
import signal
import sys

def timeout_handler(signum, frame):
    """Print stack trace on timeout"""
    print("\n=== TIMEOUT - Stack Trace ===")
    import traceback
    traceback.print_stack(frame)
    sys.exit(1)

# Set timeout (e.g., if no progress for 60 seconds)
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(60)  # 60 second timeout

# In training loop, reset alarm after each step
for batch in dataloader:
    signal.alarm(60)  # Reset timeout
    # Training code
    output = model(batch)
    loss.backward()
    optimizer.step()
    signal.alarm(0)  # Cancel alarm
find_memory_leak.py
#!/usr/bin/env python3
import torch
import gc

def find_tensor_leaks():
    """Find tensors that aren't being freed"""
    tensors = []
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) and obj.is_cuda:
                tensors.append((type(obj).__name__, tuple(obj.size()), obj.element_size() * obj.nelement()))
        except:
            pass

    # Group by size and count
    from collections import Counter
    sizes = Counter([t[1] for t in tensors])

    print("=== Tensor Sizes in Memory ===")
    for size, count in sizes.most_common(10):
        print(f"{size}: {count} tensors")

# Call periodically during training
find_tensor_leaks()

  1. Save Checkpoints Frequently - Every epoch or every N steps
  2. Test Resume Logic - Verify training can resume correctly
  3. Monitor Resource Usage - Watch GPU memory and utilization
  4. Graceful Shutdown - Handle SIGTERM/SIGINT for clean exits
  5. Log Everything - Detailed logs help debug issues