Stateful Feature Engineering in Flink for Real-time Fraud Detection

20 min read
Goh Ling Yong
Technology enthusiast and software architect specializing in AI-driven development tools and modern software engineering practices. Passionate about the intersection of artificial intelligence and human creativity in building tomorrow's digital solutions.

The Latency Wall in Fraud Detection: Why Batch Features Fail

In modern financial and e-commerce systems, fraud detection is a millisecond game. Sophisticated fraudulent patterns, such as rapid-fire card testing or account takeovers, unfold in seconds. Traditional ML systems relying on batch-computed features are fundamentally incapable of reacting in this timeframe. By the time a nightly Spark job calculates that a credit card was used 50 times in the last hour, the damage is already done.

The challenge is to compute features as events happen. We need to answer questions like:

  • How many times has this credit card been used in the last 2 minutes?
  • Is the current transaction amount more than 3 standard deviations above the user's 30-day rolling average?
  • How many distinct user accounts have used this device ID in the last hour?

These are not simple, stateless transformations. They require maintaining and updating state over time for millions of entities (users, cards, devices) concurrently. This is the domain of stateful stream processing, and Apache Flink is the definitive tool for building such systems at scale.

This article is not an introduction to Flink. It assumes you understand the basics of the DataStream API, sources, sinks, and simple windowing. We will dive directly into the advanced patterns required to build a robust, production-ready feature engineering pipeline for real-time fraud detection.

We will build a pipeline that ingests a stream of Transaction events and outputs a stream of enriched FeatureVector events, ready for a downstream model inference service. We'll focus on the heart of the problem: implementing complex, stateful feature extractors.

Data Model and Pipeline Setup

First, let's define our core data structures. We'll work with a stream of financial transactions.

java
// Transaction.java
public class Transaction {
    public String transactionId;
    public String userId;
    public String cardId;
    public String deviceId;
    public double amount;
    public long timestamp; // Event time in milliseconds

    // Constructors, getters, setters...
}

// FeatureVector.java
public class FeatureVector {
    public String transactionId;
    public double amount;

    // --- Features --- //
    // Feature 1: Short-term velocity
    public long cardTxCountLast2Min;

    // Feature 2: User behavioral deviation
    public double userAvgTxAmountLast30Days;
    public double amountDeviationRatio;

    // Feature 3: Entity association
    public long distinctUsersForDeviceLast1Hour;

    // Constructors, getters, setters...
}

Our Flink job will consume Transaction objects from a Kafka topic, process them through a series of stateful operators, and emit FeatureVector objects to another Kafka topic.

java
// FraudDetectionJob.java (Skeleton)
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();

// Production config: EventTime, Checkpointing, State Backend
env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime);
env.enableCheckpointing(60000); // Checkpoint every 60 seconds
env.getCheckpointConfig().setCheckpointingMode(CheckpointingMode.EXACTLY_ONCE);

// For large state, RocksDB is essential
env.setStateBackend(new RocksDBStateBackend("hdfs:///flink/checkpoints"));

DataStream<Transaction> transactionStream = env
    .addSource(new FlinkKafkaConsumer<>("transactions", new TransactionSchema(), kafkaProps))
    .assignTimestampsAndWatermarks(
        // We'll discuss watermark strategies later
        WatermarkStrategy.<Transaction>forBoundedOutOfOrderness(Duration.ofSeconds(20))
            .withTimestampAssigner((event, timestamp) -> event.timestamp)
    );

// The core logic will go here

// Sink to Kafka for model inference service
DataStream<FeatureVector> featureStream = ... // Result of our processing
featureStream.addSink(new FlinkKafkaProducer<>("feature_vectors", new FeatureVectorSchema(), kafkaProps));

env.execute("Real-time Fraud Feature Engineering");

This setup includes critical production configurations: event time processing, exactly-once checkpointing, and the RocksDB state backend for handling potentially massive state that won't fit in memory.

Feature 1: High-Frequency Velocity with `KeyedProcessFunction`

A simple tumbling window could count transactions per minute, but this is often too coarse. Fraudsters operate on a timescale of seconds. A fixed window (.window(TumblingEventTimeWindows.of(Time.minutes(2)))) also introduces artificial latency; a result is only emitted at the end of the window.

We need a more responsive pattern: for each transaction, we want to know the count of transactions for that cardId in the preceding 2 minutes. This calls for a KeyedProcessFunction, Flink's most powerful low-level operator, giving us direct access to state and timers.

The Pattern:

  • Key the stream by cardId.
  • Use a KeyedProcessFunction to process each transaction.
  • For each incoming transaction, add its timestamp to a ListState. This list stores the timestamps of recent transactions for the current card.
    • Prune the list by removing timestamps older than our 2-minute look-back window.
    • The size of the pruned list is our feature value.
  • Use timers (processTimer) to periodically clean up state for cards that are no longer active, preventing unbounded state growth. This is a crucial optimization.
  • java
    // In the main job...
    DataStream<Tuple2<String, Long>> cardVelocityStream = transactionStream
        .keyBy(t -> t.cardId)
        .process(new CardTransactionVelocity());
    
    // Implementation of the KeyedProcessFunction
    public class CardTransactionVelocity extends KeyedProcessFunction<String, Transaction, Tuple2<String, Long>> {
    
        // State to store timestamps of transactions in the look-back window
        private transient ListState<Long> transactionTimestamps;
    
        // Look-back window: 2 minutes
        private static final long WINDOW_SIZE_MS = 2 * 60 * 1000;
    
        @Override
        public void open(Configuration parameters) {
            ListStateDescriptor<Long> descriptor = new ListStateDescriptor<>("transactionTimestamps", Long.class);
            transactionTimestamps = getRuntimeContext().getListState(descriptor);
        }
    
        @Override
        public void processElement(Transaction tx, Context ctx, Collector<Tuple2<String, Long>> out) throws Exception {
            long currentTimestamp = tx.timestamp;
    
            // Add current transaction's timestamp to state
            transactionTimestamps.add(currentTimestamp);
    
            // --- State Pruning --- //
            // This is the critical part. We must manually manage the state.
            Iterator<Long> iterator = transactionTimestamps.get().iterator();
            long pruneTimestamp = currentTimestamp - WINDOW_SIZE_MS;
            while (iterator.hasNext()) {
                if (iterator.next() < pruneTimestamp) {
                    iterator.remove();
                }
            }
    
            // --- Feature Calculation --- //
            long count = 0;
            for (Long ts : transactionTimestamps.get()) {
                count++;
            }
    
            out.collect(new Tuple2<>(tx.transactionId, count));
    
            // --- State TTL via Timers --- //
            // Set a cleanup timer to fire 2 minutes after this event.
            // If another event arrives, we'll delete the old timer and set a new one.
            // This ensures we don't keep state for inactive cards forever.
            ctx.timerService().registerEventTimeTimer(currentTimestamp + WINDOW_SIZE_MS);
        }
    
        @Override
        public void onTimer(long timestamp, OnTimerContext ctx, Collector<Tuple2<String, Long>> out) throws Exception {
            // This timer fires when the state for this key might be getting stale.
            long pruneTimestamp = ctx.timerService().currentWatermark() - WINDOW_SIZE_MS;
            
            Iterator<Long> iterator = transactionTimestamps.get().iterator();
            while (iterator.hasNext()) {
                if (iterator.next() < pruneTimestamp) {
                    iterator.remove();
                }
            }
            
            // If the list is empty after pruning, clear the state completely.
            if (!transactionTimestamps.get().iterator().hasNext()) {
                transactionTimestamps.clear();
            }
        }
    }

    Why this is superior to a sliding window:

    Low Latency: A feature is computed for every single event*, not just at window boundaries.

    Accuracy: It calculates the count over the exact* preceding time interval from the event's timestamp, avoiding window alignment artifacts.

    * Efficiency: While more complex, manual state pruning can be more efficient than Flink's internal window buffers for certain access patterns, especially when the state per key is small.

    Edge Case: State Growth: The manual timer-based cleanup is vital. Without it, the ListState for every credit card ever seen would persist indefinitely. Flink's built-in State TTL is an alternative, but understanding manual timer management is key for complex cleanup logic.

    Feature 2: Behavioral Deviation with Managed State

    Next, we want to detect when a transaction is anomalous compared to a user's typical behavior. For this, we need to maintain a long-term rolling average of their transaction amount.

    A 30-day window is too large to hold all transactions in a ListState. This would lead to massive state size and slow computations. Instead, we can maintain an aggregate (sum and count) in a ValueState.

    The Pattern:

  • Key the stream by userId.
  • Use a RichFlatMapFunction (or KeyedProcessFunction) to process each transaction.
  • Maintain a ValueState holding a custom object UserBehaviorProfile which contains totalAmount and transactionCount.
  • For each new transaction, update the totalAmount and transactionCount in the state.
  • Calculate the current average (totalAmount / transactionCount).
  • Compute the feature: deviationRatio = currentTx.amount / currentAverage.
  • To make this a rolling average, we need a mechanism to expire old data. A full 30-day sliding window is too expensive. A pragmatic compromise is to use a less precise but far more scalable approach: periodic decay or tumbling window aggregation.
  • Let's implement a more robust version using a daily tumbling window to update the long-term profile. This separates the real-time scoring from the model update, a common production pattern.

    java
    // State object for user's long-term profile
    public class UserBehaviorProfile {
        public double totalAmount;
        public long transactionCount;
        public long lastUpdatedTimestamp;
    }
    
    // Main job logic to join real-time stream with the stateful profile
    DataStream<Transaction> transactionByUser = transactionStream.keyBy(t -> t.userId);
    
    // This function will update and use the profile
    DataStream<Tuple3<String, Double, Double>> deviationStream = transactionByUser
        .flatMap(new CalculateAmountDeviation());
    
    // Implementation of the stateful FlatMap
    public class CalculateAmountDeviation extends RichFlatMapFunction<Transaction, Tuple3<String, Double, Double>> {
    
        private transient ValueState<UserBehaviorProfile> profileState;
    
        @Override
        public void open(Configuration parameters) {
            ValueStateDescriptor<UserBehaviorProfile> descriptor =
                new ValueStateDescriptor<>("userBehaviorProfile", UserBehaviorProfile.class);
            
            // Configure State TTL to automatically clean up profiles for inactive users
            StateTtlConfig ttlConfig = StateTtlConfig
                .newBuilder(Time.days(60)) // Keep state for 60 days
                .setUpdateType(StateTtlConfig.UpdateType.OnCreateAndWrite)
                .setStateVisibility(StateTtlConfig.StateVisibility.NeverReturnExpired)
                .build();
            descriptor.enableTimeToLive(ttlConfig);
            
            profileState = getRuntimeContext().getState(descriptor);
        }
    
        @Override
        public void flatMap(Transaction tx, Collector<Tuple3<String, Double, Double>> out) throws Exception {
            UserBehaviorProfile profile = profileState.value();
    
            // Initialize profile if it's the user's first transaction
            if (profile == null) {
                profile = new UserBehaviorProfile();
                profile.totalAmount = 0.0;
                profile.transactionCount = 0L;
            }
    
            double currentAvg = (profile.transactionCount > 0) ? (profile.totalAmount / profile.transactionCount) : tx.amount;
            double deviationRatio = (currentAvg > 0) ? (tx.amount / currentAvg) : 1.0;
    
            // Output the features for this specific transaction
            out.collect(new Tuple3<>(tx.transactionId, currentAvg, deviationRatio));
    
            // --- Update the long-term profile state --- //
            // In a real system, you might use a more sophisticated update logic
            // (e.g., exponential decay) instead of a simple cumulative average.
            profile.totalAmount += tx.amount;
            profile.transactionCount++;
            profile.lastUpdatedTimestamp = tx.timestamp;
            profileState.update(profile);
        }
    }

    Performance and Production Considerations:

    * State TTL: This is a perfect use case for Flink's built-in State TTL. We configure it to discard user profiles that haven't been updated in 60 days. This is far simpler and more robust than manual timer-based cleanup for this scenario.

    * State Backend: For millions of users, the total state size will be substantial. This is where the RocksDBStateBackend is non-negotiable. It stores state on local disk, spilling from memory, allowing for state sizes far exceeding available RAM.

    * Model Accuracy vs. Performance: The simple cumulative average shown here is a starting point. A more accurate model would use exponential weighting or re-aggregate daily/hourly tumbling windows to gradually phase out old data. The key architectural pattern—separating the stateful profile from the per-event calculation—remains the same.

    Feature 3: Multi-Entity Association with `MapState`

    Our most complex feature: How many distinct user accounts have used this device ID in the last hour? This is a strong indicator of fraud (e.g., a bot farm using one device to test many stolen credentials).

    This requires state keyed by deviceId, but the value of the state needs to track multiple userIds and their last-seen timestamps.

    The Pattern:

  • Key the stream by deviceId.
  • Use a KeyedProcessFunction.
  • The state will be a MapState, mapping userId to the lastSeenTimestamp.
  • For each incoming transaction, add/update the (userId, timestamp) entry in the MapState.
    • To calculate the feature, iterate through the map and count how many entries have a timestamp within the last hour.
  • Crucially, we must prune the map to remove old userIds to prevent the map from growing indefinitely.
  • java
    // In the main job...
    DataStream<Tuple2<String, Long>> deviceUserCountStream = transactionStream
        .keyBy(t -> t.deviceId)
        .process(new DeviceUserAssociationCounter());
    
    // Implementation
    public class DeviceUserAssociationCounter extends KeyedProcessFunction<String, Transaction, Tuple2<String, Long>> {
    
        private transient MapState<String, Long> userTimestampMap;
        private static final long WINDOW_SIZE_MS = 60 * 60 * 1000; // 1 hour
    
        @Override
        public void open(Configuration parameters) {
            MapStateDescriptor<String, Long> descriptor =
                new MapStateDescriptor<>("userTimestampMap", String.class, Long.class);
            userTimestampMap = getRuntimeContext().getMapState(descriptor);
        }
    
        @Override
        public void processElement(Transaction tx, Context ctx, Collector<Tuple2<String, Long>> out) throws Exception {
            long currentTimestamp = tx.timestamp;
    
            // Update the map with the current user and timestamp
            userTimestampMap.put(tx.userId, currentTimestamp);
    
            // --- Prune and Count in one pass --- //
            long pruneTimestamp = currentTimestamp - WINDOW_SIZE_MS;
            long distinctUserCount = 0;
            Iterator<Map.Entry<String, Long>> iterator = userTimestampMap.entries().iterator();
            
            while (iterator.hasNext()) {
                Map.Entry<String, Long> entry = iterator.next();
                if (entry.getValue() < pruneTimestamp) {
                    // This user's last activity was outside the window, so remove them
                    iterator.remove();
                } else {
                    distinctUserCount++;
                }
            }
    
            out.collect(new Tuple2<>(tx.transactionId, distinctUserCount));
    
            // Register a cleanup timer to handle idle devices
            ctx.timerService().registerEventTimeTimer(currentTimestamp + WINDOW_SIZE_MS);
        }
    
        @Override
        public void onTimer(long timestamp, OnTimerContext ctx, Collector<Tuple2<String, Long>> out) throws Exception {
            // Similar to the velocity example, this timer cleans up state for devices
            // that haven't been seen in a while.
            long pruneTimestamp = ctx.timerService().currentWatermark() - WINDOW_SIZE_MS;
            
            Iterator<Map.Entry<String, Long>> iterator = userTimestampMap.entries().iterator();
            while (iterator.hasNext()) {
                if (iterator.next().getValue() < pruneTimestamp) {
                    iterator.remove();
                }
            }
    
            if (!userTimestampMap.iterator().hasNext()) {
                userTimestampMap.clear();
            }
        }
    }

    Scalability Challenge: Key Skew

    What if one deviceId corresponds to a public proxy or a NAT gateway, resulting in millions of users? This single key would become a hot spot, overwhelming one Flink task manager. This is known as key skew.

    Mitigating Key Skew:

  • Two-Phase Aggregation: A standard pattern is to append a random suffix to the hot key, splitting the load across multiple sub-tasks. Then, a second downstream operator aggregates the results from the suffixed keys.
  • * keyBy(t -> t.deviceId + "_" + (t.userId.hashCode() % 10))

    * This fans out the processing for the hot deviceId across 10 parallel instances.

    * A subsequent keyBy(t -> t.originalDeviceId).sum() operation would be needed to get the final count.

  • Pre-sharding in Kafka: If the hot keys are known, you can produce to different Kafka partitions based on a compound key (deviceId, userId) to ensure better data distribution before it even reaches Flink.
  • Assembling the Final Feature Vector

    We now have three separate streams, each calculating a different feature. To create the final FeatureVector, we need to join them. Since they are all keyed by transactionId (or can be), the most efficient way to do this in a streaming context is with a CoProcessFunction or, for more than two streams, a series of joins.

    However, a simpler and often more robust pattern in practice is to perform the joins within a single, multi-purpose KeyedProcessFunction. This avoids the complexity of Flink's stream-to-stream joins and keeps related state co-located.

    Let's refactor into a single, unified processor keyed by a composite key if necessary, or perform lookups. A more common pattern is to have a primary key (e.g., userId) and enrich the event as it flows.

    Revised Architecture: A Single Enrichment KeyedProcessFunction

    java
    // Key by the primary entity, e.g., userId
    DataStream<FeatureVector> featureStream = transactionStream
        .keyBy(t -> t.userId)
        .process(new UnifiedFeatureExtractor());
    
    // This single function will manage state for multiple features.
    // NOTE: This approach is good when most features are keyed by the same entity.
    // For features keyed differently (like our deviceId feature), a stream join is unavoidable.

    For our specific features, a join is necessary. We'll use Flink's Interval Join, which is designed for this exact use case: joining two streams on a common key where elements are considered a match if their timestamps are within a certain bound.

    java
    // Assume we have:
    // cardVelocityStream: DataStream<Tuple2<String, Long>> (txId, count)
    // deviationStream: DataStream<Tuple3<String, Double, Double>> (txId, avg, ratio)
    // deviceUserCountStream: DataStream<Tuple2<String, Long>> (txId, count)
    
    // Join 1: cardVelocity + deviation
    DataStream<PartialFeatures1> partial1 = cardVelocityStream
        .keyBy(t -> t.f0) // key by transactionId
        .intervalJoin(deviationStream.keyBy(t -> t.f0))
        .between(Time.seconds(-5), Time.seconds(5)) // Match events within a 10s window
        .process(new JoinFunction1());
    
    // Join 2: partial1 + deviceUserCount
    DataStream<FeatureVector> finalFeatures = partial1
        .keyBy(p -> p.transactionId)
        .intervalJoin(deviceUserCountStream.keyBy(t -> t.f0))
        .between(Time.seconds(-5), Time.seconds(5))
        .process(new FinalJoinFunction());

    This join strategy requires careful tuning of the time bounds (.between()). The bounds must be large enough to account for any differential latency between the parallel feature calculation paths but small enough to avoid incorrect joins. This adds operational complexity, which is why designing the pipeline to minimize joins is a key architectural goal.

    Conclusion: Beyond the Basics

    We've constructed a Flink pipeline that moves far beyond simple stateless transformations or basic windowing. By leveraging KeyedProcessFunction, various state primitives (ListState, ValueState, MapState), and manual timer management, we can calculate complex, low-latency features that are impossible to generate with batch systems.

    Key Takeaways for Production Systems:

    * Choose the Right Tool: For per-event, low-latency feature calculation with look-backs, KeyedProcessFunction is superior to standard windows.

    * Aggressively Manage State: Unbounded state is the silent killer of streaming jobs. Use State TTL, manual timers, and intelligent pruning to keep your state footprint under control.

    * Embrace RocksDB: For any non-trivial use case, on-disk state management with RocksDB is essential for stability and scalability. Plan your disk I/O and memory accordingly.

    * Anticipate Skew: Data is never perfectly distributed. Be prepared to implement two-phase aggregation or other mitigation strategies for hot keys.

    * Event Time is Non-Negotiable: For correctness in a distributed system with potential network delays or source lag, always process data in event time and configure a sensible watermark strategy to handle out-of-order data.

    Building a real-time feature engineering pipeline is a significant engineering challenge, but it's a foundational component for any organization serious about real-time AI. Flink provides the powerful, low-level building blocks necessary to solve these problems with high throughput and correctness guarantees.

    Found this article helpful?

    Share it with others who might benefit from it.

    More Articles