""" AsyncParentDocumentRetriever is used by RAG to split up documents into smaller *and* larger related chunks. """ import pickle import uuid from typing import Any, ClassVar, Collection, List, Optional, Tuple from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, ) from langchain.retrievers import ParentDocumentRetriever from langchain.schema.document import Document from langchain.schema.storage import BaseStore from langchain.storage import LocalFileStore from langchain_community.storage import RedisStore from langchain.vectorstores.base import VectorStore class AsyncParentDocumentRetriever(ParentDocumentRetriever): """Retrieve small chunks then retrieve their parent documents. When splitting documents for retrieval, there are often conflicting desires: 1. You may want to have small documents, so that their embeddings can most accurately reflect their meaning. If too long, then the embeddings can lose meaning. 2. You want to have long enough documents that the context of each chunk is retained. The ParentDocumentRetriever strikes that balance by splitting and storing small chunks of data. During retrieval, it first fetches the small chunks but then looks up the parent ids for those chunks and returns those larger documents. Note that "parent document" refers to the document that a small chunk originated from. This can either be the whole raw document OR a larger chunk. Examples: .. code-block:: python # Imports from langchain.vectorstores import Chroma from langchain.embeddings import OpenAIEmbeddings from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.storage import InMemoryStore # This text splitter is used to create the parent documents parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000) # This text splitter is used to create the child documents # It should create documents smaller than the parent child_splitter = RecursiveCharacterTextSplitter(chunk_size=400) # The vectorstore to use to index the child chunks vectorstore = Chroma(embedding_function=OpenAIEmbeddings()) # The storage layer for the parent documents store = InMemoryStore() # Initialize the retriever retriever = AsyncParentDocumentRetriever( vectorstore=vectorstore, docstore=store, child_splitter=child_splitter, parent_splitter=parent_splitter, ) """ docstore: LocalFileStore | RedisStore | BaseStore[str, Document] search_type: str = "similarity" """Type of search to perform. Defaults to "similarity".""" allowed_search_types: ClassVar[Collection[str]] = ( "similarity", "similarity_score_threshold", "mmr", ) def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: """Get documents relevant to a query. Args: query: String to find relevant documents for run_manager: The callbacks handler to use Returns: List of relevant documents """ if self.search_type == "similarity": sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs) elif self.search_type == "similarity_score_threshold": docs_and_similarities = ( self.vectorstore.similarity_search_with_relevance_scores( query, **self.search_kwargs ) ) sub_docs = [doc for doc, _ in docs_and_similarities] elif self.search_type == "mmr": sub_docs = self.vectorstore.max_marginal_relevance_search( query, **self.search_kwargs ) else: raise ValueError(f"search_type of {self.search_type} not allowed.") # We do this to maintain the order of the ids that are returned ids: List[str] = [] for d in sub_docs: if d.metadata[self.id_key] not in ids: ids.append(d.metadata[self.id_key]) if isinstance(self.docstore, (RedisStore, LocalFileStore)): serialized_docs = self.docstore.mget(ids) docs: List[Document] = [ pickle.loads(doc) for doc in serialized_docs if doc is not None ] else: docs: List[Document] = [ doc for doc in self.docstore.mget(ids) if doc is not None ] return docs async def _aget_relevant_documents( self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: """Get documents relevant to a query. Args: query: String to find relevant documents for run_manager: The callbacks handler to use Returns: List of relevant documents """ if self.search_type == "similarity": sub_docs = await self.vectorstore.asimilarity_search( query, **self.search_kwargs ) elif self.search_type == "similarity_score_threshold": docs_and_similarities = ( await self.vectorstore.asimilarity_search_with_relevance_scores( query, **self.search_kwargs ) ) sub_docs = [doc for doc, _ in docs_and_similarities] elif self.search_type == "mmr": sub_docs = await self.vectorstore.amax_marginal_relevance_search( query, **self.search_kwargs ) else: raise ValueError(f"search_type of {self.search_type} not allowed.") # We do this to maintain the order of the ids that are returned ids: List[str] = [] for d in sub_docs: if d.metadata[self.id_key] not in ids: ids.append(d.metadata[self.id_key]) if isinstance(self.docstore, (RedisStore, LocalFileStore)): # deserialize documents from bytes serialized_docs = self.docstore.mget(ids) docs: List[Document] = [ pickle.loads(doc) for doc in serialized_docs if doc is not None ] else: docs: List[Document] = [ doc for doc in self.docstore.mget(ids) if doc is not None ] return docs def add_documents( self, documents: List[Document], ids: Optional[List[str]] = None, add_to_docstore: bool = True, **kwargs: Any ) -> None: """Adds documents to the docstore and vectorstores. Args: documents: List of documents to add ids: Optional list of ids for documents. If provided should be the same length as the list of documents. Can provided if parent documents are already in the document store and you don't want to re-add to the docstore. If not provided, random UUIDs will be used as ids. add_to_docstore: Boolean of whether to add documents to docstore. This can be false if and only if `ids` are provided. You may want to set this to False if the documents are already in the docstore and you don't want to re-add them. """ if self.parent_splitter is not None: documents = self.parent_splitter.split_documents(documents) if ids is None: doc_ids = [str(uuid.uuid4()) for _ in documents] if not add_to_docstore: raise ValueError( "If ids are not passed in, `add_to_docstore` MUST be True" ) else: if len(documents) != len(ids): raise ValueError( "Got uneven list of documents and ids. " "If `ids` is provided, should be same length as `documents`." ) doc_ids = ids docs: List[Document] = [] full_docs: List[Tuple[str, Document]] = [] for i, doc in enumerate(documents): _id = doc_ids[i] sub_docs = self.child_splitter.split_documents([doc]) for _doc in sub_docs: _doc.metadata[self.id_key] = _id docs.extend(sub_docs) full_docs.append((_id, doc)) self.vectorstore.add_documents(docs) if add_to_docstore: if isinstance(self.docstore, (RedisStore, LocalFileStore)): # serialize documents to bytes serialized_docs = [(id, pickle.dumps(doc)) for id, doc in full_docs] self.docstore.mset(serialized_docs) else: self.docstore.mset(full_docs) async def aadd_documents( self, documents: List[Document], ids: Optional[List[str]] = None, add_to_docstore: bool = True, **kwargs: Any ) -> None: """Adds documents to the docstore and vectorstores. Args: documents: List of documents to add ids: Optional list of ids for documents. If provided should be the same length as the list of documents. Can provided if parent documents are already in the document store and you don't want to re-add to the docstore. If not provided, random UUIDs will be used as ids. add_to_docstore: Boolean of whether to add documents to docstore. This can be false if and only if `ids` are provided. You may want to set this to False if the documents are already in the docstore and you don't want to re-add them. """ if self.parent_splitter is not None: documents = self.parent_splitter.split_documents(documents) if ids is None: doc_ids = [str(uuid.uuid4()) for _ in documents] if not add_to_docstore: raise ValueError( "If ids are not passed in, `add_to_docstore` MUST be True" ) else: if len(documents) != len(ids): raise ValueError( "Got uneven list of documents and ids. " "If `ids` is provided, should be same length as `documents`." ) doc_ids = ids docs: List[Document] = [] full_docs: List[Tuple[str, Document]] = [] if len(documents) < 1: return for i, doc in enumerate(documents): _id = doc_ids[i] sub_docs = self.child_splitter.split_documents([doc]) for _doc in sub_docs: _doc.metadata[self.id_key] = _id docs.extend(sub_docs) full_docs.append((_id, doc)) # check if vectorstore supports async adds if ( type(self.vectorstore).aadd_documents != VectorStore.aadd_documents and type(self.vectorstore).aadd_texts != VectorStore.aadd_texts ): await self.vectorstore.aadd_documents(docs) else: self.vectorstore.add_documents(docs) if add_to_docstore: if isinstance(self.docstore, (RedisStore, LocalFileStore)): # serialize documents to bytes serialized_docs = [(id, pickle.dumps(doc)) for id, doc in full_docs] self.docstore.mset(serialized_docs) else: self.docstore.mset(full_docs)