Stateful Streaming Inference with Kafka and ONNX for Fraud Detection
The Architectural Flaw of Stateless Inference
In the realm of real-time machine learning, particularly for applications like fraud detection, a common anti-pattern is the stateless inference endpoint. A typical flow involves a microservice receiving a single event (e.g., a transaction), making a synchronous call to a feature store or database to gather historical context, and then invoking a model for a prediction. While simple, this architecture crumbles under high throughput and low-latency requirements. The database becomes a central bottleneck, and network I/O for feature retrieval introduces unacceptable delays.
Sophisticated fraud isn't detected by analyzing a single transaction in isolation. It's identified through patterns and deviations in behavior over time:
* Velocity Attacks: A user making 10 small transactions in 2 minutes from different merchants.
* Anomalous Aggregates: A sudden spike in transaction value compared to the user's 24-hour rolling average.
* Geographic Jumps: A transaction in New York followed five minutes later by one in London.
These patterns are inherently stateful. To detect them, our system needs a memory of the user's recent activity. The core challenge is maintaining and accessing this state at the speed of the event stream. This is where traditional architectures fail and a stream-native approach excels.
This article details a production-ready pattern that co-locates state and computation. We will build a fraud detection application using Kafka Streams for stateful stream processing and the ONNX Runtime (ORT) for high-performance, decoupled model inference. By managing state directly on the stream processing nodes within RocksDB (managed by Kafka Streams), we eliminate external database calls during the inference hot path, achieving millisecond-level processing latency.
Core Components: A Synergistic Pair
Kafka Streams for Stateful Processing
Kafka Streams is not just a library for reading from and writing to Kafka topics. Its power lies in the Processor API, which provides mechanisms for stateful operations. The key component for our use case is the StateStore.
A StateStore is a fault-tolerant key-value store, typically backed by a local RocksDB instance on each application node. Kafka Streams ensures its resilience by maintaining a changelog topic in Kafka. Any write to the local StateStore is also written to this topic. If a node fails, a new node can restore its state by replaying the changelog, ensuring no data loss and exactly-once processing semantics.
For our fraud detection system, we will use a KeyValueStore where the key is the userId and the value is a serialized object containing aggregated features like:
* transactionCountLast5Min
* totalSpendLastHour
* uniqueMerchantsLast24Hours
* lastTransactionTimestamp
This allows us to enrich incoming transactions with rich, historical context in microseconds by querying the local RocksDB instance.
ONNX Runtime for Decoupled High-Performance Inference
While our stream processing logic is in Java (a common choice for Kafka Streams), our models are likely trained in Python using frameworks like PyTorch, TensorFlow, or Scikit-learn. Deploying a Python model server and making RPC calls from our Java application reintroduces the network latency we're trying to avoid. Furthermore, Python's Global Interpreter Lock (GIL) can be a bottleneck in high-concurrency scenarios.
The Open Neural Network Exchange (ONNX) format solves this. It's an open standard for representing machine learning models. We can train a model in Python, export it to the .onnx format, and then load it directly into our Java application using the ONNX Runtime (ORT).
ORT is a high-performance inference engine with a Java API. It allows us to execute the model in the same process as our Kafka Streams application, eliminating network overhead. ORT is highly optimized, supporting various execution providers (e.g., CPU, CUDA, TensorRT) to leverage hardware acceleration.
This decoupling is powerful: data scientists can iterate on models in Python, and engineers can deploy them into a high-performance Java streaming application without code changes, simply by providing a new model file.
Production Implementation: The Fraud Detection Stream Processor
Let's build the application. We'll assume we have a Kafka topic named transactions with JSON-serialized transaction events.
Data Schemas (Conceptual)
// Input: transactions topic
{
"transactionId": "uuid",
"userId": "string",
"amount": 123.45,
"merchantId": "string",
"countryCode": "US",
"timestamp": "long"
}
// Internal State: UserProfile (serialized in StateStore)
{
"userId": "string",
"windowedTransactionCounts": {"5m": 10, "1h": 50},
"windowedTransactionAmounts": {"5m": 1200.0, "1h": 8000.0},
"lastTransactionTimestamp": "long",
"seenCountryCodes": ["US", "GB"]
}
// Output: fraud-alerts topic
{
"transactionId": "uuid",
"userId": "string",
"isFraudulent": true,
"fraudScore": 0.98,
"reason": "High transaction velocity and anomalous amount"
}
Step 1: Setting up the Kafka Streams Topology
The foundation is a Topology that defines our data flow and connects our custom processor to a state store.
import org.apache.kafka.common.serialization.Serdes;
import org.apache.kafka.streams.KafkaStreams;
import org.apache.kafka.streams.StreamsBuilder;
import org.apache.kafka.streams.Topology;
import org.apache.kafka.streams.state.Stores;
import org.apache.kafka.streams.state.KeyValueStore;
import java.util.Properties;
public class FraudDetectionApplication {
private static final String USER_PROFILE_STORE_NAME = "user-profile-store";
public static void main(String[] args) {
Properties props = getStreamsConfig(); // Standard Kafka Streams properties
// The path to our pre-trained ONNX model file
String onnxModelPath = System.getenv("ONNX_MODEL_PATH");
Topology topology = buildTopology(onnxModelPath);
KafkaStreams streams = new KafkaStreams(topology, props);
streams.start();
Runtime.getRuntime().addShutdownHook(new Thread(streams::close));
}
public static Topology buildTopology(String onnxModelPath) {
StreamsBuilder builder = new StreamsBuilder();
// Create a persistent KeyValue store for user profiles.
// RocksDB is the default implementation.
builder.addStateStore(Stores.keyValueStoreBuilder(
Stores.persistentKeyValueStore(USER_PROFILE_STORE_NAME),
Serdes.String(),
new JsonSerde<UserProfile>() // Custom Serde for our UserProfile class
));
builder.stream("transactions", Consumed.with(Serdes.String(), new JsonSerde<Transaction>()))
// Ensure transactions are partitioned by userId for state co-location
.selectKey((key, transaction) -> transaction.getUserId())
// The core logic is in our custom Transformer
.transform(
() -> new FraudDetectionTransformer(onnxModelPath),
USER_PROFILE_STORE_NAME
)
.to("fraud-alerts", Produced.with(Serdes.String(), new JsonSerde<FraudAlert>()));
return builder.build();
}
}
Two critical details here:
selectKey((key, transaction) -> transaction.getUserId()): We re-key the stream by userId. This is essential. Kafka Streams guarantees that all messages with the same key will be processed by the same stream task. This ensures that a user's transaction will always be processed on the node that holds their state in its local StateStore, avoiding any network hops for state access.transform(...): We use the transform operator, which allows for arbitrary stateful processing. We connect it to our USER_PROFILE_STORE_NAME so it can be accessed within the transformer.Step 2: The Stateful Transformer with ONNX Inference
This is where the magic happens. The FraudDetectionTransformer class implements the Transformer interface. It will load the ONNX model, access the state store, update the user profile, create features, and run inference.
import ai.onnxruntime.*;
import org.apache.kafka.streams.kstream.Transformer;
import org.apache.kafka.streams.processor.ProcessorContext;
import org.apache.kafka.streams.state.KeyValueStore;
import java.util.Collections;
import java.util.Map;
public class FraudDetectionTransformer implements Transformer<String, Transaction, KeyValue<String, FraudAlert>> {
private final String onnxModelPath;
private ProcessorContext context;
private KeyValueStore<String, UserProfile> userProfileStore;
private OrtEnvironment env;
private OrtSession session;
public FraudDetectionTransformer(String onnxModelPath) {
this.onnxModelPath = onnxModelPath;
}
@Override
public void init(ProcessorContext context) {
this.context = context;
this.userProfileStore = (KeyValueStore<String, UserProfile>) context.getStateStore(FraudDetectionApplication.USER_PROFILE_STORE_NAME);
// Initialize ONNX Runtime
try {
this.env = OrtEnvironment.getEnvironment();
this.session = env.createSession(onnxModelPath, new OrtSession.SessionOptions());
System.out.println("ONNX model loaded successfully.");
// You can print model input/output info for debugging
this.session.getInputInfo().forEach((key, value) -> System.out.println("Input: " + key + " -> " + value.getInfo().toString()));
} catch (OrtException e) {
throw new RuntimeException("Failed to initialize ONNX Runtime session", e);
}
}
@Override
public KeyValue<String, FraudAlert> transform(String userId, Transaction transaction) {
// 1. Retrieve user's current state
UserProfile userProfile = userProfileStore.get(userId);
if (userProfile == null) {
userProfile = new UserProfile(userId);
}
// 2. Derive stateful features
long timeSinceLastTx = (userProfile.getLastTransactionTimestamp() == 0) ? -1 : transaction.getTimestamp() - userProfile.getLastTransactionTimestamp();
double avgAmountLastHour = userProfile.getAverageSpend("1h");
double amountDeviation = transaction.getAmount() - avgAmountLastHour;
// 3. Update the user's state with the new transaction
userProfile.addTransaction(transaction);
userProfileStore.put(userId, userProfile); // Write back to the state store
// 4. Construct the feature vector for the model
// The order and shape must exactly match the model's expectations.
// Assuming the model expects a float tensor of shape [1, 5]
float[][] featureVector = new float[1][5];
featureVector[0][0] = (float) transaction.getAmount();
featureVector[0][1] = (float) amountDeviation;
featureVector[0][2] = (float) timeSinceLastTx;
featureVector[0][3] = (float) userProfile.getTransactionCount("5m");
featureVector[0][4] = (float) countryCodeToFloat(transaction.getCountryCode()); // Example feature engineering
try {
// 5. Run inference
OnnxTensor inputTensor = OnnxTensor.createTensor(env, featureVector);
String inputName = session.getInputNames().iterator().next();
try (OrtSession.Result results = session.run(Collections.singletonMap(inputName, inputTensor))) {
OnnxValue resultValue = results.get(0);
float[][] fraudProbabilities = (float[][]) resultValue.getValue();
float fraudScore = fraudProbabilities[0][1]; // Assuming output is [prob_not_fraud, prob_fraud]
// 6. Emit an alert if score exceeds threshold
if (fraudScore > 0.9) {
FraudAlert alert = new FraudAlert(transaction.getTransactionId(), userId, true, fraudScore, "High fraud score detected");
return KeyValue.pair(userId, alert);
}
} finally {
inputTensor.close();
}
} catch (OrtException e) {
// Log error, but don't crash the stream
System.err.println("Error during ONNX inference for user " + userId + ": " + e.getMessage());
}
// No alert to emit
return null;
}
@Override
public void close() {
// Close the ONNX session and environment
try {
if (session != null) session.close();
if (env != null) env.close();
} catch (OrtException e) {
System.err.println("Error closing ONNX resources: " + e.getMessage());
}
}
}
This transform method is the heart of our system. For every single transaction, it performs a read-compute-write cycle against the local state store and executes a complex ML model, all within the same thread and process, delivering extremely low latency.
Advanced Patterns and Edge Case Handling
A proof-of-concept is one thing; a production system is another. Here are critical considerations for making this architecture robust.
Model Management: Hot-Swapping without Downtime
ML models are not static. Data science teams will produce new, improved versions that must be deployed without stopping the world. A naive approach of restarting the Kafka Streams application causes processing delays and state-store rebalancing overhead.
A robust solution is to use a control topic.
* Store your ONNX models in a versioned object store like AWS S3 or Google Cloud Storage (e.g., s3://my-models/fraud-model-v1.2.onnx).
* Create a dedicated, low-volume Kafka topic called model-update-commands.
* When a new model is ready for deployment, an operator or a CI/CD pipeline publishes a message to this topic, e.g., {"modelName": "fraud-detector", "newVersionPath": "s3://my-models/fraud-model-v1.3.onnx"}.
* Modify the FraudDetectionApplication to consume from this control topic using a separate GlobalKTable.
A GlobalKTable replicates the data from a topic to every* instance of your application. This is perfect for broadcasting configuration changes.
* In your FraudDetectionTransformer, instead of holding a final OrtSession, use an AtomicReference.
* When the GlobalKTable receives an update, it triggers a callback. This callback downloads the new model file from S3 to a temporary local path, creates a new OrtSession, and then atomically swaps the reference in the AtomicReference.
Code Snippet (Conceptual)
// In the Transformer
private final AtomicReference<OrtSession> sessionRef = new AtomicReference<>();
// In the application setup
GlobalKTable<String, ModelUpdateCommand> modelUpdates = builder.globalTable("model-update-commands");
// In the transform() method, get the session for each message
OrtSession currentSession = sessionRef.get();
if (currentSession != null) {
// ... run inference
}
// A separate thread or mechanism listens for updates from the GlobalKTable
// and executes this logic:
private void updateModel(String newModelPath) {
try {
// Download new model from S3/GCS to a local temp file
Path localPath = downloadModel(newModelPath);
OrtSession newSession = env.createSession(localPath.toString(), new OrtSession.SessionOptions());
// Atomically swap the session
OrtSession oldSession = sessionRef.getAndSet(newSession);
// Gracefully close the old session after a delay to allow in-flight requests to finish
if (oldSession != null) {
// Use a scheduled executor to close it after, e.g., 30 seconds
closeOldSessionGracefully(oldSession);
}
} catch (Exception e) {
// Log error, but don't replace the existing model if the new one is invalid
}
}
This pattern ensures zero-downtime model deployments. The atomic reference swap is instantaneous, and requests seamlessly start using the new model.
Handling Late-Arriving Data
In any distributed system, events can arrive out of order. A transaction from 5 minutes ago might arrive after a more recent one. If we naively update our state, our aggregates (like transactionCountLast5Min) will become incorrect. Kafka Streams has a first-class solution for this: Windowing with Grace Periods.
While our Transformer example uses a manual approach, for time-based aggregations, using Kafka Streams' built-in windowing is often more robust.
import org.apache.kafka.streams.kstream.TimeWindows;
import java.time.Duration;
// ... inside buildTopology()
builder.stream("transactions", ...)
.groupByKey()
.windowedBy(TimeWindows.of(Duration.ofMinutes(5)).grace(Duration.ofMinutes(1)))
.aggregate(
() -> 0L, // Initializer
(userId, transaction, aggregate) -> aggregate + 1, // Aggregator
Materialized.<String, Long, WindowStore<Bytes, byte[]>>as("5-min-counts")
.withValueSerde(Serdes.Long())
);
* windowedBy(TimeWindows.of(Duration.ofMinutes(5))): This groups transactions into 5-minute, non-overlapping windows based on their event timestamp.
* .grace(Duration.ofMinutes(1)): This is the critical part. It tells Kafka Streams to keep a window open for an extra minute after its official end time. If a transaction that belongs to that window arrives late (but within the grace period), it will still be processed and included in the correct window's aggregate. Any events arriving after the grace period are dropped. This prevents state corruption from late data.
Performance Tuning and Optimization
To meet strict latency SLAs, both Kafka Streams and ONNX Runtime must be tuned.
Kafka Streams Tuning (in streams.properties):
* num.stream.threads: Increase this to match the number of cores available to the instance, allowing for parallel processing of different topic partitions.
* commit.interval.ms: The default is 30000ms. For lower latency recovery, reduce this to 1000 or even 100. This increases the frequency of committing offsets but reduces the amount of data that needs to be reprocessed after a failure.
* rocksdb.config.setter: Implement a custom RocksDBConfigSetter class to tune the underlying state store. You can configure block cache size, write buffer sizes, and enable bloom filters to optimize read performance, which is crucial for our userProfileStore.get(userId) call.
ONNX Runtime Tuning:
* Execution Providers: By default, ORT uses the standard CPU provider. If you have GPUs available, you can configure it to use the CUDA or TensorRT execution providers for significant speedups.
// In the Transformer's init() method
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
// Add CUDA provider with device ID 0. ORT will fall back to CPU if CUDA is not available.
opts.addCUDA(0);
this.session = env.createSession(onnxModelPath, opts);
* Intra-Op Parallelism: ORT can use multiple threads for a single inference operation. This can be configured in SessionOptions.
opts.setIntraOpNumThreads(4); // Use 4 threads for one inference call
Be careful with this setting in a multi-threaded Kafka Streams application. The total number of threads (num.stream.threads * intraOpNumThreads) can lead to contention. It's often better to have more stream threads with single-threaded inference calls.
* Graph Optimizations: ORT can apply optimizations (e.g., operator fusion) to the model graph. You can set the optimization level.
opts.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ORT_ENABLE_ALL);
Scalability and Fault Tolerance by Design
The beauty of this architecture is that scalability and fault tolerance are inherent properties of Kafka and Kafka Streams.
* Horizontal Scalability: To handle more traffic, you simply launch more instances of your application. Kafka's consumer group protocol automatically rebalances the transactions topic partitions among the available instances. When a partition is moved to a new instance, Kafka Streams ensures its corresponding state store (and the data within it) is also migrated by restoring from the changelog topic.
* Fault Tolerance: If an application instance crashes, the partitions it was processing are reassigned to the remaining healthy instances. Those instances will restore the state for the newly assigned partitions from their changelog topics on the Kafka brokers and resume processing exactly where the failed instance left off. There is no single point of failure.
* Interactive Queries: Kafka Streams allows you to directly query the local state stores of your application instances. You can expose a REST endpoint on each instance that, given a userId, can query its local RocksDB store and return the user's current profile. This is a powerful pattern for building dashboards or allowing manual review processes to see the real-time state of any user without querying a separate database.
Conclusion
By embedding state management and model inference directly into our stream processing application, we've built a system that is not only extremely performant but also scalable, resilient, and operationally simpler than a complex microservice architecture with multiple external dependencies. This stateful streaming inference pattern, combining the state management power of Kafka Streams with the portable performance of ONNX Runtime, represents a modern, robust approach to building next-generation, real-time AI systems. It moves beyond simple, stateless predictions and enables the detection of complex, temporal patterns that are critical for domains like fraud detection, real-time recommendations, and predictive maintenance.