Maximizing Transformer Throughput with Triton's Dynamic Batching
The Concurrency Bottleneck in Transformer Inference
Transformer-based models, while state-of-the-art, are computationally demanding. A common production anti-pattern is deploying them with a 1:1 mapping of incoming request to model execution. This approach, while simple, is catastrophically inefficient. A single inference request for a model like BERT, even with a moderate sequence length, will only utilize a fraction of the thousands of CUDA cores available on a modern GPU like an NVIDIA A100 or H100. The GPU remains mostly idle, waiting for the next request, leading to abysmal throughput and a high cost-per-inference.
The core problem is the mismatch between the serial nature of incoming requests and the massively parallel architecture of the GPU. To achieve high throughput, we must batch multiple inference requests together and execute them as a single, larger computation. This allows the GPU to process data in parallel, drastically improving its utilization (Tensor Core saturation) and overall throughput.
While static batching (where the client gathers requests) is an option, it introduces client-side complexity and latency. A far more elegant and powerful solution is server-side dynamic batching, a flagship feature of the NVIDIA Triton Inference Server.
Let's quantify the problem. Consider a standard BERT model for sentiment analysis, deployed on Triton without any batching. The configuration is minimal:
Baseline config.pbtxt
(No Batching)
name: "bert_onnx_no_batching"
platform: "onnxruntime_onnx"
max_batch_size: 0 # Explicitly disable batching
input [
{
name: "input_ids"
data_type: TYPE_INT64
dims: [ -1, -1 ]
},
{
name: "attention_mask"
data_type: TYPE_INT64
dims: [ -1, -1 ]
}
]
output [
{
name: "output"
data_type: TYPE_FP32
dims: [ -1, 2 ]
}
]
Running Triton's perf_analyzer
against this endpoint reveals the performance ceiling. This tool is essential for any serious performance tuning work with Triton.
# Assumes you have a Triton server running with the model loaded
perf_analyzer -m bert_onnx_no_batching --concurrency-range 1:16 -u localhost:8000
The output would show throughput saturating quickly. For a typical setup on a V100, you might see something like this:
Concurrency | Throughput (inf/sec) | Avg Latency (ms) |
---|---|---|
1 | 45 | 22 |
4 | 55 | 72 |
8 | 58 | 137 |
16 | 59 | 270 |
Notice how throughput barely increases beyond a concurrency of 4, while latency explodes. The server is processing requests mostly serially, and the GPU is starved. This is the problem dynamic batching is designed to solve.
Dissecting Triton's Dynamic Batching Scheduler
Dynamic batching instructs Triton to hold incoming inference requests in a queue for a very short, configurable period. The Triton scheduler then intelligently groups these queued requests into a single batch, pads them to a uniform shape if necessary, and sends this larger batch to the model backend for execution. This entire process is transparent to the client; it sends a request for a batch-size 1 inference and receives a batch-size 1 response.
The magic happens within the scheduler's algorithm, which is primarily governed by two parameters in the config.pbtxt
:
max_batch_size
: The absolute maximum number of requests that can be combined into a single batch.dynamic_batching
: A block containing the fine-tuning parameters. * preferred_batch_size
: An array of batch sizes the scheduler should aim for. It will try to form a batch of one of these sizes before moving on.
* max_queue_delay_microseconds
: The maximum time a request will wait in the queue before the scheduler is forced to form a batch, even if it's smaller than a preferred size.
The scheduler's logic is a constant balancing act:
- A request arrives and is added to the queue.
preferred_batch_size
.- If it matches, a batch is immediately formed and sent for execution.
- If it doesn't match, the scheduler checks the age of the oldest request in the queue.
max_queue_delay_microseconds
, a batch is formed with whatever is currently in the queue (up to max_batch_size
) and dispatched.- If neither condition is met, the scheduler waits for more requests to arrive.
This mechanism allows the system to achieve high throughput under heavy load (by forming large, preferred-size batches) while maintaining bounded latency under light load (by using the delay timeout).
Production Implementation: A BERT Example with Dynamic Batching
Let's refactor our BERT deployment to use dynamic batching and demonstrate the performance impact.
1. Model Directory Structure
Your model repository should look like this. We'll use an ONNX-exported version of bert-base-uncased
.
/models
└── bert_onnx
├── 1
│ └── model.onnx
└── config.pbtxt
2. config.pbtxt
with Dynamic Batching
This configuration is where the tuning happens. We'll enable dynamic batching and set some sensible defaults.
name: "bert_onnx"
platform: "onnxruntime_onnx"
max_batch_size: 64 # Enable batching up to 64 requests
input [
{
name: "input_ids"
data_type: TYPE_INT64
dims: [ -1 ] # Note: We define the shape for a single item, not the batch
},
{
name: "attention_mask"
data_type: TYPE_INT64
dims: [ -1 ]
}
]
output [
{
name: "output"
data_type: TYPE_FP32
dims: [ 2 ]
}
]
# Dynamic Batching Configuration
dynamic_batching {
preferred_batch_size: [ 8, 16, 32 ]
max_queue_delay_microseconds: 5000 # 5 milliseconds
}
# Optional: Instance grouping for parallel execution
instance_group [
{
count: 1
kind: KIND_GPU
}
]
Key Changes:
* max_batch_size
is now 64
, telling Triton it can handle batches of this size.
The dims
for inputs/outputs now describe a single* instance in the batch. Triton will prepend the batch dimension automatically.
* The dynamic_batching
block is added. We've told the scheduler to prioritize creating batches of 8, 16, or 32. If a batch can't be formed quickly, it will dispatch whatever it has after 5ms to avoid starving requests.
3. Advanced Python Client for Concurrent Requests
A simple loop isn't sufficient to test high concurrency. We need an asynchronous client that can fire off many requests in parallel. We'll use gevent
and the tritonclient
library.
import gevent
import gevent.pool
from transformers import BertTokenizer
import numpy as np
import tritonclient.http as httpclient
import time
# --- Configuration ---
MODEL_NAME = "bert_onnx"
TRITON_URL = "localhost:8000"
NUM_REQUESTS = 1000
CONCURRENCY = 100
# --- Sample Data ---
sentences = [
"This movie was absolutely fantastic!",
"I would not recommend this product to my worst enemy.",
"The service was mediocre at best.",
"A true masterpiece of modern cinema."
]
# --- Tokenizer ---
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# --- Triton Client Setup ---
triton_client = httpclient.InferenceServerClient(url=TRITON_URL)
def prepare_input(sentence: str):
inputs = tokenizer(sentence, return_tensors="np", padding='max_length', max_length=128, truncation=True)
input_ids = inputs['input_ids'].astype(np.int64)
attention_mask = inputs['attention_mask'].astype(np.int64)
# Set up Triton input objects
triton_input_ids = httpclient.InferInput('input_ids', input_ids.shape, 'INT64')
triton_input_ids.set_data_from_numpy(input_ids)
triton_attention_mask = httpclient.InferInput('attention_mask', attention_mask.shape, 'INT64')
triton_attention_mask.set_data_from_numpy(attention_mask)
return [triton_input_ids, triton_attention_mask]
def inference_worker(worker_id, sentence):
"""A single worker that sends one request."""
try:
inputs = prepare_input(sentence)
response = triton_client.infer(
model_name=MODEL_NAME,
inputs=inputs
)
# In a real app, you'd process the response
# output = response.as_numpy('output')
return True
except Exception as e:
print(f"Worker {worker_id} failed: {e}")
return False
def main():
print(f"Starting benchmark with {CONCURRENCY} concurrent workers, {NUM_REQUESTS} total requests.")
pool = gevent.pool.Pool(CONCURRENCY)
jobs = []
start_time = time.time()
for i in range(NUM_REQUESTS):
sentence = sentences[i % len(sentences)]
jobs.append(pool.spawn(inference_worker, i, sentence))
gevent.joinall(jobs)
end_time = time.time()
total_time = end_time - start_time
throughput = NUM_REQUESTS / total_time
print("\n--- Benchmark Results ---")
print(f"Total time: {total_time:.2f} seconds")
print(f"Total requests: {NUM_REQUESTS}")
print(f"Throughput: {throughput:.2f} inferences/sec")
if __name__ == '__main__':
main()
4. Benchmarking with perf_analyzer
While the Python script is good for a functional test, perf_analyzer
is the canonical tool for rigorous benchmarking.
perf_analyzer -m bert_onnx --concurrency-range 1:128:8 -u localhost:8000 --shape input_ids:1,128 --shape attention_mask:1,128
Now, the results will be dramatically different:
Concurrency | Throughput (inf/sec) | Avg Latency (ms) |
---|---|---|
1 | 44 | 27 |
8 | 340 | 28 |
16 | 650 | 29 |
32 | 1150 | 32 |
64 | 1800 | 40 |
128 | 1950 | 70 |
This is a ~33x increase in throughput at high concurrency. Notice that latency remains low and stable until we approach the server's saturation point. The scheduler is successfully forming large batches (likely of size 32 or 64), leading to massive gains in GPU utilization.
Advanced Patterns and Edge Case Handling
Achieving the initial throughput gain is just the beginning. Production environments demand fine-grained control and handling of complex scenarios.
Tuning `max_queue_delay_microseconds`
This parameter is the most critical lever for managing the latency-throughput trade-off. There is no single correct value; it depends entirely on your application's SLOs.
* Low-Latency Use Case (e.g., real-time search suggestions): Your p99 latency SLO might be <50ms. You would set a low delay, perhaps 1000
(1ms) or even 500
(0.5ms). This ensures that even a single, isolated request is served quickly. The trade-off is that during periods of low traffic, you will form smaller, less efficient batches, reducing maximum possible throughput.
* High-Throughput Use Case (e.g., offline document processing): Latency is less critical than overall processing speed. You can set a much higher delay, like 20000
(20ms) or more. This gives the scheduler ample time to wait for enough requests to form a large, optimal batch (e.g., 32 or 64), maximizing GPU utilization at the cost of higher per-request latency.
Methodology for Tuning:
max_queue_delay_microseconds
(e.g., 1000).perf_analyzer
to measure throughput and latency across your expected concurrency range.- Gradually increase the delay and repeat the measurements.
max_queue_delay_microseconds
.- Choose the highest delay value that comfortably meets your latency SLO. This will give you the best possible throughput for your specific constraints.
Handling Variable Sequence Lengths with Ragged Batching
Our example used fixed-size padding (max_length=128
). This is inefficient. If you batch a request with 10 tokens and one with 120 tokens, both get padded to 128, and the GPU wastes cycles computing on padding tokens.
Triton can address this with ragged batching. This allows requests with different sequence lengths to be batched together without being padded to the single largest length in the batch. Instead, Triton provides the backend with the raw, unpadded tensor data along with metadata to reconstruct the individual sequence boundaries.
To enable this, you must:
config.pbtxt
: Signal that the model can handle ragged inputs. # In the input definition
input [
{
name: "input_ids"
data_type: TYPE_INT64
dims: [ -1 ]
allow_ragged_batch: true
},
...
]
max_length
.This is an advanced feature that significantly improves performance for NLP tasks with highly variable text lengths, as it eliminates wasted computation.
Request Prioritization for Mixed Workloads
Imagine your single BERT model serves two purposes: a low-latency, user-facing feature (high priority) and a high-throughput, offline analysis pipeline (low priority). Without intervention, a flood of low-priority requests could saturate the queue, increasing latency for high-priority ones.
Triton's dynamic batcher supports priority levels.
config.pbtxt
with Priority:
dynamic_batching {
preferred_batch_size: [ 8, 16, 32 ]
max_queue_delay_microseconds: 5000
priority_levels: 3
default_priority_level: 2
}
* priority_levels: 3
: We define 3 priority levels (1 is highest, 3 is lowest).
* default_priority_level: 2
: Requests without a specified priority get level 2.
The scheduler will now attempt to form batches using requests from the highest-priority queue (level 1) first. It will only form a batch from a lower-priority queue if the higher-priority queues are empty.
Client-side Priority Setting:
The client specifies the priority in the request parameters.
# High priority request
response = triton_client.infer(
model_name=MODEL_NAME,
inputs=inputs,
priority=1
)
# Low priority request (or omit for default)
response = triton_client.infer(
model_name=MODEL_NAME,
inputs=inputs,
priority=3
)
This ensures your latency-sensitive traffic is always processed first, even when the server is under heavy load from background tasks.
The Conundrum: Dynamic Batching and Stateful LLMs
Dynamic batching excels for stateless models where each request is independent. Stateful models, like generative LLMs (GPT, Llama), present a major challenge. An LLM inference is a sequence of operations (one per generated token), and the state (KV cache) must be maintained between steps for a given user session.
Triton handles this with the Sequence Batcher, which uses a correlation_id
provided by the client to route all requests in a sequence to the same model instance. The dynamic batcher and sequence batcher are fundamentally at odds:
Dynamic Batcher: Groups unrelated requests from different* sequences together at a single point in time.
Sequence Batcher: Groups related requests from the same* sequence together across time.
You generally cannot use the dynamic batcher for the token-by-token decoding loop of an LLM. Doing so would incorrectly mix the KV caches of different sequences.
However, there is a powerful hybrid pattern used in production systems like NVIDIA's TensorRT-LLM backend:
The advanced solution is to use an ensemble model in Triton that separates these concerns:
* Prompt Processing Model: A model responsible only for the prefill step. This model's config.pbtxt
uses dynamic batching to group incoming prompts for maximum throughput.
* Token Generation Model: A second, stateful model responsible for the decoding loop. This model's config.pbtxt
uses the sequence batcher to maintain state via correlation IDs.
This architecture provides the best of both worlds: high-throughput ingestion of new requests via dynamic batching on the prefill, and correct, stateful management of ongoing generation streams via the sequence batcher.
Production Deployment Checklist
Before deploying a Transformer model with dynamic batching, senior engineers should validate the following:
perf_analyzer
to profile your model's performance across the full range of expected concurrent loads?max_batch_size
fit comfortably within your GPU's VRAM? Monitor nvidia-smi
during stress tests to ensure you're not close to the limit, which can cause CUDA OOM errors.max_queue_delay_microseconds
tuned to meet your specific p95/p99 latency requirements? Have you verified this with load testing?By moving beyond a basic configuration and embracing these advanced patterns, you can transform an underperforming model endpoint into a highly efficient, cost-effective, and production-ready inference service.