TransformerLens: Mechanistic Interpretability for Transformers
TransformerLens is the de facto standard library for mechanistic interpretability research on GPT-style language models. Created by Neel Nanda and maintained by Bryce Meyer, it provides clean interfaces to inspect and manipulate model internals via HookPoints on every activation.
GitHub: TransformerLensOrg/TransformerLens (2,900+ stars)
When to Use TransformerLens
Use TransformerLens when you need to:
Reverse-engineer algorithms learned during training Perform activation patching / causal tracing experiments Study attention patterns and information flow Analyze circuits (e.g., induction heads, IOI circuit) Cache and inspect intermediate activations Apply direct logit attribution
Consider alternatives when:
You need to work with non-transformer architectures → Use nnsight or pyvene You want to train/analyze Sparse Autoencoders → Use SAELens You need remote execution on massive models → Use nnsight with NDIF You want higher-level causal intervention abstractions → Use pyvene Installation pip install transformer-lens
For development version:
pip install git+https://github.com/TransformerLensOrg/TransformerLens
Core Concepts HookedTransformer
The main class that wraps transformer models with HookPoints on every activation:
from transformer_lens import HookedTransformer
Load a model
model = HookedTransformer.from_pretrained("gpt2-small")
For gated models (LLaMA, Mistral)
import os os.environ["HF_TOKEN"] = "your_token" model = HookedTransformer.from_pretrained("meta-llama/Llama-2-7b-hf")
Supported Models (50+) Family Models GPT-2 gpt2, gpt2-medium, gpt2-large, gpt2-xl LLaMA llama-7b, llama-13b, llama-2-7b, llama-2-13b EleutherAI pythia-70m to pythia-12b, gpt-neo, gpt-j-6b Mistral mistral-7b, mixtral-8x7b Others phi, qwen, opt, gemma Activation Caching
Run the model and cache all intermediate activations:
Get all activations
tokens = model.to_tokens("The Eiffel Tower is in") logits, cache = model.run_with_cache(tokens)
Access specific activations
residual = cache["resid_post", 5] # Layer 5 residual stream attn_pattern = cache["pattern", 3] # Layer 3 attention pattern mlp_out = cache["mlp_out", 7] # Layer 7 MLP output
Filter which activations to cache (saves memory)
logits, cache = model.run_with_cache( tokens, names_filter=lambda name: "resid_post" in name )
ActivationCache Keys Key Pattern Shape Description resid_pre, layer [batch, pos, d_model] Residual before attention resid_mid, layer [batch, pos, d_model] Residual after attention resid_post, layer [batch, pos, d_model] Residual after MLP attn_out, layer [batch, pos, d_model] Attention output mlp_out, layer [batch, pos, d_model] MLP output pattern, layer [batch, head, q_pos, k_pos] Attention pattern (post-softmax) q, layer [batch, pos, head, d_head] Query vectors k, layer [batch, pos, head, d_head] Key vectors v, layer [batch, pos, head, d_head] Value vectors Workflow 1: Activation Patching (Causal Tracing)
Identify which activations causally affect model output by patching clean activations into corrupted runs.
Step-by-Step from transformer_lens import HookedTransformer, patching import torch
model = HookedTransformer.from_pretrained("gpt2-small")
1. Define clean and corrupted prompts
clean_prompt = "The Eiffel Tower is in the city of" corrupted_prompt = "The Colosseum is in the city of"
clean_tokens = model.to_tokens(clean_prompt) corrupted_tokens = model.to_tokens(corrupted_prompt)
2. Get clean activations
_, clean_cache = model.run_with_cache(clean_tokens)
3. Define metric (e.g., logit difference)
paris_token = model.to_single_token(" Paris") rome_token = model.to_single_token(" Rome")
def metric(logits): return logits[0, -1, paris_token] - logits[0, -1, rome_token]
4. Patch each position and layer
results = torch.zeros(model.cfg.n_layers, clean_tokens.shape[1])
for layer in range(model.cfg.n_layers): for pos in range(clean_tokens.shape[1]): def patch_hook(activation, hook): activation[0, pos] = clean_cache[hook.name][0, pos] return activation
patched_logits = model.run_with_hooks(
corrupted_tokens,
fwd_hooks=[(f"blocks.{layer}.hook_resid_post", patch_hook)]
)
results[layer, pos] = metric(patched_logits)
5. Visualize results (layer x position heatmap)
Checklist Define clean and corrupted inputs that differ minimally Choose metric that captures behavior difference Cache clean activations Systematically patch each (layer, position) combination Visualize results as heatmap Identify causal hotspots Workflow 2: Circuit Analysis (Indirect Object Identification)
Replicate the IOI circuit discovery from "Interpretability in the Wild".
Step-by-Step from transformer_lens import HookedTransformer import torch
model = HookedTransformer.from_pretrained("gpt2-small")
IOI task: "When John and Mary went to the store, Mary gave a bottle to"
Model should predict "John" (indirect object)
prompt = "When John and Mary went to the store, Mary gave a bottle to" tokens = model.to_tokens(prompt)
1. Get baseline logits
logits, cache = model.run_with_cache(tokens)
john_token = model.to_single_token(" John") mary_token = model.to_single_token(" Mary")
2. Compute logit difference (IO - S)
logit_diff = logits[0, -1, john_token] - logits[0, -1, mary_token] print(f"Logit difference: {logit_diff.item():.3f}")
3. Direct logit attribution by head
def get_head_contribution(layer, head): # Project head output to logits head_out = cache["z", layer][0, :, head, :] # [pos, d_head] W_O = model.W_O[layer, head] # [d_head, d_model] W_U = model.W_U # [d_model, vocab]
# Head contribution to logits at final position
contribution = head_out[-1] @ W_O @ W_U
return contribution[john_token] - contribution[mary_token]
4. Map all heads
head_contributions = torch.zeros(model.cfg.n_layers, model.cfg.n_heads) for layer in range(model.cfg.n_layers): for head in range(model.cfg.n_heads): head_contributions[layer, head] = get_head_contribution(layer, head)
5. Identify top contributing heads (name movers, backup name movers)
Checklist Set up task with clear IO/S tokens Compute baseline logit difference Decompose by attention head contributions Identify key circuit components (name movers, S-inhibition, induction) Validate with ablation experiments Workflow 3: Induction Head Detection
Find induction heads that implement [A][B]...[A] → [B] pattern.
from transformer_lens import HookedTransformer import torch
model = HookedTransformer.from_pretrained("gpt2-small")
Create repeated sequence: [A][B][A] should predict [B]
repeated_tokens = torch.tensor([[1000, 2000, 1000]]) # Arbitrary tokens
_, cache = model.run_with_cache(repeated_tokens)
Induction heads attend from final [A] back to first [B]
Check attention from position 2 to position 1
induction_scores = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
for layer in range(model.cfg.n_layers): pattern = cache["pattern", layer][0] # [head, q_pos, k_pos] # Attention from pos 2 to pos 1 induction_scores[layer] = pattern[:, 2, 1]
Heads with high scores are induction heads
top_heads = torch.topk(induction_scores.flatten(), k=5)
Common Issues & Solutions Issue: Hooks persist after debugging
WRONG: Old hooks remain active
model.run_with_hooks(tokens, fwd_hooks=[...]) # Debug, add new hooks model.run_with_hooks(tokens, fwd_hooks=[...]) # Old hooks still there!
RIGHT: Always reset hooks
model.reset_hooks() model.run_with_hooks(tokens, fwd_hooks=[...])
Issue: Tokenization gotchas
WRONG: Assuming consistent tokenization
model.to_tokens("Tim") # Single token model.to_tokens("Neel") # Becomes "Ne" + "el" (two tokens!)
RIGHT: Check tokenization explicitly
tokens = model.to_tokens("Neel", prepend_bos=False) print(model.to_str_tokens(tokens)) # ['Ne', 'el']
Issue: LayerNorm ignored in analysis
WRONG: Ignoring LayerNorm
pre_activation = residual @ model.W_in[layer]
RIGHT: Include LayerNorm
ln_scale = model.blocks[layer].ln2.w ln_out = model.blocks[layer].ln2(residual) pre_activation = ln_out @ model.W_in[layer]
Issue: Memory explosion with large models
Use selective caching
logits, cache = model.run_with_cache( tokens, names_filter=lambda n: "resid_post" in n or "pattern" in n, device="cpu" # Cache on CPU )
Key Classes Reference Class Purpose HookedTransformer Main model wrapper with hooks ActivationCache Dictionary-like cache of activations HookedTransformerConfig Model configuration FactoredMatrix Efficient factored matrix operations Integration with SAELens
TransformerLens integrates with SAELens for Sparse Autoencoder analysis:
from transformer_lens import HookedTransformer from sae_lens import SAE
model = HookedTransformer.from_pretrained("gpt2-small") sae = SAE.from_pretrained("gpt2-small-res-jb", "blocks.8.hook_resid_pre")
Run with SAE
tokens = model.to_tokens("Hello world") _, cache = model.run_with_cache(tokens) sae_acts = sae.encode(cache["resid_pre", 8])
Reference Documentation
For detailed API documentation, tutorials, and advanced usage, see the references/ folder:
File Contents references/README.md Overview and quick start guide references/api.md Complete API reference for HookedTransformer, ActivationCache, HookPoints references/tutorials.md Step-by-step tutorials for activation patching, circuit analysis, logit lens External Resources Tutorials Main Demo Notebook Activation Patching Demo ARENA Mech Interp Course - 200+ hours of tutorials Papers A Mathematical Framework for Transformer Circuits In-context Learning and Induction Heads Interpretability in the Wild (IOI) Official Documentation Official Docs Model Properties Table Neel Nanda's Glossary Version Notes v2.0: Removed HookedSAE (moved to SAELens) v3.0 (alpha): TransformerBridge for loading any nn.Module