You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/playground/demos/agent_in_5/chroma_db.py

186 lines
5.3 KiB

6 months ago
import logging
import os
import uuid
from typing import List, Optional
import chromadb
import numpy as np
from dotenv import load_dotenv
from swarms.utils.data_to_text import data_to_text
from swarms.utils.markdown_message import display_markdown_message
# Load environment variables
load_dotenv()
# Results storage using local ChromaDB
class ChromaDB:
"""
ChromaDB database
Args:
metric (str): The similarity metric to use.
output (str): The name of the collection to store the results in.
limit_tokens (int, optional): The maximum number of tokens to use for the query. Defaults to 1000.
n_results (int, optional): The number of results to retrieve. Defaults to 2.
Methods:
add: _description_
query: _description_
Examples:
>>> chromadb = ChromaDB(
>>> metric="cosine",
>>> output="results",
>>> llm="gpt3",
>>> openai_api_key=OPENAI_API_KEY,
>>> )
>>> chromadb.add(task, result, result_id)
"""
def __init__(
self,
metric: str = "cosine",
output_dir: str = "swarms",
limit_tokens: Optional[int] = 1000,
n_results: int = 2,
docs_folder: Optional[str] = None,
verbose: bool = False,
*args,
**kwargs,
):
self.metric = metric
self.output_dir = output_dir
self.limit_tokens = limit_tokens
self.n_results = n_results
self.docs_folder = docs_folder
self.verbose = verbose
# Disable ChromaDB logging
if verbose:
logging.getLogger("chromadb").setLevel(logging.INFO)
# Create Chroma collection
chroma_persist_dir = "chroma"
chroma_client = chromadb.PersistentClient(
settings=chromadb.config.Settings(
persist_directory=chroma_persist_dir,
),
*args,
**kwargs,
)
# Create ChromaDB client
self.client = chromadb.Client()
# Create Chroma collection
self.collection = chroma_client.get_or_create_collection(
name=output_dir,
metadata={"hnsw:space": metric},
*args,
**kwargs,
)
display_markdown_message(
"ChromaDB collection created:"
f" {self.collection.name} with metric: {self.metric} and"
f" output directory: {self.output_dir}"
)
# If docs
if docs_folder:
display_markdown_message(
f"Traversing directory: {docs_folder}"
)
self.traverse_directory()
def add(
self,
document: str,
images: List[np.ndarray] = None,
img_urls: List[str] = None,
*args,
**kwargs,
):
"""
Add a document to the ChromaDB collection.
Args:
document (str): The document to be added.
condition (bool, optional): The condition to check before adding the document. Defaults to True.
Returns:
str: The ID of the added document.
"""
try:
doc_id = str(uuid.uuid4())
self.collection.add(
ids=[doc_id],
documents=[document],
images=images,
uris=img_urls,
*args,
**kwargs,
)
return doc_id
except Exception as e:
raise Exception(f"Failed to add document: {str(e)}")
def query(
self,
query_text: str,
query_images: List[np.ndarray],
*args,
**kwargs,
):
"""
Query documents from the ChromaDB collection.
Args:
query (str): The query string.
n_docs (int, optional): The number of documents to retrieve. Defaults to 1.
Returns:
dict: The retrieved documents.
"""
try:
docs = self.collection.query(
query_texts=[query_text],
query_images=query_images,
n_results=self.n_docs,
*args,
**kwargs,
)["documents"]
return docs[0]
except Exception as e:
raise Exception(f"Failed to query documents: {str(e)}")
def traverse_directory(self):
"""
Traverse through every file in the given directory and its subdirectories,
and return the paths of all files.
Parameters:
- directory_name (str): The name of the directory to traverse.
Returns:
- list: A list of paths to each file in the directory and its subdirectories.
"""
image_extensions = [
".jpg",
".jpeg",
".png",
]
images = []
for root, dirs, files in os.walk(self.docs_folder):
for file in files:
_, ext = os.path.splitext(file)
if ext.lower() in image_extensions:
images.append(os.path.join(root, file))
else:
data = data_to_text(file)
added_to_db = self.add([data])
print(f"{file} added to Database")
if images:
added_to_db = self.add(img_urls=[images])
print(f"{len(images)} images added to Database ")
return added_to_db