diff --git a/swarms/server/server.py b/swarms/server/server.py index 95e1185b..0080f04f 100644 --- a/swarms/server/server.py +++ b/swarms/server/server.py @@ -1,38 +1,36 @@ +""" Chatbot with RAG Server """ + import asyncio -import json import logging import os -from datetime import datetime -from typing import List +# import torch +from contextlib import asynccontextmanager import langchain -from pydantic import ValidationError, parse_obj_as -from swarms.prompts.chat_prompt import Message -from swarms.server.callback_handlers import SourceDocumentsStreamingCallbackHandler, TokenStreamingCallbackHandler import tiktoken - -# import torch from dotenv import load_dotenv -from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile +from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, JSONResponse from fastapi.routing import APIRouter from fastapi.staticfiles import StaticFiles from huggingface_hub import login -from langchain.callbacks import StreamingStdOutCallbackHandler +from langchain.chains.combine_documents.stuff import StuffDocumentsChain +from langchain.chains.conversational_retrieval.base import ( + ConversationalRetrievalChain, +) +from langchain.chains.llm import LLMChain from langchain.memory import ConversationBufferMemory -from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory -from langchain_core.messages import SystemMessage, AIMessage, HumanMessage -from langchain.chains.history_aware_retriever import create_history_aware_retriever -from langchain.chains.retrieval import create_retrieval_chain -from langchain.chains.combine_documents import create_stuff_documents_chain -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain.memory.chat_message_histories.in_memory import ( + ChatMessageHistory, +) from langchain.prompts.prompt import PromptTemplate from langchain_community.chat_models import ChatOpenAI -from swarms.server.responses import LangchainStreamingResponse -from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain -from langchain.chains.llm import LLMChain -from langchain.chains.combine_documents.stuff import StuffDocumentsChain + +# from langchain_core.messages import AIMessage, HumanMessage, SystemMessage +# from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder + +from swarms.prompts.chat_prompt import Message from swarms.prompts.conversational_RAG import ( B_INST, B_SYS, @@ -41,40 +39,35 @@ from swarms.prompts.conversational_RAG import ( E_INST, E_SYS, QA_PROMPT_TEMPLATE, - QA_PROMPT_TEMPLATE_STR, - QA_CONDENSE_TEMPLATE_STR, ) - +from swarms.server.responses import LangchainStreamingResponse +from swarms.server.server_models import ChatRequest, Role from swarms.server.vector_store import VectorStorage -from swarms.server.server_models import ( - ChatRequest, - LogMessage, - AIModel, - AIModels, - RAGFile, - RAGFiles, - Role, - State, - GetRAGFileStateRequest, - ProcessRAGFileRequest -) - # Explicitly specify the path to the .env file # Two folders above the current file's directory -dotenv_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), '.env') +dotenv_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))), ".env" +) load_dotenv(dotenv_path) -hf_token = os.environ.get("HUGGINFACEHUB_API_KEY") # Get the Huggingface API Token -uploads = os.environ.get("UPLOADS") # Directory where user uploads files to be parsed for RAG +hf_token = os.environ.get( + "HUGGINFACEHUB_API_KEY" +) # Get the Huggingface API Token +uploads = os.environ.get( + "UPLOADS" +) # Directory where user uploads files to be parsed for RAG model_dir = os.environ.get("MODEL_DIR") # hugginface.co model (eg. meta-llama/Llama-2-70b-hf) model_name = os.environ.get("MODEL_NAME") -# Set OpenAI's API key to 'EMPTY' and API base URL to use vLLM's API server, or set them to OpenAI API key and base URL. +# Set OpenAI's API key to 'EMPTY' and API base URL to use vLLM's API server +# or set them to OpenAI API key and base URL. openai_api_key = os.environ.get("OPENAI_API_KEY") or "EMPTY" -openai_api_base = os.environ.get("OPENAI_API_BASE") or "http://localhost:8000/v1" +openai_api_base = ( + os.environ.get("OPENAI_API_BASE") or "http://localhost:8000/v1" +) env_vars = [ hf_token, @@ -93,13 +86,13 @@ if missing_vars: exit(1) useMetal = os.environ.get("USE_METAL", "False") == "True" -useGPU = os.environ.get("USE_GPU", "False") == "True" +use_gpu = os.environ.get("USE_GPU", "False") == "True" print(f"Uploads={uploads}") print(f"MODEL_DIR={model_dir}") print(f"MODEL_NAME={model_name}") print(f"USE_METAL={useMetal}") -print(f"USE_GPU={useGPU}") +print(f"USE_GPU={use_gpu}") print(f"OPENAI_API_KEY={openai_api_key}") print(f"OPENAI_API_BASE={openai_api_base}") @@ -116,23 +109,25 @@ login(token=hf_token) # login to huggingface.co langchain.debug = True langchain.verbose = True -from contextlib import asynccontextmanager @asynccontextmanager async def lifespan(app: FastAPI): - asyncio.create_task(vector_store.initRetrievers()) + """Initializes the vector store in a background task.""" + print(f"Initializing vector store retrievers for {app.title}.") + asyncio.create_task(vector_store.init_retrievers()) yield -app = FastAPI(title="Chatbot", lifespan=lifespan) + +chatbot = FastAPI(title="Chatbot", lifespan=lifespan) router = APIRouter() current_dir = os.path.dirname(__file__) print("current_dir: " + current_dir) static_dir = os.path.join(current_dir, "static") print("static_dir: " + static_dir) -app.mount(static_dir, StaticFiles(directory=static_dir), name="static") +chatbot.mount(static_dir, StaticFiles(directory=static_dir), name="static") -app.add_middleware( +chatbot.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, @@ -147,21 +142,15 @@ if not os.path.exists(uploads): os.makedirs(uploads) # Initialize the vector store -vector_store = VectorStorage(directoryOrUrl=uploads, useGPU=useGPU) +vector_store = VectorStorage(directory=uploads, use_gpu=use_gpu) async def create_chain( messages: list[Message], - model=model_dir, - max_tokens_to_gen=2048, - temperature=0.5, prompt: PromptTemplate = QA_PROMPT_TEMPLATE, - file: RAGFile | None = None, - key: str | None = None, ): - print( - f"Creating chain with key={key}, model={model}, max_tokens={max_tokens_to_gen}, temperature={temperature}, prompt={prompt}, file={file.title}" - ) + """Creates the RAG Langchain conversational retrieval chain.""" + print("Creating chain ...") llm = ChatOpenAI( api_key=openai_api_key, @@ -181,7 +170,7 @@ async def create_chain( # if llm is VLLMAsync: # llm.max_tokens = max_tokens_to_gen - retriever = await vector_store.getRetriever() + retriever = await vector_store.get_retriever() chat_memory = ChatMessageHistory() for message in messages: @@ -236,26 +225,26 @@ async def create_chain( router = APIRouter() + @router.post( "/chat", summary="Chatbot", description="Chatbot AI Service", ) async def chat(request: ChatRequest): + """ Handles chatbot chat POST requests """ chain = await create_chain( - file=request.file, messages=request.messages[:-1], - model=request.model.id, - max_tokens_to_gen=request.maxTokens, - temperature=request.temperature, prompt=PromptTemplate.from_template( f"{B_INST}{B_SYS}{request.prompt.strip()}{E_SYS}{E_INST}" ), ) - json = { + json_config = { "question": request.messages[-1].content, - "chat_history": [message.content for message in request.messages[:-1]], + "chat_history": [ + message.content for message in request.messages[:-1] + ], # "callbacks": [ # StreamingStdOutCallbackHandler(), # TokenStreamingCallbackHandler(output_key="answer"), @@ -264,178 +253,41 @@ async def chat(request: ChatRequest): } return LangchainStreamingResponse( chain, - config=json, + config=json_config, ) -app.include_router(router, tags=["chat"]) +chatbot.include_router(router, tags=["chat"]) -@app.get("/") +@chatbot.get("/") def root(): - return {"message": "Chatbot API"} + """Swarms Chatbot API Root""" + return {"message": "Swarms Chatbot API"} -@app.get("/favicon.ico") +@chatbot.get("/favicon.ico") def favicon(): + """ Returns a favicon """ file_name = "favicon.ico" - file_path = os.path.join(app.root_path, "static", file_name) + file_path = os.path.join(chatbot.root_path, "static", file_name) return FileResponse( path=file_path, - headers={"Content-Disposition": "attachment; filename=" + file_name}, + headers={ + "Content-Disposition": "attachment; filename=" + file_name + }, ) -@app.post("/log") -def log_message(log_message: LogMessage): - try: - with open("log.txt", "a") as log_file: - log_file.write(log_message.message + "\n") - except Exception as e: - raise HTTPException(status_code=500, detail=f"Error saving log: {e}") - return {"message": "Log saved successfully"} - - -@app.get("/models") -def get_models(): - # llama7B = AIModel( - # id="llama-2-7b-chat-ggml-q4_0", - # name="llama-2-7b-chat-ggml-q4_0", - # maxLength=2048, - # tokenLimit=2048, - # ) - # llama13B = AIModel( - # id="llama-2-13b-chat-ggml-q4_0", - # name="llama-2-13b-chat-ggml-q4_0", - # maxLength=2048, - # tokenLimit=2048, - # ) - llama70B = AIModel( - id="llama-2-70b.Q5_K_M", - name="llama-2-70b.Q5_K_M", - maxLength=2048, - tokenLimit=2048, - ) - models = AIModels(models=[llama70B]) - return models - - -@app.get("/titles") -def getTitles(): - titles = RAGFiles( - titles=[ - # RAGFile( - # versionId="d8ad3b1d-c33c-4524-9691-e93967d4d863", - # title="d8ad3b1d-c33c-4524-9691-e93967d4d863", - # state=State.Unavailable, - # ), - RAGFile( - versionId=collection.name, - title=collection.name, - state=State.InProcess - if collection.name in processing_books - else State.Processed, - ) - for collection in vector_store.list_collections() - if collection.name != "langchain" - ] - ) - return titles - - -processing_books: list[str] = [] -processing_books_lock = asyncio.Lock() - logging.basicConfig(level=logging.ERROR) -@app.post("/titleState") -async def getTitleState(request: GetRAGFileStateRequest): - # FastAPI + Pydantic will throw a 422 Unprocessable Entity if the request isn't the right type. - # try: - logging.debug(f"Received getTitleState request: {request}") - titleStateRequest: GetRAGFileStateRequest = request - # except ValidationError as e: - # print(f"Error validating JSON: {e}") - # raise HTTPException(status_code=422, detail=str(e)) - # except json.JSONDecodeError as e: - # print(f"Error parsing JSON: {e}") - # raise HTTPException(status_code=422, detail="Invalid JSON format") - # check to see if the book has already been processed. - # return the proper State directly to response. - matchingCollection = next( - ( - x - for x in vector_store.list_collections() - if x.name == titleStateRequest.versionRef - ), - None, +@chatbot.exception_handler(HTTPException) +async def http_exception_handler(r: Request, exc: HTTPException): + """Log and return exception details in response.""" + logging.error( + "HTTPException: %s executing request: %s", exc.detail, r.base_url + ) + return JSONResponse( + status_code=exc.status_code, content={"detail": exc.detail} ) - print("Got a Title State request for version " + titleStateRequest.versionRef) - if titleStateRequest.versionRef in processing_books: - return {"message": State.InProcess} - elif matchingCollection is not None: - return {"message": State.Processed} - else: - return {"message": State.Unavailable} - - -@app.post("/processRAGFile") -async def processRAGFile( - request: str = Form(...), - files: List[UploadFile] = File(...), -): - try: - logging.debug(f"Received processBook request: {request}") - # Parse the JSON string into a ProcessBookRequest object - fileRAGRequest: ProcessRAGFileRequest = parse_obj_as( - ProcessRAGFileRequest, json.loads(request) - ) - except ValidationError as e: - print(f"Error validating JSON: {e}") - raise HTTPException(status_code=422, detail=str(e)) - except json.JSONDecodeError as e: - print(f"Error parsing JSON: {e}") - raise HTTPException(status_code=422, detail="Invalid JSON format") - - try: - print( - f"Processing file {fileRAGRequest.filename} for user {fileRAGRequest.username}." - ) - # check to see if the file has already been processed. - # write html to subfolder - print(f"Writing file to path: {fileRAGRequest.username}/{fileRAGRequest.filename}...") - - for index, segment in enumerate(files): - filename = segment.filename if segment.filename else str(index) - subDir = f"{fileRAGRequest.username}" - with open(os.path.join(subDir, filename), "wb") as htmlFile: - htmlFile.write(await segment.read()) - - # write metadata to subfolder - print(f"Writing metadata to subfolder {fileRAGRequest.username}...") - with open(os.path.join({fileRAGRequest.username}, "metadata.json"), "w") as metadataFile: - metaData = { - "filename": fileRAGRequest.filename, - "username": fileRAGRequest.username, - "processDate": datetime.now().isoformat(), - } - metadataFile.write(json.dumps(metaData)) - - vector_store.retrievers[ - f"{fileRAGRequest.username}/{fileRAGRequest.filename}" - ] = await vector_store.initRetriever(f"{fileRAGRequest.username}/{fileRAGRequest.filename}") - - return { - "message": f"File {fileRAGRequest.filename} processed successfully." - } - except Exception as e: - logging.error(f"Error processing book: {e}") - return {"message": f"Error processing book: {e}"} - - -@app.exception_handler(HTTPException) -async def http_exception_handler(bookRequest: Request, exc: HTTPException): - logging.error(f"HTTPException: {exc.detail}") - return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail}) - diff --git a/swarms/server/vector_store.py b/swarms/server/vector_store.py index 16f853f5..a4783584 100644 --- a/swarms/server/vector_store.py +++ b/swarms/server/vector_store.py @@ -1,9 +1,11 @@ +""" Vector storage with RAG (Retrieval Augmented Generation) support for Markdown.""" + import asyncio import json import os import glob from datetime import datetime -from typing import Dict, Literal +from typing import Dict from chromadb.config import Settings from langchain.document_loaders.markdown import UnstructuredMarkdownLoader from langchain.embeddings import HuggingFaceBgeEmbeddings @@ -11,34 +13,39 @@ from langchain.storage import LocalFileStore from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores.chroma import Chroma from langchain.schema import BaseRetriever -from swarms.server.async_parent_document_retriever import AsyncParentDocumentRetriever +from swarms.server.async_parent_document_retriever import ( + AsyncParentDocumentRetriever, +) + +STORE_TYPE = "local" # "redis" or "local" -store_type = "local" # "redis" or "local" class VectorStorage: - def __init__(self, directoryOrUrl, useGPU=False): + """Vector storage class handles loading documents from a given directory.""" + + def __init__(self, directory, use_gpu=False): self.embeddings = HuggingFaceBgeEmbeddings( cache_folder="./.embeddings", model_name="BAAI/bge-large-en", - model_kwargs={"device": "cuda" if useGPU else "cpu"}, + model_kwargs={"device": "cuda" if use_gpu else "cpu"}, encode_kwargs={"normalize_embeddings": True}, query_instruction="Represent this sentence for searching relevant passages: ", ) - self.directoryOrUrl = directoryOrUrl + self.directory = directory self.child_splitter = RecursiveCharacterTextSplitter( chunk_size=200, chunk_overlap=20 ) self.parent_splitter = RecursiveCharacterTextSplitter( chunk_size=2000, chunk_overlap=200 ) - if store_type == "redis": + if STORE_TYPE == "redis": from langchain.storage import RedisStore from langchain.utilities.redis import get_client username = r"username" password = r"password" client = get_client( - redis_url=f"redis://{username}:{password}@redis-10854.c282.east-us-mz.azure.cloud.redislabs.com:10854" + redis_url=f"redis://{username}:{password}@localhost:6239" ) self.store = RedisStore(client=client) else: @@ -49,7 +56,7 @@ class VectorStorage: anonymized_telemetry=False, ) # create a new vectorstore or get an existing one, with default collection - self.vectorstore = self.getVectorStore() + self.vectorstore = self.get_vector_store() self.client = self.vectorstore._client self.retrievers: Dict[str, BaseRetriever] = {} # default retriever for when no collection title is specified @@ -57,22 +64,25 @@ class VectorStorage: str(self.vectorstore._LANGCHAIN_DEFAULT_COLLECTION_NAME) ] = self.vectorstore.as_retriever() - async def initRetrievers(self, directories: list[str] | None = None): + async def init_retrievers(self, directories: list[str] | None = None): + """Initializes the vector storage retrievers.""" start_time = datetime.now() print(f"Start vectorstore initialization time: {start_time}") # for each subdirectory in the directory, create a new collection if it doesn't exist - dirs = directories or os.listdir(self.directoryOrUrl) + dirs = directories or os.listdir(self.directory) # make sure the subdir is not a file on MacOS (which has a hidden .DS_Store file) dirs = [ subdir for subdir in dirs - if not os.path.isfile(f"{self.directoryOrUrl}/{subdir}") + if not os.path.isfile(f"{self.directory}/{subdir}") ] print(f"{len(dirs)} subdirectories to load: {dirs}") - self.retrievers[self.directoryOrUrl] = await self.initRetriever(self.directoryOrUrl) - + self.retrievers[self.directory] = await self.init_retriever( + self.directory + ) + end_time = datetime.now() print("Vectorstore initialization complete.") print(f"Vectorstore initialization end time: {end_time}") @@ -80,110 +90,140 @@ class VectorStorage: return self.retrievers - async def initRetriever(self, subdir: str) -> BaseRetriever: + async def init_retriever(self, subdir: str) -> BaseRetriever: + """ Initialize each retriever. """ # Ensure only one process/thread is executing this method at a time lock = asyncio.Lock() async with lock: - # subdir_start_time = datetime.now() - # print(f"Start {subdir} processing time: {subdir_start_time}") - - # # get all existing collections - # collections = self.client.list_collections() - # print(f"Existing collections: {collections}") - - # # Initialize an empty list to hold the documents - # documents = [] - # # Define the maximum number of files to load at a time - # max_files = 1000 - - # # Load existing metadata - # metadata_file = f"{self.directoryOrUrl}/metadata.json" - # metadata = {"processDate": str(datetime.now()), "processed_files": []} - # processed_files = set() # Track processed files - # if os.path.isfile(metadata_file): - # with open(metadata_file, "r") as metadataFile: - # metadata = dict[str, str](json.load(metadataFile)) - # processed_files = {entry["file"] for entry in metadata.get("processed_files", [])} - - # # Get a list of all files in the directory and exclude processed files - # all_files = [ - # file for file in glob.glob(f"{self.directoryOrUrl}/**/*.md", recursive=True) - # if file not in processed_files - # ] - - # print(f"Loading {len(all_files)} documents for title version {subdir}.") - # # Load files in chunks of max_files - # for i in range(0, len(all_files), max_files): - # chunksStartTime = datetime.now() - # chunk_files = all_files[i : i + max_files] - # for file in chunk_files: - # loader = UnstructuredMarkdownLoader( - # file, - # mode="single", - # strategy="fast" - # ) - # print(f"Loaded {file} in {subdir} ...") - # documents.extend(loader.load()) - - # # Record the file as processed in metadata - # metadata["processed_files"].append({ - # "file": file, - # "processed_at": str(datetime.now()) - # }) - - # print(f"Creating new collection for {self.directoryOrUrl}...") - # # Create or get the collection - # collection = self.client.create_collection( - # name=self.directoryOrUrl, - # get_or_create=True, - # metadata={"processDate": metadata["processDate"]}, - # ) - - # # Reload vectorstore based on collection - # vectorstore = self.getVectorStore(collection_name=self.directoryOrUrl) - - # # Create a new parent document retriever - # retriever = AsyncParentDocumentRetriever( - # docstore=self.store, - # vectorstore=vectorstore, - # child_splitter=self.child_splitter, - # parent_splitter=self.parent_splitter, - # ) - - # # force reload of collection to make sure we don't have the default langchain collection - # collection = self.client.get_collection(name=self.directoryOrUrl) - # vectorstore = self.getVectorStore(collection_name=self.directoryOrUrl) - - # # Add documents to the collection and docstore - # print(f"Adding {len(documents)} documents to collection...") - # add_docs_start_time = datetime.now() - # await retriever.aadd_documents( - # documents=documents, add_to_docstore=True - # ) - # add_docs_end_time = datetime.now() - # print( - # f"Adding {len(documents)} documents to collection took: {add_docs_end_time - add_docs_start_time}" - # ) - - # documents = [] # clear documents list for next chunk - - # # Save metadata to the metadata.json file - # with open(metadata_file, "w") as metadataFile: - # json.dump(metadata, metadataFile, indent=4) - - # print(f"Loaded {len(documents)} documents for directory '{subdir}'.") - # chunksEndTime = datetime.now() - # print( - # f"{max_files} markdown file chunks processing time: {chunksEndTime - chunksStartTime}" - # ) - - # subdir_end_time = datetime.now() - # print(f"Subdir {subdir} processing end time: {subdir_end_time}") - # print(f"Time taken: {subdir_end_time - subdir_start_time}") + subdir_start_time = datetime.now() + print(f"Start {subdir} processing time: {subdir_start_time}") + + # get all existing collections + collections = self.client.list_collections() + print(f"Existing collections: {collections}") + + # Initialize an empty list to hold the documents + documents = [] + # Define the maximum number of files to load at a time + max_files = 1000 + + # Load existing metadata + metadata_file = f"{self.directory}/metadata.json" + metadata = { + "processDate": str(datetime.now()), + "processed_files": [], + } + processed_files = set() # Track processed files + if os.path.isfile(metadata_file): + with open( + metadata_file, "r", encoding="utf-8" + ) as metadata_file: + metadata = dict[str, str](json.load(metadata_file)) + processed_files = { + entry["file"] + for entry in metadata.get("processed_files", []) + } + + # Get a list of all files in the directory and exclude processed files + all_files = [ + file + for file in glob.glob( + f"{self.directory}/**/*.md", recursive=True + ) + if file not in processed_files + ] + + print( + f"Loading {len(all_files)} documents for title version {subdir}." + ) + # Load files in chunks of max_files + for i in range(0, len(all_files), max_files): + chunks_start_time = datetime.now() + chunk_files = all_files[i : i + max_files] + for file in chunk_files: + loader = UnstructuredMarkdownLoader( + file, mode="single", strategy="fast" + ) + print(f"Loaded {file} in {subdir} ...") + documents.extend(loader.load()) + + # Record the file as processed in metadata + metadata["processed_files"].append( + {"file": file, "processed_at": str(datetime.now())} + ) + + print( + f"Creating new collection for {self.directory}..." + ) + # Create or get the collection + collection = self.client.create_collection( + name=self.directory, + get_or_create=True, + metadata={"processDate": metadata["processDate"]}, + ) + + # Reload vectorstore based on collection + vectorstore = self.get_vector_store( + collection_name=collection.name + ) + + # Create a new parent document retriever + retriever = AsyncParentDocumentRetriever( + docstore=self.store, + vectorstore=vectorstore, + child_splitter=self.child_splitter, + parent_splitter=self.parent_splitter, + ) + + # force reload of collection to make sure we don't have + # the default langchain collection + collection = self.client.get_collection( + name=self.directory + ) + vectorstore = self.get_vector_store( + collection_name=self.directory + ) + + # Add documents to the collection and docstore + print( + f"Adding {len(documents)} documents to collection..." + ) + add_docs_start_time = datetime.now() + await retriever.aadd_documents( + documents=documents, add_to_docstore=True + ) + add_docs_end_time = datetime.now() + total_time = add_docs_end_time - add_docs_start_time + print( + f"Adding {len(documents)} documents to collection took: {total_time}" + ) + + documents = [] # clear documents list for next chunk + + # Save metadata to the metadata.json file + with open( + metadata_file, "w", encoding="utf-8" + ) as metadata_file: + json.dump(metadata, metadata_file, indent=4) + + print( + f"Loaded {len(documents)} documents for directory '{subdir}'." + ) + chunks_end_time = datetime.now() + chunk_time = chunks_end_time - chunks_start_time + print( + f"{max_files} markdown file chunks processing time: {chunk_time}" + ) + + subdir_end_time = datetime.now() + print( + f"Subdir {subdir} processing end time: {subdir_end_time}" + ) + print(f"Time taken: {subdir_end_time - subdir_start_time}") # Reload vectorstore based on collection to pass to parent doc retriever - # collection = self.client.get_collection(name=self.directoryOrUrl) - vectorstore = self.getVectorStore() + # collection = self.client.get_collection(name=self.directory) + vectorstore = self.get_vector_store() retriever = AsyncParentDocumentRetriever( docstore=self.store, vectorstore=vectorstore, @@ -192,8 +232,9 @@ class VectorStorage: ) return retriever - def getVectorStore(self, collection_name: str | None = None) -> Chroma: - if collection_name is None or "" or "None" : + def get_vector_store(self, collection_name: str | None = None) -> Chroma: + """ get a specific vector store for a collection """ + if collection_name is None or "" or "None": collection_name = "langchain" print("collection_name: " + collection_name) vectorstore = Chroma( @@ -204,21 +245,24 @@ class VectorStorage: return vectorstore def list_collections(self): + """ Get a list of all collections in the vectorstore """ vectorstore = Chroma( - client_settings=self.settings, embedding_function=self.embeddings + client_settings=self.settings, + embedding_function=self.embeddings, ) return vectorstore._client.list_collections() - async def getRetriever(self, collection_name: str | None = None): + async def get_retriever(self, collection_name: str | None = None): + """ get a specific retriever for a collection in the vectorstore """ if self.retrievers is None: - self.retrievers = await self.initRetrievers() + self.retrievers = await self.init_retrievers() if ( collection_name is None or collection_name == "" or collection_name == "None" ): - name = str(Chroma._LANGCHAIN_DEFAULT_COLLECTION_NAME) + name = "swarms" else: name = collection_name @@ -226,6 +270,8 @@ class VectorStorage: retriever = self.retrievers[name] except KeyError: print(f"Retriever for {name} not found, using default...") - retriever = self.retrievers[Chroma._LANGCHAIN_DEFAULT_COLLECTION_NAME] + retriever = self.retrievers[ + "swarms" + ] return retriever