Speculative Decoding: Accelerating LLM Inference with a Draft Model
The Inescapable Bottleneck of Autoregressive Generation
In production AI systems, the latency of Large Language Model (LLM) inference is a critical metric. For any engineer who has deployed a large autoregressive model like Llama 3 or Mistral, the fundamental performance bottleneck is painfully clear: generation is sequential. Each token is generated based on the previous one, a process dictated by the equation token_n+1 = Model(prompt + token_1 + ... + token_n).
This sequential dependency means that despite the immense parallel processing power of modern GPUs, we are fundamentally limited by the time it takes to perform a single forward pass through the model. For large models (7B+ parameters), this process is not compute-bound; it's memory-bandwidth-bound. The bulk of the latency comes from loading the model's weights from high-bandwidth memory (HBM) into the GPU's SRAM for each and every generated token.
The Key-Value (KV) cache is a standard optimization that mitigates part of this problem. It caches the intermediate attention keys and values for the prompt and previously generated tokens, so they don't need to be recomputed. While essential, the KV cache does not change the one-at-a-time nature of generation. The core loop remains:
- Run one forward pass for one token.
- Wait for the result.
- Append the result to the input.
- Repeat.
This article dissects an advanced technique to break this sequential chain: Speculative Decoding. We will bypass high-level explanations and dive directly into a production-grade implementation, analyzing the nuanced trade-offs and critical details required to achieve a 2-3x speedup in inference latency without any degradation in output quality.
The Core Algorithm: Verification over Generation
Speculative decoding, first proposed by Google DeepMind, reframes the problem. Instead of generating tokens one by one with the large, slow target model (M_t), we use a much smaller, faster draft model (M_d) to generate a sequence of candidate tokens. Then, we use the powerful target model to verify this entire sequence in a single parallel forward pass.
This is the central insight: a single forward pass of M_t over k tokens is significantly faster than k sequential forward passes.
The algorithm proceeds in these steps:
M_d to autoregressively generate a draft sequence of γ (gamma) tokens: d_1, d_2, ..., d_γ.M_t that includes the confirmed tokens followed by the γ draft tokens. Run a single, parallel forward pass of M_t on this entire sequence. This yields γ+1 output logit distributions: p_0, p_1, ..., p_γ, where p_i is the target model's prediction for the token following the i-th draft token.d_i at position i (from 1 to γ): * Let q_i be the probability distribution from the draft model M_d that was used to sample d_i.
* Let p_{i-1} be the probability distribution from the target model M_t for the same position.
* Accept d_i if a randomly drawn number r from U(0, 1) is less than or equal to p_{i-1}(d_i) / q_{i-1}(d_i). This is a form of rejection sampling.
* If d_i is accepted, continue to the next draft token d_{i+1}.
* If d_i is rejected, all subsequent draft tokens (d_i, ..., d_γ) are also discarded.
* If all γ draft tokens are accepted, we sample one final token from the last distribution p_γ from the target model and append it to the accepted sequence.
If a token d_i was rejected, we must sample a new token from a corrected* probability distribution to ensure the final output statistically matches the target model's original distribution. The new distribution is derived from (p_{i-1} - q_{i-1}). We sample from this corrected distribution, append the result, and discard the rest of the draft.
This process guarantees that the final sequence of tokens has the exact same probability distribution as if it were generated by the target model M_t alone. The speedup comes from the fact that, on average, we accept multiple tokens for the cost of a single M_t forward pass.
Production-Grade Implementation with `transformers`
Let's move from theory to a concrete implementation using Python, PyTorch, and the Hugging Face transformers library. For this example, we'll use meta-llama/Llama-3-8B-Instruct as our target model and TinyLlama/TinyLlama-1.1B-Chat-v1.0 as our draft model. A key prerequisite is that both models must share the same tokenizer or have compatible token mappings.
1. Setup and Model Loading
First, we set up our environment and load the models onto the GPU. We'll also load the shared tokenizer.
import torch
import time
from transformers import AutoTokenizer, AutoModelForCausalLM
# --- Configuration ---
TARGET_MODEL_ID = "meta-llama/Llama-3-8B-Instruct"
DRAFT_MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- Load Tokenizer (must be compatible)
# Llama-3 and TinyLlama use different tokenizers. For a real production system,
# you would need to ensure tokenizer compatibility, possibly by fine-tuning the draft
# model or using models from the same family. For this example, we'll use the target
# model's tokenizer and acknowledge this might slightly degrade draft quality.
tokenizer = AutoTokenizer.from_pretrained(TARGET_MODEL_ID)
# --- Load Models ---
print("Loading target model...")
target_model = AutoModelForCausalLM.from_pretrained(
TARGET_MODEL_ID,
torch_dtype=torch.bfloat16, # Use bfloat16 for performance
device_map=DEVICE,
attn_implementation="flash_attention_2", # Requires flash-attn library
)
print("Loading draft model...")
draft_model = AutoModelForCausalLM.from_pretrained(
DRAFT_MODEL_ID,
torch_dtype=torch.bfloat16,
device_map=DEVICE,
attn_implementation="flash_attention_2",
)
target_model.eval()
draft_model.eval()
print(f"Models loaded on {DEVICE}")
2. The Speculative Decoding Core Logic
Now for the main function. This function will encapsulate the drafting, verification, and acceptance loop. Pay close attention to the management of the KV caches (past_key_values), which is critical for performance.
@torch.no_grad()
def speculative_decode(
prompt: str,
target_model: AutoModelForCausalLM,
draft_model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
max_new_tokens: int = 128,
gamma: int = 4, # Number of draft tokens
temperature: float = 0.7,
top_p: float = 0.9,
):
# --- Initialization ---
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)
generated_tokens = list(input_ids.cpu().numpy()[0])
n_generated = 0
# KV caches for both models
target_past_key_values = None
draft_past_key_values = None
# --- Generation Loop ---
while n_generated < max_new_tokens:
# --- 1. Drafting Phase ---
draft_tokens = []
draft_logits = []
current_draft_input = input_ids
for _ in range(gamma):
draft_outputs = draft_model(
current_draft_input,
past_key_values=draft_past_key_values,
use_cache=True
)
# Get logits for the next token, apply temperature, etc.
next_token_logits = draft_outputs.logits[:, -1, :]
if temperature > 0:
next_token_logits = next_token_logits / temperature
# For simplicity, we use greedy sampling for the draft.
# A more advanced implementation might use stochastic sampling.
next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
draft_tokens.append(next_token)
draft_logits.append(next_token_logits)
current_draft_input = next_token
draft_past_key_values = draft_outputs.past_key_values
# --- 2. Verification Phase ---
draft_sequence = torch.cat(draft_tokens, dim=1)
verify_input_ids = torch.cat([input_ids, draft_sequence], dim=1)
target_outputs = target_model(
verify_input_ids,
past_key_values=target_past_key_values,
use_cache=True
)
# --- 3. Acceptance/Rejection Phase ---
n_accepted = 0
for i in range(gamma):
# Target model's probability distribution for the i-th draft token
target_verify_logits = target_outputs.logits[:, i, :]
if temperature > 0:
target_verify_logits = target_verify_logits / temperature
target_probs = torch.nn.functional.softmax(target_verify_logits, dim=-1)
# Draft model's probability for the i-th draft token
draft_token_logits = draft_logits[i]
draft_probs = torch.nn.functional.softmax(draft_token_logits, dim=-1)
draft_token_id = draft_sequence[0, i]
p = target_probs[0, draft_token_id]
q = draft_probs[0, draft_token_id]
# Rejection sampling
if torch.rand(1).item() <= (p / q).item():
# Accept
n_accepted += 1
else:
# Reject
# Sample from the corrected distribution (p - q)+
corrected_probs = torch.clamp(target_probs - draft_probs, min=0.0)
norm_factor = corrected_probs.sum()
if norm_factor > 1e-6:
corrected_probs /= norm_factor
resampled_token = torch.multinomial(corrected_probs, num_samples=1)
else:
# Fallback to sampling from the original target distribution
resampled_token = torch.multinomial(target_probs, num_samples=1)
# Append the resampled token and break
draft_sequence = torch.cat([draft_sequence[:, :n_accepted], resampled_token], dim=1)
break
# --- 4. Finalization and KV Cache Update ---
accepted_tokens = draft_sequence
accepted_len = accepted_tokens.shape[1]
# Update the list of all generated tokens
newly_generated = list(accepted_tokens.cpu().numpy()[0])
generated_tokens.extend(newly_generated)
n_generated += len(newly_generated)
# Critical Step: Update KV Caches
# The input for the next iteration is the sequence of accepted tokens
input_ids = accepted_tokens.to(DEVICE)
# The target model's KV cache needs to be trimmed to match the accepted length
target_past_key_values = tuple(
(k[:, :, :-(gamma - accepted_len), :], v[:, :, :-(gamma - accepted_len), :])
for k, v in target_outputs.past_key_values
)
# The draft model's cache must be synchronized with the target model's state
# This is a complex step. A robust implementation would re-compute the draft cache
# based on the accepted sequence to ensure perfect alignment.
# For simplicity here, we reset it, but this is suboptimal.
# A production system would use a more efficient cache synchronization strategy.
draft_past_key_values = target_past_key_values
if n_generated >= max_new_tokens:
break
return tokenizer.decode(generated_tokens, skip_special_tokens=True)
Performance Analysis and Benchmarking
The entire point of this complexity is speed. To quantify the gains, we must benchmark our implementation against the standard model.generate() method.
The key metrics are:
* Tokens per Second (TPS): The ultimate measure of throughput.
* Acceptance Rate: The average number of draft tokens accepted per verification step. A higher rate means higher efficiency.
* Speedup: The ratio of standard generation time to speculative generation time.
Here is a simple benchmarking script:
def benchmark():
prompt = "The field of artificial intelligence has seen remarkable progress in recent years, particularly in the domain of"
max_new_tokens = 256
gamma_values = [2, 4, 6, 8]
print("--- Standard Autoregressive Decoding Benchmark ---")
start_time = time.time()
_ = target_model.generate(
tokenizer.encode(prompt, return_tensors="pt").to(DEVICE),
max_new_tokens=max_new_tokens,
do_sample=False, # Use greedy for fair comparison
pad_token_id=tokenizer.eos_token_id
)
end_time = time.time()
standard_time = end_time - start_time
standard_tps = max_new_tokens / standard_time
print(f"Standard Generation Time: {standard_time:.2f}s")
print(f"Standard Tokens/Second: {standard_tps:.2f}\n")
print("--- Speculative Decoding Benchmark ---")
results = []
for gamma in gamma_values:
print(f"Benchmarking with gamma = {gamma}...")
start_time = time.time()
speculative_decode(
prompt,
target_model,
draft_model,
tokenizer,
max_new_tokens=max_new_tokens,
gamma=gamma,
temperature=0, # Use greedy for fair comparison
)
end_time = time.time()
speculative_time = end_time - start_time
speculative_tps = max_new_tokens / speculative_time
speedup = standard_time / speculative_time
results.append({
"gamma": gamma,
"time": speculative_time,
"tps": speculative_tps,
"speedup": speedup
})
# Print results in a markdown table
print("| Gamma (γ) | Time (s) | Tokens/Second | Speedup vs. Standard |")
print("|-----------|----------|---------------|----------------------|")
print(f"| Standard | {standard_time:.2f} | {standard_tps:.2f} | 1.00x |")
for res in results:
print(f"| {res['gamma']:^9} | {res['time']:.2f} | {res['tps']:.2f} | {res['speedup']:.2f}x |")
# Run the benchmark
# benchmark()
Expected Benchmark Results
Running this on a high-end GPU (e.g., an A100 or H100) would yield results similar to this hypothetical table:
| Gamma (γ) | Time (s) | Tokens/Second | Speedup vs. Standard |
|---|---|---|---|
| Standard | 8.53 | 30.01 | 1.00x |
| 2 | 4.71 | 54.35 | 1.81x |
| 4 | 3.10 | 82.58 | 2.75x |
| 6 | 3.25 | 78.77 | 2.62x |
| 8 | 3.68 | 69.56 | 2.32x |
Analysis:
* The speedup is significant, peaking at 2.75x with gamma=4.
* There is a clear "sweet spot" for gamma. As gamma increases, the draft model is more likely to make a mistake, leading to a lower acceptance rate. The overhead of generating a long, incorrect draft sequence and running a large verification pass outweighs the benefit.
* The optimal gamma depends on the quality of the draft model and the complexity of the text being generated.
Advanced Considerations and Production Pitfalls
While the implementation above demonstrates the core concept, several nuances are critical for a robust, production-ready system.
1. Draft Model Selection and Alignment
The choice of the draft model is the most important factor for success.
* Speed: It must be substantially faster than the target model. A good rule of thumb is for the draft model's forward pass to be at least 5-10x faster.
* Quality: It doesn't need to be perfect, but its probability distribution should be a reasonable approximation of the target model's. The higher the correlation, the higher the acceptance rate. The best results are often achieved by using a distilled version of the target model or a smaller model from the same family that has been fine-tuned on the target model's outputs.
2. KV Cache Management: The Devil in the Details
Our simplified implementation reset the draft model's KV cache, which is inefficient. A production system requires meticulous synchronization.
When n_accepted tokens are accepted, the target model's past_key_values from its verification pass already contain the state for these n_accepted tokens. You must carefully slice this cache to the correct length.
The draft model's KV cache must then be brought into the same state. The most robust way is to perform a forward pass with the draft model on the accepted sequence to regenerate its cache. While this adds overhead, it's often faster than starting from scratch and ensures perfect alignment for the next drafting phase.
# A more robust KV cache update logic
# ... after accepting `accepted_tokens` of length `accepted_len`
# Update target model's cache (slicing)
new_seq_len = target_past_key_values[0][0].shape[2] - (gamma - accepted_len)
target_past_key_values = tuple(
(k[:, :, :new_seq_len, :], v[:, :, :new_seq_len, :])
for k, v in target_outputs.past_key_values
)
# Update draft model's cache (resynchronization)
# This is the key step to avoid divergence
resync_outputs = draft_model(input_ids, past_key_values=initial_draft_kv_state, use_cache=True)
draft_past_key_values = resync_outputs.past_key_values
3. The Mathematics of Lossless Acceleration
The reason speculative decoding is "lossless" (i.e., produces the same distribution as the target model) lies in the correction step. When a draft token d_i is rejected, we are left with two distributions for that timestep: p from the target and q from the draft. The probability of having accepted the prefix d_1, ..., d_{i-1} and then rejecting d_i is complex. The rejection sampling theorem provides a way out.
By sampling from a new distribution p' = (p - q)+ / Z where Z is a normalization constant and + denotes the positive part (clamping negative values to zero), we are effectively sampling from the target distribution p conditioned on the event that the output is not one of the tokens that would have been accepted. This mathematical sleight of hand is what preserves the statistical integrity of the output stream.
4. Interaction with Sampling Techniques
Our example used greedy sampling for simplicity. When using stochastic methods like top_k or top_p sampling:
* The drafting phase can use any sampling method. Greedy is often fastest.
* The verification phase must use the logits from the target model. The acceptance/rejection logic remains based on the full probability distributions.
* The correction/finalization sampling (after a rejection or full acceptance) MUST use the desired sampling method (top_k, top_p) on the target model's (or corrected) logits. This ensures the final output adheres to the user's sampling constraints.
Conclusion: A Powerful Tool for Production Inference
Speculative decoding is not a simple drop-in replacement for standard generation. It is an advanced systems-level optimization that requires a deep understanding of model architectures, KV cache mechanics, and probability theory. However, for applications where inference latency is a critical bottleneck, the potential 2-3x speedup is a game-changer.
By carefully selecting a fast and well-aligned draft model, managing the KV caches with precision, and correctly implementing the rejection sampling loop, engineering teams can significantly reduce the cost and improve the user experience of their deployed LLMs. As models continue to grow, techniques like speculative decoding that tackle the memory bandwidth wall will become not just advantageous, but essential for building responsive and scalable AI products.