You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/playground/memory/chroma_db.py

200 lines
5.7 KiB

import logging
11 months ago
import os
12 months ago
import uuid
11 months ago
from typing import Callable, List, Optional
1 year ago
import chromadb
11 months ago
import numpy as np
from dotenv import load_dotenv
12 months ago
from swarms.utils.data_to_text import data_to_text
from swarms.utils.markdown_message import display_markdown_message
9 months ago
from swarms.memory.base_vectordb import BaseVectorDatabase
1 year ago
12 months ago
# Load environment variables
load_dotenv()
# Results storage using local ChromaDB
9 months ago
class ChromaDB(BaseVectorDatabase):
"""
ChromaDB database
Args:
12 months ago
metric (str): The similarity metric to use.
output (str): The name of the collection to store the results in.
limit_tokens (int, optional): The maximum number of tokens to use for the query. Defaults to 1000.
n_results (int, optional): The number of results to retrieve. Defaults to 2.
Methods:
add: _description_
query: _description_
Examples:
>>> chromadb = ChromaDB(
>>> metric="cosine",
12 months ago
>>> output="results",
>>> llm="gpt3",
>>> openai_api_key=OPENAI_API_KEY,
>>> )
>>> chromadb.add(task, result, result_id)
"""
def __init__(
self,
metric: str = "cosine",
output_dir: str = "swarms",
limit_tokens: Optional[int] = 1000,
12 months ago
n_results: int = 2,
embedding_function: Callable = None,
docs_folder: str = None,
verbose: bool = False,
1 year ago
*args,
**kwargs,
):
self.metric = metric
12 months ago
self.output_dir = output_dir
self.limit_tokens = limit_tokens
12 months ago
self.n_results = n_results
self.docs_folder = docs_folder
self.verbose = verbose
# Disable ChromaDB logging
if verbose:
logging.getLogger("chromadb").setLevel(logging.INFO)
12 months ago
# Create Chroma collection
chroma_persist_dir = "chroma"
chroma_client = chromadb.PersistentClient(
settings=chromadb.config.Settings(
persist_directory=chroma_persist_dir,
12 months ago
),
*args,
**kwargs,
)
12 months ago
# Embedding model
if embedding_function:
self.embedding_function = embedding_function
else:
self.embedding_function = None
# Create ChromaDB client
self.client = chromadb.Client()
# Create Chroma collection
self.collection = chroma_client.get_or_create_collection(
12 months ago
name=output_dir,
metadata={"hnsw:space": metric},
12 months ago
embedding_function=self.embedding_function,
11 months ago
# data_loader=self.data_loader,
12 months ago
*args,
**kwargs,
)
display_markdown_message(
"ChromaDB collection created:"
f" {self.collection.name} with metric: {self.metric} and"
f" output directory: {self.output_dir}"
)
# If docs
if docs_folder:
display_markdown_message(
f"Traversing directory: {docs_folder}"
)
self.traverse_directory()
def add(
12 months ago
self,
document: str,
images: List[np.ndarray] = None,
img_urls: List[str] = None,
*args,
**kwargs,
):
12 months ago
"""
Add a document to the ChromaDB collection.
Args:
12 months ago
document (str): The document to be added.
condition (bool, optional): The condition to check before adding the document. Defaults to True.
12 months ago
Returns:
str: The ID of the added document.
"""
try:
12 months ago
doc_id = str(uuid.uuid4())
self.collection.add(
ids=[doc_id],
documents=[document],
images=images,
uris=img_urls,
*args,
**kwargs,
)
12 months ago
return doc_id
except Exception as e:
raise Exception(f"Failed to add document: {str(e)}")
12 months ago
def query(
self,
query_text: str,
query_images: List[np.ndarray],
*args,
**kwargs,
):
"""
Query documents from the ChromaDB collection.
Args:
12 months ago
query (str): The query string.
n_docs (int, optional): The number of documents to retrieve. Defaults to 1.
Returns:
12 months ago
dict: The retrieved documents.
"""
try:
12 months ago
docs = self.collection.query(
query_texts=[query_text],
query_images=query_images,
n_results=self.n_docs,
1 year ago
*args,
**kwargs,
12 months ago
)["documents"]
return docs[0]
except Exception as e:
raise Exception(f"Failed to query documents: {str(e)}")
def traverse_directory(self):
"""
Traverse through every file in the given directory and its subdirectories,
and return the paths of all files.
Parameters:
- directory_name (str): The name of the directory to traverse.
Returns:
- list: A list of paths to each file in the directory and its subdirectories.
"""
added_to_db = False
image_extensions = [
".jpg",
".jpeg",
".png",
]
images = []
for root, dirs, files in os.walk(self.docs_folder):
for file in files:
_, ext = os.path.splitext(file)
if ext.lower() in image_extensions:
images.append(os.path.join(root, file))
else:
data = data_to_text(file)
added_to_db = self.add([data])
print(f"{file} added to Database")
if images:
added_to_db = self.add(img_urls=[images])
print(f"{len(images)} images added to Database ")
return added_to_db