Low-Latency RAG: Vector Similarity Caching Patterns for Production Q&A

16 min read
Goh Ling Yong
Technology enthusiast and software architect specializing in AI-driven development tools and modern software engineering practices. Passionate about the intersection of artificial intelligence and human creativity in building tomorrow's digital solutions.

The New I/O Bottleneck: Vector Search in Real-Time RAG

In modern AI applications, particularly real-time question-answering systems, the Retrieval-Augmented Generation (RAG) pattern has become canonical. While it elegantly grounds Large Language Models (LLMs) in factual, proprietary data, it introduces a new performance bottleneck that is often overlooked in initial development: the latency of the vector similarity search. Senior engineers know that any external I/O call is a potential performance killer, and a query to a vector database is no exception.

This isn't a beginner's guide to RAG. We assume you're already running a RAG pipeline and are now facing the production reality of P95 and P99 latencies that are unacceptable for an interactive user experience. A typical RAG flow looks like this:

User Query -> Embedding Model -> Vector DB Search -> Retrieved Context -> LLM Prompt -> LLM Response

When you profile this pipeline, the LLM's time-to-first-token is a major factor, but the vector search—especially over a large corpus with millions of vectors and complex filtering—can easily add 150ms to 500ms or more. This is the new database query cost, and it's what we will systematically dismantle in this article.

We will not discuss trivial solutions like caching the final LLM response. That's a fragile strategy, easily defeated by trivial rephrasing of the user's query. Instead, we'll focus on caching the most expensive, deterministic part of the retrieval process: the mapping from a query vector to a set of relevant document IDs. The core challenge, of course, is that you can't key a cache on a 768-dimensional floating-point vector. We'll solve this with advanced, production-ready patterns.

Baseline Performance: Quantifying the Problem

Before optimizing, we must measure. Let's establish a simple baseline FastAPI application using a sentence-transformer for embeddings and a local FAISS index (simulating a fast, in-memory vector DB).

python
# baseline_rag_app.py
import time
import faiss
import numpy as np
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer

# --- Setup ---
app = FastAPI()

# 1. Load a pre-trained model
model = SentenceTransformer('all-MiniLM-L6-v2')
dimension = model.get_sentence_embedding_dimension()

# 2. Create a dummy FAISS index
num_vectors = 1_000_000
dummy_vectors = np.float32(np.random.random((num_vectors, dimension)))
faiss.normalize_L2(dummy_vectors)
index = faiss.IndexFlatL2(dimension)
index.add(dummy_vectors)

# --- API ---
class Query(BaseModel):
    text: str
    top_k: int = 5

@app.post("/rag_baseline")
def retrieve_documents(query: Query):
    # 1. Embed the query
    start_embed = time.perf_counter()
    query_vector = model.encode(query.text, convert_to_tensor=False)
    query_vector = np.float32([query_vector])
    faiss.normalize_L2(query_vector)
    end_embed = time.perf_counter()

    # 2. Search the vector index
    start_search = time.perf_counter()
    distances, indices = index.search(query_vector, query.top_k)
    end_search = time.perf_counter()

    return {
        "timings": {
            "embedding_ms": (end_embed - start_embed) * 1000,
            "vector_search_ms": (end_search - start_search) * 1000,
        },
        "results": {
            "ids": indices[0].tolist(),
            "distances": distances[0].tolist(),
        }
    }

Running a simple load test with hey or k6 against this endpoint reveals the problem. While embedding might take ~20-30ms, the vector search on a 1M vector index, even a simple IndexFlatL2, can take 50-100ms. With a more complex index like IndexIVFPQ or a networked call to a managed service like Pinecone or Weaviate, this search time becomes the dominant latency factor.

Pattern 1: Approximate Caching with Locality-Sensitive Hashing (LSH)

The fundamental problem is that two semantically similar queries, e.g., "How do I reset my password?" and "I forgot my password, what do I do?", will produce slightly different embedding vectors. A standard key-value cache will see them as two distinct keys, resulting in a cache miss for the second query.

We need a way to make similar vectors map to the same cache key. This is precisely the problem that Locality-Sensitive Hashing (LSH) is designed to solve. LSH is a family of hashing techniques that ensures a higher probability of collision for similar items.

Our strategy:

  • When a query arrives, generate its embedding vector.
  • Apply an LSH function to the vector to produce a set of hash-based signatures (our cache keys).
  • Check an in-memory cache (like Redis) for any of these keys.
  • If a key is found (a cache hit), we retrieve the pre-computed list of document IDs and bypass the expensive vector DB search.
  • If it's a miss, we perform the vector DB search, then store the results in the cache against the LSH-generated keys.

Implementation with `datasketch` and Redis

Let's implement this. We'll use the datasketch library, which provides a convenient MinHashLSH implementation.

python
# lsh_caching_rag_app.py
import time
import faiss
import numpy as np
import redis
import json
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from datasketch import MinHash, MinHashLSH

# --- Setup ---
app = FastAPI()
model = SentenceTransformer('all-MiniLM-L6-v2')
dimension = model.get_sentence_embedding_dimension()

# Dummy FAISS index (same as baseline)
num_vectors = 1_000_000
dummy_vectors = np.float32(np.random.random((num_vectors, dimension)))
faiss.normalize_L2(dummy_vectors)
index = faiss.IndexFlatL2(dimension)
index.add(dummy_vectors)

# Redis connection
r = redis.Redis(decode_responses=True)

# LSH setup
# The parameters `num_perm` and `threshold` are critical for tuning.
# `num_perm` controls the hash signature size (accuracy vs. memory).
# `threshold` is the Jaccard similarity threshold for considering two items as candidates.
NUM_PERM = 128
LSH_THRESHOLD = 0.8  # High threshold for high similarity
lsh = MinHashLSH(threshold=LSH_THRESHOLD, num_perm=NUM_PERM)

# --- Helper Functions ---
def vector_to_minhash(vector, num_perm=NUM_PERM):
    """Convert a dense vector to a MinHash signature."""
    # MinHash works on sets of integers. We need to discretize our float vector.
    # A simple but effective way is to use the signs of the vector components.
    # More sophisticated methods exist (e.g., hyperplane LSH).
    # For this example, we'll use a simple discretization.
    vector = (vector > 0).astype(int)
    m = MinHash(num_perm=num_perm)
    for i, val in enumerate(vector):
        if val == 1:
            m.update(str(i).encode('utf8'))
    return m

# --- API ---
class Query(BaseModel):
    text: str
    top_k: int = 5

@app.post("/rag_lsh_cache")
def retrieve_documents_with_lsh_cache(query: Query):
    timings = {}

    # 1. Embed the query
    start_embed = time.perf_counter()
    query_vector = model.encode(query.text, convert_to_tensor=False)
    end_embed = time.perf_counter()
    timings["embedding_ms"] = (end_embed - start_embed) * 1000
    
    # 2. Generate MinHash for the query vector
    start_lsh = time.perf_counter()
    query_minhash = vector_to_minhash(query_vector)
    # The cache key will be the first hash value from the LSH bands
    # This is a simplification; in practice, you might check multiple bands.
    cache_key = f"lsh_cache:{query_minhash.hashvalues[0]}"
    end_lsh = time.perf_counter()
    timings["lsh_generation_ms"] = (end_lsh - start_lsh) * 1000

    # 3. Check cache
    start_cache_check = time.perf_counter()
    cached_result = r.get(cache_key)
    end_cache_check = time.perf_counter()
    timings["cache_check_ms"] = (end_cache_check - start_cache_check) * 1000

    if cached_result:
        timings["cache_hit"] = True
        return {
            "timings": timings,
            "results": json.loads(cached_result)
        }

    # 4. Cache Miss: Perform the actual vector search
    timings["cache_hit"] = False
    start_search = time.perf_counter()
    search_vector = np.float32([query_vector])
    faiss.normalize_L2(search_vector)
    distances, indices = index.search(search_vector, query.top_k)
    end_search = time.perf_counter()
    timings["vector_search_ms"] = (end_search - start_search) * 1000

    results = {
        "ids": indices[0].tolist(),
        "distances": distances[0].tolist(),
    }

    # 5. Store result in cache
    # Set a TTL (Time-To-Live) for cache invalidation
    r.set(cache_key, json.dumps(results), ex=3600) # Cache for 1 hour

    return {
        "timings": timings,
        "results": results
    }

Edge Cases and Performance Considerations for LSH Caching

  • Tuning num_perm and threshold: This is the most critical part. A low threshold or low num_perm will cause too many collisions (low precision), caching incorrect results for dissimilar queries. A high threshold or high num_perm will cause too few collisions (low recall), reducing your cache hit rate. This requires empirical tuning based on your query distribution and accuracy requirements. You may need to run offline experiments to find the sweet spot.
  • Vector Discretization: The vector_to_minhash function is a simplification. The choice of how to convert a dense float vector into a set of features for MinHash (or other LSH schemes like SimHash for cosine similarity) is non-trivial. Random projection-based methods (like in Spotify's annoy library) are another popular approach. The goal is to preserve the similarity metric of the original vector space in the hashed representation.
  • Cache Invalidation: What happens when the underlying documents are updated, deleted, or added? A simple TTL is a blunt instrument. A more robust solution involves an event-driven approach. When a document is updated in your primary data store, a change data capture (CDC) event (e.g., from Debezium via Kafka) should trigger a process that identifies and invalidates all cache keys that might contain that document ID. This is complex, as it requires an inverted index mapping from document_id to the lsh_cache_keys that returned it.
  • Pattern 2: Proactive Caching with Query Vector Clustering

    LSH caching is a reactive strategy. An alternative, proactive approach is to analyze historical query patterns and pre-warm the cache for the most common types of queries. This is particularly effective in domains with high query repetition, like customer support or product information retrieval.

    The strategy:

  • Offline Step: Periodically (e.g., daily), collect all query embedding vectors from your application logs.
  • Cluster: Run a clustering algorithm like K-Means on these vectors. The resulting centroids represent the "average" vector for common query topics.
  • Pre-warm Cache: For each centroid, perform the expensive vector DB search and store the results in your cache. The cache key is simply the cluster ID (e.g., cluster_cache:1, cluster_cache:2).
  • Online Step: When a new query arrives, embed it. Instead of a full vector DB search, perform a much faster search to find the nearest cluster centroid to the new query vector. Use that centroid's cluster ID to hit the pre-warmed cache.
  • Implementation with Scikit-learn and FAISS

    Here's how the online and offline components would look.

    Offline Clustering and Cache Warming Script:

    python
    # offline_cluster_job.py
    import numpy as np
    import faiss
    import redis
    import json
    from sklearn.cluster import MiniBatchKMeans
    
    # Assume `logged_query_vectors` is a NumPy array of shape (N, D) 
    # loaded from your production logs.
    # For example:
    num_logged_queries = 50000
    dimension = 384
    logged_query_vectors = np.float32(np.random.random((num_logged_queries, dimension)))
    
    # --- Main Clustering Logic ---
    NUM_CLUSTERS = 500 # This is a key hyperparameter
    
    print("Starting K-Means clustering...")
    kmeans = MiniBatchKMeans(n_clusters=NUM_CLUSTERS, 
                              random_state=42, 
                              batch_size=256, 
                              n_init=10)
    kmeans.fit(logged_query_vectors)
    
    centroids = kmeans.cluster_centers_
    print(f"Found {len(centroids)} centroids.")
    
    # --- Cache Warming ---
    # Setup connections (FAISS index, Redis)
    # This would be the same production FAISS index from the online app
    index = faiss.read_index("path/to/your/prod.index")
    r = redis.Redis(decode_responses=True)
    
    print("Warming cache with centroid search results...")
    for i, centroid in enumerate(centroids):
        centroid_vector = np.float32([centroid])
        faiss.normalize_L2(centroid_vector)
        
        # Perform the expensive search for the centroid
        distances, indices = index.search(centroid_vector, 5) # top_k=5
        
        results = {
            "ids": indices[0].tolist(),
            "distances": distances[0].tolist(),
        }
        
        # Store in Redis with the cluster ID as the key
        cache_key = f"cluster_cache:{i}"
        r.set(cache_key, json.dumps(results))
    
    print("Cache warming complete.")
    
    # We also need to store the centroids themselves for online lookup
    centroid_index = faiss.IndexFlatL2(dimension)
    centroid_index.add(np.float32(centroids))
    faiss.write_index(centroid_index, "path/to/your/centroids.index")
    print("Centroid index saved.")

    Online API Endpoint using the Centroid Cache:

    python
    # In your main FastAPI app...
    
    # --- Load Centroid Index at Startup ---
    centroid_index = faiss.read_index("path/to/your/centroids.index")
    
    @app.post("/rag_cluster_cache")
    def retrieve_documents_with_cluster_cache(query: Query):
        timings = {}
    
        # 1. Embed query
        start_embed = time.perf_counter()
        query_vector = model.encode(query.text, convert_to_tensor=False)
        query_vector = np.float32([query_vector])
        faiss.normalize_L2(query_vector)
        end_embed = time.perf_counter()
        timings["embedding_ms"] = (end_embed - start_embed) * 1000
    
        # 2. Find nearest centroid (this is a very fast search)
        start_centroid_search = time.perf_counter()
        # Search for the single nearest neighbor (k=1) in the small centroid index
        _, nearest_cluster_ids = centroid_index.search(query_vector, 1)
        cluster_id = nearest_cluster_ids[0][0]
        end_centroid_search = time.perf_counter()
        timings["centroid_search_ms"] = (end_centroid_search - start_centroid_search) * 1000
    
        # 3. Hit the cache with the cluster ID
        start_cache_check = time.perf_counter()
        cache_key = f"cluster_cache:{cluster_id}"
        cached_result = r.get(cache_key)
        end_cache_check = time.perf_counter()
        timings["cache_check_ms"] = (end_cache_check - start_cache_check) * 1000
    
        if cached_result:
            timings["cache_hit"] = True
            return {
                "timings": timings,
                "results": json.loads(cached_result)
            }
        else:
            # This should be rare if the cache is warmed properly
            # Fallback to a full search or return an error
            timings["cache_hit"] = False
            # ... (implementation for fallback) ...
            return {"error": "Cache miss on cluster cache", "timings": timings}

    This pattern trades offline computational cost for extremely low online latency. The search against a small index of 500 centroids is orders of magnitude faster than searching against millions of document vectors.

    Combining Patterns: A Multi-Layered Caching Architecture

    Neither pattern is a silver bullet. The LSH cache works well for ad-hoc, long-tail queries that are semantically similar to previous ones, while the cluster cache excels at handling the high-volume, repetitive head of your query distribution. In a high-performance production system, you should combine them into a multi-layered cache.

    Here is the lookup flow:

  • L1: Exact Match Cache: A simple Redis cache keyed on the raw query string. CACHE_KEY = "raw_query:" + query.text. Catches exact duplicate queries. Short TTL.
  • L2: LSH Approximate Match Cache: If L1 misses, compute the LSH hash of the query vector and check for similar past queries. CACHE_KEY = "lsh_cache:" + lsh_hash.
  • L3: Cluster Proactive Cache: If L2 misses, find the nearest query cluster centroid and check the pre-warmed cache. CACHE_KEY = "cluster_cache:" + cluster_id.
  • Fallback: Vector Database: If all caches miss, perform the expensive, full vector database search. Then, importantly, populate the L2 LSH cache with the result to benefit future similar queries.
  • This layered approach provides the best of all worlds: immediate hits for duplicates, fast responses for semantically similar queries, and excellent performance for the most common query types, all while maintaining a fallback for novel requests.

    Production Considerations and Advanced Edge Cases

    Deploying this system requires more than just the core logic. Here are critical factors for senior engineers to consider:

  • Cache Invalidation at Scale: As mentioned, event-driven invalidation is superior to TTLs for dynamic data. Your CDC pipeline must be robust. When a document doc_123 is updated, you need a reverse mapping to find which LSH and cluster cache keys need to be purged. This could be another Redis data structure (e.g., a Set for each doc_id storing the cache keys that contain it).
  • Thundering Herd Problem: What happens when a popular cache key (e.g., a cluster centroid key) expires or is invalidated? Multiple concurrent requests will all miss the cache and hammer the vector database simultaneously. Implement a distributed lock (e.g., using Redis's SETNX or a library like redlock-py) around the vector DB call. The first process to acquire the lock performs the search and repopulates the cache, while others wait briefly for the cache to be repopulated.
  • Handling Per-User Data Filters: This is a major challenge. If your vector search includes a metadata filter like where user_id = 'user_abc', the cached results are only valid for that user. The cache key must incorporate the user context.
  • - CACHE_KEY = f"lsh_cache:{lsh_hash}:u:{user_id}"

    - This dramatically increases cache cardinality, which can crater your hit ratio. For this scenario, you might only cache results for high-traffic users or for public, non-filtered data. The cluster cache is often less effective here unless you can create user-group-specific clusters.

  • Observability is Non-Negotiable: You must instrument every layer of this system. Using Prometheus and Grafana, track:
  • - cache_hit_ratio (with labels for each layer: l1_raw, l2_lsh, l3_cluster)

    - latency_histogram_seconds (again, with labels for cache hits vs. misses)

    - lsh_collisions_total (to monitor the effectiveness of your LSH tuning)

    - cache_memory_usage_bytes (to prevent your Redis instance from overflowing)

    Here is an example of instrumenting the FastAPI app with prometheus-fastapi-instrumentator:

    python
    from prometheus_fastapi_instrumentator import Instrumentator
    
    # ... in your app setup ...
    Instrumentator().instrument(app).expose(app)
    
    # Custom metrics
    LSH_CACHE_HITS = Counter("rag_lsh_cache_hits_total", "Total hits for the LSH cache")
    LSH_CACHE_MISSES = Counter("rag_lsh_cache_misses_total", "Total misses for the LSH cache")
    
    # ... inside your endpoint logic ...
    if cached_result:
        LSH_CACHE_HITS.inc()
        # ...
    else:
        LSH_CACHE_MISSES.inc()
        # ...

    Conclusion

    For RAG systems to transition from impressive demos to production-grade, low-latency services, we must treat vector search as the critical I/O bottleneck it is. Simple response caching is insufficient. By implementing intelligent, intermediate-layer caching of the vector search results themselves, we can achieve significant performance gains.

    The LSH-based and cluster-based patterns discussed here provide a powerful toolkit. They require careful tuning, a solid understanding of the trade-offs between precision and recall, and a robust strategy for invalidation and observability. By building a multi-layered caching architecture, you can systematically attack the latency problem, ensuring your RAG application is not just intelligent, but also incredibly fast.

    Found this article helpful?

    Share it with others who might benefit from it.

    More Articles