You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/playground/memory/chromadb_example.py

187 lines
5.3 KiB

import logging
11 months ago
import os
11 months ago
import uuid
7 months ago
from typing import Optional
1 year ago
import chromadb
from dotenv import load_dotenv
11 months ago
from swarms.utils.data_to_text import data_to_text
from swarms.utils.markdown_message import display_markdown_message
7 months ago
from swarms.memory.base_vectordb import BaseVectorDatabase
1 year ago
11 months ago
# Load environment variables
load_dotenv()
# Results storage using local ChromaDB
9 months ago
class ChromaDB(BaseVectorDatabase):
"""
ChromaDB database
Args:
11 months ago
metric (str): The similarity metric to use.
output (str): The name of the collection to store the results in.
limit_tokens (int, optional): The maximum number of tokens to use for the query. Defaults to 1000.
n_results (int, optional): The number of results to retrieve. Defaults to 2.
Methods:
add: _description_
query: _description_
Examples:
>>> chromadb = ChromaDB(
>>> metric="cosine",
11 months ago
>>> output="results",
>>> llm="gpt3",
>>> openai_api_key=OPENAI_API_KEY,
>>> )
>>> chromadb.add(task, result, result_id)
"""
def __init__(
self,
metric: str = "cosine",
output_dir: str = "swarms",
limit_tokens: Optional[int] = 1000,
n_results: int = 1,
docs_folder: str = None,
verbose: bool = False,
1 year ago
*args,
**kwargs,
):
self.metric = metric
11 months ago
self.output_dir = output_dir
self.limit_tokens = limit_tokens
11 months ago
self.n_results = n_results
self.docs_folder = docs_folder
self.verbose = verbose
# Disable ChromaDB logging
if verbose:
logging.getLogger("chromadb").setLevel(logging.INFO)
11 months ago
# Create Chroma collection
chroma_persist_dir = "chroma"
chroma_client = chromadb.PersistentClient(
settings=chromadb.config.Settings(
persist_directory=chroma_persist_dir,
11 months ago
),
*args,
**kwargs,
)
11 months ago
# Create ChromaDB client
self.client = chromadb.Client()
# Create Chroma collection
self.collection = chroma_client.get_or_create_collection(
11 months ago
name=output_dir,
metadata={"hnsw:space": metric},
11 months ago
*args,
**kwargs,
)
display_markdown_message(
"ChromaDB collection created:"
f" {self.collection.name} with metric: {self.metric} and"
f" output directory: {self.output_dir}"
)
# If docs
if docs_folder:
display_markdown_message(
f"Traversing directory: {docs_folder}"
)
self.traverse_directory()
def add(
11 months ago
self,
document: str,
*args,
**kwargs,
):
11 months ago
"""
Add a document to the ChromaDB collection.
Args:
11 months ago
document (str): The document to be added.
condition (bool, optional): The condition to check before adding the document. Defaults to True.
11 months ago
Returns:
str: The ID of the added document.
"""
try:
11 months ago
doc_id = str(uuid.uuid4())
self.collection.add(
ids=[doc_id],
documents=[document],
*args,
**kwargs,
)
print("-----------------")
print("Document added successfully")
print("-----------------")
11 months ago
return doc_id
except Exception as e:
raise Exception(f"Failed to add document: {str(e)}")
11 months ago
def query(
self,
query_text: str,
*args,
**kwargs,
) -> str:
11 months ago
"""
Query documents from the ChromaDB collection.
Args:
11 months ago
query (str): The query string.
n_docs (int, optional): The number of documents to retrieve. Defaults to 1.
Returns:
11 months ago
dict: The retrieved documents.
"""
try:
logging.info(f"Querying documents for: {query_text}")
11 months ago
docs = self.collection.query(
query_texts=[query_text],
n_results=self.n_results,
1 year ago
*args,
**kwargs,
11 months ago
)["documents"]
# Convert into a string
out = ""
for doc in docs:
out += f"{doc}\n"
# Display the retrieved document
display_markdown_message(f"Query: {query_text}")
display_markdown_message(f"Retrieved Document: {out}")
return out
11 months ago
except Exception as e:
raise Exception(f"Failed to query documents: {str(e)}")
def traverse_directory(self):
"""
Traverse through every file in the given directory and its subdirectories,
and return the paths of all files.
Parameters:
- directory_name (str): The name of the directory to traverse.
Returns:
- list: A list of paths to each file in the directory and its subdirectories.
"""
added_to_db = False
for root, dirs, files in os.walk(self.docs_folder):
for file in files:
file_path = os.path.join(root, file) # Change this line
_, ext = os.path.splitext(file_path)
data = data_to_text(file_path)
7 months ago
added_to_db = self.add(str(data))
print(f"{file_path} added to Database")
return added_to_db