Edge Inference: Real-Time Personalization with WASM and ONNX Runtime
The Latency Tax of Centralized ML Personalization
In modern web applications, real-time personalization is no longer a luxury—it's an expectation. Whether it's re-ranking a product feed, suggesting related articles, or customizing a user interface, the speed of this personalization directly impacts user engagement. The conventional architecture involves a client request that triggers a server-side API call to a separate, centralized machine learning microservice. This pattern, while simple to reason about, introduces a significant latency tax.
Consider the network path:
- Client → CDN Edge
- CDN Edge → Application Origin Server
- Application Origin → ML Inference Service (often in a different VPC or even region)
- ML Inference Service → Application Origin
- Application Origin → CDN Edge
- CDN Edge → Client
Each hop adds tens, if not hundreds, of milliseconds. For UI-blocking personalization, this round-trip penalty is unacceptable. Caching strategies at the origin or CDN are often ineffective because the response is unique per-user or per-session. The ideal solution is to perform the inference as close to the user as possible: at the CDN edge.
This article presents a production-ready pattern for achieving this with WebAssembly (WASM) and the ONNX Runtime. We will bypass high-level, black-box AI services and build a custom, high-performance inference pipeline that runs within the tight constraints of an edge worker environment like Cloudflare Workers or Vercel Edge Functions. We'll tackle the non-trivial challenges of compiling C++ code to WASM, optimizing model size and performance, and managing the model lifecycle at the edge.
Architectural Overview: Inference on the Edge
Our target architecture looks like this:
.wasm). This is a critical one-time step..wasm runtime and the .onnx model file are deployed alongside our edge function code.a. Instantiates the WASM-based ONNX Runtime.
b. Loads the ONNX model into the runtime.
c. Pre-processes request data (e.g., user ID, product features from a fetched API response) into tensors.
d. Executes the model to get inference results (e.g., a new product ranking).
e. Post-processes the results and modifies the response to the user.
This entire process occurs within the same edge location that serves the initial request, reducing latency from 200-500ms to a mere 20-50ms.
Part 1: Model Preparation and ONNX Export
We need a model that is small enough to be loaded quickly in an edge environment but complex enough to provide meaningful personalization. A simple feed-forward neural network for ranking is a perfect candidate. Let's assume we're building a system to re-rank a list of articles for a user based on their embedding and the articles' embeddings.
The Ranking Model in PyTorch
Our model will take a concatenated user and item embedding and output a single score. Higher scores indicate a better match.
# model.py
import torch
import torch.nn as nn
class RankingModel(nn.Module):
def __init__(self, embedding_dim: int):
super(RankingModel, self).__init__()
# Example: User embedding dim = 32, Item embedding dim = 32 -> input_dim = 64
input_dim = embedding_dim * 2
self.layer_stack = nn.Sequential(
nn.Linear(input_dim, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 1)
)
def forward(self, user_embedding, item_embedding):
# Concatenate user and item embeddings
x = torch.cat((user_embedding, item_embedding), dim=1)
return self.layer_stack(x)
# --- Training would happen here ---
# For this example, we'll just instantiate and export a dummy model.
def export_model_to_onnx():
EMBEDDING_DIM = 32
BATCH_SIZE = 1 # We'll use dynamic axes to handle variable batch sizes
model = RankingModel(embedding_dim=EMBEDDING_DIM)
model.eval() # Set model to evaluation mode
# Create dummy inputs that match the expected dimensions
dummy_user_embedding = torch.randn(BATCH_SIZE, EMBEDDING_DIM, requires_grad=False)
dummy_item_embedding = torch.randn(BATCH_SIZE, EMBEDDING_DIM, requires_grad=False)
# Define input and output names
input_names = ["user_embedding", "item_embedding"]
output_names = ["score"]
print("Exporting model to ONNX...")
torch.onnx.export(
model,
(dummy_user_embedding, dummy_item_embedding),
"ranker.onnx",
input_names=input_names,
output_names=output_names,
opset_version=14,
dynamic_axes={
'user_embedding': {0: 'batch_size'},
'item_embedding': {0: 'batch_size'},
'score': {0: 'batch_size'}
},
verbose=True
)
print("Model exported to ranker.onnx")
if __name__ == "__main__":
export_model_to_onnx()
Critical Implementation Detail: The dynamic_axes parameter is non-negotiable for production. It tells the ONNX exporter that the first dimension (the batch size) is variable. Without this, our model would be fixed to whatever batch size we used for the dummy input (BATCH_SIZE = 1). In our edge function, we will want to score multiple items for a single user simultaneously, so we need to process a variable number of item embeddings. This configuration allows the ONNX Runtime to accept an input tensor of shape [N, 32], where N can be any positive integer.
Running python model.py will produce a ranker.onnx file. This is our portable model artifact.
Part 2: Compiling ONNX Runtime for WASM/WASI
This is the most complex and least-documented part of the process. We cannot simply npm install the ONNX Runtime for a Node.js environment; we need a build that targets the generic wasm32-wasi platform used by many edge runtimes. This requires compiling the C++ source code.
The easiest way to ensure a reproducible build environment is using Docker.
The Build Script and Dockerfile
First, create a build script, build_onnx_wasm.sh, to encapsulate the CMake commands.
#!/bin/bash
set -e
# Clone the ONNX Runtime repository
# We pin to a specific version for reproducibility
ORT_VERSION="1.14.1"
git clone --single-branch --branch v${ORT_VERSION} --depth 1 https://github.com/microsoft/onnxruntime.git
cd onnxruntime
# Run the build script
# This script is provided by the ONNX Runtime team
# Key flags:
# --build_wasm: Enables the WebAssembly build
# --disable_wasm_exception_catching: Reduces binary size. We'll handle errors in JS.
# --disable_rtti: Disables Run-Time Type Information for smaller binary.
# --minimal_build: Only include ops required by our model.
# --config Release: Standard release optimizations.
# --add_wasm_flags "-s MODULARIZE=1 -s EXPORT_ES6=1": Crucial for modern JS module compatibility
./build.sh --config Release --build_wasm --disable_wasm_exception_catching --disable_rtti --minimal_build --parallel --add_wasm_flags "-s MODULARIZE=1 -s EXPORT_ES6=1"
# The output will be in build/js/packages/onnxruntime-web/dist/
# We are interested in onnxruntime-core.wasm and the JS wrapper.
Now, let's create a Dockerfile to execute this script using the correct Emscripten SDK version.
# Use the Emscripten SDK image which has all the C++ -> WASM toolchains
FROM emscripten/emsdk:3.1.25
WORKDIR /src
# Copy the build script into the container
COPY build_onnx_wasm.sh .
RUN chmod +x build_onnx_wasm.sh
# Run the build
RUN ./build_onnx_wasm.sh
# This is a multi-stage build. The final stage will just copy the artifacts.
FROM alpine:latest
WORKDIR /artifacts
# Copy the essential build artifacts from the builder stage
COPY --from=0 /src/onnxruntime/build/js/packages/onnxruntime-web/dist/ort-wasm.wasm .
COPY --from=0 /src/onnxruntime/build/js/packages/onnxruntime-web/dist/ort-wasm.min.js .
# The final artifacts will be in the /artifacts directory of this image
To build:
build_onnx_wasm.sh and Dockerfile.docker build -t onnx-wasm-builder ..docker create --name onnx-artifacts onnx-wasm-builder.docker cp onnx-artifacts:/artifacts ..You will now have a local artifacts directory containing ort-wasm.wasm and ort-wasm.min.js. These are the core components of our edge inference engine.
Note on Minimal Builds: For a true production build, you would use the --include_ops_by_config flag with a configuration file listing only the ONNX operators your ranker.onnx model uses (e.g., MatMul, Add, Relu). This can dramatically shrink the WASM binary size. You can generate this config file using ONNX Runtime's Python tools.
Part 3: Implementing the Edge Function on Cloudflare Workers
Now we'll assemble the pieces into a functioning Cloudflare Worker. We'll use Wrangler for deployment.
Project Setup
npm create cloudflare@latest edge-inference-worker -- --type=module-worker
cd edge-inference-worker
npm install onnxruntime-common # Provides type definitions and tensor helpers
vendor directory and place your compiled ort-wasm.min.js and ort-wasm.wasm inside it.ranker.onnx model file in the root directory.wrangler.toml to include the WASM and model files:name = "edge-inference-worker"
main = "src/index.ts"
compatibility_date = "2023-10-30"
[vars]
ENVIRONMENT = "development"
# Make the WASM module available to the worker
[[wasm_modules]]
name = "ort_wasm"
path = "vendor/ort-wasm.wasm"
# Make the ONNX model available as a binding
[[r2_buckets]]
binding = "MODEL_BUCKET"
bucket_name = "my-ml-models"
# For local dev, we can use a site binding to load the model
[site]
bucket = "."
For production, you should upload ranker.onnx to a service like Cloudflare R2 and use the r2_buckets binding. For local development with wrangler dev, the [site] binding allows us to load the file directly from the project directory.
The Worker Implementation
Here is the core logic in src/index.ts. This code is dense and annotated with production considerations.
// src/index.ts
import { Tensor, InferenceSession } from 'onnxruntime-common';
import ortInit from '../vendor/ort-wasm.min.js';
// Import the WASM module we defined in wrangler.toml
import ortWasmModule from 'ort_wasm';
// Define the environment bindings for TypeScript
export interface Env {
// This binding gives us access to the .onnx file
__STATIC_CONTENT: KVNamespace;
// In production, you would use R2
MODEL_BUCKET: R2Bucket;
}
// --- Singleton Pattern for Caching ---
// These variables are defined in the global scope of the worker.
// They will persist between requests on a warm instance, avoiding
// the expensive re-initialization cost.
let ort: any;
let session: InferenceSession;
let modelArrayBuffer: ArrayBuffer;
async function initializeOrt(env: Env) {
if (ort && session) {
// Already initialized on this warm worker
return;
}
console.log('Performing cold start initialization...');
const start = Date.now();
// 1. Initialize the ONNX Runtime
// The JS wrapper needs to be initialized with the WASM module.
ort = await ortInit({ wasm: { module: ortWasmModule } });
// Disable multithreading for serverless environments
ort.env.wasm.numThreads = 1;
console.log(`ORT initialized in ${Date.now() - start}ms`);
// 2. Load the model data
// In local dev with `[site]`, files are served via a KV namespace.
// In production, this would be `env.MODEL_BUCKET.get('ranker.onnx')`.
const modelResponse = await env.__STATIC_CONTENT.get('ranker.onnx', 'arrayBuffer');
if (!modelResponse) {
throw new Error('Model not found!');
}
modelArrayBuffer = modelResponse;
console.log(`Model loaded in ${Date.now() - start}ms`);
// 3. Create the inference session
// This parses the model and prepares it for execution.
// It's a key operation to cache.
session = await InferenceSession.create(modelArrayBuffer, {
executionProviders: ['wasm'],
graphOptimizationLevel: 'all',
});
console.log(`Session created in ${Date.now() - start}ms. Total cold start: ${Date.now() - start}ms`);
}
export default {
async fetch(request: Request, env: Env, ctx: ExecutionContext): Promise<Response> {
try {
// Initialization is deferred until the first request
await initializeOrt(env);
// --- Example Scenario: Re-ranking a list of items ---
// Assume we get a user ID and a list of item IDs to rank
const { userId, itemIds } = await request.json<{ userId: string; itemIds: string[] }>();
// In a real app, you would fetch these embeddings from a low-latency key-value store
// like Cloudflare KV or Redis.
const userEmbedding = await fetchUserEmbedding(userId);
const itemEmbeddings = await fetchItemEmbeddings(itemIds);
const numItems = itemIds.length;
// Prepare model inputs. The user embedding is repeated for each item.
const userEmbeddingTiled = new Float32Array(numItems * 32);
for (let i = 0; i < numItems; i++) {
userEmbeddingTiled.set(userEmbedding, i * 32);
}
// The item embeddings are flattened into a single array.
const itemEmbeddingsFlattened = new Float32Array(numItems * 32);
itemEmbeddings.forEach((emb, i) => {
itemEmbeddingsFlattened.set(emb, i * 32);
});
// Create ONNX Runtime Tensors
const userTensor = new Tensor('float32', userEmbeddingTiled, [numItems, 32]);
const itemTensor = new Tensor('float32', itemEmbeddingsFlattened, [numItems, 32]);
const feeds = {
user_embedding: userTensor,
item_embedding: itemTensor,
};
// --- Run Inference ---
const inferenceStart = Date.now();
const results = await session.run(feeds);
const inferenceTime = Date.now() - inferenceStart;
const scores = results.score.data as Float32Array;
// Post-process: Combine items with their scores and sort
const rankedItems = itemIds
.map((id, i) => ({ id, score: scores[i] }))
.sort((a, b) => b.score - a.score);
return new Response(JSON.stringify({ rankedItems, inferenceTime }), {
headers: { 'Content-Type': 'application/json' },
});
} catch (e: any) {
console.error(e);
return new Response(`Error during inference: ${e.message}`, { status: 500 });
}
},
};
// --- Mock data fetching functions ---
async function fetchUserEmbedding(userId: string): Promise<Float32Array> {
// In production, fetch from KV/Upstash
return new Float32Array(32).fill(0.5);
}
async function fetchItemEmbeddings(itemIds: string[]): Promise<Float32Array[]> {
// In production, use kv.getMultiple()
return itemIds.map(() => new Float32Array(32).map(() => Math.random()));
}
Key Production Patterns in this Code:
ort, session, and modelArrayBuffer are declared in the global scope. The initializeOrt function acts as a guard, only running its expensive logic once per worker instance. Subsequent requests to the same (now warm) instance will skip this entire block, saving 50-200ms.Float32Arrays and create Tensor objects directly. This avoids any intermediate JSON serialization/deserialization and minimizes data copying between the JavaScript heap and the WASM linear memory.session.run() call. This is vastly more efficient than looping and calling run() for each item, as it leverages the parallelization capabilities of the underlying ONNX runtime and reduces overhead.Part 4: Advanced Optimization and Performance Tuning
Getting the basic implementation working is one thing; making it production-ready for high-throughput, low-latency scenarios is another.
Cold Start Analysis and Mitigation
A cold start for our worker involves:
- JS Worker script parsing (~5-10ms)
- WASM module compilation and instantiation (~30-100ms, depending on size)
- Model loading from R2/KV (~5-20ms)
- ONNX session creation (model parsing) (~10-50ms)
Total cold start latency can easily be 150ms+. The biggest lever we can pull to reduce this is shrinking the WASM binary and the ONNX model.
Model Quantization (FP32 to INT8)
Quantization is the process of converting a model's weights from 32-bit floating-point numbers (FP32) to 8-bit integers (INT8). This has two major benefits:
* Model Size Reduction: The .onnx file becomes ~4x smaller.
* Faster Inference: Integer arithmetic is often faster than floating-point math, especially on CPUs without specialized vector units.
We can use the onnxruntime.quantization Python library to do this.
# quantize.py
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType
model_fp32 = 'ranker.onnx'
model_quant = 'ranker.quant.onnx'
print("Starting quantization...")
quantize_dynamic(model_fp32,
model_quant,
weight_type=QuantType.QInt8)
print(f"Quantized model saved to {model_quant}")
Running this script will generate ranker.quant.onnx. You can expect to see the file size drop from ~80KB to ~20KB. This smaller model will load faster from storage and be parsed more quickly by the InferenceSession.create call, directly reducing cold start time. The impact on model accuracy is usually negligible for many ranking and recommendation tasks.
Benchmarking: Cold vs. Warm Performance
Let's analyze the expected performance:
* Cold Start Request:
* initializeOrt(): ~150ms
* Data fetching (embeddings): ~15ms
* Tensor preparation: ~1ms
* session.run(): ~10ms
* Total: ~176ms
* Warm Start Request:
* initializeOrt(): ~0ms (skipped)
* Data fetching (embeddings): ~15ms
* Tensor preparation: ~1ms
* session.run(): ~10ms
* Total: ~26ms
The difference is stark. The goal is to maximize the warm hit rate, which Cloudflare and other providers do automatically based on traffic patterns. Using a quantized model might reduce the initializeOrt() time to under 80ms, a significant improvement.
Edge Case: Model Versioning and A/B Testing
How do you roll out a new model without downtime? You can't just replace the model file, as existing warm workers might have sessions open with the old model. A robust pattern involves versioning.
ranker_v1.quant.onnx, ranker_v2.quant.onnx).model_config with value {"default": "v1", "beta_users": "v2"}).initializeOrt to first read this config from KV.beta cookie?) and loads the appropriate model from R2.// Simplified A/B testing logic within initializeOrt
// 1. Read config from a low-TTL KV key
const modelConfig = await env.CONFIG_KV.get('model_config', 'json');
// 2. Determine model version based on request
const useBeta = request.headers.get('X-Use-Beta-Model') === 'true';
const model_version = useBeta ? modelConfig.beta_version : modelConfig.default_version;
const model_key = `ranker_${model_version}.quant.onnx`;
// 3. Check if the correct session is already cached
// The singleton 'session' object can be a map of version -> InferenceSession
if (sessions[model_version]) {
// Session for this version is already warm
return;
}
// 4. Load the specific model version from R2 and create a new session
const modelBuffer = await env.MODEL_BUCKET.get(model_key).arrayBuffer();
sessions[model_version] = await InferenceSession.create(modelBuffer, ...);
This pattern allows for seamless, zero-downtime model updates and sophisticated experimentation directly at the edge.
Part 5: Observability and Security
Monitoring Performance
Your edge function is now a critical piece of infrastructure. You must monitor it. Wrap key operations in timers and log the structured data.
// Inside the fetch handler
const timings = {};
timings.start = Date.now();
await initializeOrt(env);
timings.init_complete = Date.now();
// ... run inference ...
timings.inference_complete = Date.now();
const logData = {
userId,
numItems,
cold_start: timings.init_complete - timings.start > 10, // A simple heuristic
init_duration: timings.init_complete - timings.start,
inference_duration: timings.inference_complete - timings.init_complete,
// ... other metadata
};
// Use console.log with JSON for structured logging services
console.log(JSON.stringify(logData));
Services like Cloudflare Analytics Engine, Datadog, or Logflare can ingest these structured logs, allowing you to build dashboards monitoring p95/p99 inference latency, cold start frequency, and error rates.
Security Considerations
The WASM sandbox provides strong isolation between the host (the Worker runtime) and our inference code. This prevents a bug in the ONNX Runtime from compromising the entire edge node. However, you are still responsible for:
* Input Validation: Sanitize all inputs before creating tensors. An attacker could provide malformed data (e.g., an array of strings instead of numbers) to crash the worker. Ensure your data fetching and pre-processing steps are robust.
* Model Provenance: Only deploy models that you have trained and exported from a trusted environment. Loading a malicious .onnx file could potentially exploit a vulnerability in the ONNX Runtime parser, even within the WASM sandbox.
Conclusion
Executing ML inference at the edge with WASM and ONNX Runtime is a powerful, advanced technique that moves beyond the limitations of traditional centralized architectures. While the initial setup, particularly the custom compilation of the ONNX Runtime, is complex, the payoff is immense: a dramatic reduction in personalization latency, leading to a tangibly better user experience.
By embracing patterns like singleton caching for warm instances, model quantization, and versioned A/B testing, you can build a system that is not only fast but also robust, scalable, and maintainable. This architecture represents a shift in how we think about deploying machine learning, treating inference not as a remote procedure call but as a local, high-performance function integrated directly into the content delivery path.