Building Custom Document Retrieval Systems with LangChain: A Comprehensive Guide to Implementing BaseRetriever
Building Custom Document Retrieval Systems with LangChain: A Comprehensive Guide to Implementing BaseRetriever
Document retrieval is a critical component in many modern AI applications, particularly those involving Large Language Models (LLMs). LangChain provides a powerful framework for building retrieval systems through its BaseRetriever
abstract class. In this guide, we’ll explore how to implement custom document retrievers by extending this class, enabling you to create tailored retrieval solutions for your specific use cases.
Understanding BaseRetriever
At its core, BaseRetriever
is an abstract base class that defines the interface for document retrieval systems in LangChain. It inherits from RunnableSerializable
and the Python ABC
(Abstract Base Class), making it both a runnable component in LangChain pipelines and a template for custom implementations.
The primary purpose of a retriever is straightforward: take a string query and return the most relevant documents from a source. This simple concept powers many advanced applications, from question-answering systems to research assistants.
The BaseRetriever Interface
Before diving into implementation, let’s understand the key methods and interfaces of BaseRetriever
:
from langchain_core.retrievers import BaseRetriever
from langchain_core.documents import Document
from typing import List
class MyCustomRetriever(BaseRetriever):
"""A custom retriever implementation."""
def _get_relevant_documents(self, query: str) -> List[Document]:
"""Core implementation to retrieve relevant documents."""
# Your custom retrieval logic goes here
pass
The most important method to implement is _get_relevant_documents
, which contains the actual retrieval logic. Optionally, you can also implement _aget_relevant_documents
for asynchronous retrieval.
Usage Patterns
BaseRetriever follows the standard Runnable interface in LangChain, which means you can use it with the following methods:
invoke
: For synchronous retrievalainvoke
: For asynchronous retrievalbatch
: For processing multiple queries in batchesabatch
: For asynchronous batch processing
Here’s a simple example of using a retriever:
# Using a retriever
documents = retriever.invoke("What is the capital of France?")
# Asynchronous usage
import asyncio
documents = asyncio.run(retriever.ainvoke("What is the capital of France?"))
# Batch processing
results = retriever.batch(["Query 1", "Query 2", "Query 3"])
Implementing a Custom Retriever
Let’s implement a simple custom retriever that demonstrates the basic pattern:
from langchain_core.retrievers import BaseRetriever
from langchain_core.documents import Document
from typing import List
class SimpleVectorRetriever(BaseRetriever):
"""A simple vector-based retriever."""
def __init__(self, documents, vectorizer):
"""Initialize with documents and a vectorizer."""
super().__init__()
self.documents = documents
self.vectorizer = vectorizer
# Pre-compute document vectors
self.doc_contents = [doc.page_content for doc in documents]
self.doc_vectors = vectorizer.transform(self.doc_contents).toarray()
def _get_relevant_documents(self, query: str) -> List[Document]:
"""Get documents relevant to the query."""
# Vectorize the query
query_vector = self.vectorizer.transform([query]).toarray()[0]
# Compute similarity scores (using dot product for simplicity)
scores = self.doc_vectors.dot(query_vector)
# Get top 5 documents
top_indices = scores.argsort()[-5:][::-1]
return [self.documents[i] for i in top_indices]
This example uses a simple vector-based approach, but you could implement more sophisticated retrieval mechanisms based on your needs.
A More Complete Example
Let’s create a more practical example using scikit-learn’s TF-IDF vectorizer:
from langchain_core.retrievers import BaseRetriever
from langchain_core.documents import Document
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
class TfidfRetriever(BaseRetriever):
"""Retriever based on TF-IDF vectorization and cosine similarity."""
def __init__(self, documents: List[Document], top_k: int = 5):
"""Initialize the retriever with documents."""
super().__init__()
self.documents = documents
self.top_k = top_k
# Extract text from documents
texts = [doc.page_content for doc in documents]
# Create and fit vectorizer
self.vectorizer = TfidfVectorizer()
self.doc_vectors = self.vectorizer.fit_transform(texts)
def _get_relevant_documents(self, query: str) -> List[Document]:
"""Retrieve documents based on cosine similarity."""
# Vectorize the query
query_vector = self.vectorizer.transform([query])
# Calculate cosine similarity
similarities = cosine_similarity(query_vector, self.doc_vectors).flatten()
# Get indices of top-k documents
top_indices = similarities.argsort()[-self.top_k:][::-1]
# Return the top documents
return [self.documents[i] for i in top_indices]
# Optional: Implement the async version for better performance in async contexts
async def _aget_relevant_documents(self, query: str) -> List[Document]:
"""Async implementation of document retrieval."""
return self._get_relevant_documents(query)
To use this retriever:
# Create some sample documents
documents = [
Document(page_content="Paris is the capital of France."),
Document(page_content="Berlin is the capital of Germany."),
Document(page_content="Rome is the capital of Italy."),
Document(page_content="Madrid is the capital of Spain."),
Document(page_content="France is known for the Eiffel Tower and fine cuisine.")
]
# Initialize our retriever
retriever = TfidfRetriever(documents)
# Retrieve relevant documents
results = retriever.invoke("Tell me about France")
# Print the results
for doc in results:
print(doc.page_content)
Advanced Features of BaseRetriever
The BaseRetriever
class comes with several advanced features that make it powerful and flexible:
Metadata and Tags
You can associate metadata and tags with your retriever, which are passed along during callbacks:
retriever = TfidfRetriever(
documents,
metadata={"source": "wikipedia", "purpose": "travel_assistant"},
tags=["geography", "travel"]
)
Streaming Results
For retrievers that can produce results incrementally, you can implement the stream
and astream
methods:
class StreamingRetriever(BaseRetriever):
# ... other methods ...
def stream(self, query: str, config=None, **kwargs):
"""Stream results as they become available."""
for doc in self._get_relevant_documents(query):
yield doc
Integration with Callbacks
BaseRetriever integrates with LangChain’s callback system, allowing you to track and monitor retrieval operations:
from langchain.callbacks import StdOutCallbackHandler
handler = StdOutCallbackHandler()
results = retriever.invoke("France", callbacks=[handler])
Fallbacks and Retry Logic
You can add fallbacks and retry logic to your retrievers:
from langchain_core.runnables import RunnableConfig
# Create a retriever with fallbacks
robust_retriever = primary_retriever.with_fallbacks([backup_retriever])
# Add retry logic
retry_retriever = retriever.with_retry(
stop_after_attempt=3,
wait_exponential_jitter=True
)
Real-World Use Case: Hybrid Retrieval
Let’s implement a more sophisticated retriever that combines keyword search and semantic search:
from langchain_core.retrievers import BaseRetriever
from langchain_core.documents import Document
from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Any
class HybridRetriever(BaseRetriever):
"""A hybrid retriever that combines keyword-based and semantic search."""
def __init__(
self,
documents: List[Document],
embedding_model: str = "all-MiniLM-L6-v2",
alpha: float = 0.5,
top_k: int = 5
):
"""Initialize the hybrid retriever."""
super().__init__()
self.documents = documents
self.alpha = alpha # Weight for combining scores
self.top_k = top_k
# Extract text from documents
self.texts = [doc.page_content for doc in documents]
# Set up TF-IDF for keyword search
self.tfidf = TfidfVectorizer()
self.tfidf_matrix = self.tfidf.fit_transform(self.texts)
# Set up embedding model for semantic search
self.embedding_model = SentenceTransformer(embedding_model)
self.document_embeddings = self.embedding_model.encode(self.texts)
def _get_relevant_documents(self, query: str) -> List[Document]:
"""Get documents using hybrid retrieval approach."""
# Get keyword-based scores
query_tfidf = self.tfidf.transform([query])
tfidf_scores = query_tfidf.dot(self.tfidf_matrix.T).toarray()[0]
# Get semantic search scores
query_embedding = self.embedding_model.encode(query)
semantic_scores = np.dot(
self.document_embeddings,
query_embedding
) / (
np.linalg.norm(self.document_embeddings, axis=1) *
np.linalg.norm(query_embedding)
)
# Combine scores
combined_scores = (
self.alpha * tfidf_scores +
(1 - self.alpha) * semantic_scores
)
# Get top-k indices
top_indices = combined_scores.argsort()[-self.top_k:][::-1]
# Return documents with scores in metadata
results = []
for idx in top_indices:
doc = self.documents[idx].copy()
doc.metadata.update({
"score": float(combined_scores[idx]),
"tfidf_score": float(tfidf_scores[idx]),
"semantic_score": float(semantic_scores[idx])
})
results.append(doc)
return results
To use this hybrid retriever:
# Install required packages
# pip install sentence-transformers
# Create documents
documents = [
Document(page_content="Paris is the capital of France and known for the Eiffel Tower."),
Document(page_content="Berlin is the capital of Germany and home to the Brandenburg Gate."),
Document(page_content="Rome is the capital of Italy and features the Colosseum."),
Document(page_content="France has a rich history of art, literature, and cuisine."),
Document(page_content="The Louvre Museum in Paris houses the Mona Lisa painting.")
]
# Initialize the hybrid retriever
hybrid_retriever = HybridRetriever(documents)
# Retrieve documents
results = hybrid_retriever.invoke("art and culture in France")
# Print results with scores
for doc in results:
print(f"Document: {doc.page_content}")
print(f"Combined Score: {doc.metadata['score']:.4f}")
print(f"TF-IDF Score: {doc.metadata['tfidf_score']:.4f}")
print(f"Semantic Score: {doc.metadata['semantic_score']:.4f}")
print("-" * 50)
Best Practices for Custom Retrievers
When implementing your own retrievers, consider these best practices:
- Implement both synchronous and asynchronous methods for maximum flexibility
- Include proper error handling to ensure robustness
- Add meaningful metadata to retrieved documents to provide context
- Consider performance implications for large document collections
- Use appropriate similarity metrics for your specific use case
- Leverage caching for frequently accessed documents
- Add proper documentation for your custom retriever
Conclusion
The BaseRetriever
abstract class in LangChain provides a powerful foundation for building custom document retrieval systems. By implementing the _get_relevant_documents
method, you can create retrievers tailored to your specific needs, whether you’re building a simple keyword-based system or a sophisticated hybrid retrieval engine.
Custom retrievers can be seamlessly integrated into LangChain pipelines, combined with other components, and enhanced with features like fallbacks and retry logic. This flexibility makes LangChain an excellent choice for building robust retrieval-augmented generation (RAG) systems and other applications that require efficient document retrieval.
By understanding the BaseRetriever
interface and following the patterns shown in this guide, you’ll be well-equipped to create powerful, custom retrieval solutions for your AI applications.
This post was originally written in my native language and then translated using an LLM. I apologize if there are any grammatical inconsistencies.