6 minute read

Building Custom Document Retrieval Systems with LangChain: A Comprehensive Guide to Implementing BaseRetriever

Document retrieval is a critical component in many modern AI applications, particularly those involving Large Language Models (LLMs). LangChain provides a powerful framework for building retrieval systems through its BaseRetriever abstract class. In this guide, we’ll explore how to implement custom document retrievers by extending this class, enabling you to create tailored retrieval solutions for your specific use cases.

Understanding BaseRetriever

At its core, BaseRetriever is an abstract base class that defines the interface for document retrieval systems in LangChain. It inherits from RunnableSerializable and the Python ABC (Abstract Base Class), making it both a runnable component in LangChain pipelines and a template for custom implementations.

The primary purpose of a retriever is straightforward: take a string query and return the most relevant documents from a source. This simple concept powers many advanced applications, from question-answering systems to research assistants.

The BaseRetriever Interface

Before diving into implementation, let’s understand the key methods and interfaces of BaseRetriever:

from langchain_core.retrievers import BaseRetriever
from langchain_core.documents import Document
from typing import List

class MyCustomRetriever(BaseRetriever):
    """A custom retriever implementation."""
    
    def _get_relevant_documents(self, query: str) -> List[Document]:
        """Core implementation to retrieve relevant documents."""
        # Your custom retrieval logic goes here
        pass

The most important method to implement is _get_relevant_documents, which contains the actual retrieval logic. Optionally, you can also implement _aget_relevant_documents for asynchronous retrieval.

Usage Patterns

BaseRetriever follows the standard Runnable interface in LangChain, which means you can use it with the following methods:

  • invoke: For synchronous retrieval
  • ainvoke: For asynchronous retrieval
  • batch: For processing multiple queries in batches
  • abatch: For asynchronous batch processing

Here’s a simple example of using a retriever:

# Using a retriever
documents = retriever.invoke("What is the capital of France?")

# Asynchronous usage
import asyncio
documents = asyncio.run(retriever.ainvoke("What is the capital of France?"))

# Batch processing
results = retriever.batch(["Query 1", "Query 2", "Query 3"])

Implementing a Custom Retriever

Let’s implement a simple custom retriever that demonstrates the basic pattern:

from langchain_core.retrievers import BaseRetriever
from langchain_core.documents import Document
from typing import List

class SimpleVectorRetriever(BaseRetriever):
    """A simple vector-based retriever."""
    
    def __init__(self, documents, vectorizer):
        """Initialize with documents and a vectorizer."""
        super().__init__()
        self.documents = documents
        self.vectorizer = vectorizer
        
        # Pre-compute document vectors
        self.doc_contents = [doc.page_content for doc in documents]
        self.doc_vectors = vectorizer.transform(self.doc_contents).toarray()
    
    def _get_relevant_documents(self, query: str) -> List[Document]:
        """Get documents relevant to the query."""
        # Vectorize the query
        query_vector = self.vectorizer.transform([query]).toarray()[0]
        
        # Compute similarity scores (using dot product for simplicity)
        scores = self.doc_vectors.dot(query_vector)
        
        # Get top 5 documents
        top_indices = scores.argsort()[-5:][::-1]
        return [self.documents[i] for i in top_indices]

This example uses a simple vector-based approach, but you could implement more sophisticated retrieval mechanisms based on your needs.

A More Complete Example

Let’s create a more practical example using scikit-learn’s TF-IDF vectorizer:

from langchain_core.retrievers import BaseRetriever
from langchain_core.documents import Document
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

class TfidfRetriever(BaseRetriever):
    """Retriever based on TF-IDF vectorization and cosine similarity."""
    
    def __init__(self, documents: List[Document], top_k: int = 5):
        """Initialize the retriever with documents."""
        super().__init__()
        self.documents = documents
        self.top_k = top_k
        
        # Extract text from documents
        texts = [doc.page_content for doc in documents]
        
        # Create and fit vectorizer
        self.vectorizer = TfidfVectorizer()
        self.doc_vectors = self.vectorizer.fit_transform(texts)
    
    def _get_relevant_documents(self, query: str) -> List[Document]:
        """Retrieve documents based on cosine similarity."""
        # Vectorize the query
        query_vector = self.vectorizer.transform([query])
        
        # Calculate cosine similarity
        similarities = cosine_similarity(query_vector, self.doc_vectors).flatten()
        
        # Get indices of top-k documents
        top_indices = similarities.argsort()[-self.top_k:][::-1]
        
        # Return the top documents
        return [self.documents[i] for i in top_indices]
    
    # Optional: Implement the async version for better performance in async contexts
    async def _aget_relevant_documents(self, query: str) -> List[Document]:
        """Async implementation of document retrieval."""
        return self._get_relevant_documents(query)

To use this retriever:

# Create some sample documents
documents = [
    Document(page_content="Paris is the capital of France."),
    Document(page_content="Berlin is the capital of Germany."),
    Document(page_content="Rome is the capital of Italy."),
    Document(page_content="Madrid is the capital of Spain."),
    Document(page_content="France is known for the Eiffel Tower and fine cuisine.")
]

# Initialize our retriever
retriever = TfidfRetriever(documents)

# Retrieve relevant documents
results = retriever.invoke("Tell me about France")

# Print the results
for doc in results:
    print(doc.page_content)

Advanced Features of BaseRetriever

The BaseRetriever class comes with several advanced features that make it powerful and flexible:

Metadata and Tags

You can associate metadata and tags with your retriever, which are passed along during callbacks:

retriever = TfidfRetriever(
    documents,
    metadata={"source": "wikipedia", "purpose": "travel_assistant"},
    tags=["geography", "travel"]
)

Streaming Results

For retrievers that can produce results incrementally, you can implement the stream and astream methods:

class StreamingRetriever(BaseRetriever):
    # ... other methods ...
    
    def stream(self, query: str, config=None, **kwargs):
        """Stream results as they become available."""
        for doc in self._get_relevant_documents(query):
            yield doc

Integration with Callbacks

BaseRetriever integrates with LangChain’s callback system, allowing you to track and monitor retrieval operations:

from langchain.callbacks import StdOutCallbackHandler

handler = StdOutCallbackHandler()
results = retriever.invoke("France", callbacks=[handler])

Fallbacks and Retry Logic

You can add fallbacks and retry logic to your retrievers:

from langchain_core.runnables import RunnableConfig

# Create a retriever with fallbacks
robust_retriever = primary_retriever.with_fallbacks([backup_retriever])

# Add retry logic
retry_retriever = retriever.with_retry(
    stop_after_attempt=3,
    wait_exponential_jitter=True
)

Real-World Use Case: Hybrid Retrieval

Let’s implement a more sophisticated retriever that combines keyword search and semantic search:

from langchain_core.retrievers import BaseRetriever
from langchain_core.documents import Document
from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Any

class HybridRetriever(BaseRetriever):
    """A hybrid retriever that combines keyword-based and semantic search."""
    
    def __init__(
        self,
        documents: List[Document],
        embedding_model: str = "all-MiniLM-L6-v2",
        alpha: float = 0.5,
        top_k: int = 5
    ):
        """Initialize the hybrid retriever."""
        super().__init__()
        self.documents = documents
        self.alpha = alpha  # Weight for combining scores
        self.top_k = top_k
        
        # Extract text from documents
        self.texts = [doc.page_content for doc in documents]
        
        # Set up TF-IDF for keyword search
        self.tfidf = TfidfVectorizer()
        self.tfidf_matrix = self.tfidf.fit_transform(self.texts)
        
        # Set up embedding model for semantic search
        self.embedding_model = SentenceTransformer(embedding_model)
        self.document_embeddings = self.embedding_model.encode(self.texts)
    
    def _get_relevant_documents(self, query: str) -> List[Document]:
        """Get documents using hybrid retrieval approach."""
        # Get keyword-based scores
        query_tfidf = self.tfidf.transform([query])
        tfidf_scores = query_tfidf.dot(self.tfidf_matrix.T).toarray()[0]
        
        # Get semantic search scores
        query_embedding = self.embedding_model.encode(query)
        semantic_scores = np.dot(
            self.document_embeddings, 
            query_embedding
        ) / (
            np.linalg.norm(self.document_embeddings, axis=1) * 
            np.linalg.norm(query_embedding)
        )
        
        # Combine scores
        combined_scores = (
            self.alpha * tfidf_scores + 
            (1 - self.alpha) * semantic_scores
        )
        
        # Get top-k indices
        top_indices = combined_scores.argsort()[-self.top_k:][::-1]
        
        # Return documents with scores in metadata
        results = []
        for idx in top_indices:
            doc = self.documents[idx].copy()
            doc.metadata.update({
                "score": float(combined_scores[idx]),
                "tfidf_score": float(tfidf_scores[idx]),
                "semantic_score": float(semantic_scores[idx])
            })
            results.append(doc)
        
        return results

To use this hybrid retriever:

# Install required packages
# pip install sentence-transformers

# Create documents
documents = [
    Document(page_content="Paris is the capital of France and known for the Eiffel Tower."),
    Document(page_content="Berlin is the capital of Germany and home to the Brandenburg Gate."),
    Document(page_content="Rome is the capital of Italy and features the Colosseum."),
    Document(page_content="France has a rich history of art, literature, and cuisine."),
    Document(page_content="The Louvre Museum in Paris houses the Mona Lisa painting.")
]

# Initialize the hybrid retriever
hybrid_retriever = HybridRetriever(documents)

# Retrieve documents
results = hybrid_retriever.invoke("art and culture in France")

# Print results with scores
for doc in results:
    print(f"Document: {doc.page_content}")
    print(f"Combined Score: {doc.metadata['score']:.4f}")
    print(f"TF-IDF Score: {doc.metadata['tfidf_score']:.4f}")
    print(f"Semantic Score: {doc.metadata['semantic_score']:.4f}")
    print("-" * 50)

Best Practices for Custom Retrievers

When implementing your own retrievers, consider these best practices:

  1. Implement both synchronous and asynchronous methods for maximum flexibility
  2. Include proper error handling to ensure robustness
  3. Add meaningful metadata to retrieved documents to provide context
  4. Consider performance implications for large document collections
  5. Use appropriate similarity metrics for your specific use case
  6. Leverage caching for frequently accessed documents
  7. Add proper documentation for your custom retriever

Conclusion

The BaseRetriever abstract class in LangChain provides a powerful foundation for building custom document retrieval systems. By implementing the _get_relevant_documents method, you can create retrievers tailored to your specific needs, whether you’re building a simple keyword-based system or a sophisticated hybrid retrieval engine.

Custom retrievers can be seamlessly integrated into LangChain pipelines, combined with other components, and enhanced with features like fallbacks and retry logic. This flexibility makes LangChain an excellent choice for building robust retrieval-augmented generation (RAG) systems and other applications that require efficient document retrieval.

By understanding the BaseRetriever interface and following the patterns shown in this guide, you’ll be well-equipped to create powerful, custom retrieval solutions for your AI applications.

This post was originally written in my native language and then translated using an LLM. I apologize if there are any grammatical inconsistencies.

Categories:

Updated: