parent
48322af3a2
commit
c86e62400a
@ -0,0 +1,274 @@
|
||||
import pickle
|
||||
import uuid
|
||||
from typing import ClassVar, Collection, List, Optional, Tuple
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.retrievers import ParentDocumentRetriever
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.storage import BaseStore
|
||||
from langchain.storage import LocalFileStore, RedisStore
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
class AsyncParentDocumentRetriever(ParentDocumentRetriever):
|
||||
"""Retrieve small chunks then retrieve their parent documents.
|
||||
|
||||
When splitting documents for retrieval, there are often conflicting desires:
|
||||
|
||||
1. You may want to have small documents, so that their embeddings can most
|
||||
accurately reflect their meaning. If too long, then the embeddings can
|
||||
lose meaning.
|
||||
2. You want to have long enough documents that the context of each chunk is
|
||||
retained.
|
||||
|
||||
The ParentDocumentRetriever strikes that balance by splitting and storing
|
||||
small chunks of data. During retrieval, it first fetches the small chunks
|
||||
but then looks up the parent ids for those chunks and returns those larger
|
||||
documents.
|
||||
|
||||
Note that "parent document" refers to the document that a small chunk
|
||||
originated from. This can either be the whole raw document OR a larger
|
||||
chunk.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Imports
|
||||
from langchain.vectorstores import Chroma
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.storage import InMemoryStore
|
||||
|
||||
# This text splitter is used to create the parent documents
|
||||
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000)
|
||||
# This text splitter is used to create the child documents
|
||||
# It should create documents smaller than the parent
|
||||
child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
|
||||
# The vectorstore to use to index the child chunks
|
||||
vectorstore = Chroma(embedding_function=OpenAIEmbeddings())
|
||||
# The storage layer for the parent documents
|
||||
store = InMemoryStore()
|
||||
|
||||
# Initialize the retriever
|
||||
retriever = AsyncParentDocumentRetriever(
|
||||
vectorstore=vectorstore,
|
||||
docstore=store,
|
||||
child_splitter=child_splitter,
|
||||
parent_splitter=parent_splitter,
|
||||
)
|
||||
"""
|
||||
|
||||
docstore: LocalFileStore | RedisStore | BaseStore[str, Document]
|
||||
search_type: str = "similarity"
|
||||
"""Type of search to perform. Defaults to "similarity"."""
|
||||
allowed_search_types: ClassVar[Collection[str]] = (
|
||||
"similarity",
|
||||
"similarity_score_threshold",
|
||||
"mmr",
|
||||
)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant to a query.
|
||||
Args:
|
||||
query: String to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
if self.search_type == "similarity":
|
||||
sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
||||
elif self.search_type == "similarity_score_threshold":
|
||||
docs_and_similarities = (
|
||||
self.vectorstore.similarity_search_with_relevance_scores(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
)
|
||||
sub_docs = [doc for doc, _ in docs_and_similarities]
|
||||
elif self.search_type == "mmr":
|
||||
sub_docs = self.vectorstore.max_marginal_relevance_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||
# We do this to maintain the order of the ids that are returned
|
||||
ids: List[str] = []
|
||||
for d in sub_docs:
|
||||
if d.metadata[self.id_key] not in ids:
|
||||
ids.append(d.metadata[self.id_key])
|
||||
if isinstance(self.docstore, (RedisStore, LocalFileStore)):
|
||||
serialized_docs = self.docstore.mget(ids)
|
||||
docs: List[Document] = [
|
||||
pickle.loads(doc) for doc in serialized_docs if doc is not None
|
||||
]
|
||||
else:
|
||||
docs: List[Document] = [
|
||||
doc for doc in self.docstore.mget(ids) if doc is not None
|
||||
]
|
||||
return docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant to a query.
|
||||
Args:
|
||||
query: String to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
if self.search_type == "similarity":
|
||||
sub_docs = await self.vectorstore.asimilarity_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
elif self.search_type == "similarity_score_threshold":
|
||||
docs_and_similarities = (
|
||||
await self.vectorstore.asimilarity_search_with_relevance_scores(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
)
|
||||
sub_docs = [doc for doc, _ in docs_and_similarities]
|
||||
elif self.search_type == "mmr":
|
||||
sub_docs = await self.vectorstore.amax_marginal_relevance_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||
# We do this to maintain the order of the ids that are returned
|
||||
ids: List[str] = []
|
||||
for d in sub_docs:
|
||||
if d.metadata[self.id_key] not in ids:
|
||||
ids.append(d.metadata[self.id_key])
|
||||
if isinstance(self.docstore, (RedisStore, LocalFileStore)):
|
||||
# deserialize documents from bytes
|
||||
serialized_docs = self.docstore.mget(ids)
|
||||
docs: List[Document] = [
|
||||
pickle.loads(doc) for doc in serialized_docs if doc is not None
|
||||
]
|
||||
else:
|
||||
docs: List[Document] = [
|
||||
doc for doc in self.docstore.mget(ids) if doc is not None
|
||||
]
|
||||
return docs
|
||||
|
||||
def add_documents(
|
||||
self,
|
||||
documents: List[Document],
|
||||
ids: Optional[List[str]] = None,
|
||||
add_to_docstore: bool = True,
|
||||
) -> None:
|
||||
"""Adds documents to the docstore and vectorstores.
|
||||
|
||||
Args:
|
||||
documents: List of documents to add
|
||||
ids: Optional list of ids for documents. If provided should be the same
|
||||
length as the list of documents. Can provided if parent documents
|
||||
are already in the document store and you don't want to re-add
|
||||
to the docstore. If not provided, random UUIDs will be used as
|
||||
ids.
|
||||
add_to_docstore: Boolean of whether to add documents to docstore.
|
||||
This can be false if and only if `ids` are provided. You may want
|
||||
to set this to False if the documents are already in the docstore
|
||||
and you don't want to re-add them.
|
||||
"""
|
||||
if self.parent_splitter is not None:
|
||||
documents = self.parent_splitter.split_documents(documents)
|
||||
if ids is None:
|
||||
doc_ids = [str(uuid.uuid4()) for _ in documents]
|
||||
if not add_to_docstore:
|
||||
raise ValueError(
|
||||
"If ids are not passed in, `add_to_docstore` MUST be True"
|
||||
)
|
||||
else:
|
||||
if len(documents) != len(ids):
|
||||
raise ValueError(
|
||||
"Got uneven list of documents and ids. "
|
||||
"If `ids` is provided, should be same length as `documents`."
|
||||
)
|
||||
doc_ids = ids
|
||||
|
||||
docs: List[Document] = []
|
||||
full_docs: List[Tuple[str, Document]] = []
|
||||
for i, doc in enumerate(documents):
|
||||
_id = doc_ids[i]
|
||||
sub_docs = self.child_splitter.split_documents([doc])
|
||||
for _doc in sub_docs:
|
||||
_doc.metadata[self.id_key] = _id
|
||||
docs.extend(sub_docs)
|
||||
full_docs.append((_id, doc))
|
||||
self.vectorstore.add_documents(docs)
|
||||
if add_to_docstore:
|
||||
if isinstance(self.docstore, (RedisStore, LocalFileStore)):
|
||||
# serialize documents to bytes
|
||||
serialized_docs = [(id, pickle.dumps(doc)) for id, doc in full_docs]
|
||||
self.docstore.mset(serialized_docs)
|
||||
else:
|
||||
self.docstore.mset(full_docs)
|
||||
|
||||
async def aadd_documents(
|
||||
self,
|
||||
documents: List[Document],
|
||||
ids: Optional[List[str]] = None,
|
||||
add_to_docstore: bool = True,
|
||||
) -> None:
|
||||
"""Adds documents to the docstore and vectorstores.
|
||||
|
||||
Args:
|
||||
documents: List of documents to add
|
||||
ids: Optional list of ids for documents. If provided should be the same
|
||||
length as the list of documents. Can provided if parent documents
|
||||
are already in the document store and you don't want to re-add
|
||||
to the docstore. If not provided, random UUIDs will be used as
|
||||
ids.
|
||||
add_to_docstore: Boolean of whether to add documents to docstore.
|
||||
This can be false if and only if `ids` are provided. You may want
|
||||
to set this to False if the documents are already in the docstore
|
||||
and you don't want to re-add them.
|
||||
"""
|
||||
if self.parent_splitter is not None:
|
||||
documents = self.parent_splitter.split_documents(documents)
|
||||
if ids is None:
|
||||
doc_ids = [str(uuid.uuid4()) for _ in documents]
|
||||
if not add_to_docstore:
|
||||
raise ValueError(
|
||||
"If ids are not passed in, `add_to_docstore` MUST be True"
|
||||
)
|
||||
else:
|
||||
if len(documents) != len(ids):
|
||||
raise ValueError(
|
||||
"Got uneven list of documents and ids. "
|
||||
"If `ids` is provided, should be same length as `documents`."
|
||||
)
|
||||
doc_ids = ids
|
||||
|
||||
docs: List[Document] = []
|
||||
full_docs: List[Tuple[str, Document]] = []
|
||||
for i, doc in enumerate(documents):
|
||||
_id = doc_ids[i]
|
||||
sub_docs = self.child_splitter.split_documents([doc])
|
||||
for _doc in sub_docs:
|
||||
_doc.metadata[self.id_key] = _id
|
||||
docs.extend(sub_docs)
|
||||
full_docs.append((_id, doc))
|
||||
|
||||
# check if vectorstore supports async adds
|
||||
if (
|
||||
type(self.vectorstore).aadd_documents != VectorStore.aadd_documents
|
||||
and type(self.vectorstore).aadd_texts != VectorStore.aadd_texts
|
||||
):
|
||||
await self.vectorstore.aadd_documents(docs)
|
||||
else:
|
||||
self.vectorstore.add_documents(docs)
|
||||
|
||||
if add_to_docstore:
|
||||
if isinstance(self.docstore, (RedisStore, LocalFileStore)):
|
||||
# serialize documents to bytes
|
||||
serialized_docs = [(id, pickle.dumps(doc)) for id, doc in full_docs]
|
||||
self.docstore.mset(serialized_docs)
|
||||
else:
|
||||
self.docstore.mset(full_docs)
|
Loading…
Reference in new issue