You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/swarms/server/async_parent_document_retri...

284 lines
12 KiB

""" AsyncParentDocumentRetriever is used by RAG
to split up documents into smaller *and* larger related chunks. """
import pickle
import uuid
from typing import Any, ClassVar, Collection, List, Optional, Tuple
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.retrievers import ParentDocumentRetriever
from langchain.schema.document import Document
from langchain.schema.storage import BaseStore
from langchain.storage import LocalFileStore
from langchain_community.storage import RedisStore
from langchain.vectorstores.base import VectorStore
class AsyncParentDocumentRetriever(ParentDocumentRetriever):
"""Retrieve small chunks then retrieve their parent documents.
When splitting documents for retrieval, there are often conflicting desires:
1. You may want to have small documents, so that their embeddings can most
accurately reflect their meaning. If too long, then the embeddings can
lose meaning.
2. You want to have long enough documents that the context of each chunk is
retained.
The ParentDocumentRetriever strikes that balance by splitting and storing
small chunks of data. During retrieval, it first fetches the small chunks
but then looks up the parent ids for those chunks and returns those larger
documents.
Note that "parent document" refers to the document that a small chunk
originated from. This can either be the whole raw document OR a larger
chunk.
Examples:
.. code-block:: python
# Imports
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.storage import InMemoryStore
# This text splitter is used to create the parent documents
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000)
# This text splitter is used to create the child documents
# It should create documents smaller than the parent
child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
# The vectorstore to use to index the child chunks
vectorstore = Chroma(embedding_function=OpenAIEmbeddings())
# The storage layer for the parent documents
store = InMemoryStore()
# Initialize the retriever
retriever = AsyncParentDocumentRetriever(
vectorstore=vectorstore,
docstore=store,
child_splitter=child_splitter,
parent_splitter=parent_splitter,
)
"""
docstore: LocalFileStore | RedisStore | BaseStore[str, Document]
search_type: str = "similarity"
"""Type of search to perform. Defaults to "similarity"."""
allowed_search_types: ClassVar[Collection[str]] = (
"similarity",
"similarity_score_threshold",
"mmr",
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant to a query.
Args:
query: String to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
if self.search_type == "similarity":
sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
elif self.search_type == "similarity_score_threshold":
docs_and_similarities = (
self.vectorstore.similarity_search_with_relevance_scores(
query, **self.search_kwargs
)
)
sub_docs = [doc for doc, _ in docs_and_similarities]
elif self.search_type == "mmr":
sub_docs = self.vectorstore.max_marginal_relevance_search(
query, **self.search_kwargs
)
else:
raise ValueError(f"search_type of {self.search_type} not allowed.")
# We do this to maintain the order of the ids that are returned
ids: List[str] = []
for d in sub_docs:
if d.metadata[self.id_key] not in ids:
ids.append(d.metadata[self.id_key])
if isinstance(self.docstore, (RedisStore, LocalFileStore)):
serialized_docs = self.docstore.mget(ids)
docs: List[Document] = [
pickle.loads(doc) for doc in serialized_docs if doc is not None
]
else:
docs: List[Document] = [
doc for doc in self.docstore.mget(ids) if doc is not None
]
return docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant to a query.
Args:
query: String to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
if self.search_type == "similarity":
sub_docs = await self.vectorstore.asimilarity_search(
query, **self.search_kwargs
)
elif self.search_type == "similarity_score_threshold":
docs_and_similarities = (
await self.vectorstore.asimilarity_search_with_relevance_scores(
query, **self.search_kwargs
)
)
sub_docs = [doc for doc, _ in docs_and_similarities]
elif self.search_type == "mmr":
sub_docs = await self.vectorstore.amax_marginal_relevance_search(
query, **self.search_kwargs
)
else:
raise ValueError(f"search_type of {self.search_type} not allowed.")
# We do this to maintain the order of the ids that are returned
ids: List[str] = []
for d in sub_docs:
if d.metadata[self.id_key] not in ids:
ids.append(d.metadata[self.id_key])
if isinstance(self.docstore, (RedisStore, LocalFileStore)):
# deserialize documents from bytes
serialized_docs = self.docstore.mget(ids)
docs: List[Document] = [
pickle.loads(doc) for doc in serialized_docs if doc is not None
]
else:
docs: List[Document] = [
doc for doc in self.docstore.mget(ids) if doc is not None
]
return docs
def add_documents(
self,
documents: List[Document],
ids: Optional[List[str]] = None,
add_to_docstore: bool = True,
**kwargs: Any
) -> None:
"""Adds documents to the docstore and vectorstores.
Args:
documents: List of documents to add
ids: Optional list of ids for documents. If provided should be the same
length as the list of documents. Can provided if parent documents
are already in the document store and you don't want to re-add
to the docstore. If not provided, random UUIDs will be used as
ids.
add_to_docstore: Boolean of whether to add documents to docstore.
This can be false if and only if `ids` are provided. You may want
to set this to False if the documents are already in the docstore
and you don't want to re-add them.
"""
if self.parent_splitter is not None:
documents = self.parent_splitter.split_documents(documents)
if ids is None:
doc_ids = [str(uuid.uuid4()) for _ in documents]
if not add_to_docstore:
raise ValueError(
"If ids are not passed in, `add_to_docstore` MUST be True"
)
else:
if len(documents) != len(ids):
raise ValueError(
"Got uneven list of documents and ids. "
"If `ids` is provided, should be same length as `documents`."
)
doc_ids = ids
docs: List[Document] = []
full_docs: List[Tuple[str, Document]] = []
for i, doc in enumerate(documents):
_id = doc_ids[i]
sub_docs = self.child_splitter.split_documents([doc])
for _doc in sub_docs:
_doc.metadata[self.id_key] = _id
docs.extend(sub_docs)
full_docs.append((_id, doc))
self.vectorstore.add_documents(docs)
if add_to_docstore:
if isinstance(self.docstore, (RedisStore, LocalFileStore)):
# serialize documents to bytes
serialized_docs = [(id, pickle.dumps(doc)) for id, doc in full_docs]
self.docstore.mset(serialized_docs)
else:
self.docstore.mset(full_docs)
async def aadd_documents(
self,
documents: List[Document],
ids: Optional[List[str]] = None,
add_to_docstore: bool = True,
**kwargs: Any
) -> None:
"""Adds documents to the docstore and vectorstores.
Args:
documents: List of documents to add
ids: Optional list of ids for documents. If provided should be the same
length as the list of documents. Can provided if parent documents
are already in the document store and you don't want to re-add
to the docstore. If not provided, random UUIDs will be used as
ids.
add_to_docstore: Boolean of whether to add documents to docstore.
This can be false if and only if `ids` are provided. You may want
to set this to False if the documents are already in the docstore
and you don't want to re-add them.
"""
if self.parent_splitter is not None:
documents = self.parent_splitter.split_documents(documents)
if ids is None:
doc_ids = [str(uuid.uuid4()) for _ in documents]
if not add_to_docstore:
raise ValueError(
"If ids are not passed in, `add_to_docstore` MUST be True"
)
else:
if len(documents) != len(ids):
raise ValueError(
"Got uneven list of documents and ids. "
"If `ids` is provided, should be same length as `documents`."
)
doc_ids = ids
docs: List[Document] = []
full_docs: List[Tuple[str, Document]] = []
if len(documents) < 1:
return
for i, doc in enumerate(documents):
_id = doc_ids[i]
sub_docs = self.child_splitter.split_documents([doc])
for _doc in sub_docs:
_doc.metadata[self.id_key] = _id
docs.extend(sub_docs)
full_docs.append((_id, doc))
# check if vectorstore supports async adds
if (
type(self.vectorstore).aadd_documents != VectorStore.aadd_documents
and type(self.vectorstore).aadd_texts != VectorStore.aadd_texts
):
await self.vectorstore.aadd_documents(docs)
else:
self.vectorstore.add_documents(docs)
if add_to_docstore:
if isinstance(self.docstore, (RedisStore, LocalFileStore)):
# serialize documents to bytes
serialized_docs = [(id, pickle.dumps(doc)) for id, doc in full_docs]
self.docstore.mset(serialized_docs)
else:
self.docstore.mset(full_docs)