Stateful Feature Engineering with Flink's KeyedProcessFunction

14 min read
Goh Ling Yong
Technology enthusiast and software architect specializing in AI-driven development tools and modern software engineering practices. Passionate about the intersection of artificial intelligence and human creativity in building tomorrow's digital solutions.

Beyond Windows: The Power of `KeyedProcessFunction` for Complex Features

In real-time Machine Learning systems, feature engineering often transcends simple aggregations over fixed windows. While Flink's windowing API is powerful for calculating metrics like 'transactions per minute', it falls short when features require complex state logic, dynamic event-driven triggers, or interactions between multiple event streams. Consider features crucial for fraud detection:

* Rapid Repeat Transactions: A user making several transactions at the same merchant within a 2-minute window.

* Anomalous Spending: A transaction amount that is more than 5x the user's historical average transaction value.

* Card Testing: A series of small, rapid-fire transactions followed by a large one.

These scenarios cannot be modeled effectively with standard tumbling or sliding windows. They require per-key (e.g., per-user) state management and the ability to react to events on a granular level. This is the domain of Flink's Process Function API, specifically the KeyedProcessFunction.

This function provides the two fundamental building blocks for advanced stateful processing:

  • Access to Keyed State: Fine-grained, fault-tolerant state scoped to a specific key (e.g., userId, cardId). Flink manages the partitioning and checkpointing of this state automatically.
  • Access to Timers: The ability to register event-time or processing-time timers. These timers trigger a callback at a future point, allowing for delayed actions, state cleanup, and timeout-based logic.
  • In this post, we'll build a sophisticated feature engineering pipeline for a fraud detection model. We will ingest a stream of financial transactions from Kafka and use a KeyedProcessFunction to compute complex features in real-time, demonstrating patterns you can directly apply in production environments.

    The Core Scenario: Real-time Fraud Detection Features

    Our goal is to process a JSON stream of transaction events from a Kafka topic and enrich them with stateful features. Each event has the following structure:

    json
    {
      "transactionId": "t123",
      "userId": "u456",
      "merchantId": "m789",
      "amount": 150.75,
      "timestamp": 1678886400000
    }

    We will compute the following features for each incoming transaction:

  • user_avg_tx_amount: The user's average transaction amount calculated over their entire history.
  • tx_amount_vs_avg_ratio: The ratio of the current transaction amount to the user's historical average.
  • tx_within_2min_window_count: The number of transactions this user has made in the last 2 minutes.
  • This requires us to maintain three pieces of state for each userId:

    * The total sum of their transaction amounts.

    * The total count of their transactions.

    * A timestamp-ordered list of their recent transactions to calculate the 2-minute window feature.

    Setting Up the Flink Job

    First, let's define our data structures and the basic Flink DataStream job setup. We'll use Java and Maven.

    Maven Dependencies:

    xml
    <dependencies>
        <dependency>
            <groupId>org.apache.flink</groupId>
            <artifactId>flink-streaming-java</artifactId>
            <version>1.17.1</version>
            <scope>provided</scope>
        </dependency>
        <dependency>
            <groupId>org.apache.flink</groupId>
            <artifactId>flink-clients</artifactId>
            <version>1.17.1</version>
            <scope>provided</scope>
        </dependency>
        <dependency>
            <groupId>org.apache.flink</groupId>
            <artifactId>flink-connector-kafka</artifactId>
            <version>3.0.1-1.17</version>
        </dependency>
        <dependency>
            <groupId>org.apache.flink</groupId>
            <artifactId>flink-json</artifactId>
            <version>1.17.1</version>
        </dependency>
    </dependencies>

    Data POJOs:

    java
    // Input Transaction Event
    public class Transaction {
        public String transactionId;
        public String userId;
        public String merchantId;
        public double amount;
        public long timestamp;
        // Getters, setters, constructor
    }
    
    // Output Enriched Transaction with Features
    public class EnrichedTransaction extends Transaction {
        public double userAvgTxAmount;
        public double txAmountVsAvgRatio;
        public int txWithin2minWindowCount;
        // Getters, setters, constructor
    }

    Main Flink Job Structure:

    java
    import org.apache.flink.api.common.eventtime.WatermarkStrategy;
    import org.apache.flink.connector.kafka.source.KafkaSource;
    import org.apache.flink.connector.kafka.source.enumerator.initializer.OffsetsInitializer;
    import org.apache.flink.formats.json.JsonDeserializationSchema;
    import org.apache.flink.streaming.api.datastream.DataStream;
    import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
    
    public class FraudDetectionJob {
        public static void main(String[] args) throws Exception {
            final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
    
            // Production configuration: Use RocksDB for state backend
            // env.setStateBackend(new RocksDBStateBackend("hdfs:///flink/checkpoints"));
            // env.enableCheckpointing(60000); // Checkpoint every 60 seconds
    
            KafkaSource<Transaction> source = KafkaSource.<Transaction>builder()
                .setBootstrapServers("kafka:9092")
                .setTopics("transactions")
                .setGroupId("fraud-detector-group")
                .setStartingOffsets(OffsetsInitializer.latest())
                .setValueOnlyDeserializer(new JsonDeserializationSchema<>(Transaction.class))
                .build();
    
            DataStream<Transaction> transactions = env.fromSource(source, 
                WatermarkStrategy.<Transaction>forMonotonousTimestamps()
                    .withTimestampAssigner((event, timestamp) -> event.timestamp), 
                "Kafka Transaction Source");
    
            DataStream<EnrichedTransaction> enrichedTransactions = transactions
                .keyBy(t -> t.userId)
                .process(new FraudFeatureGenerator());
    
            enrichedTransactions.print(); // In production, sink to another Kafka topic or database
    
            env.execute("Real-time Fraud Detection Feature Engineering");
        }
    }

    Implementing the `FraudFeatureGenerator` `KeyedProcessFunction`

    This is where the core logic resides. We'll define our state handles and implement the processElement and onTimer methods.

    java
    import org.apache.flink.api.common.state.MapState;
    import org.apache.flink.api.common.state.MapStateDescriptor;
    import org.apache.flink.api.common.state.ValueState;
    import org.apache.flink.api.common.state.ValueStateDescriptor;
    import org.apache.flink.api.common.typeinfo.Types;
    import org.apache.flink.configuration.Configuration;
    import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
    import org.apache.flink.util.Collector;
    
    public class FraudFeatureGenerator extends KeyedProcessFunction<String, Transaction, EnrichedTransaction> {
    
        // State for calculating the historical average
        private transient ValueState<Double> totalAmountState;
        private transient ValueState<Long> transactionCountState;
    
        // State for the 2-minute window feature. Key: timestamp, Value: transactionId
        private transient MapState<Long, String> recentTransactionsState;
    
        @Override
        public void open(Configuration parameters) throws Exception {
            // Historical Average State
            ValueStateDescriptor<Double> totalAmountDescriptor = new ValueStateDescriptor<>("totalAmount", Types.DOUBLE);
            totalAmountState = getRuntimeContext().getState(totalAmountDescriptor);
    
            ValueStateDescriptor<Long> txCountDescriptor = new ValueStateDescriptor<>("txCount", Types.LONG);
            transactionCountState = getRuntimeContext().getState(txCountDescriptor);
    
            // Recent Transactions State
            MapStateDescriptor<Long, String> recentTxDescriptor = new MapStateDescriptor<>("recentTransactions", Types.LONG, Types.STRING);
            recentTransactionsState = getRuntimeContext().getMapState(recentTxDescriptor);
        }
    
        @Override
        public void processElement(Transaction tx, Context ctx, Collector<EnrichedTransaction> out) throws Exception {
            // --- 1. Update Historical Average State ---
            Double currentTotal = totalAmountState.value();
            if (currentTotal == null) currentTotal = 0.0;
            totalAmountState.update(currentTotal + tx.amount);
    
            Long currentCount = transactionCountState.value();
            if (currentCount == null) currentCount = 0L;
            transactionCountState.update(currentCount + 1);
    
            // --- 2. Calculate Historical Features ---
            double userAvgTxAmount = (currentTotal + tx.amount) / (currentCount + 1);
            double txAmountVsAvgRatio = (currentCount > 0) ? tx.amount / (currentTotal / currentCount) : 1.0;
    
            // --- 3. Update and Calculate 2-Minute Window Feature ---
            long currentTimestamp = tx.timestamp;
            recentTransactionsState.put(currentTimestamp, tx.transactionId);
            
            // Register a timer to clean up this transaction from state later
            // We add 2 minutes (120,000 ms) to the event timestamp
            ctx.timerService().registerEventTimeTimer(currentTimestamp + 120000);
    
            // Iterate through the map state to count recent transactions
            int recentTxCount = 0;
            long twoMinBoundary = currentTimestamp - 120000;
            for (Long ts : recentTransactionsState.keys()) {
                if (ts >= twoMinBoundary) {
                    recentTxCount++;
                }
            }
    
            // --- 4. Emit the Enriched Transaction ---
            EnrichedTransaction enriched = new EnrichedTransaction(tx);
            enriched.userAvgTxAmount = userAvgTxAmount;
            enriched.txAmountVsAvgRatio = txAmountVsAvgRatio;
            enriched.txWithin2minWindowCount = recentTxCount;
    
            out.collect(enriched);
        }
    
        @Override
        public void onTimer(long timestamp, OnTimerContext ctx, Collector<EnrichedTransaction> out) throws Exception {
            // The timer timestamp is the original event_timestamp + 2 minutes.
            // We can now safely remove the original transaction from our state.
            long originalTimestamp = timestamp - 120000;
            recentTransactionsState.remove(originalTimestamp);
        }
    }

    Deconstructing the `KeyedProcessFunction` Implementation

    Let's analyze the advanced patterns used here:

  • State Initialization (open): We don't initialize state in the constructor. The open() method is the correct lifecycle hook, as it's called once per parallel task instance, providing access to the RuntimeContext needed to register state descriptors. This is crucial for fault tolerance and recovery.
  • ValueState for Simple Aggregates: For the historical average, ValueState is perfect. It stores a single value per key (userId). We handle the null case for the very first event for a given key, a common edge case.
  • MapState for Custom Windows: The real power is shown in the 2-minute window calculation. Instead of a fixed Flink window, we use a MapState to store the timestamps of recent transactions. This gives us complete control. When a new event arrives, we can iterate over the map's keys to count how many fall within our desired look-back period. This is far more flexible than a sliding window, as the 'window' is re-evaluated on every single event.
  • Event-Time Timers for State Cleanup (onTimer): Storing every transaction timestamp forever would lead to unbounded state growth. This is a critical production issue. We solve it using ctx.timerService().registerEventTimeTimer(). For each transaction we add to MapState, we register a timer to fire 2 minutes after its event-time timestamp. The onTimer callback then removes that specific transaction from the state. This is a highly efficient, self-cleaning state management pattern.
  • * Why Event Time? We use event time to ensure our logic is deterministic and correct, regardless of processing delays or out-of-order events. If we used processing time, a system slowdown could cause us to prematurely evict state, leading to incorrect feature values.

    Production-Grade Considerations and Optimizations

    Running this in production requires more than just correct logic. Here's how to make it robust and performant.

    State Backend: `RocksDBStateBackend` is Non-Negotiable

    For any stateful application with non-trivial state, the RocksDBStateBackend is essential. The default MemoryStateBackend stores all state on the JVM heap.

    * Problem: With millions of users, our recentTransactionsState map would quickly cause OutOfMemoryError exceptions.

    * Solution: RocksDBStateBackend stores state in an embedded on-disk database (RocksDB). It intelligently uses off-heap memory and spills to local disk. This allows your state size to far exceed available RAM, a common requirement for real-world user-centric applications.

    Configuration:

    java
    // In your main job class
    import org.apache.flink.contrib.streaming.state.RocksDBStateBackend;
    import org.apache.flink.streaming.api.CheckpointingMode;
    
    // ...
    
    env.setStateBackend(new RocksDBStateBackend("s3://my-flink-app/checkpoints"));
    env.getCheckpointConfig().setCheckpointingMode(CheckpointingMode.EXACTLY_ONCE);
    env.getCheckpointConfig().setCheckpointInterval(60000); // 1 minute
    env.getCheckpointConfig().setMinPauseBetweenCheckpoints(30000); // 30s pause
    env.getCheckpointConfig().setCheckpointTimeout(180000); // 3 minutes
    env.getCheckpointConfig().setMaxConcurrentCheckpoints(1);

    This configuration enables durable, fault-tolerant state by periodically snapshotting the RocksDB state to a distributed filesystem like S3 or HDFS. If a TaskManager fails, Flink can restore the state from the last successful checkpoint and resume processing with no data loss (exactly-once semantics).

    Handling Late-Arriving Data

    Our WatermarkStrategy was forMonotonousTimestamps(), which is optimistic. In reality, events can be delayed. A more robust strategy is forBoundedOutOfOrderness.

    java
    WatermarkStrategy.<Transaction>forBoundedOutOfOrderness(Duration.ofSeconds(10))
        .withTimestampAssigner((event, timestamp) -> event.timestamp)

    This tells Flink that watermarks (which trigger timers) should lag behind the maximum event time seen so far by 10 seconds. This gives late events a 10-second grace period to arrive before their corresponding timers are fired. Events arriving later than the watermark are considered 'late'.

    What happens to a transaction that arrives 15 seconds late? The timer to clean it up might have already fired. While our current logic is safe (removing a non-existent key is a no-op), in more complex scenarios, you might need to handle this explicitly. You can send late data to a side output for separate processing:

    java
    // Inside your KeyedProcessFunction
    final OutputTag<Transaction> lateDataTag = new OutputTag<Transaction>("late-transactions"){};
    
    @Override
    public void processElement(Transaction tx, Context ctx, Collector<EnrichedTransaction> out) {
        if (tx.timestamp < ctx.timerService().currentWatermark()) {
            // This event is late
            ctx.output(lateDataTag, tx);
        } else {
            // Process as normal
            // ...
        }
    }
    
    // In your main job
    SingleOutputStreamOperator<EnrichedTransaction> mainStream = transactions
        .keyBy(t -> t.userId)
        .process(new FraudFeatureGenerator());
    
    DataStream<Transaction> lateStream = mainStream.getSideOutput(new OutputTag<Transaction>("late-transactions"){});
    // Sink lateStream to a logging topic for analysis

    Advanced Edge Case: State Schema Evolution

    Imagine a V2 of our application where we also need to track the average transaction amount per merchant for a user. This requires changing our state object. For example, we might change totalAmountState from a ValueState to a MapState (merchantId -> totalAmount).

    If you simply deploy the new code, Flink will fail to restore from the previous checkpoint because the state serializers are incompatible. This is a catastrophic failure in production.

    Flink provides a robust mechanism for state schema evolution. It requires more boilerplate but is essential for long-running applications.

    Example: Migrating transactionCountState from Long to a custom Counter object.

    V1 State:

    ValueStateDescriptor txCountDescriptor = new ValueStateDescriptor<>("txCount", Types.LONG);

    V2 POJO and State:

    java
    // New POJO for state
    public class UserStats {
        public long count;
        public long lastUpdated;
    }
    
    // V2 State Descriptor
    ValueStateDescriptor<UserStats> txCountDescriptor = 
        new ValueStateDescriptor<>("txCount", TypeInformation.of(UserStats.class));

    To make this migration work, you need to provide Flink with a TypeSerializerSnapshot that can read the old format (Long) and convert it to the new format (UserStats). This is an advanced topic involving creating custom TypeSerializer classes. Here's a conceptual outline:

  • Create a custom UserStatsSerializer that extends TypeSerializer.
  • Implement the snapshotConfiguration() method in the serializer. This method returns a TypeSerializerSnapshot.
  • The snapshot class must implement resolveSchemaCompatibility(). This is the key method where you check the previous serializer's snapshot. If it's the old LongSerializer, you define a TypeSerializerCompatbility.compatibleAsWthMigration() result and provide a TypeSerializer that can read a Long and write a UserStats object.
  • This process allows you to perform a rolling update of your Flink job, restoring from a V1 savepoint, migrating the state schema on the fly, and continuing processing without downtime or data loss.

    Conclusion

    The KeyedProcessFunction is the ultimate tool in the Flink developer's arsenal for implementing complex, stateful logic. By moving beyond the constraints of the standard Window API, you can build highly specific and powerful features that are critical for modern real-time applications like fraud detection, anomaly detection, and real-time personalization.

    We've demonstrated how to combine ValueState, MapState, and event-time timers to create a self-cleaning, efficient feature generation process. More importantly, we've addressed the critical production concerns that separate a proof-of-concept from a resilient, scalable system: selecting the right state backend (RocksDB), configuring checkpointing for exactly-once guarantees, handling data imperfections like late arrivals, and planning for the inevitable challenge of state schema evolution. Mastering these patterns is essential for any senior engineer tasked with building mission-critical stream processing pipelines.

    Found this article helpful?

    Share it with others who might benefit from it.

    More Articles