Stateful Streaming Inference with Apache Flink & ONNX Runtime
The Limitations of Stateless Inference in Sequential AI
In modern data architectures, the default pattern for machine learning inference is often a stateless REST or gRPC endpoint. A client sends a feature vector, the model returns a prediction. This works beautifully for tasks like image classification or sentiment analysis on isolated text. However, it fundamentally breaks down for a significant class of high-value problems that are inherently sequential and stateful:
* Real-time Fraud Detection: Is this transaction fraudulent? The answer depends not just on the current transaction's features, but on the user's transaction history over the last hour, day, or week.
* Predictive Maintenance: Will this IoT sensor fail? The prediction requires a time-series of recent sensor readings, not just the latest one.
* Session-based Recommendation Engines: What product should we recommend next? The optimal recommendation depends on the user's clickstream within the current session.
Attempting to solve these with a stateless API forces the client to manage and transmit the entire historical context with every request. This is inefficient, introduces high latency, and tightly couples the client application with the model's state requirements. The alternative—a stateful inference service—presents its own challenges: how do you manage potentially massive amounts of state with high availability, fault tolerance, and consistency, all while processing data streams at millions of events per second?
This is where the fusion of a distributed stateful stream processor and a high-performance inference runtime becomes a game-changing architectural pattern. By embedding ONNX Runtime within an Apache Flink job, we can perform stateful inference directly on the data stream, co-locating the model's logic and its required state. This article provides a blueprint for this architecture, focusing on production-level implementation details.
Architectural Blueprint: Flink + ONNX Runtime
Our goal is to build a system that can consume a stream of events (e.g., transactions from Kafka), maintain historical state for each entity (e.g., user), and apply an ML model (e.g., an LSTM or Transformer) to make a prediction.
Here's the high-level data flow:
* Consumes events from Kafka.
* keyBy() the stream by an entity ID (e.g., userId). This partitions the data, ensuring all events for a given user are processed by the same physical task manager, enabling stateful operations.
* A custom RichAsyncFunction is applied to the keyed stream.
RichAsyncFunction: * State Retrieval: For each incoming event, the function retrieves the current state for that userId from Flink's managed state (backed by RocksDB).
* Inference: The ONNX Runtime engine, loaded into memory within the Flink operator, executes the model using the incoming event and the retrieved state.
* State Update: The function updates the userId's state with new information derived from the current event and model output (e.g., updating the hidden state of an RNN).
* Result Emission: The prediction result is emitted downstream.
This architecture leverages the core strengths of each component:
* Apache Flink: Provides robust, scalable, and fault-tolerant state management, checkpointing, and exactly-once processing guarantees.
* ONNX Runtime: Offers a high-performance, cross-platform, and framework-agnostic engine for executing ML models, with support for CPU, GPU (CUDA), and other accelerators.
Implementing the Core Inference Operator
The heart of our implementation is a Flink operator that manages the ONNX model's lifecycle and executes inference. A RichAsyncFunction is an excellent choice because it provides a lifecycle (open, close) for expensive resource initialization and allows for asynchronous operations, preventing the inference step from blocking the main stream processing thread.
Let's define the structure. We'll use Java for this example.
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.functions.async.ResultFuture;
import org.apache.flink.streaming.api.functions.async.RichAsyncFunction;
import ai.onnxruntime.*;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Collections;
// Assuming TransactionEvent and FraudPrediction are POJOs
public class OnnxInferenceFunction extends RichAsyncFunction<TransactionEvent, FraudPrediction> {
// Transient fields are not serialized during checkpointing
private transient OrtEnvironment env;
private transient OrtSession session;
private final String modelPath;
// Flink's managed state for the model's hidden state (e.g., for an RNN)
private transient ValueState<float[][]> modelHiddenState;
public OnnxInferenceFunction(String modelPath) {
this.modelPath = modelPath;
}
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
// 1. Initialize ONNX Runtime Environment and Session
env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions options = new OrtSession.SessionOptions();
// Fine-tune performance: e.g., options.setInterOpNumThreads(1);
// To use CUDA: options.addCUDA(0);
// Load model from a path accessible by the TaskManager
byte[] modelBytes = Files.readAllBytes(Paths.get(modelPath));
session = env.createSession(modelBytes, options);
// 2. Initialize Flink State Descriptor
ValueStateDescriptor<float[][]> descriptor = new ValueStateDescriptor<>(
"modelHiddenState", // the state name
(Class<float[][]>) (Class<?>) float[][].class // type information
);
modelHiddenState = getRuntimeContext().getState(descriptor);
}
@Override
public void asyncInvoke(TransactionEvent input, ResultFuture<FraudPrediction> resultFuture) throws Exception {
// This logic would typically run in a managed thread pool
// Flink's Async I/O handles this implicitly
// 1. Retrieve current state for the key (e.g., userId)
float[][] currentState = modelHiddenState.value();
if (currentState == null) {
// Initialize state for a new user
currentState = new float[1][128]; // Example: batch_size=1, hidden_size=128
}
// 2. Preprocess input and create ONNX Tensors
// This is highly model-specific. Let's assume the model takes
// the current transaction features and the previous hidden state.
float[][] features = input.toFeatureArray(); // Shape: [1][50]
try (OnnxTensor featuresTensor = OnnxTensor.createTensor(env, features);
OnnxTensor stateTensor = OnnxTensor.createTensor(env, currentState)) {
Map<String, OnnxTensor> modelInputs = Map.of(
"input_features", featuresTensor,
"input_hidden_state", stateTensor
);
// 3. Execute the model
try (OrtSession.Result results = session.run(modelInputs)) {
OnnxValue outputProbabilityValue = results.get("output_probability").get();
float[][] outputProbability = (float[][]) outputProbabilityValue.getValue();
OnnxValue nextStateValue = results.get("output_hidden_state").get();
float[][] nextState = (float[][]) nextStateValue.getValue();
// 4. Update Flink state with the new hidden state from the model
modelHiddenState.update(nextState);
// 5. Emit the result
FraudPrediction prediction = new FraudPrediction(input.getTransactionId(), outputProbability[0][0]);
resultFuture.complete(Collections.singleton(prediction));
}
} catch (OrtException e) {
// Handle inference errors, maybe log and drop, or route to a dead-letter queue
resultFuture.completeExceptionally(e);
}
}
@Override
public void close() throws Exception {
super.close();
if (session != null) {
session.close();
}
if (env != null) {
env.close();
}
}
}
Key Implementation Details:
open/close): The OrtSession and OrtEnvironment are heavyweight objects. Initializing them once per operator instance in open() is critical for performance. They are marked transient because they are not serializable and must be recreated on each TaskManager after a failure or deployment.ValueState to store the last hidden state of our recurrent model. Flink manages the persistence and fault tolerance of this state. When a checkpoint is triggered, Flink serializes this state and writes it to a durable store (like S3 or HDFS). If the job fails, it can restart and restore the state precisely to where it left off.KeyedStream. The getRuntimeContext().getState() call implicitly scopes the state to the current key. This means each userId gets its own independent modelHiddenState.try-with-resources block for OnnxTensor and Result is crucial. These objects manage off-heap memory, and failing to close them will lead to memory leaks.Deep Dive: State Backend and Checkpointing
Your choice of state backend has massive performance and operational implications.
* HashMapStateBackend: Stores state on the Java heap. It's extremely fast for reads and writes but is limited by the TaskManager's memory. It's suitable for jobs with small state or for local development, but not for production scenarios with potentially unbounded state per key.
* RocksDBStateBackend: This is the de-facto standard for production Flink jobs with large state. It stores state in an embedded RocksDB instance on the local disk of each TaskManager. State size is only limited by local disk space. All reads/writes involve serialization/deserialization and disk I/O, making it slower than the heap-based backend, but it enables state sizes far exceeding available RAM.
For our stateful inference use case, where the state per user can grow and the number of users can be in the millions, RocksDBStateBackend is the only viable option.
// In your main Flink job setup
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
// Use RocksDB and configure it to store checkpoints in S3
env.setStateBackend(new RocksDBStateBackend("s3://my-flink-checkpoints/checkpoints"));
// Enable checkpointing: every 5 minutes
env.enableCheckpointing(300000);
// Set advanced checkpointing options for production
CheckpointConfig checkpointConfig = env.getCheckpointConfig();
checkpointConfig.setCheckpointingMode(CheckpointingMode.EXACTLY_ONCE);
checkpointConfig.setMinPauseBetweenCheckpoints(60000); // 1 minute
checkpointConfig.setCheckpointTimeout(600000); // 10 minutes
checkpointConfig.setMaxConcurrentCheckpoints(1);
checkpointConfig.enableExternalizedCheckpoints(CheckpointConfig.ExternalizedCheckpointCleanup.RETAIN_ON_CANCELLATION);
The magic of Flink's checkpointing: When a checkpoint barrier flows through our OnnxInferenceFunction, Flink takes an atomic snapshot of the modelHiddenState for all keys and writes it to the configured durable storage. This snapshot is tied to a specific offset in the input Kafka topic. If the job fails and restarts, Flink restores the operator's state from the last successful checkpoint and rewinds the Kafka source to the corresponding offset. This guarantees that no data is lost and the model's state remains perfectly synchronized with the data stream, achieving exactly-once semantics.
Production Pattern: Dynamic Model Updates via Broadcast Streams
Machine learning models are not static. They are frequently retrained and updated. A naive approach would be to stop the Flink job, deploy the new code with the new model path, and restart. This causes downtime and loses any in-flight state. A far superior pattern is to use Flink's broadcast streams to push model updates to the running operators dynamically.
The Architecture:
model-control-stream, is created.{"modelId": "fraud-v2.1", "path": "s3://models/fraud-v2.1.onnx"}.BroadcastStream from this topic.connect() our main KeyedStream of transactions with this BroadcastStream.KeyedBroadcastProcessFunction instead of a RichAsyncFunction. This function can process elements from both the main stream and the broadcast stream.Implementation Sketch:
// New operator that handles both data and broadcasted model updates
public class DynamicOnnxInferenceFunction extends KeyedBroadcastProcessFunction<String, TransactionEvent, ModelUpdate, FraudPrediction> {
// State to hold the ONNX sessions, keyed by modelId
private transient MapState<String, OrtSession> modelSessions;
// State to hold the currently active model ID
private transient ValueState<String> activeModelId;
// Descriptor for the broadcasted model info
private final MapStateDescriptor<String, String> modelPathDescriptor =
new MapStateDescriptor<>("modelPaths", String.class, String.class);
@Override
public void open(Configuration parameters) {
// Initialize state descriptors...
}
// This method processes the main transaction stream
@Override
public void processElement(TransactionEvent value, ReadOnlyContext ctx, Collector<FraudPrediction> out) throws Exception {
String currentModelId = activeModelId.value();
if (currentModelId == null) {
// No model is active yet, drop or wait
return;
}
OrtSession session = modelSessions.get(currentModelId);
if (session == null) {
// The session for the active model isn't loaded yet. Load it on demand.
ReadOnlyBroadcastState<String, String> modelPaths = ctx.getBroadcastState(modelPathDescriptor);
String path = modelPaths.get(currentModelId);
// ... load model from path, create session, and put it in modelSessions state ...
session = ...;
modelSessions.put(currentModelId, session);
}
// ... Perform inference using the retrieved session ...
// ... Update keyed hidden state (not shown for brevity) ...
}
// This method processes the model update control stream
@Override
public void processBroadcastElement(ModelUpdate value, Context ctx, Collector<FraudPrediction> out) throws Exception {
BroadcastState<String, String> modelPaths = ctx.getBroadcastState(modelPathDescriptor);
// Store the path for the new model version
modelPaths.put(value.getModelId(), value.getPath());
// You can implement different strategies here:
// 1. Immediately switch all new events to the new model.
// 2. A/B test by routing a percentage of keys to the new model.
// 3. Wait for an explicit activation message.
// Simple strategy: update the active model ID globally
// Note: For a real A/B test, this logic would be in processElement
activeModelId.update(value.getModelId());
// Advanced: Pre-load the new model to warm it up
// ... load model, create session, put in modelSessions state ...
}
}
This pattern provides incredible flexibility. You can achieve zero-downtime model deployments, A/B test competing models in production, or perform gradual rollouts by controlling which keys use which model version.
Edge Cases and Performance Tuning
Deploying this pattern in production requires handling several critical edge cases.
1. Backpressure and Latency
If model inference takes longer than the event arrival rate, Flink's backpressure mechanism will naturally slow down the Kafka consumer. This is a feature, not a bug, as it prevents the job from being overwhelmed.
* Monitoring: Monitor the backpressure metric in the Flink UI. High backpressure indicates your inference logic is the bottleneck.
* Optimization:
* Scale Out: Increase the parallelism of your Flink job. Since the stream is keyed, this distributes the key space across more TaskManagers.
* Hardware: Use more powerful hardware. If your model benefits from it, switch to GPU instances and configure the ONNX Runtime CUDA Execution Provider. This requires careful setup of your Flink cluster (e.g., using Docker with nvidia-container-runtime).
* ONNX Session Options: Tune session.setInterOpNumThreads() and session.setIntraOpNumThreads(). For Flink, which manages its own task slots, a good starting point is setting both to 1 to avoid thread contention between parallel operators running on the same TaskManager, and instead rely on Flink's parallelism.
2. State Growth and TTL
If you are keying by userId, the state will grow indefinitely as new users appear. For many use cases, state older than a certain period is irrelevant (e.g., a user's transaction history from 3 years ago is not needed for real-time fraud).
Flink's State Time-To-Live (TTL) feature is the solution. It allows you to automatically purge state for a key after it hasn't been accessed for a configured duration.
// In the open() method of your function:
StateTtlConfig ttlConfig = StateTtlConfig
.newBuilder(Time.days(30)) // State expires 30 days after last access
.setUpdateType(StateTtlConfig.UpdateType.OnReadAndWrite)
.setStateVisibility(StateTtlConfig.StateVisibility.NeverReturnExpired)
.build();
ValueStateDescriptor<float[][]> descriptor = new ValueStateDescriptor<>("modelHiddenState", ...);
descriptor.enableTimeToLive(ttlConfig);
modelHiddenState = getRuntimeContext().getState(descriptor);
This simple configuration prevents unbounded state growth in your RocksDB backend, which is a critical operational concern for long-running jobs.
3. Error Handling and Dead-Letter Queues
What happens if session.run() throws an OrtException due to malformed input? Or if a model update message points to a corrupted model file?
* Robust asyncInvoke: Wrap the core inference logic in a try-catch block. On failure, you have options:
1. Log the error and drop the event using resultFuture.complete(Collections.emptyList()).
2. Use Flink's side-output mechanism to route the failed TransactionEvent and the exception to a separate stream, which can then be written to a dead-letter Kafka topic for later analysis.
4. Benchmarking and Sizing
Before going to production, rigorously benchmark your operator. Create a test Flink job that generates mock data and runs only your inference function.
* Example Results: On a typical AWS c5.2xlarge instance (8 vCPUs), a single Flink operator with a moderately complex LSTM model might achieve a throughput of 2,000-5,000 events/sec with a p99 latency of <20ms using the CPU Execution Provider. Switching to a g4dn.xlarge instance with a GPU and the CUDA provider could increase throughput by 5-10x for models that can leverage it, pushing throughput to 25,000+ events/sec per operator instance.
This data is essential for capacity planning and determining the required parallelism for your production cluster.
Conclusion
By embedding ONNX Runtime within Apache Flink, we elevate stream processing from simple ETL and analytics to a platform for sophisticated, real-time AI. This stateful streaming inference pattern moves beyond the request-response paradigm, allowing models to understand the temporal context of data natively and at massive scale. While the implementation requires careful attention to state management, resource lifecycle, and fault tolerance, the result is a highly scalable, resilient, and powerful architecture capable of solving a class of problems that are intractable with traditional stateless services. This is the future of operational AI, running continuously and intelligently, directly where the data is born.