(why DeviceMesh exists & how it manages process groups).
2) Build model on meta device (recommended for very large models)
For big models, initialize on
meta
, apply sharding, then materialize weights on GPU:
with torch.device("meta"): model = ...
apply
fully_shard(...)
on submodules, then
fully_shard(model)
model.to_empty(device="cuda")
model.reset_parameters()
(or your init routine)
Reference:
references/pytorch_fsdp2_tutorial.md
(migration guide shows this flow explicitly).
3) Apply
fully_shard()
bottom-up (wrapping policy = “apply where needed”)
Do not
only call
fully_shard
on the topmost module.
Recommended sharding pattern for transformer-like models:
iterate modules,
if isinstance(m, TransformerBlock): fully_shard(m, ...)
then
fully_shard(model, ...)
Why:
fully_shard
forms “parameter groups” for collective efficiency and excludes params already grouped by earlier calls. Bottom-up gives better overlap and lower peak memory.
Reference:
references/pytorch_fully_shard_api.md
(bottom-up requirement and why).
4) Configure
reshard_after_forward
for memory/perf trade-offs
Default behavior:
None
means
True
for non-root modules and
False
for root modules (good default).
Heuristics:
If you’re memory-bound: keep defaults or force
True
on many blocks.
If you’re throughput-bound and can afford memory: consider keeping unsharded params longer (root often
False
).
Advanced: use an
int
to reshard to a smaller mesh after forward (e.g., intra-node) if it’s a meaningful divisor.
Reference:
references/pytorch_fully_shard_api.md
(full semantics).
5) Mixed precision & offload (optional but common)
The coding agent should implement a script with these labeled blocks:
init_distributed()
init process group, set device
build_model_meta()
model on meta, apply
fully_shard
, materialize weights
build_optimizer()
optimizer created after sharding
train_step()
forward/backward/step with
model(inputs)
and DTensor-aware patterns
checkpoint_save/load()
DCP or distributed state dict helpers
Concrete examples live in
references/pytorch_examples_fsdp2.md
and the official tutorial reference.
References
references/pytorch_fsdp2_tutorial.md
references/pytorch_fully_shard_api.md
references/pytorch_ddp_notes.md
references/pytorch_fsdp1_api.md
references/pytorch_device_mesh_tutorial.md
references/pytorch_tp_tutorial.md
references/pytorch_dcp_overview.md
references/pytorch_dcp_recipe.md
references/pytorch_dcp_async_recipe.md
references/pytorch_examples_fsdp2.md
references/torchtitan_fsdp_notes.md
(optional, production notes)
references/ray_train_fsdp2_example.md
(optional, integration example)