long-context

安装量: 258
排名: #3405

安装

npx skills add https://github.com/davila7/claude-code-templates --skill long-context

Long Context: Extending Transformer Context Windows When to Use This Skill

Use Long Context techniques when you need to:

Process long documents (32k, 64k, 128k+ tokens) with transformer models Extend context windows of pre-trained models (LLaMA, Mistral, etc.) Implement efficient positional encodings (RoPE, ALiBi) Train models with length extrapolation capabilities Deploy models that handle variable-length inputs efficiently Fine-tune existing models for longer contexts with minimal compute

Key Techniques: RoPE (Rotary Position Embeddings), YaRN, ALiBi (Attention with Linear Biases), Position Interpolation

Papers: RoFormer (arXiv 2104.09864), YaRN (arXiv 2309.00071), ALiBi (arXiv 2108.12409), Position Interpolation (arXiv 2306.15595)

Installation

HuggingFace Transformers (includes RoPE, YaRN support)

pip install transformers torch

For custom implementations

pip install einops # Tensor operations pip install rotary-embedding-torch # Standalone RoPE

Optional: FlashAttention for efficiency

pip install flash-attn --no-build-isolation

Quick Start RoPE (Rotary Position Embeddings) import torch import torch.nn as nn

class RotaryEmbedding(nn.Module): """Rotary Position Embeddings (RoPE)."""

def __init__(self, dim, max_seq_len=8192, base=10000):
    super().__init__()
    # Compute inverse frequencies
    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    self.register_buffer("inv_freq", inv_freq)
    self.max_seq_len = max_seq_len

def forward(self, seq_len, device):
    # Position indices
    t = torch.arange(seq_len, device=device).type_as(self.inv_freq)

    # Compute frequencies
    freqs = torch.outer(t, self.inv_freq)  # (seq_len, dim/2)

    # Compute sin and cos
    emb = torch.cat((freqs, freqs), dim=-1)  # (seq_len, dim)
    return emb.cos(), emb.sin()

def rotate_half(x): """Rotate half the hidden dimensions.""" x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin): """Apply rotary embeddings to queries and keys.""" # q, k shape: (batch, heads, seq_len, dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed

Usage

rope = RotaryEmbedding(dim=64, max_seq_len=8192) cos, sin = rope(seq_len=2048, device='cuda')

In attention layer

q_rotated, k_rotated = apply_rotary_pos_emb(query, key, cos, sin)

ALiBi (Attention with Linear Biases) def get_alibi_slopes(num_heads): """Get ALiBi slope values for each attention head.""" def get_slopes_power_of_2(n): start = 2 ** (-(2 ** -(math.log2(n) - 3))) ratio = start return [start * (ratio ** i) for i in range(n)]

if math.log2(num_heads).is_integer():
    return get_slopes_power_of_2(num_heads)
else:
    # Closest power of 2
    closest_power = 2 ** math.floor(math.log2(num_heads))
    slopes = get_slopes_power_of_2(closest_power)
    # Add extra slopes
    extra = get_slopes_power_of_2(2 * closest_power)
    slopes.extend(extra[0::2][:num_heads - closest_power])
    return slopes

def create_alibi_bias(seq_len, num_heads): """Create ALiBi attention bias.""" # Distance matrix context_position = torch.arange(seq_len) memory_position = torch.arange(seq_len) relative_position = memory_position[None, :] - context_position[:, None]

# Get slopes
slopes = torch.tensor(get_alibi_slopes(num_heads))

# Apply slopes to distances
alibi = slopes[:, None, None] * relative_position[None, :, :]
return alibi  # (num_heads, seq_len, seq_len)

Usage in attention

num_heads = 8 seq_len = 2048 alibi_bias = create_alibi_bias(seq_len, num_heads).to('cuda')

Add bias to attention scores

attn_scores shape: (batch, num_heads, seq_len, seq_len)

attn_scores = attn_scores + alibi_bias attn_weights = torch.softmax(attn_scores, dim=-1)

Position Interpolation for LLaMA from transformers import LlamaForCausalLM, LlamaTokenizer

Original context: 2048 tokens

model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")

Extend to 32k with position interpolation

Modify RoPE base frequency

model.config.rope_scaling = { "type": "linear", "factor": 16.0 # 2048 * 16 = 32768 }

Or use dynamic scaling

model.config.rope_scaling = { "type": "dynamic", "factor": 16.0 }

Fine-tune with long documents (minimal steps needed)

Position interpolation works out-of-the-box after this config change

Core Concepts 1. RoPE (Rotary Position Embeddings)

How it works:

Encodes absolute position via rotation matrix Provides relative position dependency in attention Enables length extrapolation

Mathematical formulation:

q_m = (W_q * x_m) * e^(imθ) k_n = (W_k * x_n) * e^(inθ)

where θ_j = base^(-2j/d) for j ∈ [0, d/2)

Advantages:

Decaying inter-token dependency with distance Compatible with linear attention Better extrapolation than absolute position encodings 2. YaRN (Yet another RoPE extensioN)

Key innovation:

NTK-aware interpolation (Neural Tangent Kernel) Attention temperature scaling Efficient context extension (10× less tokens vs baselines)

Parameters:

YaRN configuration

yarn_config = { "scale": 16, # Extension factor "original_max_position": 2048, # Base context "extrapolation_factor": 1.0, # NTK parameter "attn_factor": 1.0, # Attention scaling "beta_fast": 32, # High-frequency scale "beta_slow": 1, # Low-frequency scale }

Performance:

Extends LLaMA to 128k tokens 2.5× less training steps than baselines State-of-the-art context window extension 3. ALiBi (Attention with Linear Biases)

Core idea:

No positional embeddings added to tokens Apply distance penalty directly to attention scores Bias proportional to key-query distance

Formula:

attention_bias[i, j] = -m * |i - j|

where m = slope for each attention head

Advantages:

11% faster training vs sinusoidal embeddings 11% less memory usage Strong length extrapolation (train 1k, test 2k+) Inductive bias towards recency 4. Position Interpolation

Technique:

Linearly down-scale position indices Interpolate within trained range (vs extrapolate beyond) Minimal fine-tuning required

Formula:

Original: position indices [0, 1, 2, ..., L]

Extended: position indices [0, 0.5, 1.0, ..., L/2]

(for 2× extension)

scaled_position[i] = i / extension_factor

Results:

LLaMA 7B-65B extended to 32k tokens 1000 fine-tuning steps sufficient 600× better stability than extrapolation Method Comparison Method Max Context Training Needed Memory Extrapolation Best For RoPE 8k-32k Full pre-training Moderate Good New models YaRN 32k-128k Minimal (10× efficient) Moderate Excellent Extending existing models ALiBi Unlimited Full pre-training Low (-11%) Excellent Training from scratch Position Interpolation 32k+ Minimal (1k steps) Moderate Poor (by design) Quick extension Implementation Patterns HuggingFace Transformers Integration from transformers import AutoModelForCausalLM, AutoConfig

RoPE with YaRN scaling

config = AutoConfig.from_pretrained("mistralai/Mistral-7B-v0.1") config.rope_scaling = { "type": "yarn", "factor": 8.0, "original_max_position_embeddings": 8192, "attention_factor": 1.0 }

model = AutoModelForCausalLM.from_config(config)

Position interpolation (simpler)

config.rope_scaling = { "type": "linear", "factor": 4.0 }

Dynamic scaling (adjusts based on input length)

config.rope_scaling = { "type": "dynamic", "factor": 8.0 }

Custom RoPE Implementation class LongContextAttention(nn.Module): """Multi-head attention with RoPE."""

def __init__(self, hidden_size, num_heads, max_seq_len=32768):
    super().__init__()
    self.num_heads = num_heads
    self.head_dim = hidden_size // num_heads

    # Q, K, V projections
    self.q_proj = nn.Linear(hidden_size, hidden_size)
    self.k_proj = nn.Linear(hidden_size, hidden_size)
    self.v_proj = nn.Linear(hidden_size, hidden_size)
    self.o_proj = nn.Linear(hidden_size, hidden_size)

    # RoPE
    self.rotary_emb = RotaryEmbedding(
        dim=self.head_dim,
        max_seq_len=max_seq_len
    )

def forward(self, hidden_states):
    batch_size, seq_len, _ = hidden_states.shape

    # Project to Q, K, V
    q = self.q_proj(hidden_states)
    k = self.k_proj(hidden_states)
    v = self.v_proj(hidden_states)

    # Reshape for multi-head
    q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
    k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
    v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

    # Apply RoPE
    cos, sin = self.rotary_emb(seq_len, device=hidden_states.device)
    q, k = apply_rotary_pos_emb(q, k, cos, sin)

    # Standard attention
    attn_output = F.scaled_dot_product_attention(q, k, v)

    # Reshape and project
    attn_output = attn_output.transpose(1, 2).contiguous()
    attn_output = attn_output.view(batch_size, seq_len, -1)
    output = self.o_proj(attn_output)

    return output

Fine-tuning for Long Context Minimal Fine-tuning (Position Interpolation) from transformers import Trainer, TrainingArguments

Extend model config

model.config.max_position_embeddings = 32768 model.config.rope_scaling = {"type": "linear", "factor": 16.0}

Training args (minimal steps needed)

training_args = TrainingArguments( output_dir="./llama-32k", num_train_epochs=1, max_steps=1000, # Only 1000 steps! per_device_train_batch_size=1, gradient_accumulation_steps=16, learning_rate=2e-5, warmup_steps=100, logging_steps=10, save_steps=500, )

Train on long documents

trainer = Trainer( model=model, args=training_args, train_dataset=long_document_dataset, # 32k token sequences )

trainer.train()

YaRN Fine-tuning

Clone YaRN implementation

git clone https://github.com/jquesnelle/yarn cd yarn

Fine-tune LLaMA with YaRN

python scripts/train.py \ --model meta-llama/Llama-2-7b-hf \ --scale 16 \ --rope_theta 10000 \ --max_length 32768 \ --batch_size 1 \ --gradient_accumulation 16 \ --steps 400 \ --learning_rate 2e-5

Best Practices 1. Choose the Right Method

For NEW models (training from scratch)

use_method = "ALiBi" # Best extrapolation, lowest memory

For EXTENDING existing RoPE models

use_method = "YaRN" # Most efficient extension (10× less data)

For QUICK extension with minimal compute

use_method = "Position Interpolation" # 1000 steps

For MODERATE extension with good efficiency

use_method = "Linear RoPE Scaling" # Built-in, simple

  1. Scaling Factor Selection

Conservative (safer, better quality)

scaling_factor = 2.0 # 8k → 16k

Moderate (good balance)

scaling_factor = 4.0 # 8k → 32k

Aggressive (requires more fine-tuning)

scaling_factor = 8.0 # 8k → 64k scaling_factor = 16.0 # 8k → 128k

Rule: Larger factors need more fine-tuning steps

steps_needed = 100 * scaling_factor # Rough estimate

  1. Fine-tuning Data

✅ Good: Long documents matching target length

train_data = [ {"text": long_doc_32k_tokens}, # Full 32k {"text": long_doc_24k_tokens}, # Varied lengths {"text": long_doc_16k_tokens}, ]

❌ Bad: Short documents (won't learn long context)

train_data = [ {"text": short_doc_2k_tokens}, ]

Use datasets like:

- PG-19 (books, long texts)

- arXiv papers

- Long-form conversations

- GitHub repositories (concatenated files)

  1. Avoid Common Pitfalls

❌ Bad: Applying position interpolation without fine-tuning

model.config.rope_scaling = {"type": "linear", "factor": 16.0}

Model will perform poorly without fine-tuning!

✅ Good: Fine-tune after scaling

model.config.rope_scaling = {"type": "linear", "factor": 16.0} fine_tune(model, long_documents, steps=1000)

❌ Bad: Too aggressive scaling without data

scale_to_1M_tokens() # Won't work without massive fine-tuning

✅ Good: Incremental scaling

8k → 16k → 32k → 64k (fine-tune at each step)

Production Deployment Inference with Long Context from transformers import AutoModelForCausalLM, AutoTokenizer

Load long-context model

model = AutoModelForCausalLM.from_pretrained( "togethercomputer/LLaMA-2-7B-32K", # 32k context torch_dtype=torch.float16, device_map="auto" ) tokenizer = AutoTokenizer.from_pretrained("togethercomputer/LLaMA-2-7B-32K")

Process long document

long_text = "..." * 30000 # 30k tokens inputs = tokenizer(long_text, return_tensors="pt", truncation=False).to('cuda')

Generate

outputs = model.generate( **inputs, max_new_tokens=512, temperature=0.7, )

response = tokenizer.decode(outputs[0], skip_special_tokens=True)

Memory Optimization

Use gradient checkpointing for fine-tuning

model.gradient_checkpointing_enable()

Use Flash Attention 2

model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-2-7b-hf", attn_implementation="flash_attention_2", # 2-3× faster torch_dtype=torch.float16 )

Use paged attention (vLLM)

from vllm import LLM

llm = LLM( model="togethercomputer/LLaMA-2-7B-32K", max_model_len=32768, # 32k context gpu_memory_utilization=0.9 )

Resources RoPE Paper: https://arxiv.org/abs/2104.09864 (RoFormer) YaRN Paper: https://arxiv.org/abs/2309.00071 ALiBi Paper: https://arxiv.org/abs/2108.12409 (Train Short, Test Long) Position Interpolation: https://arxiv.org/abs/2306.15595 HuggingFace RoPE Utils: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py YaRN Implementation: https://github.com/jquesnelle/yarn Together AI Blog: https://www.together.ai/blog/llama-2-7b-32k See Also references/rope.md - Detailed RoPE implementation and theory references/extension_methods.md - YaRN, ALiBi, Position Interpolation comparisons references/fine_tuning.md - Complete fine-tuning guide for context extension

返回排行榜