You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
284 lines
12 KiB
284 lines
12 KiB
""" AsyncParentDocumentRetriever is used by RAG
|
|
to split up documents into smaller *and* larger related chunks. """
|
|
import pickle
|
|
import uuid
|
|
from typing import Any, ClassVar, Collection, List, Optional, Tuple
|
|
|
|
from langchain.callbacks.manager import (
|
|
AsyncCallbackManagerForRetrieverRun,
|
|
CallbackManagerForRetrieverRun,
|
|
)
|
|
from langchain.retrievers import ParentDocumentRetriever
|
|
from langchain.schema.document import Document
|
|
from langchain.schema.storage import BaseStore
|
|
from langchain.storage import LocalFileStore
|
|
from langchain_community.storage import RedisStore
|
|
from langchain.vectorstores.base import VectorStore
|
|
|
|
|
|
class AsyncParentDocumentRetriever(ParentDocumentRetriever):
|
|
"""Retrieve small chunks then retrieve their parent documents.
|
|
|
|
When splitting documents for retrieval, there are often conflicting desires:
|
|
|
|
1. You may want to have small documents, so that their embeddings can most
|
|
accurately reflect their meaning. If too long, then the embeddings can
|
|
lose meaning.
|
|
2. You want to have long enough documents that the context of each chunk is
|
|
retained.
|
|
|
|
The ParentDocumentRetriever strikes that balance by splitting and storing
|
|
small chunks of data. During retrieval, it first fetches the small chunks
|
|
but then looks up the parent ids for those chunks and returns those larger
|
|
documents.
|
|
|
|
Note that "parent document" refers to the document that a small chunk
|
|
originated from. This can either be the whole raw document OR a larger
|
|
chunk.
|
|
|
|
Examples:
|
|
|
|
.. code-block:: python
|
|
|
|
# Imports
|
|
from langchain.vectorstores import Chroma
|
|
from langchain.embeddings import OpenAIEmbeddings
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
from langchain.storage import InMemoryStore
|
|
|
|
# This text splitter is used to create the parent documents
|
|
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000)
|
|
# This text splitter is used to create the child documents
|
|
# It should create documents smaller than the parent
|
|
child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
|
|
# The vectorstore to use to index the child chunks
|
|
vectorstore = Chroma(embedding_function=OpenAIEmbeddings())
|
|
# The storage layer for the parent documents
|
|
store = InMemoryStore()
|
|
|
|
# Initialize the retriever
|
|
retriever = AsyncParentDocumentRetriever(
|
|
vectorstore=vectorstore,
|
|
docstore=store,
|
|
child_splitter=child_splitter,
|
|
parent_splitter=parent_splitter,
|
|
)
|
|
"""
|
|
|
|
docstore: LocalFileStore | RedisStore | BaseStore[str, Document]
|
|
search_type: str = "similarity"
|
|
"""Type of search to perform. Defaults to "similarity"."""
|
|
allowed_search_types: ClassVar[Collection[str]] = (
|
|
"similarity",
|
|
"similarity_score_threshold",
|
|
"mmr",
|
|
)
|
|
|
|
def _get_relevant_documents(
|
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
|
) -> List[Document]:
|
|
"""Get documents relevant to a query.
|
|
Args:
|
|
query: String to find relevant documents for
|
|
run_manager: The callbacks handler to use
|
|
Returns:
|
|
List of relevant documents
|
|
"""
|
|
if self.search_type == "similarity":
|
|
sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
|
elif self.search_type == "similarity_score_threshold":
|
|
docs_and_similarities = (
|
|
self.vectorstore.similarity_search_with_relevance_scores(
|
|
query, **self.search_kwargs
|
|
)
|
|
)
|
|
sub_docs = [doc for doc, _ in docs_and_similarities]
|
|
elif self.search_type == "mmr":
|
|
sub_docs = self.vectorstore.max_marginal_relevance_search(
|
|
query, **self.search_kwargs
|
|
)
|
|
else:
|
|
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
|
# We do this to maintain the order of the ids that are returned
|
|
ids: List[str] = []
|
|
for d in sub_docs:
|
|
if d.metadata[self.id_key] not in ids:
|
|
ids.append(d.metadata[self.id_key])
|
|
if isinstance(self.docstore, (RedisStore, LocalFileStore)):
|
|
serialized_docs = self.docstore.mget(ids)
|
|
docs: List[Document] = [
|
|
pickle.loads(doc) for doc in serialized_docs if doc is not None
|
|
]
|
|
else:
|
|
docs: List[Document] = [
|
|
doc for doc in self.docstore.mget(ids) if doc is not None
|
|
]
|
|
return docs
|
|
|
|
async def _aget_relevant_documents(
|
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
|
) -> List[Document]:
|
|
"""Get documents relevant to a query.
|
|
Args:
|
|
query: String to find relevant documents for
|
|
run_manager: The callbacks handler to use
|
|
Returns:
|
|
List of relevant documents
|
|
"""
|
|
if self.search_type == "similarity":
|
|
sub_docs = await self.vectorstore.asimilarity_search(
|
|
query, **self.search_kwargs
|
|
)
|
|
elif self.search_type == "similarity_score_threshold":
|
|
docs_and_similarities = (
|
|
await self.vectorstore.asimilarity_search_with_relevance_scores(
|
|
query, **self.search_kwargs
|
|
)
|
|
)
|
|
sub_docs = [doc for doc, _ in docs_and_similarities]
|
|
elif self.search_type == "mmr":
|
|
sub_docs = await self.vectorstore.amax_marginal_relevance_search(
|
|
query, **self.search_kwargs
|
|
)
|
|
else:
|
|
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
|
# We do this to maintain the order of the ids that are returned
|
|
ids: List[str] = []
|
|
for d in sub_docs:
|
|
if d.metadata[self.id_key] not in ids:
|
|
ids.append(d.metadata[self.id_key])
|
|
if isinstance(self.docstore, (RedisStore, LocalFileStore)):
|
|
# deserialize documents from bytes
|
|
serialized_docs = self.docstore.mget(ids)
|
|
docs: List[Document] = [
|
|
pickle.loads(doc) for doc in serialized_docs if doc is not None
|
|
]
|
|
else:
|
|
docs: List[Document] = [
|
|
doc for doc in self.docstore.mget(ids) if doc is not None
|
|
]
|
|
return docs
|
|
|
|
def add_documents(
|
|
self,
|
|
documents: List[Document],
|
|
ids: Optional[List[str]] = None,
|
|
add_to_docstore: bool = True,
|
|
**kwargs: Any
|
|
) -> None:
|
|
"""Adds documents to the docstore and vectorstores.
|
|
|
|
Args:
|
|
documents: List of documents to add
|
|
ids: Optional list of ids for documents. If provided should be the same
|
|
length as the list of documents. Can provided if parent documents
|
|
are already in the document store and you don't want to re-add
|
|
to the docstore. If not provided, random UUIDs will be used as
|
|
ids.
|
|
add_to_docstore: Boolean of whether to add documents to docstore.
|
|
This can be false if and only if `ids` are provided. You may want
|
|
to set this to False if the documents are already in the docstore
|
|
and you don't want to re-add them.
|
|
"""
|
|
if self.parent_splitter is not None:
|
|
documents = self.parent_splitter.split_documents(documents)
|
|
if ids is None:
|
|
doc_ids = [str(uuid.uuid4()) for _ in documents]
|
|
if not add_to_docstore:
|
|
raise ValueError(
|
|
"If ids are not passed in, `add_to_docstore` MUST be True"
|
|
)
|
|
else:
|
|
if len(documents) != len(ids):
|
|
raise ValueError(
|
|
"Got uneven list of documents and ids. "
|
|
"If `ids` is provided, should be same length as `documents`."
|
|
)
|
|
doc_ids = ids
|
|
|
|
docs: List[Document] = []
|
|
full_docs: List[Tuple[str, Document]] = []
|
|
for i, doc in enumerate(documents):
|
|
_id = doc_ids[i]
|
|
sub_docs = self.child_splitter.split_documents([doc])
|
|
for _doc in sub_docs:
|
|
_doc.metadata[self.id_key] = _id
|
|
docs.extend(sub_docs)
|
|
full_docs.append((_id, doc))
|
|
self.vectorstore.add_documents(docs)
|
|
if add_to_docstore:
|
|
if isinstance(self.docstore, (RedisStore, LocalFileStore)):
|
|
# serialize documents to bytes
|
|
serialized_docs = [(id, pickle.dumps(doc)) for id, doc in full_docs]
|
|
self.docstore.mset(serialized_docs)
|
|
else:
|
|
self.docstore.mset(full_docs)
|
|
|
|
async def aadd_documents(
|
|
self,
|
|
documents: List[Document],
|
|
ids: Optional[List[str]] = None,
|
|
add_to_docstore: bool = True,
|
|
**kwargs: Any
|
|
) -> None:
|
|
"""Adds documents to the docstore and vectorstores.
|
|
|
|
Args:
|
|
documents: List of documents to add
|
|
ids: Optional list of ids for documents. If provided should be the same
|
|
length as the list of documents. Can provided if parent documents
|
|
are already in the document store and you don't want to re-add
|
|
to the docstore. If not provided, random UUIDs will be used as
|
|
ids.
|
|
add_to_docstore: Boolean of whether to add documents to docstore.
|
|
This can be false if and only if `ids` are provided. You may want
|
|
to set this to False if the documents are already in the docstore
|
|
and you don't want to re-add them.
|
|
"""
|
|
if self.parent_splitter is not None:
|
|
documents = self.parent_splitter.split_documents(documents)
|
|
if ids is None:
|
|
doc_ids = [str(uuid.uuid4()) for _ in documents]
|
|
if not add_to_docstore:
|
|
raise ValueError(
|
|
"If ids are not passed in, `add_to_docstore` MUST be True"
|
|
)
|
|
else:
|
|
if len(documents) != len(ids):
|
|
raise ValueError(
|
|
"Got uneven list of documents and ids. "
|
|
"If `ids` is provided, should be same length as `documents`."
|
|
)
|
|
doc_ids = ids
|
|
|
|
docs: List[Document] = []
|
|
full_docs: List[Tuple[str, Document]] = []
|
|
|
|
if len(documents) < 1:
|
|
return
|
|
|
|
for i, doc in enumerate(documents):
|
|
_id = doc_ids[i]
|
|
sub_docs = self.child_splitter.split_documents([doc])
|
|
for _doc in sub_docs:
|
|
_doc.metadata[self.id_key] = _id
|
|
docs.extend(sub_docs)
|
|
full_docs.append((_id, doc))
|
|
|
|
# check if vectorstore supports async adds
|
|
if (
|
|
type(self.vectorstore).aadd_documents != VectorStore.aadd_documents
|
|
and type(self.vectorstore).aadd_texts != VectorStore.aadd_texts
|
|
):
|
|
await self.vectorstore.aadd_documents(docs)
|
|
else:
|
|
self.vectorstore.add_documents(docs)
|
|
|
|
if add_to_docstore:
|
|
if isinstance(self.docstore, (RedisStore, LocalFileStore)):
|
|
# serialize documents to bytes
|
|
serialized_docs = [(id, pickle.dumps(doc)) for id, doc in full_docs]
|
|
self.docstore.mset(serialized_docs)
|
|
else:
|
|
self.docstore.mset(full_docs)
|