speculative-decoding

安装量: 175
排名: #4940

安装

npx skills add https://github.com/davila7/claude-code-templates --skill speculative-decoding

Speculative Decoding: Accelerating LLM Inference When to Use This Skill

Use Speculative Decoding when you need to:

Speed up inference by 1.5-3.6× without quality loss Reduce latency for real-time applications (chatbots, code generation) Optimize throughput for high-volume serving Deploy efficiently on limited hardware Generate faster without changing model architecture

Key Techniques: Draft model speculative decoding, Medusa (multiple heads), Lookahead Decoding (Jacobi iteration)

Papers: Medusa (arXiv 2401.10774), Lookahead Decoding (ICML 2024), Speculative Decoding Survey (ACL 2024)

Installation

Standard speculative decoding (transformers)

pip install transformers accelerate

Medusa (multiple decoding heads)

git clone https://github.com/FasterDecoding/Medusa cd Medusa pip install -e .

Lookahead Decoding

git clone https://github.com/hao-ai-lab/LookaheadDecoding cd LookaheadDecoding pip install -e .

Optional: vLLM with speculative decoding

pip install vllm

Quick Start Basic Speculative Decoding (Draft Model) from transformers import AutoModelForCausalLM, AutoTokenizer

Load target model (large, slow)

target_model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-2-70b-hf", device_map="auto", torch_dtype=torch.float16 )

Load draft model (small, fast)

draft_model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16 )

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-70b-hf")

Generate with speculative decoding

prompt = "Explain quantum computing in simple terms:" inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

Transformers 4.36+ supports assisted generation

outputs = target_model.generate( **inputs, assistant_model=draft_model, # Enable speculative decoding max_new_tokens=256, do_sample=True, temperature=0.7, )

response = tokenizer.decode(outputs[0], skip_special_tokens=True) print(response)

Medusa (Multiple Decoding Heads) from medusa.model.medusa_model import MedusaModel

Load Medusa-enhanced model

model = MedusaModel.from_pretrained( "FasterDecoding/medusa-vicuna-7b-v1.3", # Pre-trained with Medusa heads torch_dtype=torch.float16, device_map="auto" )

tokenizer = AutoTokenizer.from_pretrained("FasterDecoding/medusa-vicuna-7b-v1.3")

Generate with Medusa (2-3× speedup)

prompt = "Write a Python function to calculate fibonacci numbers:" inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

outputs = model.medusa_generate( **inputs, max_new_tokens=256, temperature=0.7, posterior_threshold=0.09, # Acceptance threshold posterior_alpha=0.3, # Tree construction parameter )

response = tokenizer.decode(outputs[0], skip_special_tokens=True)

Lookahead Decoding (Jacobi Iteration) from lookahead.lookahead_decoding import LookaheadDecoding

Load model

model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16, device_map="auto" ) tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

Initialize lookahead decoding

lookahead = LookaheadDecoding( model=model, tokenizer=tokenizer, window_size=15, # Lookahead window (W) ngram_size=5, # N-gram size (N) guess_size=5 # Number of parallel guesses )

Generate (1.5-2.3× speedup)

prompt = "Implement quicksort in Python:" output = lookahead.generate(prompt, max_new_tokens=256) print(output)

Core Concepts 1. Speculative Decoding (Draft Model)

Idea: Use small draft model to generate candidates, large target model to verify in parallel.

Algorithm:

Draft model generates K tokens speculatively Target model evaluates all K tokens in parallel (single forward pass) Accept tokens where draft and target agree Reject first disagreement, continue from there def speculative_decode(target_model, draft_model, prompt, K=4): """Speculative decoding algorithm.""" # 1. Generate K draft tokens draft_tokens = draft_model.generate(prompt, max_new_tokens=K)

# 2. Target model evaluates all K tokens in one forward pass
target_logits = target_model(draft_tokens)  # Parallel!

# 3. Accept/reject based on probability match
accepted = []
for i in range(K):
    p_draft = softmax(draft_model.logits[i])
    p_target = softmax(target_logits[i])

    # Acceptance probability
    if random.random() < min(1, p_target[draft_tokens[i]] / p_draft[draft_tokens[i]]):
        accepted.append(draft_tokens[i])
    else:
        break  # Reject, resample from target

return accepted

Performance:

Speedup: 1.5-2× with good draft model Zero quality loss (mathematically equivalent to target model) Best when draft model is 5-10× smaller than target 2. Medusa (Multiple Decoding Heads)

Source: arXiv 2401.10774 (2024)

Innovation: Add multiple prediction heads to existing model, predict future tokens without separate draft model.

Architecture:

Input → Base LLM (frozen) → Hidden State ├→ Head 1 (predicts token t+1) ├→ Head 2 (predicts token t+2) ├→ Head 3 (predicts token t+3) └→ Head 4 (predicts token t+4)

Training:

Medusa-1: Freeze base LLM, train only heads 2.2× speedup, lossless Medusa-2: Fine-tune base LLM + heads together 2.3-3.6× speedup, better quality

Tree-based Attention:

Medusa constructs tree of candidates

Example: Predict 2 steps ahead with top-2 per step

Root

# / \

T1a T1b (Step 1: 2 candidates)

# / \ / \

T2a T2b T2c T2d (Step 2: 4 candidates total)

Single forward pass evaluates entire tree!

Advantages:

No separate draft model needed Minimal training (only heads) Compatible with any LLM 3. Lookahead Decoding (Jacobi Iteration)

Source: ICML 2024

Core idea: Reformulate autoregressive decoding as solving system of equations, solve in parallel using Jacobi iteration.

Mathematical formulation:

Traditional: y_t = f(x, y_1, ..., y_{t-1}) (sequential) Jacobi: y_t^{(k+1)} = f(x, y_1^{(k)}, ..., y_{t-1}^{(k)}) (parallel)

Two branches:

Lookahead Branch: Generate n-grams in parallel

Window size W: How many steps to look ahead N-gram size N: How many past tokens to use

Verification Branch: Verify promising n-grams

Match n-grams with generated tokens Accept if first token matches class LookaheadDecoding: def init(self, model, window_size=15, ngram_size=5): self.model = model self.W = window_size # Lookahead window self.N = ngram_size # N-gram size

def generate_step(self, tokens):
    # Lookahead branch: Generate W × N candidates
    candidates = {}
    for w in range(1, self.W + 1):
        for n in range(1, self.N + 1):
            # Generate n-gram starting at position w
            ngram = self.generate_ngram(tokens, start=w, length=n)
            candidates[(w, n)] = ngram

    # Verification branch: Find matching n-grams
    verified = []
    for ngram in candidates.values():
        if ngram[0] == tokens[-1]:  # First token matches last input
            if self.verify(tokens, ngram):
                verified.append(ngram)

    # Accept longest verified n-gram
    return max(verified, key=len) if verified else [self.model.generate_next(tokens)]

Performance:

Speedup: 1.5-2.3× (up to 3.6× for code generation) No draft model or training needed Works out-of-the-box with any model Method Comparison Method Speedup Training Needed Draft Model Quality Loss Draft Model Speculative 1.5-2× No Yes (external) None Medusa 2-3.6× Minimal (heads only) No (built-in heads) None Lookahead 1.5-2.3× None No None Naive Batching 1.2-1.5× No No None Advanced Patterns Training Medusa Heads from medusa.model.medusa_model import MedusaModel from medusa.model.kv_cache import initialize_past_key_values import torch.nn as nn

1. Load base model

base_model = AutoModelForCausalLM.from_pretrained( "lmsys/vicuna-7b-v1.3", torch_dtype=torch.float16 )

2. Add Medusa heads

num_heads = 4 medusa_heads = nn.ModuleList([ nn.Linear(base_model.config.hidden_size, base_model.config.vocab_size, bias=False) for _ in range(num_heads) ])

3. Training loop (freeze base model for Medusa-1)

for param in base_model.parameters(): param.requires_grad = False # Freeze base

optimizer = torch.optim.Adam(medusa_heads.parameters(), lr=1e-3)

for batch in dataloader: # Forward pass hidden_states = base_model(**batch, output_hidden_states=True).hidden_states[-1]

# Predict future tokens with each head
loss = 0
for i, head in enumerate(medusa_heads):
    logits = head(hidden_states)
    # Target: tokens shifted by (i+1) positions
    target = batch['input_ids'][:, i+1:]
    loss += F.cross_entropy(logits[:, :-i-1], target)

# Backward
optimizer.zero_grad()
loss.backward()
optimizer.step()

Hybrid: Speculative + Medusa

Use Medusa as draft model for speculative decoding

draft_medusa = MedusaModel.from_pretrained("medusa-vicuna-7b") target_model = AutoModelForCausalLM.from_pretrained("vicuna-33b")

Draft generates multiple candidates with Medusa

draft_tokens = draft_medusa.medusa_generate(prompt, max_new_tokens=5)

Target verifies in single forward pass

outputs = target_model.generate( prompt, assistant_model=draft_medusa, # Use Medusa as draft max_new_tokens=256 )

Combines benefits: Medusa speed + large model quality

Optimal Draft Model Selection def select_draft_model(target_model_size, target): """Select optimal draft model for speculative decoding.""" # Rule: Draft should be 5-10× smaller if target_model_size == "70B": return "7B" # 10× smaller elif target_model_size == "33B": return "7B" # 5× smaller elif target_model_size == "13B": return "1B" # 13× smaller else: return None # Target too small, use Medusa/Lookahead instead

Example

draft = select_draft_model("70B", target_model)

Returns "7B" → Use Llama-2-7b as draft for Llama-2-70b

Best Practices 1. Choose the Right Method

New deployment → Medusa (best overall speedup, no draft model)

if deploying_new_model: use_method = "Medusa"

Existing deployment with small model available → Draft speculative

elif have_small_version_of_model: use_method = "Draft Model Speculative"

Want zero training/setup → Lookahead

elif want_plug_and_play: use_method = "Lookahead Decoding"

  1. Hyperparameter Tuning

Draft Model Speculative:

K = number of speculative tokens

K = 4 # Good default K = 2 # Conservative (higher acceptance) K = 8 # Aggressive (lower acceptance, but more when accepted)

Rule: Larger K → more speedup IF draft model is good

Medusa:

Posterior threshold (acceptance confidence)

posterior_threshold = 0.09 # Standard (from paper) posterior_threshold = 0.05 # More conservative (slower, higher quality) posterior_threshold = 0.15 # More aggressive (faster, may degrade quality)

Tree depth (how many steps ahead)

medusa_choices = [[0], [0, 0], [0, 1], [0, 0, 0]] # Depth 3 (standard)

Lookahead:

Window size W (lookahead distance)

N-gram size N (context for generation)

7B model (more resources)

W, N = 15, 5

13B model (moderate)

W, N = 10, 5

33B+ model (limited resources)

W, N = 7, 5

  1. Production Deployment

vLLM with speculative decoding

from vllm import LLM, SamplingParams

Initialize with draft model

llm = LLM( model="meta-llama/Llama-2-70b-hf", speculative_model="meta-llama/Llama-2-7b-hf", # Draft model num_speculative_tokens=5, use_v2_block_manager=True, )

Generate

prompts = ["Tell me about AI:", "Explain quantum physics:"] sampling_params = SamplingParams(temperature=0.7, max_tokens=256)

outputs = llm.generate(prompts, sampling_params) for output in outputs: print(output.outputs[0].text)

Resources Medusa Paper: https://arxiv.org/abs/2401.10774 Medusa GitHub: https://github.com/FasterDecoding/Medusa Lookahead Decoding (ICML 2024): https://lmsys.org/blog/2023-11-21-lookahead-decoding/ Lookahead GitHub: https://github.com/hao-ai-lab/LookaheadDecoding Speculative Decoding Survey (ACL 2024): https://aclanthology.org/2024.findings-acl.456.pdf Comprehensive Survey: https://arxiv.org/abs/2401.07851 See Also references/draft_model.md - Draft model selection and training references/medusa.md - Medusa architecture and training references/lookahead.md - Lookahead decoding implementation details

返回排行榜