Speculative Decoding: Accelerating LLM Inference with a Draft Model

21 min read
Goh Ling Yong
Technology enthusiast and software architect specializing in AI-driven development tools and modern software engineering practices. Passionate about the intersection of artificial intelligence and human creativity in building tomorrow's digital solutions.

The Inescapable Bottleneck of Autoregressive Generation

In production AI systems, the latency of Large Language Model (LLM) inference is a critical metric. For any engineer who has deployed a large autoregressive model like Llama 3 or Mistral, the fundamental performance bottleneck is painfully clear: generation is sequential. Each token is generated based on the previous one, a process dictated by the equation token_n+1 = Model(prompt + token_1 + ... + token_n).

This sequential dependency means that despite the immense parallel processing power of modern GPUs, we are fundamentally limited by the time it takes to perform a single forward pass through the model. For large models (7B+ parameters), this process is not compute-bound; it's memory-bandwidth-bound. The bulk of the latency comes from loading the model's weights from high-bandwidth memory (HBM) into the GPU's SRAM for each and every generated token.

The Key-Value (KV) cache is a standard optimization that mitigates part of this problem. It caches the intermediate attention keys and values for the prompt and previously generated tokens, so they don't need to be recomputed. While essential, the KV cache does not change the one-at-a-time nature of generation. The core loop remains:

  • Run one forward pass for one token.
  • Wait for the result.
  • Append the result to the input.
  • Repeat.

This article dissects an advanced technique to break this sequential chain: Speculative Decoding. We will bypass high-level explanations and dive directly into a production-grade implementation, analyzing the nuanced trade-offs and critical details required to achieve a 2-3x speedup in inference latency without any degradation in output quality.

The Core Algorithm: Verification over Generation

Speculative decoding, first proposed by Google DeepMind, reframes the problem. Instead of generating tokens one by one with the large, slow target model (M_t), we use a much smaller, faster draft model (M_d) to generate a sequence of candidate tokens. Then, we use the powerful target model to verify this entire sequence in a single parallel forward pass.

This is the central insight: a single forward pass of M_t over k tokens is significantly faster than k sequential forward passes.

The algorithm proceeds in these steps:

  • Drafting Phase: Given the current sequence of confirmed tokens, use the small draft model M_d to autoregressively generate a draft sequence of γ (gamma) tokens: d_1, d_2, ..., d_γ.
  • Verification Phase: Construct a new input tensor for the target model M_t that includes the confirmed tokens followed by the γ draft tokens. Run a single, parallel forward pass of M_t on this entire sequence. This yields γ+1 output logit distributions: p_0, p_1, ..., p_γ, where p_i is the target model's prediction for the token following the i-th draft token.
  • Acceptance/Rejection Phase: Iterate through the draft tokens and decide whether to accept them. For each draft token d_i at position i (from 1 to γ):
  • * Let q_i be the probability distribution from the draft model M_d that was used to sample d_i.

    * Let p_{i-1} be the probability distribution from the target model M_t for the same position.

    * Accept d_i if a randomly drawn number r from U(0, 1) is less than or equal to p_{i-1}(d_i) / q_{i-1}(d_i). This is a form of rejection sampling.

    * If d_i is accepted, continue to the next draft token d_{i+1}.

    * If d_i is rejected, all subsequent draft tokens (d_i, ..., d_γ) are also discarded.

  • Correction and Finalization:
  • * If all γ draft tokens are accepted, we sample one final token from the last distribution p_γ from the target model and append it to the accepted sequence.

    If a token d_i was rejected, we must sample a new token from a corrected* probability distribution to ensure the final output statistically matches the target model's original distribution. The new distribution is derived from (p_{i-1} - q_{i-1}). We sample from this corrected distribution, append the result, and discard the rest of the draft.

    This process guarantees that the final sequence of tokens has the exact same probability distribution as if it were generated by the target model M_t alone. The speedup comes from the fact that, on average, we accept multiple tokens for the cost of a single M_t forward pass.

    Production-Grade Implementation with `transformers`

    Let's move from theory to a concrete implementation using Python, PyTorch, and the Hugging Face transformers library. For this example, we'll use meta-llama/Llama-3-8B-Instruct as our target model and TinyLlama/TinyLlama-1.1B-Chat-v1.0 as our draft model. A key prerequisite is that both models must share the same tokenizer or have compatible token mappings.

    1. Setup and Model Loading

    First, we set up our environment and load the models onto the GPU. We'll also load the shared tokenizer.

    python
    import torch
    import time
    from transformers import AutoTokenizer, AutoModelForCausalLM
    
    # --- Configuration ---
    TARGET_MODEL_ID = "meta-llama/Llama-3-8B-Instruct"
    DRAFT_MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # --- Load Tokenizer (must be compatible)
    # Llama-3 and TinyLlama use different tokenizers. For a real production system,
    # you would need to ensure tokenizer compatibility, possibly by fine-tuning the draft
    # model or using models from the same family. For this example, we'll use the target
    # model's tokenizer and acknowledge this might slightly degrade draft quality.
    tokenizer = AutoTokenizer.from_pretrained(TARGET_MODEL_ID)
    
    # --- Load Models ---
    print("Loading target model...")
    target_model = AutoModelForCausalLM.from_pretrained(
        TARGET_MODEL_ID,
        torch_dtype=torch.bfloat16, # Use bfloat16 for performance
        device_map=DEVICE,
        attn_implementation="flash_attention_2", # Requires flash-attn library
    )
    
    print("Loading draft model...")
    draft_model = AutoModelForCausalLM.from_pretrained(
        DRAFT_MODEL_ID,
        torch_dtype=torch.bfloat16,
        device_map=DEVICE,
        attn_implementation="flash_attention_2",
    )
    
    target_model.eval()
    draft_model.eval()
    
    print(f"Models loaded on {DEVICE}")

    2. The Speculative Decoding Core Logic

    Now for the main function. This function will encapsulate the drafting, verification, and acceptance loop. Pay close attention to the management of the KV caches (past_key_values), which is critical for performance.

    python
    @torch.no_grad()
    def speculative_decode(
        prompt: str,
        target_model: AutoModelForCausalLM,
        draft_model: AutoModelForCausalLM,
        tokenizer: AutoTokenizer,
        max_new_tokens: int = 128,
        gamma: int = 4, # Number of draft tokens
        temperature: float = 0.7,
        top_p: float = 0.9,
    ):
        # --- Initialization ---
        input_ids = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)
        generated_tokens = list(input_ids.cpu().numpy()[0])
        n_generated = 0
    
        # KV caches for both models
        target_past_key_values = None
        draft_past_key_values = None
    
        # --- Generation Loop ---
        while n_generated < max_new_tokens:
            # --- 1. Drafting Phase ---
            draft_tokens = []
            draft_logits = []
            
            current_draft_input = input_ids
            for _ in range(gamma):
                draft_outputs = draft_model(
                    current_draft_input, 
                    past_key_values=draft_past_key_values,
                    use_cache=True
                )
                # Get logits for the next token, apply temperature, etc.
                next_token_logits = draft_outputs.logits[:, -1, :]
                if temperature > 0:
                    next_token_logits = next_token_logits / temperature
                
                # For simplicity, we use greedy sampling for the draft.
                # A more advanced implementation might use stochastic sampling.
                next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
                
                draft_tokens.append(next_token)
                draft_logits.append(next_token_logits)
                current_draft_input = next_token
                draft_past_key_values = draft_outputs.past_key_values
    
            # --- 2. Verification Phase ---
            draft_sequence = torch.cat(draft_tokens, dim=1)
            verify_input_ids = torch.cat([input_ids, draft_sequence], dim=1)
            
            target_outputs = target_model(
                verify_input_ids, 
                past_key_values=target_past_key_values,
                use_cache=True
            )
    
            # --- 3. Acceptance/Rejection Phase ---
            n_accepted = 0
            for i in range(gamma):
                # Target model's probability distribution for the i-th draft token
                target_verify_logits = target_outputs.logits[:, i, :]
                if temperature > 0:
                    target_verify_logits = target_verify_logits / temperature
                target_probs = torch.nn.functional.softmax(target_verify_logits, dim=-1)
    
                # Draft model's probability for the i-th draft token
                draft_token_logits = draft_logits[i]
                draft_probs = torch.nn.functional.softmax(draft_token_logits, dim=-1)
                
                draft_token_id = draft_sequence[0, i]
    
                p = target_probs[0, draft_token_id]
                q = draft_probs[0, draft_token_id]
    
                # Rejection sampling
                if torch.rand(1).item() <= (p / q).item():
                    # Accept
                    n_accepted += 1
                else:
                    # Reject
                    # Sample from the corrected distribution (p - q)+
                    corrected_probs = torch.clamp(target_probs - draft_probs, min=0.0)
                    norm_factor = corrected_probs.sum()
                    if norm_factor > 1e-6:
                        corrected_probs /= norm_factor
                        resampled_token = torch.multinomial(corrected_probs, num_samples=1)
                    else:
                        # Fallback to sampling from the original target distribution
                        resampled_token = torch.multinomial(target_probs, num_samples=1)
                    
                    # Append the resampled token and break
                    draft_sequence = torch.cat([draft_sequence[:, :n_accepted], resampled_token], dim=1)
                    break
    
            # --- 4. Finalization and KV Cache Update ---
            accepted_tokens = draft_sequence
            accepted_len = accepted_tokens.shape[1]
    
            # Update the list of all generated tokens
            newly_generated = list(accepted_tokens.cpu().numpy()[0])
            generated_tokens.extend(newly_generated)
            n_generated += len(newly_generated)
    
            # Critical Step: Update KV Caches
            # The input for the next iteration is the sequence of accepted tokens
            input_ids = accepted_tokens.to(DEVICE)
            
            # The target model's KV cache needs to be trimmed to match the accepted length
            target_past_key_values = tuple(
                (k[:, :, :-(gamma - accepted_len), :], v[:, :, :-(gamma - accepted_len), :])
                for k, v in target_outputs.past_key_values
            )
    
            # The draft model's cache must be synchronized with the target model's state
            # This is a complex step. A robust implementation would re-compute the draft cache
            # based on the accepted sequence to ensure perfect alignment.
            # For simplicity here, we reset it, but this is suboptimal.
            # A production system would use a more efficient cache synchronization strategy.
            draft_past_key_values = target_past_key_values
            
            if n_generated >= max_new_tokens:
                break
    
        return tokenizer.decode(generated_tokens, skip_special_tokens=True)
    

    Performance Analysis and Benchmarking

    The entire point of this complexity is speed. To quantify the gains, we must benchmark our implementation against the standard model.generate() method.

    The key metrics are:

    * Tokens per Second (TPS): The ultimate measure of throughput.

    * Acceptance Rate: The average number of draft tokens accepted per verification step. A higher rate means higher efficiency.

    * Speedup: The ratio of standard generation time to speculative generation time.

    Here is a simple benchmarking script:

    python
    def benchmark():
        prompt = "The field of artificial intelligence has seen remarkable progress in recent years, particularly in the domain of"
        max_new_tokens = 256
        gamma_values = [2, 4, 6, 8]
    
        print("--- Standard Autoregressive Decoding Benchmark ---")
        start_time = time.time()
        _ = target_model.generate(
            tokenizer.encode(prompt, return_tensors="pt").to(DEVICE),
            max_new_tokens=max_new_tokens,
            do_sample=False, # Use greedy for fair comparison
            pad_token_id=tokenizer.eos_token_id
        )
        end_time = time.time()
        standard_time = end_time - start_time
        standard_tps = max_new_tokens / standard_time
        print(f"Standard Generation Time: {standard_time:.2f}s")
        print(f"Standard Tokens/Second: {standard_tps:.2f}\n")
    
        print("--- Speculative Decoding Benchmark ---")
        results = []
        for gamma in gamma_values:
            print(f"Benchmarking with gamma = {gamma}...")
            start_time = time.time()
            speculative_decode(
                prompt,
                target_model,
                draft_model,
                tokenizer,
                max_new_tokens=max_new_tokens,
                gamma=gamma,
                temperature=0, # Use greedy for fair comparison
            )
            end_time = time.time()
            speculative_time = end_time - start_time
            speculative_tps = max_new_tokens / speculative_time
            speedup = standard_time / speculative_time
            results.append({
                "gamma": gamma,
                "time": speculative_time,
                "tps": speculative_tps,
                "speedup": speedup
            })
    
        # Print results in a markdown table
        print("| Gamma (γ) | Time (s) | Tokens/Second | Speedup vs. Standard |")
        print("|-----------|----------|---------------|----------------------|")
        print(f"| Standard  | {standard_time:.2f}     | {standard_tps:.2f}          | 1.00x                |")
        for res in results:
            print(f"| {res['gamma']:^9} | {res['time']:.2f}     | {res['tps']:.2f}          | {res['speedup']:.2f}x                |")
    
    # Run the benchmark
    # benchmark()

    Expected Benchmark Results

    Running this on a high-end GPU (e.g., an A100 or H100) would yield results similar to this hypothetical table:

    Gamma (γ)Time (s)Tokens/SecondSpeedup vs. Standard
    Standard8.5330.011.00x
    24.7154.351.81x
    43.1082.582.75x
    63.2578.772.62x
    83.6869.562.32x

    Analysis:

    * The speedup is significant, peaking at 2.75x with gamma=4.

    * There is a clear "sweet spot" for gamma. As gamma increases, the draft model is more likely to make a mistake, leading to a lower acceptance rate. The overhead of generating a long, incorrect draft sequence and running a large verification pass outweighs the benefit.

    * The optimal gamma depends on the quality of the draft model and the complexity of the text being generated.

    Advanced Considerations and Production Pitfalls

    While the implementation above demonstrates the core concept, several nuances are critical for a robust, production-ready system.

    1. Draft Model Selection and Alignment

    The choice of the draft model is the most important factor for success.

    * Speed: It must be substantially faster than the target model. A good rule of thumb is for the draft model's forward pass to be at least 5-10x faster.

    * Quality: It doesn't need to be perfect, but its probability distribution should be a reasonable approximation of the target model's. The higher the correlation, the higher the acceptance rate. The best results are often achieved by using a distilled version of the target model or a smaller model from the same family that has been fine-tuned on the target model's outputs.

    2. KV Cache Management: The Devil in the Details

    Our simplified implementation reset the draft model's KV cache, which is inefficient. A production system requires meticulous synchronization.

    When n_accepted tokens are accepted, the target model's past_key_values from its verification pass already contain the state for these n_accepted tokens. You must carefully slice this cache to the correct length.

    The draft model's KV cache must then be brought into the same state. The most robust way is to perform a forward pass with the draft model on the accepted sequence to regenerate its cache. While this adds overhead, it's often faster than starting from scratch and ensures perfect alignment for the next drafting phase.

    python
    # A more robust KV cache update logic
    # ... after accepting `accepted_tokens` of length `accepted_len`
    
    # Update target model's cache (slicing)
    new_seq_len = target_past_key_values[0][0].shape[2] - (gamma - accepted_len)
    target_past_key_values = tuple(
        (k[:, :, :new_seq_len, :], v[:, :, :new_seq_len, :])
        for k, v in target_outputs.past_key_values
    )
    
    # Update draft model's cache (resynchronization)
    # This is the key step to avoid divergence
    resync_outputs = draft_model(input_ids, past_key_values=initial_draft_kv_state, use_cache=True)
    draft_past_key_values = resync_outputs.past_key_values

    3. The Mathematics of Lossless Acceleration

    The reason speculative decoding is "lossless" (i.e., produces the same distribution as the target model) lies in the correction step. When a draft token d_i is rejected, we are left with two distributions for that timestep: p from the target and q from the draft. The probability of having accepted the prefix d_1, ..., d_{i-1} and then rejecting d_i is complex. The rejection sampling theorem provides a way out.

    By sampling from a new distribution p' = (p - q)+ / Z where Z is a normalization constant and + denotes the positive part (clamping negative values to zero), we are effectively sampling from the target distribution p conditioned on the event that the output is not one of the tokens that would have been accepted. This mathematical sleight of hand is what preserves the statistical integrity of the output stream.

    4. Interaction with Sampling Techniques

    Our example used greedy sampling for simplicity. When using stochastic methods like top_k or top_p sampling:

    * The drafting phase can use any sampling method. Greedy is often fastest.

    * The verification phase must use the logits from the target model. The acceptance/rejection logic remains based on the full probability distributions.

    * The correction/finalization sampling (after a rejection or full acceptance) MUST use the desired sampling method (top_k, top_p) on the target model's (or corrected) logits. This ensures the final output adheres to the user's sampling constraints.

    Conclusion: A Powerful Tool for Production Inference

    Speculative decoding is not a simple drop-in replacement for standard generation. It is an advanced systems-level optimization that requires a deep understanding of model architectures, KV cache mechanics, and probability theory. However, for applications where inference latency is a critical bottleneck, the potential 2-3x speedup is a game-changer.

    By carefully selecting a fast and well-aligned draft model, managing the KV caches with precision, and correctly implementing the rejection sampling loop, engineering teams can significantly reduce the cost and improve the user experience of their deployed LLMs. As models continue to grow, techniques like speculative decoding that tackle the memory bandwidth wall will become not just advantageous, but essential for building responsive and scalable AI products.

    Found this article helpful?

    Share it with others who might benefit from it.

    More Articles