Edge Inference: Real-Time Personalization with WASM and ONNX Runtime

20 min read
Goh Ling Yong
Technology enthusiast and software architect specializing in AI-driven development tools and modern software engineering practices. Passionate about the intersection of artificial intelligence and human creativity in building tomorrow's digital solutions.

The Latency Tax of Centralized ML Personalization

In modern web applications, real-time personalization is no longer a luxury—it's an expectation. Whether it's re-ranking a product feed, suggesting related articles, or customizing a user interface, the speed of this personalization directly impacts user engagement. The conventional architecture involves a client request that triggers a server-side API call to a separate, centralized machine learning microservice. This pattern, while simple to reason about, introduces a significant latency tax.

Consider the network path:

  • Client → CDN Edge
  • CDN Edge → Application Origin Server
  • Application Origin → ML Inference Service (often in a different VPC or even region)
  • ML Inference Service → Application Origin
  • Application Origin → CDN Edge
  • CDN Edge → Client

Each hop adds tens, if not hundreds, of milliseconds. For UI-blocking personalization, this round-trip penalty is unacceptable. Caching strategies at the origin or CDN are often ineffective because the response is unique per-user or per-session. The ideal solution is to perform the inference as close to the user as possible: at the CDN edge.

This article presents a production-ready pattern for achieving this with WebAssembly (WASM) and the ONNX Runtime. We will bypass high-level, black-box AI services and build a custom, high-performance inference pipeline that runs within the tight constraints of an edge worker environment like Cloudflare Workers or Vercel Edge Functions. We'll tackle the non-trivial challenges of compiling C++ code to WASM, optimizing model size and performance, and managing the model lifecycle at the edge.

Architectural Overview: Inference on the Edge

Our target architecture looks like this:

  • Model Training (Offline): A model (e.g., a small neural network for ranking) is trained in a standard Python environment using a framework like PyTorch or TensorFlow.
  • Model Export: The trained model is exported to the Open Neural Network Exchange (ONNX) format, a portable standard for ML models.
  • Runtime Compilation (Offline): The ONNX Runtime, a high-performance C++ inference engine, is compiled to a WebAssembly binary (.wasm). This is a critical one-time step.
  • Edge Deployment: The .wasm runtime and the .onnx model file are deployed alongside our edge function code.
  • Live Inference: An incoming user request hits the edge worker. The worker:
  • a. Instantiates the WASM-based ONNX Runtime.

    b. Loads the ONNX model into the runtime.

    c. Pre-processes request data (e.g., user ID, product features from a fetched API response) into tensors.

    d. Executes the model to get inference results (e.g., a new product ranking).

    e. Post-processes the results and modifies the response to the user.

    This entire process occurs within the same edge location that serves the initial request, reducing latency from 200-500ms to a mere 20-50ms.


    Part 1: Model Preparation and ONNX Export

    We need a model that is small enough to be loaded quickly in an edge environment but complex enough to provide meaningful personalization. A simple feed-forward neural network for ranking is a perfect candidate. Let's assume we're building a system to re-rank a list of articles for a user based on their embedding and the articles' embeddings.

    The Ranking Model in PyTorch

    Our model will take a concatenated user and item embedding and output a single score. Higher scores indicate a better match.

    python
    # model.py
    import torch
    import torch.nn as nn
    
    class RankingModel(nn.Module):
        def __init__(self, embedding_dim: int):
            super(RankingModel, self).__init__()
            # Example: User embedding dim = 32, Item embedding dim = 32 -> input_dim = 64
            input_dim = embedding_dim * 2
            self.layer_stack = nn.Sequential(
                nn.Linear(input_dim, 64),
                nn.ReLU(),
                nn.Linear(64, 32),
                nn.ReLU(),
                nn.Linear(32, 1)
            )
    
        def forward(self, user_embedding, item_embedding):
            # Concatenate user and item embeddings
            x = torch.cat((user_embedding, item_embedding), dim=1)
            return self.layer_stack(x)
    
    # --- Training would happen here --- 
    # For this example, we'll just instantiate and export a dummy model.
    
    def export_model_to_onnx():
        EMBEDDING_DIM = 32
        BATCH_SIZE = 1 # We'll use dynamic axes to handle variable batch sizes
        
        model = RankingModel(embedding_dim=EMBEDDING_DIM)
        model.eval() # Set model to evaluation mode
    
        # Create dummy inputs that match the expected dimensions
        dummy_user_embedding = torch.randn(BATCH_SIZE, EMBEDDING_DIM, requires_grad=False)
        dummy_item_embedding = torch.randn(BATCH_SIZE, EMBEDDING_DIM, requires_grad=False)
    
        # Define input and output names
        input_names = ["user_embedding", "item_embedding"]
        output_names = ["score"]
    
        print("Exporting model to ONNX...")
        torch.onnx.export(
            model,
            (dummy_user_embedding, dummy_item_embedding),
            "ranker.onnx",
            input_names=input_names,
            output_names=output_names,
            opset_version=14,
            dynamic_axes={
                'user_embedding': {0: 'batch_size'},
                'item_embedding': {0: 'batch_size'},
                'score': {0: 'batch_size'}
            },
            verbose=True
        )
        print("Model exported to ranker.onnx")
    
    if __name__ == "__main__":
        export_model_to_onnx()

    Critical Implementation Detail: The dynamic_axes parameter is non-negotiable for production. It tells the ONNX exporter that the first dimension (the batch size) is variable. Without this, our model would be fixed to whatever batch size we used for the dummy input (BATCH_SIZE = 1). In our edge function, we will want to score multiple items for a single user simultaneously, so we need to process a variable number of item embeddings. This configuration allows the ONNX Runtime to accept an input tensor of shape [N, 32], where N can be any positive integer.

    Running python model.py will produce a ranker.onnx file. This is our portable model artifact.


    Part 2: Compiling ONNX Runtime for WASM/WASI

    This is the most complex and least-documented part of the process. We cannot simply npm install the ONNX Runtime for a Node.js environment; we need a build that targets the generic wasm32-wasi platform used by many edge runtimes. This requires compiling the C++ source code.

    The easiest way to ensure a reproducible build environment is using Docker.

    The Build Script and Dockerfile

    First, create a build script, build_onnx_wasm.sh, to encapsulate the CMake commands.

    bash
    #!/bin/bash
    set -e
    
    # Clone the ONNX Runtime repository
    # We pin to a specific version for reproducibility
    ORT_VERSION="1.14.1"
    
    git clone --single-branch --branch v${ORT_VERSION} --depth 1 https://github.com/microsoft/onnxruntime.git
    cd onnxruntime
    
    # Run the build script
    # This script is provided by the ONNX Runtime team
    # Key flags:
    # --build_wasm: Enables the WebAssembly build
    # --disable_wasm_exception_catching: Reduces binary size. We'll handle errors in JS.
    # --disable_rtti: Disables Run-Time Type Information for smaller binary.
    # --minimal_build: Only include ops required by our model.
    # --config Release: Standard release optimizations.
    # --add_wasm_flags "-s MODULARIZE=1 -s EXPORT_ES6=1": Crucial for modern JS module compatibility
    
    ./build.sh --config Release --build_wasm --disable_wasm_exception_catching --disable_rtti --minimal_build --parallel --add_wasm_flags "-s MODULARIZE=1 -s EXPORT_ES6=1"
    
    # The output will be in build/js/packages/onnxruntime-web/dist/
    # We are interested in onnxruntime-core.wasm and the JS wrapper.

    Now, let's create a Dockerfile to execute this script using the correct Emscripten SDK version.

    dockerfile
    # Use the Emscripten SDK image which has all the C++ -> WASM toolchains
    FROM emscripten/emsdk:3.1.25
    
    WORKDIR /src
    
    # Copy the build script into the container
    COPY build_onnx_wasm.sh .
    RUN chmod +x build_onnx_wasm.sh
    
    # Run the build
    RUN ./build_onnx_wasm.sh
    
    # This is a multi-stage build. The final stage will just copy the artifacts.
    FROM alpine:latest
    WORKDIR /artifacts
    
    # Copy the essential build artifacts from the builder stage
    COPY --from=0 /src/onnxruntime/build/js/packages/onnxruntime-web/dist/ort-wasm.wasm .
    COPY --from=0 /src/onnxruntime/build/js/packages/onnxruntime-web/dist/ort-wasm.min.js .
    
    # The final artifacts will be in the /artifacts directory of this image

    To build:

  • Save the files as build_onnx_wasm.sh and Dockerfile.
  • Run docker build -t onnx-wasm-builder ..
  • Create a container: docker create --name onnx-artifacts onnx-wasm-builder.
  • Copy the artifacts out: docker cp onnx-artifacts:/artifacts ..
  • You will now have a local artifacts directory containing ort-wasm.wasm and ort-wasm.min.js. These are the core components of our edge inference engine.

    Note on Minimal Builds: For a true production build, you would use the --include_ops_by_config flag with a configuration file listing only the ONNX operators your ranker.onnx model uses (e.g., MatMul, Add, Relu). This can dramatically shrink the WASM binary size. You can generate this config file using ONNX Runtime's Python tools.


    Part 3: Implementing the Edge Function on Cloudflare Workers

    Now we'll assemble the pieces into a functioning Cloudflare Worker. We'll use Wrangler for deployment.

    Project Setup

    bash
    npm create cloudflare@latest edge-inference-worker -- --type=module-worker
    cd edge-inference-worker
    npm install onnxruntime-common # Provides type definitions and tensor helpers
  • Create a vendor directory and place your compiled ort-wasm.min.js and ort-wasm.wasm inside it.
  • Place your ranker.onnx model file in the root directory.
  • Modify your wrangler.toml to include the WASM and model files:
  • toml
    name = "edge-inference-worker"
    main = "src/index.ts"
    compatibility_date = "2023-10-30"
    
    [vars]
    ENVIRONMENT = "development"
    
    # Make the WASM module available to the worker
    [[wasm_modules]]
    name = "ort_wasm"
    path = "vendor/ort-wasm.wasm"
    
    # Make the ONNX model available as a binding
    [[r2_buckets]]
    binding = "MODEL_BUCKET"
    bucket_name = "my-ml-models"
    
    # For local dev, we can use a site binding to load the model
    [site]
    bucket = "."

    For production, you should upload ranker.onnx to a service like Cloudflare R2 and use the r2_buckets binding. For local development with wrangler dev, the [site] binding allows us to load the file directly from the project directory.

    The Worker Implementation

    Here is the core logic in src/index.ts. This code is dense and annotated with production considerations.

    typescript
    // src/index.ts
    import { Tensor, InferenceSession } from 'onnxruntime-common';
    import ortInit from '../vendor/ort-wasm.min.js';
    
    // Import the WASM module we defined in wrangler.toml
    import ortWasmModule from 'ort_wasm';
    
    // Define the environment bindings for TypeScript
    export interface Env {
        // This binding gives us access to the .onnx file
        __STATIC_CONTENT: KVNamespace;
        // In production, you would use R2
        MODEL_BUCKET: R2Bucket;
    }
    
    // --- Singleton Pattern for Caching --- 
    // These variables are defined in the global scope of the worker.
    // They will persist between requests on a warm instance, avoiding
    // the expensive re-initialization cost.
    let ort: any;
    let session: InferenceSession;
    let modelArrayBuffer: ArrayBuffer;
    
    async function initializeOrt(env: Env) {
        if (ort && session) {
            // Already initialized on this warm worker
            return;
        }
    
        console.log('Performing cold start initialization...');
        const start = Date.now();
    
        // 1. Initialize the ONNX Runtime
        // The JS wrapper needs to be initialized with the WASM module.
        ort = await ortInit({ wasm: { module: ortWasmModule } });
        // Disable multithreading for serverless environments
        ort.env.wasm.numThreads = 1;
        console.log(`ORT initialized in ${Date.now() - start}ms`);
    
        // 2. Load the model data
        // In local dev with `[site]`, files are served via a KV namespace.
        // In production, this would be `env.MODEL_BUCKET.get('ranker.onnx')`.
        const modelResponse = await env.__STATIC_CONTENT.get('ranker.onnx', 'arrayBuffer');
        if (!modelResponse) {
            throw new Error('Model not found!');
        }
        modelArrayBuffer = modelResponse;
        console.log(`Model loaded in ${Date.now() - start}ms`);
    
        // 3. Create the inference session
        // This parses the model and prepares it for execution.
        // It's a key operation to cache.
        session = await InferenceSession.create(modelArrayBuffer, {
            executionProviders: ['wasm'],
            graphOptimizationLevel: 'all',
        });
        console.log(`Session created in ${Date.now() - start}ms. Total cold start: ${Date.now() - start}ms`);
    }
    
    export default {
        async fetch(request: Request, env: Env, ctx: ExecutionContext): Promise<Response> {
            try {
                // Initialization is deferred until the first request
                await initializeOrt(env);
    
                // --- Example Scenario: Re-ranking a list of items ---
                // Assume we get a user ID and a list of item IDs to rank
                const { userId, itemIds } = await request.json<{ userId: string; itemIds: string[] }>();
    
                // In a real app, you would fetch these embeddings from a low-latency key-value store
                // like Cloudflare KV or Redis.
                const userEmbedding = await fetchUserEmbedding(userId);
                const itemEmbeddings = await fetchItemEmbeddings(itemIds);
    
                const numItems = itemIds.length;
    
                // Prepare model inputs. The user embedding is repeated for each item.
                const userEmbeddingTiled = new Float32Array(numItems * 32);
                for (let i = 0; i < numItems; i++) {
                    userEmbeddingTiled.set(userEmbedding, i * 32);
                }
                
                // The item embeddings are flattened into a single array.
                const itemEmbeddingsFlattened = new Float32Array(numItems * 32);
                itemEmbeddings.forEach((emb, i) => {
                    itemEmbeddingsFlattened.set(emb, i * 32);
                });
    
                // Create ONNX Runtime Tensors
                const userTensor = new Tensor('float32', userEmbeddingTiled, [numItems, 32]);
                const itemTensor = new Tensor('float32', itemEmbeddingsFlattened, [numItems, 32]);
    
                const feeds = {
                    user_embedding: userTensor,
                    item_embedding: itemTensor,
                };
    
                // --- Run Inference --- 
                const inferenceStart = Date.now();
                const results = await session.run(feeds);
                const inferenceTime = Date.now() - inferenceStart;
    
                const scores = results.score.data as Float32Array;
    
                // Post-process: Combine items with their scores and sort
                const rankedItems = itemIds
                    .map((id, i) => ({ id, score: scores[i] }))
                    .sort((a, b) => b.score - a.score);
    
                return new Response(JSON.stringify({ rankedItems, inferenceTime }), {
                    headers: { 'Content-Type': 'application/json' },
                });
    
            } catch (e: any) {
                console.error(e);
                return new Response(`Error during inference: ${e.message}`, { status: 500 });
            }
        },
    };
    
    // --- Mock data fetching functions ---
    async function fetchUserEmbedding(userId: string): Promise<Float32Array> {
        // In production, fetch from KV/Upstash
        return new Float32Array(32).fill(0.5);
    }
    
    async function fetchItemEmbeddings(itemIds: string[]): Promise<Float32Array[]> {
        // In production, use kv.getMultiple()
        return itemIds.map(() => new Float32Array(32).map(() => Math.random()));
    }

    Key Production Patterns in this Code:

  • Singleton Caching: ort, session, and modelArrayBuffer are declared in the global scope. The initializeOrt function acts as a guard, only running its expensive logic once per worker instance. Subsequent requests to the same (now warm) instance will skip this entire block, saving 50-200ms.
  • Efficient Tensor Creation: We construct Float32Arrays and create Tensor objects directly. This avoids any intermediate JSON serialization/deserialization and minimizes data copying between the JavaScript heap and the WASM linear memory.
  • Batch Inference: We process all items for a given request in a single session.run() call. This is vastly more efficient than looping and calling run() for each item, as it leverages the parallelization capabilities of the underlying ONNX runtime and reduces overhead.

  • Part 4: Advanced Optimization and Performance Tuning

    Getting the basic implementation working is one thing; making it production-ready for high-throughput, low-latency scenarios is another.

    Cold Start Analysis and Mitigation

    A cold start for our worker involves:

    • JS Worker script parsing (~5-10ms)
    • WASM module compilation and instantiation (~30-100ms, depending on size)
    • Model loading from R2/KV (~5-20ms)
    • ONNX session creation (model parsing) (~10-50ms)

    Total cold start latency can easily be 150ms+. The biggest lever we can pull to reduce this is shrinking the WASM binary and the ONNX model.

    Model Quantization (FP32 to INT8)

    Quantization is the process of converting a model's weights from 32-bit floating-point numbers (FP32) to 8-bit integers (INT8). This has two major benefits:

    * Model Size Reduction: The .onnx file becomes ~4x smaller.

    * Faster Inference: Integer arithmetic is often faster than floating-point math, especially on CPUs without specialized vector units.

    We can use the onnxruntime.quantization Python library to do this.

    python
    # quantize.py
    import onnx
    from onnxruntime.quantization import quantize_dynamic, QuantType
    
    model_fp32 = 'ranker.onnx'
    model_quant = 'ranker.quant.onnx'
    
    print("Starting quantization...")
    quantize_dynamic(model_fp32,
                     model_quant,
                     weight_type=QuantType.QInt8)
    
    print(f"Quantized model saved to {model_quant}")

    Running this script will generate ranker.quant.onnx. You can expect to see the file size drop from ~80KB to ~20KB. This smaller model will load faster from storage and be parsed more quickly by the InferenceSession.create call, directly reducing cold start time. The impact on model accuracy is usually negligible for many ranking and recommendation tasks.

    Benchmarking: Cold vs. Warm Performance

    Let's analyze the expected performance:

    * Cold Start Request:

    * initializeOrt(): ~150ms

    * Data fetching (embeddings): ~15ms

    * Tensor preparation: ~1ms

    * session.run(): ~10ms

    * Total: ~176ms

    * Warm Start Request:

    * initializeOrt(): ~0ms (skipped)

    * Data fetching (embeddings): ~15ms

    * Tensor preparation: ~1ms

    * session.run(): ~10ms

    * Total: ~26ms

    The difference is stark. The goal is to maximize the warm hit rate, which Cloudflare and other providers do automatically based on traffic patterns. Using a quantized model might reduce the initializeOrt() time to under 80ms, a significant improvement.

    Edge Case: Model Versioning and A/B Testing

    How do you roll out a new model without downtime? You can't just replace the model file, as existing warm workers might have sessions open with the old model. A robust pattern involves versioning.

  • Store models in R2 with versioned keys (e.g., ranker_v1.quant.onnx, ranker_v2.quant.onnx).
  • Store the "active" version configuration in a central KV store (e.g., a key model_config with value {"default": "v1", "beta_users": "v2"}).
  • Modify initializeOrt to first read this config from KV.
  • The worker then constructs the model key based on the request (e.g., does the user have a beta cookie?) and loads the appropriate model from R2.
  • typescript
    // Simplified A/B testing logic within initializeOrt
    
    // 1. Read config from a low-TTL KV key
    const modelConfig = await env.CONFIG_KV.get('model_config', 'json');
    
    // 2. Determine model version based on request
    const useBeta = request.headers.get('X-Use-Beta-Model') === 'true';
    const model_version = useBeta ? modelConfig.beta_version : modelConfig.default_version;
    const model_key = `ranker_${model_version}.quant.onnx`;
    
    // 3. Check if the correct session is already cached
    // The singleton 'session' object can be a map of version -> InferenceSession
    if (sessions[model_version]) {
        // Session for this version is already warm
        return;
    }
    
    // 4. Load the specific model version from R2 and create a new session
    const modelBuffer = await env.MODEL_BUCKET.get(model_key).arrayBuffer();
    sessions[model_version] = await InferenceSession.create(modelBuffer, ...);

    This pattern allows for seamless, zero-downtime model updates and sophisticated experimentation directly at the edge.


    Part 5: Observability and Security

    Monitoring Performance

    Your edge function is now a critical piece of infrastructure. You must monitor it. Wrap key operations in timers and log the structured data.

    typescript
    // Inside the fetch handler
    const timings = {};
    
    timings.start = Date.now();
    await initializeOrt(env);
    timings.init_complete = Date.now();
    
    // ... run inference ...
    timings.inference_complete = Date.now();
    
    const logData = {
        userId,
        numItems,
        cold_start: timings.init_complete - timings.start > 10, // A simple heuristic
        init_duration: timings.init_complete - timings.start,
        inference_duration: timings.inference_complete - timings.init_complete,
        // ... other metadata
    };
    
    // Use console.log with JSON for structured logging services
    console.log(JSON.stringify(logData));

    Services like Cloudflare Analytics Engine, Datadog, or Logflare can ingest these structured logs, allowing you to build dashboards monitoring p95/p99 inference latency, cold start frequency, and error rates.

    Security Considerations

    The WASM sandbox provides strong isolation between the host (the Worker runtime) and our inference code. This prevents a bug in the ONNX Runtime from compromising the entire edge node. However, you are still responsible for:

    * Input Validation: Sanitize all inputs before creating tensors. An attacker could provide malformed data (e.g., an array of strings instead of numbers) to crash the worker. Ensure your data fetching and pre-processing steps are robust.

    * Model Provenance: Only deploy models that you have trained and exported from a trusted environment. Loading a malicious .onnx file could potentially exploit a vulnerability in the ONNX Runtime parser, even within the WASM sandbox.

    Conclusion

    Executing ML inference at the edge with WASM and ONNX Runtime is a powerful, advanced technique that moves beyond the limitations of traditional centralized architectures. While the initial setup, particularly the custom compilation of the ONNX Runtime, is complex, the payoff is immense: a dramatic reduction in personalization latency, leading to a tangibly better user experience.

    By embracing patterns like singleton caching for warm instances, model quantization, and versioned A/B testing, you can build a system that is not only fast but also robust, scalable, and maintainable. This architecture represents a shift in how we think about deploying machine learning, treating inference not as a remote procedure call but as a local, high-performance function integrated directly into the content delivery path.

    Found this article helpful?

    Share it with others who might benefit from it.

    More Articles