Advanced RAG Techniques with LangChain and ChromaDB
Master advanced retrieval strategies, query expansion, and hybrid search in RAG systems
langchain
chromadb
rag
advanced-techniques
hybrid-search
query-expansion
by Bui An Du
🔬 Advanced RAG Techniques with LangChain and ChromaDB
🚀 Beyond Basic RAG
Implement sophisticated retrieval strategies for production-grade RAG applications
Query Expansion and Rewriting
Multi-Query Retrieval
python
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
class MultiQueryRetriever:
def __init__(self, base_retriever, llm):
self.base_retriever = base_retriever
self.llm = llm
# Prompt for generating multiple queries
self.query_expansion_prompt = PromptTemplate(
input_variables=["question"],
template="""
You are an AI language model assistant. Your task is to generate
five different versions of the given user question to retrieve
relevant documents from a vector database. By generating multiple
perspectives on the user question, your goal is to help the user
overcome some of the limitations of distance-based similarity search.
Provide these alternative questions separated by newlines.
Original question: {question}
"""
)
def get_relevant_documents(self, query):
# Generate multiple queries
chain = LLMChain(llm=self.llm, prompt=self.query_expansion_prompt)
expanded_queries = chain.run(question=query).strip().split('\n')
# Remove empty strings and clean
expanded_queries = [q.strip() for q in expanded_queries if q.strip()]
expanded_queries.append(query) # Include original query
# Retrieve documents for each query
all_docs = []
for expanded_query in expanded_queries:
docs = self.base_retriever.get_relevant_documents(expanded_query)
all_docs.extend(docs)
# Remove duplicates based on content
unique_docs = self._remove_duplicates(all_docs)
return unique_docs[:5] # Return top 5 unique documents
def _remove_duplicates(self, documents):
seen = set()
unique_docs = []
for doc in documents:
content = doc.page_content[:200] # Use first 200 chars as identifier
if content not in seen:
seen.add(content)
unique_docs.append(doc)
return unique_docsHyDE (Hypothetical Document Embeddings)
python
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
class HyDERetriever:
def __init__(self, base_retriever, llm, embeddings):
self.base_retriever = base_retriever
self.llm = llm
self.embeddings = embeddings
self.hyde_prompt = PromptTemplate(
input_variables=["question"],
template="""
Please write a passage to answer the question.
Try to include as many key details as possible.
Question: {question}
Passage:"""
)
def get_relevant_documents(self, query):
# Generate hypothetical document
chain = LLMChain(llm=self.llm, prompt=self.hyde_prompt)
hypothetical_doc = chain.run(question=query)
# Embed the hypothetical document
hyde_embedding = self.embeddings.embed_query(hypothetical_doc)
# Search using the hypothetical document embedding
# This would require custom implementation in ChromaDB
# For now, we'll use semantic search with expanded query
# Alternative: Use the hypothetical document as context for search
enhanced_query = f"{query}\n\nContext: {hypothetical_doc[:500]}"
return self.base_retriever.get_relevant_documents(enhanced_query)Hybrid Search Implementation
Combining BM25 and Semantic Search
python
from rank_bm25 import BM25Okapi
import numpy as np
from typing import List, Tuple
class HybridRetriever:
def __init__(self, documents, embeddings, vectorstore):
self.documents = documents
self.embeddings = embeddings
self.vectorstore = vectorstore
# Prepare BM25
self.corpus = [doc.page_content for doc in documents]
tokenized_corpus = [doc.split() for doc in self.corpus]
self.bm25 = BM25Okapi(tokenized_corpus)
def hybrid_search(self, query: str, k: int = 5, alpha: float = 0.5) -> List[Tuple]:
"""
Perform hybrid search combining BM25 and semantic similarity
alpha: weight for semantic search (1-alpha for BM25)
"""
# BM25 Search
tokenized_query = query.split()
bm25_scores = self.bm25.get_scores(tokenized_query)
# Semantic Search
semantic_results = self.vectorstore.similarity_search_with_score(query, k=len(self.documents))
# Combine scores
combined_scores = []
for i, doc in enumerate(self.documents):
bm25_score = bm25_scores[i]
semantic_score = next((score for d, score in semantic_results if d == doc), 0)
semantic_score = 1 / (1 + semantic_score) # Convert distance to similarity
# Normalize scores
bm25_normalized = bm25_score / max(bm25_scores) if max(bm25_scores) > 0 else 0
semantic_normalized = semantic_score
# Weighted combination
combined_score = alpha * semantic_normalized + (1 - alpha) * bm25_normalized
combined_scores.append((doc, combined_score))
# Sort by combined score
combined_scores.sort(key=lambda x: x[1], reverse=True)
return combined_scores[:k]Implementing Hybrid Search with ChromaDB
python
import chromadb
from chromadb.utils import embedding_functions
import numpy as np
from rank_bm25 import BM25Okapi
class ChromaHybridSearch:
def __init__(self, collection_name="documents"):
self.client = chromadb.PersistentClient(path="./chroma_db")
self.collection = self.client.get_or_create_collection(
name=collection_name,
embedding_function=embedding_functions.DefaultEmbeddingFunction()
)
# For BM25
self.documents = []
self.bm25 = None
def add_documents(self, documents: List[str], metadatas: List[dict] = None):
"""Add documents to both ChromaDB and BM25"""
self.documents = documents
# Prepare BM25
tokenized_docs = [doc.split() for doc in documents]
self.bm25 = BM25Okapi(tokenized_docs)
# Add to ChromaDB
self.collection.add(
documents=documents,
metadatas=metadatas or [{} for _ in documents],
ids=[f"doc_{i}" for i in range(len(documents))]
)
def hybrid_search(self, query: str, k: int = 5, alpha: float = 0.5):
"""Perform hybrid search"""
# BM25 scores
tokenized_query = query.split()
bm25_scores = self.bm25.get_scores(tokenized_query)
# Semantic search
semantic_results = self.collection.query(
query_texts=[query],
n_results=len(self.documents),
include=['distances', 'metadatas', 'documents']
)
# Combine scores
combined_results = []
for i in range(len(self.documents)):
bm25_score = bm25_scores[i]
semantic_distance = semantic_results['distances'][0][i]
semantic_score = 1 / (1 + semantic_distance) # Convert distance to similarity
# Normalize
bm25_normalized = bm25_score / max(bm25_scores) if max(bm25_scores) > 0 else 0
# Weighted combination
combined_score = alpha * semantic_score + (1 - alpha) * bm25_normalized
combined_results.append({
'document': self.documents[i],
'score': combined_score,
'metadata': semantic_results['metadatas'][0][i]
})
# Sort and return top k
combined_results.sort(key=lambda x: x['score'], reverse=True)
return combined_results[:k]Advanced Retrieval Strategies
Re-ranking with Cross-Encoders
python
from sentence_transformers import CrossEncoder
class ReRankRetriever:
def __init__(self, base_retriever, cross_encoder_model="cross-encoder/ms-marco-MiniLM-L-6-v2"):
self.base_retriever = base_retriever
self.cross_encoder = CrossEncoder(cross_encoder_model)
def get_relevant_documents(self, query, initial_k=20, final_k=5):
# First stage: Retrieve more candidates
candidates = self.base_retriever.get_relevant_documents(
query, k=initial_k
)
if len(candidates) == 0:
return []
# Second stage: Re-rank using cross-encoder
query_doc_pairs = [[query, doc.page_content] for doc in candidates]
scores = self.cross_encoder.predict(query_doc_pairs)
# Sort by cross-encoder scores
scored_docs = list(zip(candidates, scores))
scored_docs.sort(key=lambda x: x[1], reverse=True)
return [doc for doc, score in scored_docs[:final_k]]Contextual Compression
python
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
def create_contextual_retriever(base_retriever, llm):
"""Create a retriever that compresses retrieved documents"""
compressor = LLMChainExtractor.from_llm(llm)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=base_retriever
)
return compression_retrieverMetadata Filtering and Advanced Queries
Implementing Metadata Filters in ChromaDB
python
class FilteredChromaRetriever:
def __init__(self, collection_name="documents"):
self.client = chromadb.PersistentClient(path="./chroma_db")
self.collection = self.client.get_collection(name=collection_name)
def search_with_filters(self, query: str, filters: dict = None, k: int = 5):
"""Search with metadata filters"""
# Build where clause for ChromaDB
where_clause = {}
if filters:
where_clause = filters
results = self.collection.query(
query_texts=[query],
n_results=k,
where=where_clause,
include=['documents', 'metadatas', 'distances']
)
return results
def advanced_search(self, query: str, **kwargs):
"""Advanced search with multiple filter options"""
# Date range filtering
date_from = kwargs.get('date_from')
date_to = kwargs.get('date_to')
# Category filtering
categories = kwargs.get('categories', [])
# Author filtering
authors = kwargs.get('authors', [])
# Build complex filter
where_clause = {}
if date_from or date_to:
date_filter = {}
if date_from:
date_filter["$gte"] = date_from
if date_to:
date_filter["$lte"] = date_to
where_clause["publish_date"] = date_filter
if categories:
where_clause["category"] = {"$in": categories}
if authors:
where_clause["author"] = {"$in": authors}
return self.search_with_filters(query, where_clause, kwargs.get('k', 5))Query Routing and Ensemble Methods
Intelligent Query Router
python
class QueryRouter:
def __init__(self, retrievers, classifier_llm):
self.retrievers = retrievers # Dict of retriever_name -> retriever
self.classifier_llm = classifier_llm
self.routing_prompt = PromptTemplate(
input_variables=["query"],
template="""
Analyze the following query and determine which retrieval strategy would be most effective:
Query: {query}
Available strategies:
- web_search: For current events, news, or real-time information
- document_search: For company policies, documentation, or internal knowledge
- code_search: For programming questions, API references, or technical documentation
- general_search: For general knowledge or broad questions
Return only the strategy name:
"""
)
def route_query(self, query: str):
"""Route query to appropriate retriever"""
chain = LLMChain(llm=self.classifier_llm, prompt=self.routing_prompt)
strategy = chain.run(query=query).strip().lower()
# Map to retriever
retriever_mapping = {
'web_search': 'web_retriever',
'document_search': 'doc_retriever',
'code_search': 'code_retriever',
'general_search': 'general_retriever'
}
retriever_name = retriever_mapping.get(strategy, 'general_retriever')
return self.retrievers[retriever_name]Ensemble Retrieval
python
class EnsembleRetriever:
def __init__(self, retrievers, weights=None):
self.retrievers = retrievers
self.weights = weights or [1.0] * len(retrievers)
def get_relevant_documents(self, query: str, k: int = 5):
"""Combine results from multiple retrievers"""
all_results = []
for retriever, weight in zip(self.retrievers, self.weights):
try:
results = retriever.get_relevant_documents(query, k=k*2)
# Add weight to each result
for doc in results:
doc.metadata['ensemble_score'] = weight
all_results.append(doc)
except Exception as e:
print(f"Retriever failed: {e}")
continue
# Remove duplicates and re-rank
unique_results = self._deduplicate_results(all_results)
# Sort by ensemble score and return top k
unique_results.sort(
key=lambda x: x.metadata.get('ensemble_score', 0),
reverse=True
)
return unique_results[:k]
def _deduplicate_results(self, results):
"""Remove duplicate documents based on content"""
seen = set()
unique = []
for doc in results:
content_hash = hash(doc.page_content[:500])
if content_hash not in seen:
seen.add(content_hash)
unique.append(doc)
return uniquePerformance Optimization Techniques
Caching Strategies
python
from functools import lru_cache
import hashlib
class CachedRetriever:
def __init__(self, base_retriever, cache_size=1000):
self.base_retriever = base_retriever
self.cache_size = cache_size
self.cache = {}
def _get_cache_key(self, query: str, k: int) -> str:
"""Generate cache key from query and parameters"""
key_data = f"{query}_{k}"
return hashlib.md5(key_data.encode()).hexdigest()
@lru_cache(maxsize=1000)
def get_relevant_documents(self, query: str, k: int = 5):
"""Cached document retrieval"""
cache_key = self._get_cache_key(query, k)
if cache_key in self.cache:
return self.cache[cache_key]
# Retrieve and cache
results = self.base_retriever.get_relevant_documents(query, k)
self.cache[cache_key] = results
# Cache size management
if len(self.cache) > self.cache_size:
# Remove oldest entries (simple FIFO)
oldest_keys = list(self.cache.keys())[:100]
for key in oldest_keys:
del self.cache[key]
return resultsBatch Processing for Embeddings
python
import asyncio
from concurrent.futures import ThreadPoolExecutor
class BatchEmbeddingRetriever:
def __init__(self, base_retriever, batch_size=32):
self.base_retriever = base_retriever
self.batch_size = batch_size
self.executor = ThreadPoolExecutor(max_workers=4)
async def batch_retrieve(self, queries: List[str], k: int = 5):
"""Retrieve documents for multiple queries in parallel"""
loop = asyncio.get_event_loop()
# Create tasks for parallel execution
tasks = [
loop.run_in_executor(
self.executor,
self.base_retriever.get_relevant_documents,
query, k
)
for query in queries
]
# Wait for all tasks to complete
results = await asyncio.gather(*tasks)
return resultsMonitoring and Analytics
Retrieval Performance Metrics
python
import time
from typing import Dict, List
class RetrievalMetrics:
def __init__(self):
self.metrics = {
'query_count': 0,
'avg_response_time': 0,
'cache_hit_rate': 0,
'error_rate': 0
}
def track_query(self, query: str, response_time: float, cache_hit: bool, error: bool = False):
"""Track individual query metrics"""
self.metrics['query_count'] += 1
# Update average response time
current_avg = self.metrics['avg_response_time']
self.metrics['avg_response_time'] = (
(current_avg * (self.metrics['query_count'] - 1)) + response_time
) / self.metrics['query_count']
# Update cache hit rate
if cache_hit:
current_cache_hits = self.metrics.get('cache_hits', 0) + 1
self.metrics['cache_hits'] = current_cache_hits
self.metrics['cache_hit_rate'] = current_cache_hits / self.metrics['query_count']
# Track errors
if error:
current_errors = self.metrics.get('errors', 0) + 1
self.metrics['errors'] = current_errors
self.metrics['error_rate'] = current_errors / self.metrics['query_count']
def get_metrics(self) -> Dict:
"""Get current metrics"""
return self.metrics.copy()Production Deployment Considerations
Scalability Patterns
python
# Horizontal scaling with multiple ChromaDB instances
class ShardedChromaRetriever:
def __init__(self, shard_configs: List[Dict]):
self.shards = []
for config in shard_configs:
client = chromadb.PersistentClient(path=config['path'])
collection = client.get_collection(config['collection'])
self.shards.append({
'collection': collection,
'weight': config.get('weight', 1.0)
})
def distributed_search(self, query: str, k: int = 5):
"""Search across multiple shards"""
all_results = []
for shard in self.shards:
results = shard['collection'].query(
query_texts=[query],
n_results=k,
include=['documents', 'metadatas', 'distances']
)
# Add shard weight
for i, doc in enumerate(results['documents']):
all_results.append({
'document': doc,
'score': (1 / (1 + results['distances'][i])) * shard['weight'],
'metadata': results['metadatas'][i]
})
# Sort by weighted score
all_results.sort(key=lambda x: x['score'], reverse=True)
return all_results[:k]Conclusion
Advanced RAG techniques significantly improve retrieval quality and system performance. Key takeaways:
- Query Expansion: Generate multiple query variations for better recall
- Hybrid Search: Combine lexical and semantic search methods
- Re-ranking: Use cross-encoders for improved relevance
- Metadata Filtering: Enable precise document selection
- Ensemble Methods: Combine multiple retrieval strategies
- Caching: Improve response times for repeated queries
- Monitoring: Track system performance and user satisfaction
Next Steps
- Experiment: Try different combinations of these techniques
- A/B Test: Compare performance of different approaches
- Monitor: Track metrics and user feedback
- Scale: Implement distributed retrieval for production workloads
The key to successful RAG systems lies in understanding your data and use cases, then applying the right combination of these advanced techniques. 🧠✨