fix linting errors

pull/570/head
Richard Anthony Hein 8 months ago
parent 4fae6839eb
commit f37223d49e

@ -1,38 +1,36 @@
""" Chatbot with RAG Server """
import asyncio import asyncio
import json
import logging import logging
import os import os
from datetime import datetime
from typing import List
# import torch
from contextlib import asynccontextmanager
import langchain import langchain
from pydantic import ValidationError, parse_obj_as
from swarms.prompts.chat_prompt import Message
from swarms.server.callback_handlers import SourceDocumentsStreamingCallbackHandler, TokenStreamingCallbackHandler
import tiktoken import tiktoken
# import torch
from dotenv import load_dotenv from dotenv import load_dotenv
from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse from fastapi.responses import FileResponse, JSONResponse
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from huggingface_hub import login from huggingface_hub import login
from langchain.callbacks import StreamingStdOutCallbackHandler from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.conversational_retrieval.base import (
ConversationalRetrievalChain,
)
from langchain.chains.llm import LLMChain
from langchain.memory import ConversationBufferMemory from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory from langchain.memory.chat_message_histories.in_memory import (
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage ChatMessageHistory,
from langchain.chains.history_aware_retriever import create_history_aware_retriever )
from langchain.chains.retrieval import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
from langchain_community.chat_models import ChatOpenAI from langchain_community.chat_models import ChatOpenAI
from swarms.server.responses import LangchainStreamingResponse
from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain # from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain.chains.llm import LLMChain # from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from swarms.prompts.chat_prompt import Message
from swarms.prompts.conversational_RAG import ( from swarms.prompts.conversational_RAG import (
B_INST, B_INST,
B_SYS, B_SYS,
@ -41,40 +39,35 @@ from swarms.prompts.conversational_RAG import (
E_INST, E_INST,
E_SYS, E_SYS,
QA_PROMPT_TEMPLATE, QA_PROMPT_TEMPLATE,
QA_PROMPT_TEMPLATE_STR,
QA_CONDENSE_TEMPLATE_STR,
) )
from swarms.server.responses import LangchainStreamingResponse
from swarms.server.server_models import ChatRequest, Role
from swarms.server.vector_store import VectorStorage from swarms.server.vector_store import VectorStorage
from swarms.server.server_models import (
ChatRequest,
LogMessage,
AIModel,
AIModels,
RAGFile,
RAGFiles,
Role,
State,
GetRAGFileStateRequest,
ProcessRAGFileRequest
)
# Explicitly specify the path to the .env file # Explicitly specify the path to the .env file
# Two folders above the current file's directory # Two folders above the current file's directory
dotenv_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), '.env') dotenv_path = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))), ".env"
)
load_dotenv(dotenv_path) load_dotenv(dotenv_path)
hf_token = os.environ.get("HUGGINFACEHUB_API_KEY") # Get the Huggingface API Token hf_token = os.environ.get(
uploads = os.environ.get("UPLOADS") # Directory where user uploads files to be parsed for RAG "HUGGINFACEHUB_API_KEY"
) # Get the Huggingface API Token
uploads = os.environ.get(
"UPLOADS"
) # Directory where user uploads files to be parsed for RAG
model_dir = os.environ.get("MODEL_DIR") model_dir = os.environ.get("MODEL_DIR")
# hugginface.co model (eg. meta-llama/Llama-2-70b-hf) # hugginface.co model (eg. meta-llama/Llama-2-70b-hf)
model_name = os.environ.get("MODEL_NAME") model_name = os.environ.get("MODEL_NAME")
# Set OpenAI's API key to 'EMPTY' and API base URL to use vLLM's API server, or set them to OpenAI API key and base URL. # Set OpenAI's API key to 'EMPTY' and API base URL to use vLLM's API server
# or set them to OpenAI API key and base URL.
openai_api_key = os.environ.get("OPENAI_API_KEY") or "EMPTY" openai_api_key = os.environ.get("OPENAI_API_KEY") or "EMPTY"
openai_api_base = os.environ.get("OPENAI_API_BASE") or "http://localhost:8000/v1" openai_api_base = (
os.environ.get("OPENAI_API_BASE") or "http://localhost:8000/v1"
)
env_vars = [ env_vars = [
hf_token, hf_token,
@ -93,13 +86,13 @@ if missing_vars:
exit(1) exit(1)
useMetal = os.environ.get("USE_METAL", "False") == "True" useMetal = os.environ.get("USE_METAL", "False") == "True"
useGPU = os.environ.get("USE_GPU", "False") == "True" use_gpu = os.environ.get("USE_GPU", "False") == "True"
print(f"Uploads={uploads}") print(f"Uploads={uploads}")
print(f"MODEL_DIR={model_dir}") print(f"MODEL_DIR={model_dir}")
print(f"MODEL_NAME={model_name}") print(f"MODEL_NAME={model_name}")
print(f"USE_METAL={useMetal}") print(f"USE_METAL={useMetal}")
print(f"USE_GPU={useGPU}") print(f"USE_GPU={use_gpu}")
print(f"OPENAI_API_KEY={openai_api_key}") print(f"OPENAI_API_KEY={openai_api_key}")
print(f"OPENAI_API_BASE={openai_api_base}") print(f"OPENAI_API_BASE={openai_api_base}")
@ -116,23 +109,25 @@ login(token=hf_token) # login to huggingface.co
langchain.debug = True langchain.debug = True
langchain.verbose = True langchain.verbose = True
from contextlib import asynccontextmanager
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
asyncio.create_task(vector_store.initRetrievers()) """Initializes the vector store in a background task."""
print(f"Initializing vector store retrievers for {app.title}.")
asyncio.create_task(vector_store.init_retrievers())
yield yield
app = FastAPI(title="Chatbot", lifespan=lifespan)
chatbot = FastAPI(title="Chatbot", lifespan=lifespan)
router = APIRouter() router = APIRouter()
current_dir = os.path.dirname(__file__) current_dir = os.path.dirname(__file__)
print("current_dir: " + current_dir) print("current_dir: " + current_dir)
static_dir = os.path.join(current_dir, "static") static_dir = os.path.join(current_dir, "static")
print("static_dir: " + static_dir) print("static_dir: " + static_dir)
app.mount(static_dir, StaticFiles(directory=static_dir), name="static") chatbot.mount(static_dir, StaticFiles(directory=static_dir), name="static")
app.add_middleware( chatbot.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["*"], allow_origins=["*"],
allow_credentials=True, allow_credentials=True,
@ -147,21 +142,15 @@ if not os.path.exists(uploads):
os.makedirs(uploads) os.makedirs(uploads)
# Initialize the vector store # Initialize the vector store
vector_store = VectorStorage(directoryOrUrl=uploads, useGPU=useGPU) vector_store = VectorStorage(directory=uploads, use_gpu=use_gpu)
async def create_chain( async def create_chain(
messages: list[Message], messages: list[Message],
model=model_dir,
max_tokens_to_gen=2048,
temperature=0.5,
prompt: PromptTemplate = QA_PROMPT_TEMPLATE, prompt: PromptTemplate = QA_PROMPT_TEMPLATE,
file: RAGFile | None = None,
key: str | None = None,
): ):
print( """Creates the RAG Langchain conversational retrieval chain."""
f"Creating chain with key={key}, model={model}, max_tokens={max_tokens_to_gen}, temperature={temperature}, prompt={prompt}, file={file.title}" print("Creating chain ...")
)
llm = ChatOpenAI( llm = ChatOpenAI(
api_key=openai_api_key, api_key=openai_api_key,
@ -181,7 +170,7 @@ async def create_chain(
# if llm is VLLMAsync: # if llm is VLLMAsync:
# llm.max_tokens = max_tokens_to_gen # llm.max_tokens = max_tokens_to_gen
retriever = await vector_store.getRetriever() retriever = await vector_store.get_retriever()
chat_memory = ChatMessageHistory() chat_memory = ChatMessageHistory()
for message in messages: for message in messages:
@ -236,26 +225,26 @@ async def create_chain(
router = APIRouter() router = APIRouter()
@router.post( @router.post(
"/chat", "/chat",
summary="Chatbot", summary="Chatbot",
description="Chatbot AI Service", description="Chatbot AI Service",
) )
async def chat(request: ChatRequest): async def chat(request: ChatRequest):
""" Handles chatbot chat POST requests """
chain = await create_chain( chain = await create_chain(
file=request.file,
messages=request.messages[:-1], messages=request.messages[:-1],
model=request.model.id,
max_tokens_to_gen=request.maxTokens,
temperature=request.temperature,
prompt=PromptTemplate.from_template( prompt=PromptTemplate.from_template(
f"{B_INST}{B_SYS}{request.prompt.strip()}{E_SYS}{E_INST}" f"{B_INST}{B_SYS}{request.prompt.strip()}{E_SYS}{E_INST}"
), ),
) )
json = { json_config = {
"question": request.messages[-1].content, "question": request.messages[-1].content,
"chat_history": [message.content for message in request.messages[:-1]], "chat_history": [
message.content for message in request.messages[:-1]
],
# "callbacks": [ # "callbacks": [
# StreamingStdOutCallbackHandler(), # StreamingStdOutCallbackHandler(),
# TokenStreamingCallbackHandler(output_key="answer"), # TokenStreamingCallbackHandler(output_key="answer"),
@ -264,178 +253,41 @@ async def chat(request: ChatRequest):
} }
return LangchainStreamingResponse( return LangchainStreamingResponse(
chain, chain,
config=json, config=json_config,
) )
app.include_router(router, tags=["chat"]) chatbot.include_router(router, tags=["chat"])
@app.get("/") @chatbot.get("/")
def root(): def root():
return {"message": "Chatbot API"} """Swarms Chatbot API Root"""
return {"message": "Swarms Chatbot API"}
@app.get("/favicon.ico") @chatbot.get("/favicon.ico")
def favicon(): def favicon():
""" Returns a favicon """
file_name = "favicon.ico" file_name = "favicon.ico"
file_path = os.path.join(app.root_path, "static", file_name) file_path = os.path.join(chatbot.root_path, "static", file_name)
return FileResponse( return FileResponse(
path=file_path, path=file_path,
headers={"Content-Disposition": "attachment; filename=" + file_name}, headers={
"Content-Disposition": "attachment; filename=" + file_name
},
) )
@app.post("/log")
def log_message(log_message: LogMessage):
try:
with open("log.txt", "a") as log_file:
log_file.write(log_message.message + "\n")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error saving log: {e}")
return {"message": "Log saved successfully"}
@app.get("/models")
def get_models():
# llama7B = AIModel(
# id="llama-2-7b-chat-ggml-q4_0",
# name="llama-2-7b-chat-ggml-q4_0",
# maxLength=2048,
# tokenLimit=2048,
# )
# llama13B = AIModel(
# id="llama-2-13b-chat-ggml-q4_0",
# name="llama-2-13b-chat-ggml-q4_0",
# maxLength=2048,
# tokenLimit=2048,
# )
llama70B = AIModel(
id="llama-2-70b.Q5_K_M",
name="llama-2-70b.Q5_K_M",
maxLength=2048,
tokenLimit=2048,
)
models = AIModels(models=[llama70B])
return models
@app.get("/titles")
def getTitles():
titles = RAGFiles(
titles=[
# RAGFile(
# versionId="d8ad3b1d-c33c-4524-9691-e93967d4d863",
# title="d8ad3b1d-c33c-4524-9691-e93967d4d863",
# state=State.Unavailable,
# ),
RAGFile(
versionId=collection.name,
title=collection.name,
state=State.InProcess
if collection.name in processing_books
else State.Processed,
)
for collection in vector_store.list_collections()
if collection.name != "langchain"
]
)
return titles
processing_books: list[str] = []
processing_books_lock = asyncio.Lock()
logging.basicConfig(level=logging.ERROR) logging.basicConfig(level=logging.ERROR)
@app.post("/titleState") @chatbot.exception_handler(HTTPException)
async def getTitleState(request: GetRAGFileStateRequest): async def http_exception_handler(r: Request, exc: HTTPException):
# FastAPI + Pydantic will throw a 422 Unprocessable Entity if the request isn't the right type. """Log and return exception details in response."""
# try: logging.error(
logging.debug(f"Received getTitleState request: {request}") "HTTPException: %s executing request: %s", exc.detail, r.base_url
titleStateRequest: GetRAGFileStateRequest = request )
# except ValidationError as e: return JSONResponse(
# print(f"Error validating JSON: {e}") status_code=exc.status_code, content={"detail": exc.detail}
# raise HTTPException(status_code=422, detail=str(e))
# except json.JSONDecodeError as e:
# print(f"Error parsing JSON: {e}")
# raise HTTPException(status_code=422, detail="Invalid JSON format")
# check to see if the book has already been processed.
# return the proper State directly to response.
matchingCollection = next(
(
x
for x in vector_store.list_collections()
if x.name == titleStateRequest.versionRef
),
None,
) )
print("Got a Title State request for version " + titleStateRequest.versionRef)
if titleStateRequest.versionRef in processing_books:
return {"message": State.InProcess}
elif matchingCollection is not None:
return {"message": State.Processed}
else:
return {"message": State.Unavailable}
@app.post("/processRAGFile")
async def processRAGFile(
request: str = Form(...),
files: List[UploadFile] = File(...),
):
try:
logging.debug(f"Received processBook request: {request}")
# Parse the JSON string into a ProcessBookRequest object
fileRAGRequest: ProcessRAGFileRequest = parse_obj_as(
ProcessRAGFileRequest, json.loads(request)
)
except ValidationError as e:
print(f"Error validating JSON: {e}")
raise HTTPException(status_code=422, detail=str(e))
except json.JSONDecodeError as e:
print(f"Error parsing JSON: {e}")
raise HTTPException(status_code=422, detail="Invalid JSON format")
try:
print(
f"Processing file {fileRAGRequest.filename} for user {fileRAGRequest.username}."
)
# check to see if the file has already been processed.
# write html to subfolder
print(f"Writing file to path: {fileRAGRequest.username}/{fileRAGRequest.filename}...")
for index, segment in enumerate(files):
filename = segment.filename if segment.filename else str(index)
subDir = f"{fileRAGRequest.username}"
with open(os.path.join(subDir, filename), "wb") as htmlFile:
htmlFile.write(await segment.read())
# write metadata to subfolder
print(f"Writing metadata to subfolder {fileRAGRequest.username}...")
with open(os.path.join({fileRAGRequest.username}, "metadata.json"), "w") as metadataFile:
metaData = {
"filename": fileRAGRequest.filename,
"username": fileRAGRequest.username,
"processDate": datetime.now().isoformat(),
}
metadataFile.write(json.dumps(metaData))
vector_store.retrievers[
f"{fileRAGRequest.username}/{fileRAGRequest.filename}"
] = await vector_store.initRetriever(f"{fileRAGRequest.username}/{fileRAGRequest.filename}")
return {
"message": f"File {fileRAGRequest.filename} processed successfully."
}
except Exception as e:
logging.error(f"Error processing book: {e}")
return {"message": f"Error processing book: {e}"}
@app.exception_handler(HTTPException)
async def http_exception_handler(bookRequest: Request, exc: HTTPException):
logging.error(f"HTTPException: {exc.detail}")
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})

@ -1,9 +1,11 @@
""" Vector storage with RAG (Retrieval Augmented Generation) support for Markdown."""
import asyncio import asyncio
import json import json
import os import os
import glob import glob
from datetime import datetime from datetime import datetime
from typing import Dict, Literal from typing import Dict
from chromadb.config import Settings from chromadb.config import Settings
from langchain.document_loaders.markdown import UnstructuredMarkdownLoader from langchain.document_loaders.markdown import UnstructuredMarkdownLoader
from langchain.embeddings import HuggingFaceBgeEmbeddings from langchain.embeddings import HuggingFaceBgeEmbeddings
@ -11,34 +13,39 @@ from langchain.storage import LocalFileStore
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores.chroma import Chroma from langchain.vectorstores.chroma import Chroma
from langchain.schema import BaseRetriever from langchain.schema import BaseRetriever
from swarms.server.async_parent_document_retriever import AsyncParentDocumentRetriever from swarms.server.async_parent_document_retriever import (
AsyncParentDocumentRetriever,
)
STORE_TYPE = "local" # "redis" or "local"
store_type = "local" # "redis" or "local"
class VectorStorage: class VectorStorage:
def __init__(self, directoryOrUrl, useGPU=False): """Vector storage class handles loading documents from a given directory."""
def __init__(self, directory, use_gpu=False):
self.embeddings = HuggingFaceBgeEmbeddings( self.embeddings = HuggingFaceBgeEmbeddings(
cache_folder="./.embeddings", cache_folder="./.embeddings",
model_name="BAAI/bge-large-en", model_name="BAAI/bge-large-en",
model_kwargs={"device": "cuda" if useGPU else "cpu"}, model_kwargs={"device": "cuda" if use_gpu else "cpu"},
encode_kwargs={"normalize_embeddings": True}, encode_kwargs={"normalize_embeddings": True},
query_instruction="Represent this sentence for searching relevant passages: ", query_instruction="Represent this sentence for searching relevant passages: ",
) )
self.directoryOrUrl = directoryOrUrl self.directory = directory
self.child_splitter = RecursiveCharacterTextSplitter( self.child_splitter = RecursiveCharacterTextSplitter(
chunk_size=200, chunk_overlap=20 chunk_size=200, chunk_overlap=20
) )
self.parent_splitter = RecursiveCharacterTextSplitter( self.parent_splitter = RecursiveCharacterTextSplitter(
chunk_size=2000, chunk_overlap=200 chunk_size=2000, chunk_overlap=200
) )
if store_type == "redis": if STORE_TYPE == "redis":
from langchain.storage import RedisStore from langchain.storage import RedisStore
from langchain.utilities.redis import get_client from langchain.utilities.redis import get_client
username = r"username" username = r"username"
password = r"password" password = r"password"
client = get_client( client = get_client(
redis_url=f"redis://{username}:{password}@redis-10854.c282.east-us-mz.azure.cloud.redislabs.com:10854" redis_url=f"redis://{username}:{password}@localhost:6239"
) )
self.store = RedisStore(client=client) self.store = RedisStore(client=client)
else: else:
@ -49,7 +56,7 @@ class VectorStorage:
anonymized_telemetry=False, anonymized_telemetry=False,
) )
# create a new vectorstore or get an existing one, with default collection # create a new vectorstore or get an existing one, with default collection
self.vectorstore = self.getVectorStore() self.vectorstore = self.get_vector_store()
self.client = self.vectorstore._client self.client = self.vectorstore._client
self.retrievers: Dict[str, BaseRetriever] = {} self.retrievers: Dict[str, BaseRetriever] = {}
# default retriever for when no collection title is specified # default retriever for when no collection title is specified
@ -57,22 +64,25 @@ class VectorStorage:
str(self.vectorstore._LANGCHAIN_DEFAULT_COLLECTION_NAME) str(self.vectorstore._LANGCHAIN_DEFAULT_COLLECTION_NAME)
] = self.vectorstore.as_retriever() ] = self.vectorstore.as_retriever()
async def initRetrievers(self, directories: list[str] | None = None): async def init_retrievers(self, directories: list[str] | None = None):
"""Initializes the vector storage retrievers."""
start_time = datetime.now() start_time = datetime.now()
print(f"Start vectorstore initialization time: {start_time}") print(f"Start vectorstore initialization time: {start_time}")
# for each subdirectory in the directory, create a new collection if it doesn't exist # for each subdirectory in the directory, create a new collection if it doesn't exist
dirs = directories or os.listdir(self.directoryOrUrl) dirs = directories or os.listdir(self.directory)
# make sure the subdir is not a file on MacOS (which has a hidden .DS_Store file) # make sure the subdir is not a file on MacOS (which has a hidden .DS_Store file)
dirs = [ dirs = [
subdir subdir
for subdir in dirs for subdir in dirs
if not os.path.isfile(f"{self.directoryOrUrl}/{subdir}") if not os.path.isfile(f"{self.directory}/{subdir}")
] ]
print(f"{len(dirs)} subdirectories to load: {dirs}") print(f"{len(dirs)} subdirectories to load: {dirs}")
self.retrievers[self.directoryOrUrl] = await self.initRetriever(self.directoryOrUrl) self.retrievers[self.directory] = await self.init_retriever(
self.directory
)
end_time = datetime.now() end_time = datetime.now()
print("Vectorstore initialization complete.") print("Vectorstore initialization complete.")
print(f"Vectorstore initialization end time: {end_time}") print(f"Vectorstore initialization end time: {end_time}")
@ -80,110 +90,140 @@ class VectorStorage:
return self.retrievers return self.retrievers
async def initRetriever(self, subdir: str) -> BaseRetriever: async def init_retriever(self, subdir: str) -> BaseRetriever:
""" Initialize each retriever. """
# Ensure only one process/thread is executing this method at a time # Ensure only one process/thread is executing this method at a time
lock = asyncio.Lock() lock = asyncio.Lock()
async with lock: async with lock:
# subdir_start_time = datetime.now() subdir_start_time = datetime.now()
# print(f"Start {subdir} processing time: {subdir_start_time}") print(f"Start {subdir} processing time: {subdir_start_time}")
# # get all existing collections # get all existing collections
# collections = self.client.list_collections() collections = self.client.list_collections()
# print(f"Existing collections: {collections}") print(f"Existing collections: {collections}")
# # Initialize an empty list to hold the documents # Initialize an empty list to hold the documents
# documents = [] documents = []
# # Define the maximum number of files to load at a time # Define the maximum number of files to load at a time
# max_files = 1000 max_files = 1000
# # Load existing metadata # Load existing metadata
# metadata_file = f"{self.directoryOrUrl}/metadata.json" metadata_file = f"{self.directory}/metadata.json"
# metadata = {"processDate": str(datetime.now()), "processed_files": []} metadata = {
# processed_files = set() # Track processed files "processDate": str(datetime.now()),
# if os.path.isfile(metadata_file): "processed_files": [],
# with open(metadata_file, "r") as metadataFile: }
# metadata = dict[str, str](json.load(metadataFile)) processed_files = set() # Track processed files
# processed_files = {entry["file"] for entry in metadata.get("processed_files", [])} if os.path.isfile(metadata_file):
with open(
# # Get a list of all files in the directory and exclude processed files metadata_file, "r", encoding="utf-8"
# all_files = [ ) as metadata_file:
# file for file in glob.glob(f"{self.directoryOrUrl}/**/*.md", recursive=True) metadata = dict[str, str](json.load(metadata_file))
# if file not in processed_files processed_files = {
# ] entry["file"]
for entry in metadata.get("processed_files", [])
# print(f"Loading {len(all_files)} documents for title version {subdir}.") }
# # Load files in chunks of max_files
# for i in range(0, len(all_files), max_files): # Get a list of all files in the directory and exclude processed files
# chunksStartTime = datetime.now() all_files = [
# chunk_files = all_files[i : i + max_files] file
# for file in chunk_files: for file in glob.glob(
# loader = UnstructuredMarkdownLoader( f"{self.directory}/**/*.md", recursive=True
# file, )
# mode="single", if file not in processed_files
# strategy="fast" ]
# )
# print(f"Loaded {file} in {subdir} ...") print(
# documents.extend(loader.load()) f"Loading {len(all_files)} documents for title version {subdir}."
)
# # Record the file as processed in metadata # Load files in chunks of max_files
# metadata["processed_files"].append({ for i in range(0, len(all_files), max_files):
# "file": file, chunks_start_time = datetime.now()
# "processed_at": str(datetime.now()) chunk_files = all_files[i : i + max_files]
# }) for file in chunk_files:
loader = UnstructuredMarkdownLoader(
# print(f"Creating new collection for {self.directoryOrUrl}...") file, mode="single", strategy="fast"
# # Create or get the collection )
# collection = self.client.create_collection( print(f"Loaded {file} in {subdir} ...")
# name=self.directoryOrUrl, documents.extend(loader.load())
# get_or_create=True,
# metadata={"processDate": metadata["processDate"]}, # Record the file as processed in metadata
# ) metadata["processed_files"].append(
{"file": file, "processed_at": str(datetime.now())}
# # Reload vectorstore based on collection )
# vectorstore = self.getVectorStore(collection_name=self.directoryOrUrl)
print(
# # Create a new parent document retriever f"Creating new collection for {self.directory}..."
# retriever = AsyncParentDocumentRetriever( )
# docstore=self.store, # Create or get the collection
# vectorstore=vectorstore, collection = self.client.create_collection(
# child_splitter=self.child_splitter, name=self.directory,
# parent_splitter=self.parent_splitter, get_or_create=True,
# ) metadata={"processDate": metadata["processDate"]},
)
# # force reload of collection to make sure we don't have the default langchain collection
# collection = self.client.get_collection(name=self.directoryOrUrl) # Reload vectorstore based on collection
# vectorstore = self.getVectorStore(collection_name=self.directoryOrUrl) vectorstore = self.get_vector_store(
collection_name=collection.name
# # Add documents to the collection and docstore )
# print(f"Adding {len(documents)} documents to collection...")
# add_docs_start_time = datetime.now() # Create a new parent document retriever
# await retriever.aadd_documents( retriever = AsyncParentDocumentRetriever(
# documents=documents, add_to_docstore=True docstore=self.store,
# ) vectorstore=vectorstore,
# add_docs_end_time = datetime.now() child_splitter=self.child_splitter,
# print( parent_splitter=self.parent_splitter,
# f"Adding {len(documents)} documents to collection took: {add_docs_end_time - add_docs_start_time}" )
# )
# force reload of collection to make sure we don't have
# documents = [] # clear documents list for next chunk # the default langchain collection
collection = self.client.get_collection(
# # Save metadata to the metadata.json file name=self.directory
# with open(metadata_file, "w") as metadataFile: )
# json.dump(metadata, metadataFile, indent=4) vectorstore = self.get_vector_store(
collection_name=self.directory
# print(f"Loaded {len(documents)} documents for directory '{subdir}'.") )
# chunksEndTime = datetime.now()
# print( # Add documents to the collection and docstore
# f"{max_files} markdown file chunks processing time: {chunksEndTime - chunksStartTime}" print(
# ) f"Adding {len(documents)} documents to collection..."
)
# subdir_end_time = datetime.now() add_docs_start_time = datetime.now()
# print(f"Subdir {subdir} processing end time: {subdir_end_time}") await retriever.aadd_documents(
# print(f"Time taken: {subdir_end_time - subdir_start_time}") documents=documents, add_to_docstore=True
)
add_docs_end_time = datetime.now()
total_time = add_docs_end_time - add_docs_start_time
print(
f"Adding {len(documents)} documents to collection took: {total_time}"
)
documents = [] # clear documents list for next chunk
# Save metadata to the metadata.json file
with open(
metadata_file, "w", encoding="utf-8"
) as metadata_file:
json.dump(metadata, metadata_file, indent=4)
print(
f"Loaded {len(documents)} documents for directory '{subdir}'."
)
chunks_end_time = datetime.now()
chunk_time = chunks_end_time - chunks_start_time
print(
f"{max_files} markdown file chunks processing time: {chunk_time}"
)
subdir_end_time = datetime.now()
print(
f"Subdir {subdir} processing end time: {subdir_end_time}"
)
print(f"Time taken: {subdir_end_time - subdir_start_time}")
# Reload vectorstore based on collection to pass to parent doc retriever # Reload vectorstore based on collection to pass to parent doc retriever
# collection = self.client.get_collection(name=self.directoryOrUrl) # collection = self.client.get_collection(name=self.directory)
vectorstore = self.getVectorStore() vectorstore = self.get_vector_store()
retriever = AsyncParentDocumentRetriever( retriever = AsyncParentDocumentRetriever(
docstore=self.store, docstore=self.store,
vectorstore=vectorstore, vectorstore=vectorstore,
@ -192,8 +232,9 @@ class VectorStorage:
) )
return retriever return retriever
def getVectorStore(self, collection_name: str | None = None) -> Chroma: def get_vector_store(self, collection_name: str | None = None) -> Chroma:
if collection_name is None or "" or "None" : """ get a specific vector store for a collection """
if collection_name is None or "" or "None":
collection_name = "langchain" collection_name = "langchain"
print("collection_name: " + collection_name) print("collection_name: " + collection_name)
vectorstore = Chroma( vectorstore = Chroma(
@ -204,21 +245,24 @@ class VectorStorage:
return vectorstore return vectorstore
def list_collections(self): def list_collections(self):
""" Get a list of all collections in the vectorstore """
vectorstore = Chroma( vectorstore = Chroma(
client_settings=self.settings, embedding_function=self.embeddings client_settings=self.settings,
embedding_function=self.embeddings,
) )
return vectorstore._client.list_collections() return vectorstore._client.list_collections()
async def getRetriever(self, collection_name: str | None = None): async def get_retriever(self, collection_name: str | None = None):
""" get a specific retriever for a collection in the vectorstore """
if self.retrievers is None: if self.retrievers is None:
self.retrievers = await self.initRetrievers() self.retrievers = await self.init_retrievers()
if ( if (
collection_name is None collection_name is None
or collection_name == "" or collection_name == ""
or collection_name == "None" or collection_name == "None"
): ):
name = str(Chroma._LANGCHAIN_DEFAULT_COLLECTION_NAME) name = "swarms"
else: else:
name = collection_name name = collection_name
@ -226,6 +270,8 @@ class VectorStorage:
retriever = self.retrievers[name] retriever = self.retrievers[name]
except KeyError: except KeyError:
print(f"Retriever for {name} not found, using default...") print(f"Retriever for {name} not found, using default...")
retriever = self.retrievers[Chroma._LANGCHAIN_DEFAULT_COLLECTION_NAME] retriever = self.retrievers[
"swarms"
]
return retriever return retriever

Loading…
Cancel
Save