fix deprecated chain, update vectorstore to handle markdown

pull/570/head
Richard Anthony Hein 8 months ago
parent f002293b9e
commit 5c46393ee1

@ -41,6 +41,8 @@ from swarms.prompts.conversational_RAG import (
E_INST, E_INST,
E_SYS, E_SYS,
QA_PROMPT_TEMPLATE, QA_PROMPT_TEMPLATE,
QA_PROMPT_TEMPLATE_STR,
QA_CONDENSE_TEMPLATE_STR,
SUMMARY_PROMPT_TEMPLATE, SUMMARY_PROMPT_TEMPLATE,
) )
@ -109,7 +111,7 @@ tiktoken.model.MODEL_TO_ENCODING.update(
print("Logging in to huggingface.co...") print("Logging in to huggingface.co...")
login(token=hf_token) # login to huggingface.co login(token=hf_token) # login to huggingface.co
# langchain.debug = True langchain.debug = True
langchain.verbose = True langchain.verbose = True
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@ -179,36 +181,31 @@ async def create_chain(
retriever = await vector_store.getRetriever(os.path.join(file.username, file.filename)) retriever = await vector_store.getRetriever(os.path.join(file.username, file.filename))
# chat_memory = ChatMessageHistory() chat_memory = ChatMessageHistory()
# for message in messages: for message in messages:
# if message.role == Role.USER: if message.role == Role.USER:
# human_msg = HumanMessage(message.content) human_msg = HumanMessage(message.content)
# chat_memory.add_user_message(human_msg) chat_memory.add_user_message(human_msg)
# elif message.role == Role.ASSISTANT: elif message.role == Role.ASSISTANT:
# ai_msg = AIMessage(message.content) ai_msg = AIMessage(message.content)
# chat_memory.add_ai_message(ai_msg) chat_memory.add_ai_message(ai_msg)
# elif message.role == Role.SYSTEM: elif message.role == Role.SYSTEM:
# system_msg = SystemMessage(message.content) system_msg = SystemMessage(message.content)
# chat_memory.add_message(system_msg) chat_memory.add_message(system_msg)
### Contextualize question ### ### Contextualize question ###
contextualize_q_system_prompt = """Given a chat history and the latest user question \ contextualize_q_system_prompt = """Given a chat history and the latest user question \
which might reference context in the chat history, formulate a standalone question \ which might reference context in the chat history, formulate a standalone question \
which can be understood without the chat history. Do NOT answer the question, \ which can be understood without the chat history. Do NOT answer the question, \
just reformulate it if needed and otherwise return it as is.""" just reformulate it if needed and otherwise return it as is."""
contextualize_q_prompt = ChatPromptTemplate.from_messages( contextualize_q_prompt = QA_PROMPT_TEMPLATE
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
history_aware_retriever = create_history_aware_retriever( history_aware_retriever = create_history_aware_retriever(
llm, retriever, contextualize_q_prompt llm, retriever, contextualize_q_prompt
) )
### Answer question ### ### Answer question ###
qa_system_prompt = """You are an assistant for question-answering tasks. \ qa_system_prompt = """You are an assistant for question-answering tasks. \
Use the following pieces of retrieved context to answer the question. \ Use the following pieces of retrieved context to answer the question. \
@ -216,14 +213,8 @@ async def create_chain(
Use three sentences maximum and keep the answer concise.\ Use three sentences maximum and keep the answer concise.\
{context}""" {context}"""
qa_prompt = ChatPromptTemplate.from_messages( qa_prompt = QA_PROMPT_TEMPLATE
[ question_answer_chain = create_stuff_documents_chain(llm, qa_prompt, document_prompt=DOCUMENT_PROMPT_TEMPLATE)
("system", qa_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
from langchain_core.runnables import RunnablePassthrough from langchain_core.runnables import RunnablePassthrough

@ -5,7 +5,7 @@ import glob
from datetime import datetime from datetime import datetime
from typing import Dict, Literal from typing import Dict, Literal
from chromadb.config import Settings from chromadb.config import Settings
from langchain.document_loaders import UnstructuredHTMLLoader from langchain.document_loaders.markdown import UnstructuredMarkdownLoader
from langchain.embeddings import HuggingFaceBgeEmbeddings from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.storage import LocalFileStore from langchain.storage import LocalFileStore
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
@ -15,7 +15,6 @@ from swarms.server.async_parent_document_retriever import AsyncParentDocumentRet
store_type = "local" # "redis" or "local" store_type = "local" # "redis" or "local"
class VectorStorage: class VectorStorage:
def __init__(self, directory): def __init__(self, directory):
self.embeddings = HuggingFaceBgeEmbeddings( self.embeddings = HuggingFaceBgeEmbeddings(
@ -72,8 +71,7 @@ class VectorStorage:
] ]
print(f"{len(dirs)} subdirectories to load: {dirs}") print(f"{len(dirs)} subdirectories to load: {dirs}")
for subdir in dirs: self.retrievers[self.directory] = await self.initRetriever(self.directory)
self.retrievers[subdir] = await self.initRetriever(subdir)
end_time = datetime.now() end_time = datetime.now()
print("Vectorstore initialization complete.") print("Vectorstore initialization complete.")
@ -93,13 +91,25 @@ class VectorStorage:
collections = self.client.list_collections() collections = self.client.list_collections()
print(f"Existing collections: {collections}") print(f"Existing collections: {collections}")
# load all .html documents in the subdirectory and ignore any .processed files
# Initialize an empty list to hold the documents # Initialize an empty list to hold the documents
documents = [] documents = []
# Define the maximum number of files to load at a time # Define the maximum number of files to load at a time
max_files = 1000 max_files = 1000
# Get a list of all files in the directory
all_files = glob.glob(f"{self.directory}/{subdir}/*.html", recursive=False) # Load existing metadata
metadata_file = f"{self.directory}/metadata.json"
metadata = {"processDate": str(datetime.now()), "processed_files": []}
processed_files = set() # Track processed files
if os.path.isfile(metadata_file):
with open(metadata_file, "r") as metadataFile:
metadata = dict[str, str](json.load(metadataFile))
processed_files = {entry["file"] for entry in metadata.get("processed_files", [])}
# Get a list of all files in the directory and exclude processed files
all_files = [
file for file in glob.glob(f"{self.directory}/**/*.md", recursive=True)
if file not in processed_files
]
print(f"Loading {len(all_files)} documents for title version {subdir}.") print(f"Loading {len(all_files)} documents for title version {subdir}.")
# Load files in chunks of max_files # Load files in chunks of max_files
@ -107,38 +117,42 @@ class VectorStorage:
chunksStartTime = datetime.now() chunksStartTime = datetime.now()
chunk_files = all_files[i : i + max_files] chunk_files = all_files[i : i + max_files]
for file in chunk_files: for file in chunk_files:
loader = UnstructuredHTMLLoader( loader = UnstructuredMarkdownLoader(
file, file,
encoding="utf-8", mode="elements",
strategy="fast"
) )
print(f"Loaded {file} in {subdir} ...")
documents.extend(loader.load()) documents.extend(loader.load())
print(f"Loaded {len(documents)} documents for title version {subdir}.") # Record the file as processed in metadata
metadata["processed_files"].append({
"file": file,
"processed_at": str(datetime.now())
})
# Save metadata to the metadata.json file
with open(metadata_file, "w") as metadataFile:
json.dump(metadata, metadataFile, indent=4)
print(f"Loaded {len(documents)} documents for directory '{subdir}'.")
chunksEndTime = datetime.now() chunksEndTime = datetime.now()
print( print(
f"{max_files} html file chunks processing time: {chunksEndTime - chunksStartTime}" f"{max_files} markdown file chunks processing time: {chunksEndTime - chunksStartTime}"
) )
print(f"Creating new collection for {subdir}...") print(f"Creating new collection for {self.directory}...")
# create a new collection # Create or get the collection
# if metadata file named metadata.json exists, use that as metadata
# otherwise, default to using a metadata with just the date processed.
metadata = {"processDate": str(datetime.now())}
metadata_file = f"{self.directory}/{subdir}/metadata.json"
if os.path.isfile(metadata_file):
# load the metadata.json into a dict
with open(metadata_file, "r") as metadataFile:
metadata = dict[str, str](json.load(metadataFile))
collection = self.client.create_collection( collection = self.client.create_collection(
name=subdir, name=self.directory,
get_or_create=True, # create if it doesn't exist, otherwise get it get_or_create=True,
metadata=metadata, metadata=metadata,
) )
# reload vectorstore based on collection to pass to parent doc retriever # Reload vectorstore based on collection
vectorstore = self.getVectorStore(collection_name=collection.name) vectorstore = self.getVectorStore(collection_name=collection.name)
# create a new parent document retriever # Create a new parent document retriever
retriever = AsyncParentDocumentRetriever( retriever = AsyncParentDocumentRetriever(
docstore=self.store, docstore=self.store,
vectorstore=vectorstore, vectorstore=vectorstore,
@ -146,7 +160,7 @@ class VectorStorage:
parent_splitter=self.parent_splitter, parent_splitter=self.parent_splitter,
) )
# add documents to the collection and docstore # Add documents to the collection and docstore
print(f"Adding {len(documents)} documents to collection...") print(f"Adding {len(documents)} documents to collection...")
add_docs_start_time = datetime.now() add_docs_start_time = datetime.now()
await retriever.aadd_documents( await retriever.aadd_documents(
@ -157,20 +171,14 @@ class VectorStorage:
f"Adding {len(documents)} documents to collection took: {add_docs_end_time - add_docs_start_time}" f"Adding {len(documents)} documents to collection took: {add_docs_end_time - add_docs_start_time}"
) )
# rename all files to .processed so they don't get loaded again
# but allows us to do a manual reload if needed, or future
# processing of the files
for file in chunk_files:
os.rename(file, f"{file}.processed")
documents = [] # clear documents list for next chunk documents = [] # clear documents list for next chunk
subdir_end_time = datetime.now() subdir_end_time = datetime.now()
print(f"Subdir {subdir } processing end time: {subdir_end_time}") print(f"Subdir {subdir} processing end time: {subdir_end_time}")
print(f"Time taken: {subdir_end_time - subdir_start_time}") print(f"Time taken: {subdir_end_time - subdir_start_time}")
# reload vectorstore based on collection to pass to parent doc retriever (it may have changed or be None) # Reload vectorstore based on collection to pass to parent doc retriever
collection = self.client.get_collection(name=subdir) collection = self.client.get_collection(name=self.directory)
vectorstore = self.getVectorStore(collection_name=collection.name) vectorstore = self.getVectorStore(collection_name=collection.name)
retriever = AsyncParentDocumentRetriever( retriever = AsyncParentDocumentRetriever(
docstore=self.store, docstore=self.store,

Loading…
Cancel
Save