Ray Train - Distributed Training Orchestration Quick start Ray Train scales machine learning training from single GPU to multi-node clusters with minimal code changes. Installation : pip install -U "ray[train]" Basic PyTorch training (single node): import ray from ray import train from ray . train import ScalingConfig from ray . train . torch import TorchTrainer import torch import torch . nn as nn
Define training function
def train_func ( config ) :
Your normal PyTorch code
model
nn . Linear ( 10 , 1 ) optimizer = torch . optim . SGD ( model . parameters ( ) , lr = 0.01 )
Prepare for distributed (Ray handles device placement)
model
train . torch . prepare_model ( model ) for epoch in range ( 10 ) :
Your training loop
output
model ( torch . randn ( 32 , 10 ) ) loss = output . sum ( ) loss . backward ( ) optimizer . step ( ) optimizer . zero_grad ( )
Report metrics (logged automatically)
train . report ( { "loss" : loss . item ( ) , "epoch" : epoch } )
Run distributed training
trainer
TorchTrainer ( train_func , scaling_config = ScalingConfig ( num_workers = 4 ,
4 GPUs/workers
use_gpu
True ) ) result = trainer . fit ( ) print ( f"Final loss: { result . metrics [ 'loss' ] } " ) That's it! Ray handles: Distributed coordination GPU allocation Fault tolerance Checkpointing Metric aggregation Common workflows Workflow 1: Scale existing PyTorch code Original single-GPU code : model = MyModel ( ) . cuda ( ) optimizer = torch . optim . Adam ( model . parameters ( ) ) for epoch in range ( epochs ) : for batch in dataloader : loss = model ( batch ) loss . backward ( ) optimizer . step ( ) Ray Train version (scales to multi-GPU/multi-node): from ray . train . torch import TorchTrainer from ray import train def train_func ( config ) : model = MyModel ( ) optimizer = torch . optim . Adam ( model . parameters ( ) )
Prepare for distributed (automatic device placement)
model
train . torch . prepare_model ( model ) dataloader = train . torch . prepare_data_loader ( dataloader ) for epoch in range ( epochs ) : for batch in dataloader : loss = model ( batch ) loss . backward ( ) optimizer . step ( )
Report metrics
train . report ( { "loss" : loss . item ( ) } )
Scale to 8 GPUs
trainer
- TorchTrainer
- (
- train_func
- ,
- scaling_config
- =
- ScalingConfig
- (
- num_workers
- =
- 8
- ,
- use_gpu
- =
- True
- )
- )
- trainer
- .
- fit
- (
- )
- Benefits
- Same code runs on 1 GPU or 1000 GPUs Workflow 2: HuggingFace Transformers integration from ray . train . huggingface import TransformersTrainer from transformers import AutoModelForCausalLM , AutoTokenizer , TrainingArguments def train_func ( config ) :
Load model and tokenizer
model
AutoModelForCausalLM . from_pretrained ( "gpt2" ) tokenizer = AutoTokenizer . from_pretrained ( "gpt2" )
Training arguments (HuggingFace API)
training_args
TrainingArguments ( output_dir = "./output" , num_train_epochs = 3 , per_device_train_batch_size = 8 , learning_rate = 2e-5 , )
Ray automatically handles distributed training
from transformers import Trainer trainer = Trainer ( model = model , args = training_args , train_dataset = train_dataset , ) trainer . train ( )
Scale to multi-node (2 nodes × 8 GPUs = 16 workers)
trainer
TransformersTrainer ( train_func , scaling_config = ScalingConfig ( num_workers = 16 , use_gpu = True , resources_per_worker = { "GPU" : 1 } ) ) result = trainer . fit ( ) Workflow 3: Hyperparameter tuning with Ray Tune from ray import tune from ray . train . torch import TorchTrainer from ray . tune . schedulers import ASHAScheduler def train_func ( config ) :
Use hyperparameters from config
lr
config [ "lr" ] batch_size = config [ "batch_size" ] model = MyModel ( ) optimizer = torch . optim . Adam ( model . parameters ( ) , lr = lr ) model = train . torch . prepare_model ( model ) for epoch in range ( 10 ) :
Training loop
loss
train_epoch ( model , optimizer , batch_size ) train . report ( { "loss" : loss , "epoch" : epoch } )
Define search space
param_space
{ "lr" : tune . loguniform ( 1e-5 , 1e-2 ) , "batch_size" : tune . choice ( [ 16 , 32 , 64 , 128 ] ) }
Run 20 trials with early stopping
tuner
- tune
- .
- Tuner
- (
- TorchTrainer
- (
- train_func
- ,
- scaling_config
- =
- ScalingConfig
- (
- num_workers
- =
- 4
- ,
- use_gpu
- =
- True
- )
- )
- ,
- param_space
- =
- param_space
- ,
- tune_config
- =
- tune
- .
- TuneConfig
- (
- num_samples
- =
- 20
- ,
- scheduler
- =
- ASHAScheduler
- (
- metric
- =
- "loss"
- ,
- mode
- =
- "min"
- )
- )
- )
- results
- =
- tuner
- .
- fit
- (
- )
- best
- =
- results
- .
- get_best_result
- (
- metric
- =
- "loss"
- ,
- mode
- =
- "min"
- )
- (
- f"Best hyperparameters:
- {
- best
- .
- config
- }
- "
- )
- Result
- Distributed hyperparameter search across cluster Workflow 4: Checkpointing and fault tolerance from ray import train from ray . train import Checkpoint def train_func ( config ) : model = MyModel ( ) optimizer = torch . optim . Adam ( model . parameters ( ) )
Try to resume from checkpoint
checkpoint
train . get_checkpoint ( ) if checkpoint : with checkpoint . as_directory ( ) as checkpoint_dir : state = torch . load ( f" { checkpoint_dir } /model.pt" ) model . load_state_dict ( state [ "model" ] ) optimizer . load_state_dict ( state [ "optimizer" ] ) start_epoch = state [ "epoch" ] else : start_epoch = 0 model = train . torch . prepare_model ( model ) for epoch in range ( start_epoch , 100 ) : loss = train_epoch ( model , optimizer )
Save checkpoint every 10 epochs
if epoch % 10 == 0 : checkpoint = Checkpoint . from_directory ( train . get_context ( ) . get_trial_dir ( ) ) torch . save ( { "model" : model . state_dict ( ) , "optimizer" : optimizer . state_dict ( ) , "epoch" : epoch } , checkpoint . path / "model.pt" ) train . report ( { "loss" : loss } , checkpoint = checkpoint ) trainer = TorchTrainer ( train_func , scaling_config = ScalingConfig ( num_workers = 8 , use_gpu = True ) )
Automatically resumes from checkpoint if training fails
result
trainer . fit ( ) Workflow 5: Multi-node training from ray . train import ScalingConfig
Connect to Ray cluster
ray . init ( address = "auto" )
Or ray.init("ray://head-node:10001")
Train across 4 nodes × 8 GPUs = 32 workers
trainer
TorchTrainer ( train_func , scaling_config = ScalingConfig ( num_workers = 32 , use_gpu = True , resources_per_worker = { "GPU" : 1 , "CPU" : 4 } , placement_strategy = "SPREAD"
Spread across nodes
) ) result = trainer . fit ( ) Launch Ray cluster :
On head node
ray start --head --port = 6379
On worker nodes
- ray start
- --address
- =
- <
- head-node-ip
- >
- :6379
- When to use vs alternatives
- Use Ray Train when
- :
- Training across multiple machines (multi-node)
- Need hyperparameter tuning at scale
- Want fault tolerance (auto-restart failed workers)
- Elastic scaling (add/remove nodes during training)
- Unified framework (same code for PyTorch/TF/HF)
- Key advantages
- :
- Multi-node orchestration
-
- Easiest multi-node setup
- Ray Tune integration
-
- Best-in-class hyperparameter tuning
- Fault tolerance
-
- Automatic recovery from failures
- Elastic
-
- Add/remove nodes without restarting
- Framework agnostic
-
- PyTorch, TensorFlow, HuggingFace, XGBoost
- Use alternatives instead
- :
- Accelerate
-
- Single-node multi-GPU, simpler
- PyTorch Lightning
-
- High-level abstractions, callbacks
- DeepSpeed
-
- Maximum performance, complex setup
- Raw DDP
- Maximum control, minimal overhead Common issues Issue: Ray cluster not connecting Check ray status: ray status
Should show:
- Nodes: 4
- GPUs: 32
- Workers: Ready
If not connected:
Restart head node
ray stop ray start --head --port = 6379 --dashboard-host = 0.0 .0.0
Restart worker nodes
ray stop ray start --address = < head-ip
:6379 Issue: Out of memory Reduce workers or use gradient accumulation: scaling_config = ScalingConfig ( num_workers = 4 ,
Reduce from 8
use_gpu
True )
In train_func, accumulate gradients
for i , batch in enumerate ( dataloader ) : loss = model ( batch ) / accumulation_steps loss . backward ( ) if ( i + 1 ) % accumulation_steps == 0 : optimizer . step ( ) optimizer . zero_grad ( ) Issue: Slow training Check if data loading is bottleneck: import time def train_func ( config ) : for epoch in range ( epochs ) : start = time . time ( ) for batch in dataloader : data_time = time . time ( ) - start
Train...
start
- time
- .
- time
- (
- )
- (
- f"Data loading:
- {
- data_time
- :
- .3f
- }
- s"
- )
- If data loading is slow, increase workers:
- dataloader
- =
- DataLoader
- (
- dataset
- ,
- num_workers
- =
- 8
- )
- Advanced topics
- Multi-node setup
-
- See
- references/multi-node.md
- for Ray cluster deployment on AWS, GCP, Kubernetes, and SLURM.
- Hyperparameter tuning
-
- See
- references/hyperparameter-tuning.md
- for Ray Tune integration, search algorithms (Optuna, HyperOpt), and population-based training.
- Custom training loops
-
- See
- references/custom-loops.md
- for advanced Ray Train usage, custom backends, and integration with other frameworks.
- Hardware requirements
- Single node
-
- 1+ GPUs (or CPUs)
- Multi-node
-
- 2+ machines with network connectivity
- Cloud
-
- AWS, GCP, Azure (Ray autoscaling)
- On-prem
- Kubernetes, SLURM clusters Supported accelerators : NVIDIA GPUs (CUDA) AMD GPUs (ROCm) TPUs (Google Cloud) CPUs Resources Docs: https://docs.ray.io/en/latest/train/train.html GitHub: https://github.com/ray-project/ray ⭐ 36,000+ Version: 2.40.0+ Examples: https://docs.ray.io/en/latest/train/examples.html Slack: https://forms.gle/9TSdDYUgxYs8SA9e8 Used by: OpenAI, Uber, Spotify, Shopify, Instacart