parent
48322af3a2
commit
c86e62400a
@ -0,0 +1,274 @@
|
|||||||
|
import pickle
|
||||||
|
import uuid
|
||||||
|
from typing import ClassVar, Collection, List, Optional, Tuple
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
|
CallbackManagerForRetrieverRun,
|
||||||
|
)
|
||||||
|
from langchain.retrievers import ParentDocumentRetriever
|
||||||
|
from langchain.schema.document import Document
|
||||||
|
from langchain.schema.storage import BaseStore
|
||||||
|
from langchain.storage import LocalFileStore, RedisStore
|
||||||
|
from langchain.vectorstores.base import VectorStore
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncParentDocumentRetriever(ParentDocumentRetriever):
|
||||||
|
"""Retrieve small chunks then retrieve their parent documents.
|
||||||
|
|
||||||
|
When splitting documents for retrieval, there are often conflicting desires:
|
||||||
|
|
||||||
|
1. You may want to have small documents, so that their embeddings can most
|
||||||
|
accurately reflect their meaning. If too long, then the embeddings can
|
||||||
|
lose meaning.
|
||||||
|
2. You want to have long enough documents that the context of each chunk is
|
||||||
|
retained.
|
||||||
|
|
||||||
|
The ParentDocumentRetriever strikes that balance by splitting and storing
|
||||||
|
small chunks of data. During retrieval, it first fetches the small chunks
|
||||||
|
but then looks up the parent ids for those chunks and returns those larger
|
||||||
|
documents.
|
||||||
|
|
||||||
|
Note that "parent document" refers to the document that a small chunk
|
||||||
|
originated from. This can either be the whole raw document OR a larger
|
||||||
|
chunk.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
# Imports
|
||||||
|
from langchain.vectorstores import Chroma
|
||||||
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
|
from langchain.storage import InMemoryStore
|
||||||
|
|
||||||
|
# This text splitter is used to create the parent documents
|
||||||
|
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000)
|
||||||
|
# This text splitter is used to create the child documents
|
||||||
|
# It should create documents smaller than the parent
|
||||||
|
child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
|
||||||
|
# The vectorstore to use to index the child chunks
|
||||||
|
vectorstore = Chroma(embedding_function=OpenAIEmbeddings())
|
||||||
|
# The storage layer for the parent documents
|
||||||
|
store = InMemoryStore()
|
||||||
|
|
||||||
|
# Initialize the retriever
|
||||||
|
retriever = AsyncParentDocumentRetriever(
|
||||||
|
vectorstore=vectorstore,
|
||||||
|
docstore=store,
|
||||||
|
child_splitter=child_splitter,
|
||||||
|
parent_splitter=parent_splitter,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
docstore: LocalFileStore | RedisStore | BaseStore[str, Document]
|
||||||
|
search_type: str = "similarity"
|
||||||
|
"""Type of search to perform. Defaults to "similarity"."""
|
||||||
|
allowed_search_types: ClassVar[Collection[str]] = (
|
||||||
|
"similarity",
|
||||||
|
"similarity_score_threshold",
|
||||||
|
"mmr",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_relevant_documents(
|
||||||
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Get documents relevant to a query.
|
||||||
|
Args:
|
||||||
|
query: String to find relevant documents for
|
||||||
|
run_manager: The callbacks handler to use
|
||||||
|
Returns:
|
||||||
|
List of relevant documents
|
||||||
|
"""
|
||||||
|
if self.search_type == "similarity":
|
||||||
|
sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
||||||
|
elif self.search_type == "similarity_score_threshold":
|
||||||
|
docs_and_similarities = (
|
||||||
|
self.vectorstore.similarity_search_with_relevance_scores(
|
||||||
|
query, **self.search_kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sub_docs = [doc for doc, _ in docs_and_similarities]
|
||||||
|
elif self.search_type == "mmr":
|
||||||
|
sub_docs = self.vectorstore.max_marginal_relevance_search(
|
||||||
|
query, **self.search_kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||||
|
# We do this to maintain the order of the ids that are returned
|
||||||
|
ids: List[str] = []
|
||||||
|
for d in sub_docs:
|
||||||
|
if d.metadata[self.id_key] not in ids:
|
||||||
|
ids.append(d.metadata[self.id_key])
|
||||||
|
if isinstance(self.docstore, (RedisStore, LocalFileStore)):
|
||||||
|
serialized_docs = self.docstore.mget(ids)
|
||||||
|
docs: List[Document] = [
|
||||||
|
pickle.loads(doc) for doc in serialized_docs if doc is not None
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
docs: List[Document] = [
|
||||||
|
doc for doc in self.docstore.mget(ids) if doc is not None
|
||||||
|
]
|
||||||
|
return docs
|
||||||
|
|
||||||
|
async def _aget_relevant_documents(
|
||||||
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Get documents relevant to a query.
|
||||||
|
Args:
|
||||||
|
query: String to find relevant documents for
|
||||||
|
run_manager: The callbacks handler to use
|
||||||
|
Returns:
|
||||||
|
List of relevant documents
|
||||||
|
"""
|
||||||
|
if self.search_type == "similarity":
|
||||||
|
sub_docs = await self.vectorstore.asimilarity_search(
|
||||||
|
query, **self.search_kwargs
|
||||||
|
)
|
||||||
|
elif self.search_type == "similarity_score_threshold":
|
||||||
|
docs_and_similarities = (
|
||||||
|
await self.vectorstore.asimilarity_search_with_relevance_scores(
|
||||||
|
query, **self.search_kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sub_docs = [doc for doc, _ in docs_and_similarities]
|
||||||
|
elif self.search_type == "mmr":
|
||||||
|
sub_docs = await self.vectorstore.amax_marginal_relevance_search(
|
||||||
|
query, **self.search_kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||||
|
# We do this to maintain the order of the ids that are returned
|
||||||
|
ids: List[str] = []
|
||||||
|
for d in sub_docs:
|
||||||
|
if d.metadata[self.id_key] not in ids:
|
||||||
|
ids.append(d.metadata[self.id_key])
|
||||||
|
if isinstance(self.docstore, (RedisStore, LocalFileStore)):
|
||||||
|
# deserialize documents from bytes
|
||||||
|
serialized_docs = self.docstore.mget(ids)
|
||||||
|
docs: List[Document] = [
|
||||||
|
pickle.loads(doc) for doc in serialized_docs if doc is not None
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
docs: List[Document] = [
|
||||||
|
doc for doc in self.docstore.mget(ids) if doc is not None
|
||||||
|
]
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def add_documents(
|
||||||
|
self,
|
||||||
|
documents: List[Document],
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
add_to_docstore: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Adds documents to the docstore and vectorstores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
documents: List of documents to add
|
||||||
|
ids: Optional list of ids for documents. If provided should be the same
|
||||||
|
length as the list of documents. Can provided if parent documents
|
||||||
|
are already in the document store and you don't want to re-add
|
||||||
|
to the docstore. If not provided, random UUIDs will be used as
|
||||||
|
ids.
|
||||||
|
add_to_docstore: Boolean of whether to add documents to docstore.
|
||||||
|
This can be false if and only if `ids` are provided. You may want
|
||||||
|
to set this to False if the documents are already in the docstore
|
||||||
|
and you don't want to re-add them.
|
||||||
|
"""
|
||||||
|
if self.parent_splitter is not None:
|
||||||
|
documents = self.parent_splitter.split_documents(documents)
|
||||||
|
if ids is None:
|
||||||
|
doc_ids = [str(uuid.uuid4()) for _ in documents]
|
||||||
|
if not add_to_docstore:
|
||||||
|
raise ValueError(
|
||||||
|
"If ids are not passed in, `add_to_docstore` MUST be True"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if len(documents) != len(ids):
|
||||||
|
raise ValueError(
|
||||||
|
"Got uneven list of documents and ids. "
|
||||||
|
"If `ids` is provided, should be same length as `documents`."
|
||||||
|
)
|
||||||
|
doc_ids = ids
|
||||||
|
|
||||||
|
docs: List[Document] = []
|
||||||
|
full_docs: List[Tuple[str, Document]] = []
|
||||||
|
for i, doc in enumerate(documents):
|
||||||
|
_id = doc_ids[i]
|
||||||
|
sub_docs = self.child_splitter.split_documents([doc])
|
||||||
|
for _doc in sub_docs:
|
||||||
|
_doc.metadata[self.id_key] = _id
|
||||||
|
docs.extend(sub_docs)
|
||||||
|
full_docs.append((_id, doc))
|
||||||
|
self.vectorstore.add_documents(docs)
|
||||||
|
if add_to_docstore:
|
||||||
|
if isinstance(self.docstore, (RedisStore, LocalFileStore)):
|
||||||
|
# serialize documents to bytes
|
||||||
|
serialized_docs = [(id, pickle.dumps(doc)) for id, doc in full_docs]
|
||||||
|
self.docstore.mset(serialized_docs)
|
||||||
|
else:
|
||||||
|
self.docstore.mset(full_docs)
|
||||||
|
|
||||||
|
async def aadd_documents(
|
||||||
|
self,
|
||||||
|
documents: List[Document],
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
add_to_docstore: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Adds documents to the docstore and vectorstores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
documents: List of documents to add
|
||||||
|
ids: Optional list of ids for documents. If provided should be the same
|
||||||
|
length as the list of documents. Can provided if parent documents
|
||||||
|
are already in the document store and you don't want to re-add
|
||||||
|
to the docstore. If not provided, random UUIDs will be used as
|
||||||
|
ids.
|
||||||
|
add_to_docstore: Boolean of whether to add documents to docstore.
|
||||||
|
This can be false if and only if `ids` are provided. You may want
|
||||||
|
to set this to False if the documents are already in the docstore
|
||||||
|
and you don't want to re-add them.
|
||||||
|
"""
|
||||||
|
if self.parent_splitter is not None:
|
||||||
|
documents = self.parent_splitter.split_documents(documents)
|
||||||
|
if ids is None:
|
||||||
|
doc_ids = [str(uuid.uuid4()) for _ in documents]
|
||||||
|
if not add_to_docstore:
|
||||||
|
raise ValueError(
|
||||||
|
"If ids are not passed in, `add_to_docstore` MUST be True"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if len(documents) != len(ids):
|
||||||
|
raise ValueError(
|
||||||
|
"Got uneven list of documents and ids. "
|
||||||
|
"If `ids` is provided, should be same length as `documents`."
|
||||||
|
)
|
||||||
|
doc_ids = ids
|
||||||
|
|
||||||
|
docs: List[Document] = []
|
||||||
|
full_docs: List[Tuple[str, Document]] = []
|
||||||
|
for i, doc in enumerate(documents):
|
||||||
|
_id = doc_ids[i]
|
||||||
|
sub_docs = self.child_splitter.split_documents([doc])
|
||||||
|
for _doc in sub_docs:
|
||||||
|
_doc.metadata[self.id_key] = _id
|
||||||
|
docs.extend(sub_docs)
|
||||||
|
full_docs.append((_id, doc))
|
||||||
|
|
||||||
|
# check if vectorstore supports async adds
|
||||||
|
if (
|
||||||
|
type(self.vectorstore).aadd_documents != VectorStore.aadd_documents
|
||||||
|
and type(self.vectorstore).aadd_texts != VectorStore.aadd_texts
|
||||||
|
):
|
||||||
|
await self.vectorstore.aadd_documents(docs)
|
||||||
|
else:
|
||||||
|
self.vectorstore.add_documents(docs)
|
||||||
|
|
||||||
|
if add_to_docstore:
|
||||||
|
if isinstance(self.docstore, (RedisStore, LocalFileStore)):
|
||||||
|
# serialize documents to bytes
|
||||||
|
serialized_docs = [(id, pickle.dumps(doc)) for id, doc in full_docs]
|
||||||
|
self.docstore.mset(serialized_docs)
|
||||||
|
else:
|
||||||
|
self.docstore.mset(full_docs)
|
Loading…
Reference in new issue