Continuous Batching for High-Throughput LLM Inference APIs
The Bottleneck of Static Batching in Production LLM Services
For any team deploying Large Language Models (LLMs) behind an API, the initial performance optimization strategy is almost always static batching. The concept is simple: instead of feeding requests to the GPU one by one, we collect a number of requests (a 'batch'), concatenate them, and perform a single forward pass. Given the parallel nature of GPUs, this dramatically increases throughput compared to a batch size of one. While this is a step up from naive sequential processing, it introduces a critical production bottleneck, particularly for interactive applications like chatbots or co-pilots: head-of-line blocking and GPU underutilization.
Consider a static batching system with a batch size of 8 and a timeout of 100ms. The process looks like this:
- The server receives Request A.
- The server waits. It needs 7 more requests or for 100ms to pass.
- Requests B, C, and D arrive.
- The 100ms timeout is reached. The batch of {A, B, C, D} is sent to the model.
- The model processes the batch. GPU is at 100% utilization.
- All requests in the batch must complete before the next batch can start. If Request A needs to generate 500 tokens but B, C, and D only need 50, the entire batch is blocked until A finishes. This leads to high variance in response times and poor perceived performance for users who submitted short requests.
- During the waiting period (steps 2-4), the GPU is completely idle, wasting expensive compute resources.
This latency-throughput trade-off is untenable for high-performance systems. To achieve both low latency for individual requests and high overall throughput, we must evolve our architecture to a more dynamic model: Continuous Batching.
Architectural Shift: The Continuous Batching Paradigm
Continuous batching, also known as in-flight batching or dynamic batching, fundamentally changes how requests are scheduled and processed. Instead of forming discrete, static batches, the system maintains a continuous, dynamic batch that is constantly being modified on each iteration of the model's generation loop.
Here's the core principle: as soon as a single sequence within the active batch finishes its generation (i.e., it produces an EOS token or reaches its max_tokens limit), its slot in the batch is immediately freed up. The scheduler can then instantly fill this empty slot with a new request from a waiting queue, without ever stopping the generation loop for the other, still-active requests in the batch.
This approach yields several significant advantages:
* Maximized GPU Utilization: The GPU is almost never idle as long as there are requests in the queue. There is no 'waiting for a batch to fill' period.
* Low Average Latency: New requests are added to the batch on the very next forward pass, dramatically reducing a request's time-to-first-token.
* Fairness and Reduced Tail Latency: Short requests are no longer blocked by long ones. They enter the batch, generate their tokens, and exit, improving the overall user experience.
To implement this, we need a more sophisticated architecture than a simple API server calling a model. The key components are:
Let's build a production-grade implementation of this system in Python using asyncio for concurrency and FastAPI for the web layer.
sequenceDiagram
participant Client
participant API Server (FastAPI)
participant Request Queue (asyncio.Queue)
participant Scheduler
participant Model Runner (GPU)
Client->>+API Server: POST /generate (Request A)
API Server->>API Server: Create RequestA object with asyncio.Event
API Server->>+Request Queue: Put RequestA
API Server-->>-Client: (Waits for Event)
Scheduler->>+Request Queue: Get RequestA
Scheduler->>+Model Runner: Add RequestA to active batch
Note over Scheduler, Model Runner: Iteration 1
Model Runner->>Model Runner: Forward Pass (Batch={A})
Model Runner-->>Scheduler: Return token for A
Scheduler->>API Server: Set result for RequestA, notify Event
Client->>+API Server: POST /generate (Request B)
API Server->>+Request Queue: Put RequestB
Scheduler->>+Request Queue: Get RequestB
Scheduler->>+Model Runner: Add RequestB to active batch
Note over Scheduler, Model Runner: Iteration 2
Model Runner->>Model Runner: Forward Pass (Batch={A, B})
Model Runner-->>Scheduler: Return tokens for A and B
Scheduler->>API Server: Set results, notify Events
Note over Scheduler, Model Runner: Request A finishes generation (EOS token)
Scheduler->>Model Runner: Remove RequestA from batch
Client->>+API Server: POST /generate (Request C)
API Server->>+Request Queue: Put RequestC
Scheduler->>+Request Queue: Get RequestC
Scheduler->>+Model Runner: Add RequestC to batch (fills A's slot)
Note over Scheduler, Model Runner: Iteration N
Model Runner->>Model Runner: Forward Pass (Batch={B, C})
Model Runner-->>Scheduler: Return tokens for B and C
Scheduler->>API Server: Set results, notify Events
API Server-->>Client: Return full generation for Request A
Deep Dive: A Production-Grade Python Implementation
We'll use asyncio to manage the concurrent operations. This is crucial because the API server needs to handle many simultaneous connections while the scheduler runs its own logic loop, all without blocking.
1. Data Structures for Requests and Batches
First, we need robust data structures to represent individual requests and the state of the active batch.
# file: data_models.py
import time
import uuid
from typing import List, Optional
from pydantic import BaseModel, Field
from asyncio import Event
class GenerationParams(BaseModel):
max_new_tokens: int = 256
temperature: float = 0.8
top_p: float = 0.95
class APIRequest(BaseModel):
prompt: str
params: GenerationParams = Field(default_factory=GenerationParams)
class Request:
"""Internal representation of a request throughout its lifecycle."""
def __init__(self, request_id: str, prompt: str, params: GenerationParams):
self.id = request_id
self.prompt = prompt
self.params = params
self.arrival_time = time.time()
self.generated_tokens: List[str] = []
self.finished = False
self.finish_reason: Optional[str] = None
self.completion_event = Event()
self.result: Optional[str] = None
self.error: Optional[str] = None
def __repr__(self):
return f"Request(id={self.id}, prompt='{self.prompt[:20]}...')"
class Batch:
"""Represents the current state of the batch on the GPU."""
def __init__(self, requests: List[Request]):
self.id = str(uuid.uuid4())
self.requests = requests
# In a real implementation, this would hold tensor data, KV caches, etc.
self.size = len(requests)
@classmethod
def from_requests(cls, requests: List[Request]) -> "Batch":
return cls(requests=requests)
def filter_finished(self) -> "Batch":
"""Return a new batch containing only unfinished requests."""
unfinished_reqs = [req for req in self.requests if not req.finished]
return Batch(requests=unfinished_reqs)
The Request class is critical. It contains not just the input data but also state management fields like finished, generated_tokens, and most importantly, completion_event. This asyncio.Event is the signaling mechanism that will allow the API server coroutine to await until the background scheduler has finished processing this specific request.
2. The Model Runner (Simulated)
To focus on the scheduling logic, we'll simulate the model runner. A real implementation would wrap a library like Hugging Face's transformers, vLLM, or a custom PyTorch model. Our simulation will mimic the key characteristics: processing time is dependent on batch size and sequence length, and it processes one token per sequence per forward pass.
# file: model_runner.py
import time
import random
import asyncio
from typing import List, Dict
from data_models import Batch, Request
class MockModelRunner:
def __init__(self):
# Simulate model loading time
print("Simulating model load...")
time.sleep(2)
print("Model loaded.")
async def process_batch(self, batch: Batch) -> Dict[str, str]:
"""Simulates a single forward pass of the model."""
# Simulate GPU processing time: base latency + per-request latency
# This is a highly simplified model of reality.
base_latency = 0.01 # Latency for kernel launch, etc.
per_token_latency = 0.005
processing_time = base_latency + (len(batch.requests) * per_token_latency)
await asyncio.sleep(processing_time)
# Simulate token generation for each request in the batch
results = {}
for req in batch.requests:
# Simulate EOS token or max_tokens reached
if len(req.generated_tokens) >= req.params.max_new_tokens:
req.finished = True
req.finish_reason = "max_tokens"
continue
if random.random() < 0.05 and len(req.generated_tokens) > 0: # 5% chance of EOS
req.finished = True
req.finish_reason = "eos_token"
token = "<EOS>"
else:
token = f" {req.id[-4:]}-{len(req.generated_tokens)}"
req.generated_tokens.append(token)
results[req.id] = token # For streaming, we'd yield this
return results
3. The Core Logic: The Scheduler
This is where the continuous batching magic happens. The scheduler runs an infinite async loop, pulling requests from the queue and managing the active batch.
# file: scheduler.py
import asyncio
from asyncio import Queue
from typing import List, Optional
from data_models import Request, Batch
from model_runner import MockModelRunner
class Scheduler:
def __init__(self, request_queue: Queue, model: MockModelRunner, max_batch_size: int):
self.request_queue = request_queue
self.model = model
self.max_batch_size = max_batch_size
self.active_batch: Optional[Batch] = None
self._shutdown = False
async def run_loop(self):
print("Scheduler loop started.")
while not self._shutdown:
# Pull new requests from the queue to fill the batch
await self.fill_batch()
if self.active_batch and self.active_batch.requests:
# Process the current batch
await self.model.process_batch(self.active_batch)
# Handle finished requests
self.handle_finished_requests()
# Filter out finished requests for the next iteration
self.active_batch = self.active_batch.filter_finished()
else:
# If no active batch, wait a bit to prevent busy-waiting
await asyncio.sleep(0.01)
async def fill_batch(self):
current_batch_size = len(self.active_batch.requests) if self.active_batch else 0
while current_batch_size < self.max_batch_size:
try:
# Non-blocking get from the queue
new_request = self.request_queue.get_nowait()
if self.active_batch is None:
self.active_batch = Batch.from_requests([new_request])
else:
self.active_batch.requests.append(new_request)
current_batch_size += 1
print(f"Added request {new_request.id} to batch. New size: {current_batch_size}")
except asyncio.QueueEmpty:
# No more requests in the queue, break the loop
break
def handle_finished_requests(self):
if not self.active_batch:
return
for req in self.active_batch.requests:
if req.finished and not req.completion_event.is_set():
req.result = "".join(req.generated_tokens).strip()
# This is the crucial step: signal the waiting API coroutine
req.completion_event.set()
print(f"Finished request {req.id}. Reason: {req.finish_reason}")
def shutdown(self):
self._shutdown = True
The key methods are:
* run_loop(): The main loop that orchestrates everything.
* fill_batch(): This method attempts to fill any empty slots in self.active_batch with requests from self.request_queue. It uses get_nowait() for a non-blocking check, ensuring the loop continues to run even if the queue is empty.
* handle_finished_requests(): After a model pass, this iterates through the batch. For any request that the model flagged as finished, it sets the final result and, most importantly, calls req.completion_event.set(). This unblocks the original await call in the API server.
4. The API Server with FastAPI
Finally, the web server that ties it all together. It accepts requests, puts them in the queue, and awaits their completion.
# file: main.py
import uuid
import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from data_models import APIRequest, Request
from model_runner import MockModelRunner
from scheduler import Scheduler
# --- Global State ---
# In a real app, this might be managed differently (e.g., dependency injection)
app_state = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
print("Application startup...")
app_state['request_queue'] = asyncio.Queue()
app_state['model_runner'] = MockModelRunner()
app_state['scheduler'] = Scheduler(
request_queue=app_state['request_queue'],
model=app_state['model_runner'],
max_batch_size=16
)
# Run the scheduler in the background
app_state['scheduler_task'] = asyncio.create_task(app_state['scheduler'].run_loop())
yield
# Shutdown
print("Application shutdown...")
app_state['scheduler'].shutdown()
app_state['scheduler_task'].cancel()
app = FastAPI(lifespan=lifespan)
@app.post("/v1/generate")
async def generate(api_request: APIRequest):
request_id = f"req_{uuid.uuid4()}"
internal_request = Request(
request_id=request_id,
prompt=api_request.prompt,
params=api_request.params
)
try:
# This is where backpressure would be handled
app_state['request_queue'].put_nowait(internal_request)
except asyncio.QueueFull:
raise HTTPException(status_code=503, detail="Server is overloaded, please try again later.")
# Wait for the request to be processed by the scheduler
try:
await asyncio.wait_for(internal_request.completion_event.wait(), timeout=30.0)
except asyncio.TimeoutError:
# This is a critical edge case: request cancellation
# We need to inform the scheduler to potentially drop this request
# (More on this in the next section)
raise HTTPException(status_code=504, detail="Request timed out.")
if internal_request.error:
return JSONResponse(status_code=500, content={"error": internal_request.error})
return {
"request_id": internal_request.id,
"prompt": internal_request.prompt,
"result": internal_request.result,
"finish_reason": internal_request.finish_reason
}
@app.get("/health")
def health():
return {"status": "ok", "queue_size": app_state['request_queue'].qsize()}
# To run: uvicorn main:app --port 8000
The lifespan context manager is used to initialize our scheduler and model runner and start the scheduler's background task when the application starts. The /v1/generate endpoint does three key things:
Request object.request_queue.await internal_request.completion_event.wait(): This line pauses the execution of this specific request handler until the scheduler signals that the work is done. This is the core of the asynchronous decoupling.Advanced Edge Cases and Production Hardening
Our implementation works, but a production system requires handling more complex scenarios.
1. Backpressure Management
What happens if requests arrive much faster than the GPU can process them? Our asyncio.Queue will grow indefinitely, consuming all available memory. This is a recipe for a cascading failure.
Solution: Use a bounded queue. We can initialize our queue with a maxsize.
# In main.py, lifespan
MAX_QUEUE_SIZE = 100
app_state['request_queue'] = asyncio.Queue(maxsize=MAX_QUEUE_SIZE)
When the queue is full, put_nowait() will raise asyncio.QueueFull. We catch this in our API endpoint and return a 503 Service Unavailable status. This is a crucial backpressure mechanism that protects the service from being overwhelmed. Clients can then implement their own retry logic with exponential backoff.
2. Request Cancellation and Timeout Handling
In our current code, if a request times out (asyncio.TimeoutError), we return a 504 to the client, but the request remains in the scheduler's batch. It will continue to consume GPU resources until it finishes naturally, even though the client has already disconnected. This is a resource leak.
Solution: We need a mechanism for the API handler to signal cancellation to the scheduler.
First, add a cancelled flag to our Request object.
# in data_models.py, Request class
class Request:
def __init__(...):
# ... existing fields
self.cancelled = False
Next, modify the API endpoint to set this flag on timeout.
# in main.py, generate endpoint
try:
await asyncio.wait_for(internal_request.completion_event.wait(), timeout=30.0)
except asyncio.TimeoutError:
internal_request.cancelled = True
print(f"Request {internal_request.id} timed out and was cancelled.")
raise HTTPException(status_code=504, detail="Request timed out.")
Finally, the scheduler must check for this flag before and after processing.
# in scheduler.py, Scheduler class
async def fill_batch(self):
# ... inside the while loop
try:
new_request = self.request_queue.get_nowait()
if new_request.cancelled:
# Discard requests that were cancelled before they even started
print(f"Discarding cancelled request {new_request.id} from queue.")
continue
# ... rest of the logic
# ...
def handle_finished_requests(self):
# ...
# Add a check for cancellation *during* processing
if not self.active_batch:
return
# We need to rebuild the request list to handle cancellations
requests_to_keep = []
for req in self.active_batch.requests:
if req.cancelled:
print(f"Dropping cancelled request {req.id} from active batch.")
# Don't set completion event, the API handler already timed out
continue
if req.finished and not req.completion_event.is_set():
req.result = "".join(req.generated_tokens).strip()
req.completion_event.set()
print(f"Finished request {req.id}. Reason: {req.finish_reason}")
requests_to_keep.append(req)
self.active_batch.requests = requests_to_keep
This ensures that cancelled requests are purged from the system, freeing up valuable batch slots for active requests.
3. Handling Heterogeneous Generation Parameters
In a real service, users will submit requests with different max_new_tokens, temperature, etc. Our model runner and scheduler must respect these on a per-request basis. The process_batch method in a real system wouldn't just take a Batch object; it would take the tensors and a list of generation configs. The sampling logic (e.g., top-p, temperature) inside the model runner would then be applied individually before generating the next token for each sequence in the batch.
Our simulation already handles max_new_tokens correctly on a per-request basis. A production implementation would extend this to all sampling parameters.
Performance Analysis: Static vs. Continuous Batching
Let's analyze the performance difference under a bursty traffic pattern. Imagine 10 requests arriving in quick succession. 5 requests want 50 tokens, and 5 want 500 tokens.
Scenario 1: Static Batching (batch_size=8)
* Batch 1: The first 8 requests are collected. Let's say it's 4 short and 4 long. The batch starts processing.
* The 4 short requests finish quickly, but their GPU slots remain occupied and blocked until the 4 long requests are complete. The entire batch takes the time of the longest request.
* Latency for short requests: Very high. They are penalized for being in a batch with long requests.
* Batch 2: The remaining 2 requests (1 short, 1 long) wait until Batch 1 is fully complete. The GPU is idle during this handoff. They form a small, inefficient batch of 2.
* GPU Utilization: Spiky. Full utilization during batch processing, but idle during batch formation and handoff.
Scenario 2: Continuous Batching (max_batch_size=8)
* Iteration 1-50: All 8 slots are filled. The model generates tokens for all 8 requests simultaneously.
* Iteration ~51: The first short request finishes. Its slot is immediately filled by one of the 2 waiting requests from the queue on the very next model iteration.
* Iteration ~52-55: The other short requests finish, and their slots are also immediately backfilled.
* The batch size stays at or near the maximum of 8 for the entire duration, as long as there are requests in the queue.
* Latency for short requests: Minimal. They get in, get their tokens, and get out.
* GPU Utilization: Consistently high. The GPU is always working on a nearly-full batch.
| Metric | Static Batching | Continuous Batching |
|---|---|---|
| Avg. Time to First Token | High (due to batch wait) | Low (immediate entry) |
| Avg. Request Latency | High (head-of-line block) | Low (especially for short reqs) |
| GPU Utilization | Spiky / Lower | Consistently High |
| Throughput (tokens/sec) | Moderate | High / Near-Optimal |
Conclusion: A Non-Negotiable Pattern for Scalable AI
While the implementation of a continuous batching scheduler is significantly more complex than a simple static batching loop, it is a non-negotiable architectural pattern for building serious, production-grade LLM inference services. The dramatic improvements in GPU utilization, latency, and overall throughput directly translate to lower operational costs and a vastly superior user experience.
Frameworks like vLLM and Text Generation Inference (TGI) have this pattern at their core, abstracting away much of the complexity. However, understanding the underlying principles of asynchronous request handling, scheduling, and state management is crucial for any senior engineer tasked with deploying, debugging, or optimizing these systems. Whether you use an off-the-shelf solution or build a custom inference server, mastering the continuous batching paradigm is essential for navigating the performance challenges of the modern AI landscape.