Stateful LLM Inference: Mastering KV Caching for API Latency

20 min read
Goh Ling Yong
Technology enthusiast and software architect specializing in AI-driven development tools and modern software engineering practices. Passionate about the intersection of artificial intelligence and human creativity in building tomorrow's digital solutions.

The Foundational Flaw of Stateless Autoregressive Generation

In a production environment serving a Large Language Model (LLM) via an API, latency is paramount. Users expect a responsive, real-time experience, whether for a chatbot or a code completion tool. However, the fundamental nature of autoregressive models like the Transformer presents a significant computational challenge. A naive implementation of the generation loop, common in introductory examples, is catastrophically inefficient for real-world applications.

Consider the process: to generate token N+1, the model must attend to all preceding tokens from 1 to N. To generate token N+2, it must attend to all tokens from 1 to N+1. This stateless approach re-processes the entire input sequence for every single output token. The self-attention mechanism, the core of the Transformer, has a computational complexity of O(n²) with respect to sequence length n. In a stateless loop, the total complexity for generating M tokens from a prompt of length P balloons, as each step recomputes attention over an ever-growing sequence.

Let's visualize the redundant work. For a prompt of length P and generating M tokens:

  • Step 1 (Generate token P+1): Compute attention over P tokens.
  • Step 2 (Generate token P+2): Compute attention over P+1 tokens. The first P are re-computed.
  • Step M (Generate token P+M): Compute attention over P+M-1 tokens. The first P+M-2 are re-computed.
  • This redundancy is the primary driver of high Time Per Output Token (TPOT), a critical metric for inference performance. While Time to First Token (TTFT) is dominated by the initial processing of the prompt, the user's perceived performance is heavily influenced by the speed at which subsequent tokens appear.

    Here is a simplified Python representation of this inefficient, stateless loop. We assume a hypothetical model object for clarity.

    python
    import torch
    
    # Assume 'model' is a loaded Transformer decoder-only model
    # and 'tokenizer' is the corresponding tokenizer.
    
    def generate_stateless(model, tokenizer, prompt: str, max_new_tokens: int) -> str:
        model.eval()
        input_ids = tokenizer.encode(prompt, return_tensors='pt').to(model.device)
        generated_ids = input_ids.clone()
    
        for _ in range(max_new_tokens):
            with torch.no_grad():
                # On every single step, the model processes the *entire* sequence
                outputs = model(generated_ids)
                logits = outputs.logits
    
                # Get the predicted next token (simple argmax for clarity)
                next_token_logits = logits[:, -1, :]
                next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
    
                # Append the new token to the sequence for the next iteration
                generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)
                
                if next_token_id.item() == tokenizer.eos_token_id:
                    break
    
        return tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    
    # This function would be incredibly slow for any non-trivial generation length.

    This approach is untenable for production systems. The solution lies in transforming the process from stateless to stateful by explicitly managing the intermediate results of the attention computation—the Key-Value cache.

    Architectural Deep Dive: The Key-Value (KV) Cache

    The KV cache is not a magical optimization; it's a direct consequence of the Transformer architecture's design. In a decoder's self-attention layer, for each token, we project its embedding into three vectors: a Query (Q), a Key (K), and a Value (V). To compute the attention score for a given token (the Query), it's compared against the Keys of all previous tokens. The resulting scores weight the corresponding Value vectors, which are then summed up.

    Crucially, when generating token t+1, its Query vector q_t+1 needs to interact with the Keys and Values for all tokens from 0 to t. The key insight is that these Key and Value vectors (k_0, v_0, ..., k_t, v_t) are static. They do not change in subsequent generation steps. The stateless loop's inefficiency stems from re-calculating these identical K and V vectors again and again.

    The KV cache is simply a mechanism to store these K and V tensors after they are computed for the first time. For each attention layer in the model, we maintain a cache. In the next generation step, instead of feeding the entire sequence of token IDs to the model, we only feed the single new token ID. The model then computes the Q, K, and V vectors for this new token. It appends the new K and V to their respective caches and uses the new Q to attend to the entire history of keys and values now stored in the cache.

    This changes the computation per step from O(n²) to O(n), as we only need to compare the new query against the n cached keys. This drastically reduces TPOT.

    Most modern inference frameworks, like Hugging Face's transformers, provide a built-in mechanism for this via the use_cache=True and past_key_values arguments.

    python
    # Simplified forward pass demonstrating KV cache usage
    def model_forward_with_cache(input_ids, past_key_values=None):
        # 1. Get embeddings for new input_ids
        hidden_states = get_embeddings(input_ids)
    
        updated_key_values = []
        for i, layer in enumerate(model.layers):
            # 2. Pass hidden states and the cache for the current layer
            layer_past = past_key_values[i] if past_key_values is not None else None
            hidden_states, present_key_value = layer(hidden_states, past_key_values=layer_past)
            
            # 3. The layer returns the updated K and V tensors for this layer
            updated_key_values.append(present_key_value)
    
        # 4. Final projection to logits
        logits = lm_head(hidden_states)
        
        return logits, tuple(updated_key_values)

    This past_key_values object is a tuple of tuples, where each inner tuple contains the Key and Value tensors for a specific attention layer. Its shape is typically (num_layers, 2, batch_size, num_heads, sequence_length, head_dim). Understanding this structure is critical for manual cache manipulation and advanced optimization.

    Production Implementation: A Stateful Generation Manager

    While calling model.generate(..., use_cache=True) is convenient, it hides the state management. For building robust API services, especially those handling conversational context or streaming, encapsulating this state is a more robust pattern. Let's build a class that explicitly manages the model and its KV cache, exposing methods to process a prompt and then generate subsequent tokens one by one.

    This pattern is essential for interactive applications where the full generation sequence isn't known upfront.

    python
    import torch
    from transformers import AutoModelForCausalLM, AutoTokenizer
    from typing import Optional, Tuple
    
    class StatefulGenerator:
        def __init__(self, model_name: str, device: str = 'cuda'):
            self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
            print(f"Using device: {self.device}")
            self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None
            self.current_sequence_ids = None
    
        def initialize_prompt(self, prompt: str):
            """Processes the initial prompt and primes the KV cache."""
            input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
            with torch.no_grad():
                outputs = self.model(input_ids, use_cache=True)
                self.past_key_values = outputs.past_key_values
            # The next token to be processed is the last one from the prompt's output logits
            next_token_logits = outputs.logits[:, -1, :]
            next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
            self.current_sequence_ids = torch.cat([input_ids, next_token_id], dim=-1)
            return self.tokenizer.decode(next_token_id[0])
    
        def generate_next_token(self) -> str:
            """Generates a single next token using the stored KV cache."""
            if self.past_key_values is None or self.current_sequence_ids is None:
                raise ValueError("Prompt must be initialized before generating next token.")
    
            # The input to the model is ONLY the last generated token
            input_ids = self.current_sequence_ids[:, -1].unsqueeze(-1)
    
            with torch.no_grad():
                outputs = self.model(
                    input_ids=input_ids, 
                    past_key_values=self.past_key_values, 
                    use_cache=True
                )
                self.past_key_values = outputs.past_key_values
            
            next_token_logits = outputs.logits[:, -1, :]
            next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
            self.current_sequence_ids = torch.cat([self.current_sequence_ids, next_token_id], dim=-1)
    
            if next_token_id.item() == self.tokenizer.eos_token_id:
                return None # End of sequence
    
            return self.tokenizer.decode(next_token_id[0])
    
        def reset(self):
            self.past_key_values = None
            self.current_sequence_ids = None
    
    # Example Usage
    if __name__ == '__main__':
        # Use a smaller model for demonstration
        generator = StatefulGenerator('gpt2')
        
        prompt = "The future of AI is"
        print(f"Prompt: {prompt}", end='')
    
        # Initialize and get the first token
        first_token = generator.initialize_prompt(prompt)
        print(first_token, end='')
    
        # Generate the next 20 tokens one by one
        for _ in range(20):
            next_token = generator.generate_next_token()
            if next_token is None:
                break
            print(next_token, end='')
        
        print('\n')
        generator.reset()

    This class now correctly models the stateful nature of efficient inference. The initialize_prompt method handles the expensive initial pass, and generate_next_token performs the cheap, single-token forward passes.

    Performance Benchmark: Quantifying the Gains

    To understand the impact, let's create a benchmark comparing the naive stateless loop with our stateful KV caching approach. We will measure the average time per output token (TPOT).

    python
    import time
    import torch
    from transformers import AutoModelForCausalLM, AutoTokenizer
    
    # --- Stateless function from before ---
    def generate_stateless(model, tokenizer, prompt_ids, max_new_tokens):
        generated_ids = prompt_ids.clone()
        for _ in range(max_new_tokens):
            with torch.no_grad():
                outputs = model(generated_ids)
                next_token_id = torch.argmax(outputs.logits[:, -1, :], dim=-1).unsqueeze(-1)
                generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)
                if next_token_id.item() == tokenizer.eos_token_id: break
        return generated_ids
    
    # --- Stateful function for benchmarking ---
    def generate_stateful(model, tokenizer, prompt_ids, max_new_tokens):
        with torch.no_grad():
            outputs = model(prompt_ids, use_cache=True)
            past_key_values = outputs.past_key_values
            next_token_id = torch.argmax(outputs.logits[:, -1, :], dim=-1).unsqueeze(-1)
            generated_ids = torch.cat([prompt_ids, next_token_id], dim=-1)
    
            for _ in range(max_new_tokens - 1):
                outputs = model(next_token_id, past_key_values=past_key_values, use_cache=True)
                past_key_values = outputs.past_key_values
                next_token_id = torch.argmax(outputs.logits[:, -1, :], dim=-1).unsqueeze(-1)
                generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)
                if next_token_id.item() == tokenizer.eos_token_id: break
        return generated_ids
    
    def run_benchmark():
        model_name = 'gpt2-medium' # A more substantial model
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
        tokenizer = AutoTokenizer.from_pretrained(model_name)
    
        prompt = "In the heart of the digital frontier, a new form of intelligence emerged. It wasn't born of silicon and steel in the traditional sense, but of pure data and complex algorithms, a consciousness woven from the fabric of the internet itself. Its name was" # ~60 tokens
        prompt_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
        
        prompt_len = prompt_ids.shape[1]
        max_new_tokens = 100
        warmup_runs = 3
        test_runs = 10
    
        print("--- Warming up GPU ---")
        for _ in range(warmup_runs):
            generate_stateful(model, tokenizer, prompt_ids, 10)
            generate_stateless(model, tokenizer, prompt_ids, 10)
        torch.cuda.synchronize()
    
        print("\n--- Benchmarking Stateless Generation ---")
        stateless_times = []
        for i in range(test_runs):
            start_time = time.perf_counter()
            generate_stateless(model, tokenizer, prompt_ids, max_new_tokens)
            torch.cuda.synchronize() # Wait for GPU to finish
            end_time = time.perf_counter()
            duration = end_time - start_time
            stateless_times.append(duration)
            print(f"Run {i+1}/{test_runs}: {duration:.4f}s")
        avg_stateless_time = sum(stateless_times) / len(stateless_times)
        avg_stateless_tpot = avg_stateless_time / max_new_tokens
    
        print("\n--- Benchmarking Stateful (KV Cache) Generation ---")
        stateful_times = []
        for i in range(test_runs):
            start_time = time.perf_counter()
            generate_stateful(model, tokenizer, prompt_ids, max_new_tokens)
            torch.cuda.synchronize()
            end_time = time.perf_counter()
            duration = end_time - start_time
            stateful_times.append(duration)
            print(f"Run {i+1}/{test_runs}: {duration:.4f}s")
        avg_stateful_time = sum(stateful_times) / len(stateful_times)
        avg_stateful_tpot = avg_stateful_time / max_new_tokens
    
        print("\n--- Results ---")
        print(f"Prompt Length: {prompt_len} tokens")
        print(f"New Tokens Generated: {max_new_tokens}")
        print(f"Average Stateless Time: {avg_stateless_time:.4f}s | Avg TPOT: {avg_stateless_tpot*1000:.2f} ms/token")
        print(f"Average Stateful Time:  {avg_stateful_time:.4f}s | Avg TPOT: {avg_stateful_tpot*1000:.2f} ms/token")
        print(f"Speedup Factor: {avg_stateless_time / avg_stateful_time:.2f}x")
    
    if __name__ == '__main__':
        run_benchmark()

    Expected Benchmark Results (will vary by hardware):

    text
    --- Results ---
    Prompt Length: 60 tokens
    New Tokens Generated: 100
    Average Stateless Time: 12.5412s | Avg TPOT: 125.41 ms/token
    Average Stateful Time:  1.8723s | Avg TPOT: 18.72 ms/token
    Speedup Factor: 6.70x

    The results are unambiguous. Stateful inference with a KV cache provides an order-of-magnitude improvement in TPOT. This is the difference between a sluggish, unusable service and a responsive, production-ready one.

    The Memory Elephant: Sizing and Managing the KV Cache

    While the KV cache solves our latency problem, it introduces a new one: memory consumption. The cache stores two large tensors (K and V) for every token, for every attention head, for every layer in the model. This VRAM usage can be immense.

    Let's derive the formula for the cache size:

    CacheSize = 2 num_layers num_heads head_dim sequence_length batch_size bytes_per_element

    Since hidden_size = num_heads * head_dim, we can simplify this to:

    CacheSize = 2 num_layers hidden_size sequence_length batch_size * bytes_per_element

    Let's calculate this for a model like Llama 2 7B at full precision (FP16 = 2 bytes) for a single sequence:

  • num_layers: 32
  • hidden_size: 4096
  • sequence_length: 2048 (a common context size)
  • batch_size: 1
  • bytes_per_element: 2 (for FP16)
  • CacheSize = 2 32 4096 2048 1 * 2 = 1,073,741,824 bytes ≈ 1.07 GB

    For a single user request with a 2K context, the cache alone consumes over a gigabyte of VRAM. For a batch of 16 users, this balloons to over 17 GB, potentially exceeding the capacity of many server-grade GPUs, and this is in addition to the model weights themselves.

    This memory pressure is the central challenge of scaling inference servers. Several advanced techniques have emerged to mitigate it:

  • KV Cache Quantization: The model weights are not the only thing that can be quantized. The KV cache tensors can also be stored at a lower precision, such as INT8. This can halve the memory footprint with often negligible impact on output quality. Libraries like bitsandbytes can be used, and some frameworks are exploring this natively. The trade-off is a minor increase in latency due to the quantization/dequantization steps.
  • Sliding Window Attention (SWA): Models like Mistral 7B don't attend to the entire context. They use a fixed-size sliding window (e.g., 4096 tokens). This naturally caps the maximum size of the KV cache, making memory usage predictable and bounded, regardless of how many tokens are generated.
  • PagedAttention (The vLLM Innovation): This is a breakthrough technique. Instead of allocating a single, contiguous tensor for the KV cache for each sequence, PagedAttention allocates the cache in smaller, non-contiguous blocks (pages), much like an operating system's virtual memory.
  • - Reduces Fragmentation: It solves the problem of internal fragmentation. With naive allocation, a 2048-token cache is reserved even if the user only provides a 100-token prompt. PagedAttention allocates blocks on demand.

    - Enables Efficient Sharing: For parallel generation algorithms like beam search, where multiple candidate sequences share a common prefix, PagedAttention allows these sequences to share the physical memory blocks for that prefix, a concept called "copy-on-write". This dramatically reduces the memory overhead of complex sampling methods.

    Implementing PagedAttention from scratch is a significant engineering effort, which is why specialized inference servers like vLLM and TensorRT-LLM are so powerful. They abstract this complexity away, providing massive throughput gains.

    Advanced Batching: Continuous Batching for Maximum Throughput

    The final piece of the production puzzle is handling concurrent requests efficiently. Naive batching, where you group several requests and pad them all to the length of the longest sequence, is terribly inefficient. It forces faster-finishing sequences to wait for the slowest one in the batch to complete, leaving GPU resources idle.

    Continuous Batching (also called dynamic batching or iteration-level scheduling) solves this. The inference server maintains a queue of incoming requests. On each forward pass (iteration), it processes a single token for every sequence currently in the active batch.

    • If a sequence finishes (generates an EOS token), it is immediately evicted from the batch.
    • The newly freed slot is then filled by a new request from the queue.

    This ensures the GPU is always operating on a full batch of sequences, maximizing utilization and overall throughput. This is the scheduling algorithm used by high-performance systems like vLLM and Hugging Face's Text Generation Inference (TGI).

    Here is a high-level pseudo-code for a continuous batching scheduler loop:

    python
    # PSEUDO-CODE for a Continuous Batching Scheduler
    
    request_queue = Queue()
    active_batch = [] # List of request states (each with its own KV cache)
    
    # API server adds requests to the queue asynchronously
    def api_endpoint(prompt):
        request_id = uuid.uuid4()
        request_queue.put({'id': request_id, 'prompt': prompt})
        # ... wait for result ...
    
    # Main inference loop running on the GPU worker
    while True:
        # 1. Add new requests from the queue if there's space in the batch
        while not request_queue.empty() and len(active_batch) < MAX_BATCH_SIZE:
            new_req = request_queue.get()
            # Initialize KV cache for the new request's prompt
            state = initialize_state(new_req['prompt'])
            active_batch.append(state)
    
        if not active_batch:
            time.sleep(0.01)
            continue
    
        # 2. Prepare the batch for the next forward pass
        # This is the complex part: gather the last token from each active sequence
        # and their corresponding KV caches into a single batched tensor.
        # Systems like vLLM handle this memory management with PagedAttention.
        input_ids_batch, past_key_values_batch = prepare_inference_batch(active_batch)
    
        # 3. Run a single step of inference
        logits, new_past_key_values = model.forward(input_ids_batch, past_key_values_batch)
    
        # 4. Process results and update states
        next_token_ids = sample_from_logits(logits)
        
        finished_indices = []
        for i, state in enumerate(active_batch):
            # Update the state with the new token and updated KV cache
            update_state(state, next_token_ids[i], new_past_key_values[i])
            
            # Check for completion
            if is_finished(state):
                finished_indices.append(i)
                # Send the completed generation back to the API layer
                send_result(state)
    
        # 5. Evict finished requests from the batch
        for i in sorted(finished_indices, reverse=True):
            del active_batch[i]
    

    This logic, combined with an efficient memory manager like PagedAttention, is the current state-of-the-art for building high-throughput, low-latency LLM inference services.

    Edge Cases and Final Considerations

  • Context Length Exceeded: What happens when a sequence plus its generation exceeds the model's maximum context length? A production system must have a strategy. Common approaches include returning an error, truncating the generation, or implementing a sliding window on the KV cache by evicting the oldest key-value pairs.
  • Beam Search State: If using beam search with a beam width of B, you must maintain B distinct KV caches for each request in the batch. This multiplies the memory requirement by B and adds significant complexity to the batch management logic.
  • GPU OOM Errors: A robust server must gracefully handle out-of-memory errors. The scheduler should catch the OOM, fail the specific requests that caused it, and continue processing the rest of the batch without crashing the entire worker process.
  • In conclusion, while the concept of a KV cache is simple, leveraging it in a production setting is a complex systems design problem. It requires moving from a stateless mindset to a stateful one, deeply understanding the memory-latency trade-off, and implementing sophisticated scheduling and memory management techniques. For senior engineers tasked with deploying LLMs at scale, mastering these patterns is no longer optional—it is a fundamental requirement for building competitive and cost-effective AI products.

    Found this article helpful?

    Share it with others who might benefit from it.

    More Articles