diff --git a/swarms/server/async_parent_document_retriever.py b/swarms/server/async_parent_document_retriever.py new file mode 100644 index 00000000..f6c24bf9 --- /dev/null +++ b/swarms/server/async_parent_document_retriever.py @@ -0,0 +1,274 @@ +import pickle +import uuid +from typing import ClassVar, Collection, List, Optional, Tuple + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForRetrieverRun, + CallbackManagerForRetrieverRun, +) +from langchain.retrievers import ParentDocumentRetriever +from langchain.schema.document import Document +from langchain.schema.storage import BaseStore +from langchain.storage import LocalFileStore, RedisStore +from langchain.vectorstores.base import VectorStore + + +class AsyncParentDocumentRetriever(ParentDocumentRetriever): + """Retrieve small chunks then retrieve their parent documents. + + When splitting documents for retrieval, there are often conflicting desires: + + 1. You may want to have small documents, so that their embeddings can most + accurately reflect their meaning. If too long, then the embeddings can + lose meaning. + 2. You want to have long enough documents that the context of each chunk is + retained. + + The ParentDocumentRetriever strikes that balance by splitting and storing + small chunks of data. During retrieval, it first fetches the small chunks + but then looks up the parent ids for those chunks and returns those larger + documents. + + Note that "parent document" refers to the document that a small chunk + originated from. This can either be the whole raw document OR a larger + chunk. + + Examples: + + .. code-block:: python + + # Imports + from langchain.vectorstores import Chroma + from langchain.embeddings import OpenAIEmbeddings + from langchain.text_splitter import RecursiveCharacterTextSplitter + from langchain.storage import InMemoryStore + + # This text splitter is used to create the parent documents + parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000) + # This text splitter is used to create the child documents + # It should create documents smaller than the parent + child_splitter = RecursiveCharacterTextSplitter(chunk_size=400) + # The vectorstore to use to index the child chunks + vectorstore = Chroma(embedding_function=OpenAIEmbeddings()) + # The storage layer for the parent documents + store = InMemoryStore() + + # Initialize the retriever + retriever = AsyncParentDocumentRetriever( + vectorstore=vectorstore, + docstore=store, + child_splitter=child_splitter, + parent_splitter=parent_splitter, + ) + """ + + docstore: LocalFileStore | RedisStore | BaseStore[str, Document] + search_type: str = "similarity" + """Type of search to perform. Defaults to "similarity".""" + allowed_search_types: ClassVar[Collection[str]] = ( + "similarity", + "similarity_score_threshold", + "mmr", + ) + + def _get_relevant_documents( + self, query: str, *, run_manager: CallbackManagerForRetrieverRun + ) -> List[Document]: + """Get documents relevant to a query. + Args: + query: String to find relevant documents for + run_manager: The callbacks handler to use + Returns: + List of relevant documents + """ + if self.search_type == "similarity": + sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs) + elif self.search_type == "similarity_score_threshold": + docs_and_similarities = ( + self.vectorstore.similarity_search_with_relevance_scores( + query, **self.search_kwargs + ) + ) + sub_docs = [doc for doc, _ in docs_and_similarities] + elif self.search_type == "mmr": + sub_docs = self.vectorstore.max_marginal_relevance_search( + query, **self.search_kwargs + ) + else: + raise ValueError(f"search_type of {self.search_type} not allowed.") + # We do this to maintain the order of the ids that are returned + ids: List[str] = [] + for d in sub_docs: + if d.metadata[self.id_key] not in ids: + ids.append(d.metadata[self.id_key]) + if isinstance(self.docstore, (RedisStore, LocalFileStore)): + serialized_docs = self.docstore.mget(ids) + docs: List[Document] = [ + pickle.loads(doc) for doc in serialized_docs if doc is not None + ] + else: + docs: List[Document] = [ + doc for doc in self.docstore.mget(ids) if doc is not None + ] + return docs + + async def _aget_relevant_documents( + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun + ) -> List[Document]: + """Get documents relevant to a query. + Args: + query: String to find relevant documents for + run_manager: The callbacks handler to use + Returns: + List of relevant documents + """ + if self.search_type == "similarity": + sub_docs = await self.vectorstore.asimilarity_search( + query, **self.search_kwargs + ) + elif self.search_type == "similarity_score_threshold": + docs_and_similarities = ( + await self.vectorstore.asimilarity_search_with_relevance_scores( + query, **self.search_kwargs + ) + ) + sub_docs = [doc for doc, _ in docs_and_similarities] + elif self.search_type == "mmr": + sub_docs = await self.vectorstore.amax_marginal_relevance_search( + query, **self.search_kwargs + ) + else: + raise ValueError(f"search_type of {self.search_type} not allowed.") + # We do this to maintain the order of the ids that are returned + ids: List[str] = [] + for d in sub_docs: + if d.metadata[self.id_key] not in ids: + ids.append(d.metadata[self.id_key]) + if isinstance(self.docstore, (RedisStore, LocalFileStore)): + # deserialize documents from bytes + serialized_docs = self.docstore.mget(ids) + docs: List[Document] = [ + pickle.loads(doc) for doc in serialized_docs if doc is not None + ] + else: + docs: List[Document] = [ + doc for doc in self.docstore.mget(ids) if doc is not None + ] + return docs + + def add_documents( + self, + documents: List[Document], + ids: Optional[List[str]] = None, + add_to_docstore: bool = True, + ) -> None: + """Adds documents to the docstore and vectorstores. + + Args: + documents: List of documents to add + ids: Optional list of ids for documents. If provided should be the same + length as the list of documents. Can provided if parent documents + are already in the document store and you don't want to re-add + to the docstore. If not provided, random UUIDs will be used as + ids. + add_to_docstore: Boolean of whether to add documents to docstore. + This can be false if and only if `ids` are provided. You may want + to set this to False if the documents are already in the docstore + and you don't want to re-add them. + """ + if self.parent_splitter is not None: + documents = self.parent_splitter.split_documents(documents) + if ids is None: + doc_ids = [str(uuid.uuid4()) for _ in documents] + if not add_to_docstore: + raise ValueError( + "If ids are not passed in, `add_to_docstore` MUST be True" + ) + else: + if len(documents) != len(ids): + raise ValueError( + "Got uneven list of documents and ids. " + "If `ids` is provided, should be same length as `documents`." + ) + doc_ids = ids + + docs: List[Document] = [] + full_docs: List[Tuple[str, Document]] = [] + for i, doc in enumerate(documents): + _id = doc_ids[i] + sub_docs = self.child_splitter.split_documents([doc]) + for _doc in sub_docs: + _doc.metadata[self.id_key] = _id + docs.extend(sub_docs) + full_docs.append((_id, doc)) + self.vectorstore.add_documents(docs) + if add_to_docstore: + if isinstance(self.docstore, (RedisStore, LocalFileStore)): + # serialize documents to bytes + serialized_docs = [(id, pickle.dumps(doc)) for id, doc in full_docs] + self.docstore.mset(serialized_docs) + else: + self.docstore.mset(full_docs) + + async def aadd_documents( + self, + documents: List[Document], + ids: Optional[List[str]] = None, + add_to_docstore: bool = True, + ) -> None: + """Adds documents to the docstore and vectorstores. + + Args: + documents: List of documents to add + ids: Optional list of ids for documents. If provided should be the same + length as the list of documents. Can provided if parent documents + are already in the document store and you don't want to re-add + to the docstore. If not provided, random UUIDs will be used as + ids. + add_to_docstore: Boolean of whether to add documents to docstore. + This can be false if and only if `ids` are provided. You may want + to set this to False if the documents are already in the docstore + and you don't want to re-add them. + """ + if self.parent_splitter is not None: + documents = self.parent_splitter.split_documents(documents) + if ids is None: + doc_ids = [str(uuid.uuid4()) for _ in documents] + if not add_to_docstore: + raise ValueError( + "If ids are not passed in, `add_to_docstore` MUST be True" + ) + else: + if len(documents) != len(ids): + raise ValueError( + "Got uneven list of documents and ids. " + "If `ids` is provided, should be same length as `documents`." + ) + doc_ids = ids + + docs: List[Document] = [] + full_docs: List[Tuple[str, Document]] = [] + for i, doc in enumerate(documents): + _id = doc_ids[i] + sub_docs = self.child_splitter.split_documents([doc]) + for _doc in sub_docs: + _doc.metadata[self.id_key] = _id + docs.extend(sub_docs) + full_docs.append((_id, doc)) + + # check if vectorstore supports async adds + if ( + type(self.vectorstore).aadd_documents != VectorStore.aadd_documents + and type(self.vectorstore).aadd_texts != VectorStore.aadd_texts + ): + await self.vectorstore.aadd_documents(docs) + else: + self.vectorstore.add_documents(docs) + + if add_to_docstore: + if isinstance(self.docstore, (RedisStore, LocalFileStore)): + # serialize documents to bytes + serialized_docs = [(id, pickle.dumps(doc)) for id, doc in full_docs] + self.docstore.mset(serialized_docs) + else: + self.docstore.mset(full_docs) \ No newline at end of file