fix linting errors

pull/570/head
Richard Anthony Hein 8 months ago
parent 4fae6839eb
commit f37223d49e

@ -1,38 +1,36 @@
""" Chatbot with RAG Server """
import asyncio
import json
import logging
import os
from datetime import datetime
from typing import List
# import torch
from contextlib import asynccontextmanager
import langchain
from pydantic import ValidationError, parse_obj_as
from swarms.prompts.chat_prompt import Message
from swarms.server.callback_handlers import SourceDocumentsStreamingCallbackHandler, TokenStreamingCallbackHandler
import tiktoken
# import torch
from dotenv import load_dotenv
from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse
from fastapi.routing import APIRouter
from fastapi.staticfiles import StaticFiles
from huggingface_hub import login
from langchain.callbacks import StreamingStdOutCallbackHandler
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.conversational_retrieval.base import (
ConversationalRetrievalChain,
)
from langchain.chains.llm import LLMChain
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
from langchain.chains.history_aware_retriever import create_history_aware_retriever
from langchain.chains.retrieval import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.memory.chat_message_histories.in_memory import (
ChatMessageHistory,
)
from langchain.prompts.prompt import PromptTemplate
from langchain_community.chat_models import ChatOpenAI
from swarms.server.responses import LangchainStreamingResponse
from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
from langchain.chains.llm import LLMChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
# from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
# from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from swarms.prompts.chat_prompt import Message
from swarms.prompts.conversational_RAG import (
B_INST,
B_SYS,
@ -41,40 +39,35 @@ from swarms.prompts.conversational_RAG import (
E_INST,
E_SYS,
QA_PROMPT_TEMPLATE,
QA_PROMPT_TEMPLATE_STR,
QA_CONDENSE_TEMPLATE_STR,
)
from swarms.server.responses import LangchainStreamingResponse
from swarms.server.server_models import ChatRequest, Role
from swarms.server.vector_store import VectorStorage
from swarms.server.server_models import (
ChatRequest,
LogMessage,
AIModel,
AIModels,
RAGFile,
RAGFiles,
Role,
State,
GetRAGFileStateRequest,
ProcessRAGFileRequest
)
# Explicitly specify the path to the .env file
# Two folders above the current file's directory
dotenv_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), '.env')
dotenv_path = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))), ".env"
)
load_dotenv(dotenv_path)
hf_token = os.environ.get("HUGGINFACEHUB_API_KEY") # Get the Huggingface API Token
uploads = os.environ.get("UPLOADS") # Directory where user uploads files to be parsed for RAG
hf_token = os.environ.get(
"HUGGINFACEHUB_API_KEY"
) # Get the Huggingface API Token
uploads = os.environ.get(
"UPLOADS"
) # Directory where user uploads files to be parsed for RAG
model_dir = os.environ.get("MODEL_DIR")
# hugginface.co model (eg. meta-llama/Llama-2-70b-hf)
model_name = os.environ.get("MODEL_NAME")
# Set OpenAI's API key to 'EMPTY' and API base URL to use vLLM's API server, or set them to OpenAI API key and base URL.
# Set OpenAI's API key to 'EMPTY' and API base URL to use vLLM's API server
# or set them to OpenAI API key and base URL.
openai_api_key = os.environ.get("OPENAI_API_KEY") or "EMPTY"
openai_api_base = os.environ.get("OPENAI_API_BASE") or "http://localhost:8000/v1"
openai_api_base = (
os.environ.get("OPENAI_API_BASE") or "http://localhost:8000/v1"
)
env_vars = [
hf_token,
@ -93,13 +86,13 @@ if missing_vars:
exit(1)
useMetal = os.environ.get("USE_METAL", "False") == "True"
useGPU = os.environ.get("USE_GPU", "False") == "True"
use_gpu = os.environ.get("USE_GPU", "False") == "True"
print(f"Uploads={uploads}")
print(f"MODEL_DIR={model_dir}")
print(f"MODEL_NAME={model_name}")
print(f"USE_METAL={useMetal}")
print(f"USE_GPU={useGPU}")
print(f"USE_GPU={use_gpu}")
print(f"OPENAI_API_KEY={openai_api_key}")
print(f"OPENAI_API_BASE={openai_api_base}")
@ -116,23 +109,25 @@ login(token=hf_token) # login to huggingface.co
langchain.debug = True
langchain.verbose = True
from contextlib import asynccontextmanager
@asynccontextmanager
async def lifespan(app: FastAPI):
asyncio.create_task(vector_store.initRetrievers())
"""Initializes the vector store in a background task."""
print(f"Initializing vector store retrievers for {app.title}.")
asyncio.create_task(vector_store.init_retrievers())
yield
app = FastAPI(title="Chatbot", lifespan=lifespan)
chatbot = FastAPI(title="Chatbot", lifespan=lifespan)
router = APIRouter()
current_dir = os.path.dirname(__file__)
print("current_dir: " + current_dir)
static_dir = os.path.join(current_dir, "static")
print("static_dir: " + static_dir)
app.mount(static_dir, StaticFiles(directory=static_dir), name="static")
chatbot.mount(static_dir, StaticFiles(directory=static_dir), name="static")
app.add_middleware(
chatbot.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
@ -147,21 +142,15 @@ if not os.path.exists(uploads):
os.makedirs(uploads)
# Initialize the vector store
vector_store = VectorStorage(directoryOrUrl=uploads, useGPU=useGPU)
vector_store = VectorStorage(directory=uploads, use_gpu=use_gpu)
async def create_chain(
messages: list[Message],
model=model_dir,
max_tokens_to_gen=2048,
temperature=0.5,
prompt: PromptTemplate = QA_PROMPT_TEMPLATE,
file: RAGFile | None = None,
key: str | None = None,
):
print(
f"Creating chain with key={key}, model={model}, max_tokens={max_tokens_to_gen}, temperature={temperature}, prompt={prompt}, file={file.title}"
)
"""Creates the RAG Langchain conversational retrieval chain."""
print("Creating chain ...")
llm = ChatOpenAI(
api_key=openai_api_key,
@ -181,7 +170,7 @@ async def create_chain(
# if llm is VLLMAsync:
# llm.max_tokens = max_tokens_to_gen
retriever = await vector_store.getRetriever()
retriever = await vector_store.get_retriever()
chat_memory = ChatMessageHistory()
for message in messages:
@ -236,26 +225,26 @@ async def create_chain(
router = APIRouter()
@router.post(
"/chat",
summary="Chatbot",
description="Chatbot AI Service",
)
async def chat(request: ChatRequest):
""" Handles chatbot chat POST requests """
chain = await create_chain(
file=request.file,
messages=request.messages[:-1],
model=request.model.id,
max_tokens_to_gen=request.maxTokens,
temperature=request.temperature,
prompt=PromptTemplate.from_template(
f"{B_INST}{B_SYS}{request.prompt.strip()}{E_SYS}{E_INST}"
),
)
json = {
json_config = {
"question": request.messages[-1].content,
"chat_history": [message.content for message in request.messages[:-1]],
"chat_history": [
message.content for message in request.messages[:-1]
],
# "callbacks": [
# StreamingStdOutCallbackHandler(),
# TokenStreamingCallbackHandler(output_key="answer"),
@ -264,178 +253,41 @@ async def chat(request: ChatRequest):
}
return LangchainStreamingResponse(
chain,
config=json,
config=json_config,
)
app.include_router(router, tags=["chat"])
chatbot.include_router(router, tags=["chat"])
@app.get("/")
@chatbot.get("/")
def root():
return {"message": "Chatbot API"}
"""Swarms Chatbot API Root"""
return {"message": "Swarms Chatbot API"}
@app.get("/favicon.ico")
@chatbot.get("/favicon.ico")
def favicon():
""" Returns a favicon """
file_name = "favicon.ico"
file_path = os.path.join(app.root_path, "static", file_name)
file_path = os.path.join(chatbot.root_path, "static", file_name)
return FileResponse(
path=file_path,
headers={"Content-Disposition": "attachment; filename=" + file_name},
headers={
"Content-Disposition": "attachment; filename=" + file_name
},
)
@app.post("/log")
def log_message(log_message: LogMessage):
try:
with open("log.txt", "a") as log_file:
log_file.write(log_message.message + "\n")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error saving log: {e}")
return {"message": "Log saved successfully"}
@app.get("/models")
def get_models():
# llama7B = AIModel(
# id="llama-2-7b-chat-ggml-q4_0",
# name="llama-2-7b-chat-ggml-q4_0",
# maxLength=2048,
# tokenLimit=2048,
# )
# llama13B = AIModel(
# id="llama-2-13b-chat-ggml-q4_0",
# name="llama-2-13b-chat-ggml-q4_0",
# maxLength=2048,
# tokenLimit=2048,
# )
llama70B = AIModel(
id="llama-2-70b.Q5_K_M",
name="llama-2-70b.Q5_K_M",
maxLength=2048,
tokenLimit=2048,
)
models = AIModels(models=[llama70B])
return models
@app.get("/titles")
def getTitles():
titles = RAGFiles(
titles=[
# RAGFile(
# versionId="d8ad3b1d-c33c-4524-9691-e93967d4d863",
# title="d8ad3b1d-c33c-4524-9691-e93967d4d863",
# state=State.Unavailable,
# ),
RAGFile(
versionId=collection.name,
title=collection.name,
state=State.InProcess
if collection.name in processing_books
else State.Processed,
)
for collection in vector_store.list_collections()
if collection.name != "langchain"
]
)
return titles
processing_books: list[str] = []
processing_books_lock = asyncio.Lock()
logging.basicConfig(level=logging.ERROR)
@app.post("/titleState")
async def getTitleState(request: GetRAGFileStateRequest):
# FastAPI + Pydantic will throw a 422 Unprocessable Entity if the request isn't the right type.
# try:
logging.debug(f"Received getTitleState request: {request}")
titleStateRequest: GetRAGFileStateRequest = request
# except ValidationError as e:
# print(f"Error validating JSON: {e}")
# raise HTTPException(status_code=422, detail=str(e))
# except json.JSONDecodeError as e:
# print(f"Error parsing JSON: {e}")
# raise HTTPException(status_code=422, detail="Invalid JSON format")
# check to see if the book has already been processed.
# return the proper State directly to response.
matchingCollection = next(
(
x
for x in vector_store.list_collections()
if x.name == titleStateRequest.versionRef
),
None,
@chatbot.exception_handler(HTTPException)
async def http_exception_handler(r: Request, exc: HTTPException):
"""Log and return exception details in response."""
logging.error(
"HTTPException: %s executing request: %s", exc.detail, r.base_url
)
return JSONResponse(
status_code=exc.status_code, content={"detail": exc.detail}
)
print("Got a Title State request for version " + titleStateRequest.versionRef)
if titleStateRequest.versionRef in processing_books:
return {"message": State.InProcess}
elif matchingCollection is not None:
return {"message": State.Processed}
else:
return {"message": State.Unavailable}
@app.post("/processRAGFile")
async def processRAGFile(
request: str = Form(...),
files: List[UploadFile] = File(...),
):
try:
logging.debug(f"Received processBook request: {request}")
# Parse the JSON string into a ProcessBookRequest object
fileRAGRequest: ProcessRAGFileRequest = parse_obj_as(
ProcessRAGFileRequest, json.loads(request)
)
except ValidationError as e:
print(f"Error validating JSON: {e}")
raise HTTPException(status_code=422, detail=str(e))
except json.JSONDecodeError as e:
print(f"Error parsing JSON: {e}")
raise HTTPException(status_code=422, detail="Invalid JSON format")
try:
print(
f"Processing file {fileRAGRequest.filename} for user {fileRAGRequest.username}."
)
# check to see if the file has already been processed.
# write html to subfolder
print(f"Writing file to path: {fileRAGRequest.username}/{fileRAGRequest.filename}...")
for index, segment in enumerate(files):
filename = segment.filename if segment.filename else str(index)
subDir = f"{fileRAGRequest.username}"
with open(os.path.join(subDir, filename), "wb") as htmlFile:
htmlFile.write(await segment.read())
# write metadata to subfolder
print(f"Writing metadata to subfolder {fileRAGRequest.username}...")
with open(os.path.join({fileRAGRequest.username}, "metadata.json"), "w") as metadataFile:
metaData = {
"filename": fileRAGRequest.filename,
"username": fileRAGRequest.username,
"processDate": datetime.now().isoformat(),
}
metadataFile.write(json.dumps(metaData))
vector_store.retrievers[
f"{fileRAGRequest.username}/{fileRAGRequest.filename}"
] = await vector_store.initRetriever(f"{fileRAGRequest.username}/{fileRAGRequest.filename}")
return {
"message": f"File {fileRAGRequest.filename} processed successfully."
}
except Exception as e:
logging.error(f"Error processing book: {e}")
return {"message": f"Error processing book: {e}"}
@app.exception_handler(HTTPException)
async def http_exception_handler(bookRequest: Request, exc: HTTPException):
logging.error(f"HTTPException: {exc.detail}")
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})

@ -1,9 +1,11 @@
""" Vector storage with RAG (Retrieval Augmented Generation) support for Markdown."""
import asyncio
import json
import os
import glob
from datetime import datetime
from typing import Dict, Literal
from typing import Dict
from chromadb.config import Settings
from langchain.document_loaders.markdown import UnstructuredMarkdownLoader
from langchain.embeddings import HuggingFaceBgeEmbeddings
@ -11,34 +13,39 @@ from langchain.storage import LocalFileStore
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores.chroma import Chroma
from langchain.schema import BaseRetriever
from swarms.server.async_parent_document_retriever import AsyncParentDocumentRetriever
from swarms.server.async_parent_document_retriever import (
AsyncParentDocumentRetriever,
)
STORE_TYPE = "local" # "redis" or "local"
store_type = "local" # "redis" or "local"
class VectorStorage:
def __init__(self, directoryOrUrl, useGPU=False):
"""Vector storage class handles loading documents from a given directory."""
def __init__(self, directory, use_gpu=False):
self.embeddings = HuggingFaceBgeEmbeddings(
cache_folder="./.embeddings",
model_name="BAAI/bge-large-en",
model_kwargs={"device": "cuda" if useGPU else "cpu"},
model_kwargs={"device": "cuda" if use_gpu else "cpu"},
encode_kwargs={"normalize_embeddings": True},
query_instruction="Represent this sentence for searching relevant passages: ",
)
self.directoryOrUrl = directoryOrUrl
self.directory = directory
self.child_splitter = RecursiveCharacterTextSplitter(
chunk_size=200, chunk_overlap=20
)
self.parent_splitter = RecursiveCharacterTextSplitter(
chunk_size=2000, chunk_overlap=200
)
if store_type == "redis":
if STORE_TYPE == "redis":
from langchain.storage import RedisStore
from langchain.utilities.redis import get_client
username = r"username"
password = r"password"
client = get_client(
redis_url=f"redis://{username}:{password}@redis-10854.c282.east-us-mz.azure.cloud.redislabs.com:10854"
redis_url=f"redis://{username}:{password}@localhost:6239"
)
self.store = RedisStore(client=client)
else:
@ -49,7 +56,7 @@ class VectorStorage:
anonymized_telemetry=False,
)
# create a new vectorstore or get an existing one, with default collection
self.vectorstore = self.getVectorStore()
self.vectorstore = self.get_vector_store()
self.client = self.vectorstore._client
self.retrievers: Dict[str, BaseRetriever] = {}
# default retriever for when no collection title is specified
@ -57,22 +64,25 @@ class VectorStorage:
str(self.vectorstore._LANGCHAIN_DEFAULT_COLLECTION_NAME)
] = self.vectorstore.as_retriever()
async def initRetrievers(self, directories: list[str] | None = None):
async def init_retrievers(self, directories: list[str] | None = None):
"""Initializes the vector storage retrievers."""
start_time = datetime.now()
print(f"Start vectorstore initialization time: {start_time}")
# for each subdirectory in the directory, create a new collection if it doesn't exist
dirs = directories or os.listdir(self.directoryOrUrl)
dirs = directories or os.listdir(self.directory)
# make sure the subdir is not a file on MacOS (which has a hidden .DS_Store file)
dirs = [
subdir
for subdir in dirs
if not os.path.isfile(f"{self.directoryOrUrl}/{subdir}")
if not os.path.isfile(f"{self.directory}/{subdir}")
]
print(f"{len(dirs)} subdirectories to load: {dirs}")
self.retrievers[self.directoryOrUrl] = await self.initRetriever(self.directoryOrUrl)
self.retrievers[self.directory] = await self.init_retriever(
self.directory
)
end_time = datetime.now()
print("Vectorstore initialization complete.")
print(f"Vectorstore initialization end time: {end_time}")
@ -80,110 +90,140 @@ class VectorStorage:
return self.retrievers
async def initRetriever(self, subdir: str) -> BaseRetriever:
async def init_retriever(self, subdir: str) -> BaseRetriever:
""" Initialize each retriever. """
# Ensure only one process/thread is executing this method at a time
lock = asyncio.Lock()
async with lock:
# subdir_start_time = datetime.now()
# print(f"Start {subdir} processing time: {subdir_start_time}")
# # get all existing collections
# collections = self.client.list_collections()
# print(f"Existing collections: {collections}")
# # Initialize an empty list to hold the documents
# documents = []
# # Define the maximum number of files to load at a time
# max_files = 1000
# # Load existing metadata
# metadata_file = f"{self.directoryOrUrl}/metadata.json"
# metadata = {"processDate": str(datetime.now()), "processed_files": []}
# processed_files = set() # Track processed files
# if os.path.isfile(metadata_file):
# with open(metadata_file, "r") as metadataFile:
# metadata = dict[str, str](json.load(metadataFile))
# processed_files = {entry["file"] for entry in metadata.get("processed_files", [])}
# # Get a list of all files in the directory and exclude processed files
# all_files = [
# file for file in glob.glob(f"{self.directoryOrUrl}/**/*.md", recursive=True)
# if file not in processed_files
# ]
# print(f"Loading {len(all_files)} documents for title version {subdir}.")
# # Load files in chunks of max_files
# for i in range(0, len(all_files), max_files):
# chunksStartTime = datetime.now()
# chunk_files = all_files[i : i + max_files]
# for file in chunk_files:
# loader = UnstructuredMarkdownLoader(
# file,
# mode="single",
# strategy="fast"
# )
# print(f"Loaded {file} in {subdir} ...")
# documents.extend(loader.load())
# # Record the file as processed in metadata
# metadata["processed_files"].append({
# "file": file,
# "processed_at": str(datetime.now())
# })
# print(f"Creating new collection for {self.directoryOrUrl}...")
# # Create or get the collection
# collection = self.client.create_collection(
# name=self.directoryOrUrl,
# get_or_create=True,
# metadata={"processDate": metadata["processDate"]},
# )
# # Reload vectorstore based on collection
# vectorstore = self.getVectorStore(collection_name=self.directoryOrUrl)
# # Create a new parent document retriever
# retriever = AsyncParentDocumentRetriever(
# docstore=self.store,
# vectorstore=vectorstore,
# child_splitter=self.child_splitter,
# parent_splitter=self.parent_splitter,
# )
# # force reload of collection to make sure we don't have the default langchain collection
# collection = self.client.get_collection(name=self.directoryOrUrl)
# vectorstore = self.getVectorStore(collection_name=self.directoryOrUrl)
# # Add documents to the collection and docstore
# print(f"Adding {len(documents)} documents to collection...")
# add_docs_start_time = datetime.now()
# await retriever.aadd_documents(
# documents=documents, add_to_docstore=True
# )
# add_docs_end_time = datetime.now()
# print(
# f"Adding {len(documents)} documents to collection took: {add_docs_end_time - add_docs_start_time}"
# )
# documents = [] # clear documents list for next chunk
# # Save metadata to the metadata.json file
# with open(metadata_file, "w") as metadataFile:
# json.dump(metadata, metadataFile, indent=4)
# print(f"Loaded {len(documents)} documents for directory '{subdir}'.")
# chunksEndTime = datetime.now()
# print(
# f"{max_files} markdown file chunks processing time: {chunksEndTime - chunksStartTime}"
# )
# subdir_end_time = datetime.now()
# print(f"Subdir {subdir} processing end time: {subdir_end_time}")
# print(f"Time taken: {subdir_end_time - subdir_start_time}")
subdir_start_time = datetime.now()
print(f"Start {subdir} processing time: {subdir_start_time}")
# get all existing collections
collections = self.client.list_collections()
print(f"Existing collections: {collections}")
# Initialize an empty list to hold the documents
documents = []
# Define the maximum number of files to load at a time
max_files = 1000
# Load existing metadata
metadata_file = f"{self.directory}/metadata.json"
metadata = {
"processDate": str(datetime.now()),
"processed_files": [],
}
processed_files = set() # Track processed files
if os.path.isfile(metadata_file):
with open(
metadata_file, "r", encoding="utf-8"
) as metadata_file:
metadata = dict[str, str](json.load(metadata_file))
processed_files = {
entry["file"]
for entry in metadata.get("processed_files", [])
}
# Get a list of all files in the directory and exclude processed files
all_files = [
file
for file in glob.glob(
f"{self.directory}/**/*.md", recursive=True
)
if file not in processed_files
]
print(
f"Loading {len(all_files)} documents for title version {subdir}."
)
# Load files in chunks of max_files
for i in range(0, len(all_files), max_files):
chunks_start_time = datetime.now()
chunk_files = all_files[i : i + max_files]
for file in chunk_files:
loader = UnstructuredMarkdownLoader(
file, mode="single", strategy="fast"
)
print(f"Loaded {file} in {subdir} ...")
documents.extend(loader.load())
# Record the file as processed in metadata
metadata["processed_files"].append(
{"file": file, "processed_at": str(datetime.now())}
)
print(
f"Creating new collection for {self.directory}..."
)
# Create or get the collection
collection = self.client.create_collection(
name=self.directory,
get_or_create=True,
metadata={"processDate": metadata["processDate"]},
)
# Reload vectorstore based on collection
vectorstore = self.get_vector_store(
collection_name=collection.name
)
# Create a new parent document retriever
retriever = AsyncParentDocumentRetriever(
docstore=self.store,
vectorstore=vectorstore,
child_splitter=self.child_splitter,
parent_splitter=self.parent_splitter,
)
# force reload of collection to make sure we don't have
# the default langchain collection
collection = self.client.get_collection(
name=self.directory
)
vectorstore = self.get_vector_store(
collection_name=self.directory
)
# Add documents to the collection and docstore
print(
f"Adding {len(documents)} documents to collection..."
)
add_docs_start_time = datetime.now()
await retriever.aadd_documents(
documents=documents, add_to_docstore=True
)
add_docs_end_time = datetime.now()
total_time = add_docs_end_time - add_docs_start_time
print(
f"Adding {len(documents)} documents to collection took: {total_time}"
)
documents = [] # clear documents list for next chunk
# Save metadata to the metadata.json file
with open(
metadata_file, "w", encoding="utf-8"
) as metadata_file:
json.dump(metadata, metadata_file, indent=4)
print(
f"Loaded {len(documents)} documents for directory '{subdir}'."
)
chunks_end_time = datetime.now()
chunk_time = chunks_end_time - chunks_start_time
print(
f"{max_files} markdown file chunks processing time: {chunk_time}"
)
subdir_end_time = datetime.now()
print(
f"Subdir {subdir} processing end time: {subdir_end_time}"
)
print(f"Time taken: {subdir_end_time - subdir_start_time}")
# Reload vectorstore based on collection to pass to parent doc retriever
# collection = self.client.get_collection(name=self.directoryOrUrl)
vectorstore = self.getVectorStore()
# collection = self.client.get_collection(name=self.directory)
vectorstore = self.get_vector_store()
retriever = AsyncParentDocumentRetriever(
docstore=self.store,
vectorstore=vectorstore,
@ -192,8 +232,9 @@ class VectorStorage:
)
return retriever
def getVectorStore(self, collection_name: str | None = None) -> Chroma:
if collection_name is None or "" or "None" :
def get_vector_store(self, collection_name: str | None = None) -> Chroma:
""" get a specific vector store for a collection """
if collection_name is None or "" or "None":
collection_name = "langchain"
print("collection_name: " + collection_name)
vectorstore = Chroma(
@ -204,21 +245,24 @@ class VectorStorage:
return vectorstore
def list_collections(self):
""" Get a list of all collections in the vectorstore """
vectorstore = Chroma(
client_settings=self.settings, embedding_function=self.embeddings
client_settings=self.settings,
embedding_function=self.embeddings,
)
return vectorstore._client.list_collections()
async def getRetriever(self, collection_name: str | None = None):
async def get_retriever(self, collection_name: str | None = None):
""" get a specific retriever for a collection in the vectorstore """
if self.retrievers is None:
self.retrievers = await self.initRetrievers()
self.retrievers = await self.init_retrievers()
if (
collection_name is None
or collection_name == ""
or collection_name == "None"
):
name = str(Chroma._LANGCHAIN_DEFAULT_COLLECTION_NAME)
name = "swarms"
else:
name = collection_name
@ -226,6 +270,8 @@ class VectorStorage:
retriever = self.retrievers[name]
except KeyError:
print(f"Retriever for {name} not found, using default...")
retriever = self.retrievers[Chroma._LANGCHAIN_DEFAULT_COLLECTION_NAME]
retriever = self.retrievers[
"swarms"
]
return retriever

Loading…
Cancel
Save