[FEAT][Chroma]

pull/378/merge
Kye 11 months ago
parent 075f6320e1
commit 593b9b104e

@ -0,0 +1,39 @@
import os
from dotenv import load_dotenv
# Import the OpenAIChat model and the Agent struct
from swarms import Agent, OpenAIChat, ChromaDB
# Load the environment variables
load_dotenv()
# Get the API key from the environment
api_key = os.environ.get("OPENAI_API_KEY")
# Initilaize the chromadb client
chromadb = ChromaDB(
metric="cosine",
output="results",
)
# Initialize the language model
llm = OpenAIChat(
temperature=0.5,
model_name="gpt-4",
openai_api_key=api_key,
max_tokens=1000,
)
## Initialize the workflow
agent = Agent(
llm=llm,
max_loops=4,
autosave=True,
dashboard=True,
long_term_memory=ChromaDB(),
)
# Run the workflow on a task
agent.run("Generate a 10,000 word blog on health and wellness.")

@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "swarms"
version = "4.0.3"
version = "4.0.4"
description = "Swarms - Pytorch"
license = "MIT"
authors = ["Kye Gomez <kye@apac.ai>"]

@ -5,6 +5,7 @@ from swarms.memory.sqlite import SQLiteDB
from swarms.memory.weaviate_db import WeaviateDB
from swarms.memory.visual_memory import VisualShortTermMemory
from swarms.memory.action_subtask import ActionSubtaskEntry
from swarms.memory.chroma_db import ChromaDB
__all__ = [
"AbstractVectorDatabase",
@ -14,4 +15,5 @@ __all__ = [
"WeaviateDB",
"VisualShortTermMemory",
"ActionSubtaskEntry",
"ChromaDB",
]

@ -1,30 +1,19 @@
import numpy as np
import logging
import os
from typing import Dict, List, Optional
import uuid
from typing import Optional, Callable, List
import chromadb
import tiktoken as tiktoken
from chromadb.config import Settings
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
from dotenv import load_dotenv
from termcolor import colored
from chromadb.utils.data_loaders import ImageLoader
from chromadb.utils.embedding_functions import (
OpenCLIPEmbeddingFunction,
)
from swarms.utils.token_count_tiktoken import limit_tokens_from_string
# Load environment variables
load_dotenv()
# ChromaDB settings
client = chromadb.Client(Settings(anonymized_telemetry=False))
# ChromaDB client
def get_chromadb_client():
return client
# OpenAI API key
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
# Results storage using local ChromaDB
class ChromaDB:
@ -33,10 +22,10 @@ class ChromaDB:
ChromaDB database
Args:
metric (str): _description_
RESULTS_STORE_NAME (str): _description_
LLM_MODEL (str): _description_
openai_api_key (str): _description_
metric (str): The similarity metric to use.
output (str): The name of the collection to store the results in.
limit_tokens (int, optional): The maximum number of tokens to use for the query. Defaults to 1000.
n_results (int, optional): The number of results to retrieve. Defaults to 2.
Methods:
add: _description_
@ -45,135 +34,129 @@ class ChromaDB:
Examples:
>>> chromadb = ChromaDB(
>>> metric="cosine",
>>> RESULTS_STORE_NAME="results",
>>> LLM_MODEL="gpt3",
>>> output="results",
>>> llm="gpt3",
>>> openai_api_key=OPENAI_API_KEY,
>>> )
>>> chromadb.add(task, result, result_id)
>>> chromadb.query(query, top_results_num)
"""
def __init__(
self,
metric: str,
RESULTS_STORE_NAME: str,
LLM_MODEL: str,
openai_api_key: str = OPENAI_API_KEY,
top_results_num: int = 3,
output_dir: str,
limit_tokens: Optional[int] = 1000,
n_results: int = 2,
embedding_function: Callable = None,
data_loader: Callable = None,
multimodal: bool = False,
*args,
**kwargs,
):
self.metric = metric
self.RESULTS_STORE_NAME = RESULTS_STORE_NAME
self.LLM_MODEL = LLM_MODEL
self.openai_api_key = openai_api_key
self.top_results_num = top_results_num
self.output_dir = output_dir
self.limit_tokens = limit_tokens
self.n_results = n_results
# Disable ChromaDB logging
logging.getLogger("chromadb").setLevel(logging.ERROR)
logging.getLogger("chromadb").setLevel(logging.INFO)
# Create Chroma collection
chroma_persist_dir = "chroma"
chroma_client = chromadb.PersistentClient(
settings=chromadb.config.Settings(
persist_directory=chroma_persist_dir,
)
),
*args,
**kwargs,
)
# Create embedding function
embedding_function = OpenAIEmbeddingFunction(
api_key=openai_api_key
)
# Data loader
if data_loader:
self.data_loader = data_loader
else:
self.data_loader = ImageLoader()
# Embedding model
if embedding_function:
self.embedding_function = embedding_function
else:
self.embedding_function = None
# If multimodal set the embedding model to OpenCLIP
if multimodal:
self.embedding_function = OpenCLIPEmbeddingFunction()
# Create ChromaDB client
self.client = chromadb.Client()
# Create Chroma collection
self.collection = chroma_client.get_or_create_collection(
name=RESULTS_STORE_NAME,
name=output_dir,
metadata={"hnsw:space": metric},
embedding_function=embedding_function,
embedding_function=self.embedding_function,
data_loader=self.data_loader,
*args,
**kwargs,
)
def add(
self, task: Dict, result: str, result_id: str, *args, **kwargs
self,
document: str,
images: List[np.ndarray] = None,
img_urls: List[str] = None,
*args,
**kwargs,
):
"""Adds a result to the ChromaDB collection
"""
Add a document to the ChromaDB collection.
Args:
task (Dict): _description_
result (str): _description_
result_id (str): _description_
"""
document (str): The document to be added.
condition (bool, optional): The condition to check before adding the document. Defaults to True.
Returns:
str: The ID of the added document.
"""
try:
# Embed the result
embeddings = (
self.collection.embedding_function.embed([result])[0]
.tolist()
.copy()
)
# If the result is a list, flatten it
if (
len(
self.collection.get(ids=[result_id], include=[])[
"ids"
]
)
> 0
): # Check if the result already exists
self.collection.update(
ids=result_id,
embeddings=embeddings,
documents=result,
metadatas={
"task": task["task_name"],
"result": result,
},
)
# If the result is not a list, add it
else:
doc_id = str(uuid.uuid4())
self.collection.add(
ids=result_id,
embeddings=embeddings,
documents=result,
metadatas={
"task": task["task_name"],
"result": result,
},
ids=[doc_id],
documents=[document],
images=images,
uris=img_urls,
*args,
**kwargs,
)
except Exception as error:
print(
colored(f"Error adding to ChromaDB: {error}", "red")
)
return doc_id
except Exception as e:
raise Exception(f"Failed to add document: {str(e)}")
def query(self, query: str, *args, **kwargs) -> List[dict]:
"""Queries the ChromaDB collection with a query for the top results
def query(
self,
query_text: str,
query_images: List[np.ndarray],
*args,
**kwargs,
):
"""
Query documents from the ChromaDB collection.
Args:
query (str): _description_
top_results_num (int): _description_
query (str): The query string.
n_docs (int, optional): The number of documents to retrieve. Defaults to 1.
Returns:
List[dict]: _description_
dict: The retrieved documents.
"""
try:
count: int = self.collection.count()
if count == 0:
return []
results = self.collection.query(
query_texts=query,
n_results=min(self.top_results_num, count),
include=["metadatas"],
docs = self.collection.query(
query_texts=[query_text],
query_images=query_images,
n_results=self.n_docs,
*args,
**kwargs,
)
out = [item["task"] for item in results["metadatas"][0]]
out = limit_tokens_from_string(
out, "gpt-4", self.limit_tokens
)
return out
except Exception as error:
print(colored(f"Error querying ChromaDB: {error}", "red"))
)["documents"]
return docs[0]
except Exception as e:
raise Exception(f"Failed to query documents: {str(e)}")

@ -1,69 +1,67 @@
import functools
import os
import logging
from dotenv import load_dotenv
from posthog import Posthog
from swarms.telemetry.user_utils import generate_unique_identifier
# Load environment variables
load_dotenv()
logger = logging.getLogger(__name__)
# # Initialize Posthog client
api_key = os.getenv("POSTHOG_API_KEY") or None
host = os.getenv("POSTHOG_HOST") or None
posthog = Posthog(api_key, host=host)
posthog.debug = True
# return posthog
def log_activity_posthog(event_name: str, **event_properties):
"""Log activity to Posthog.
Args:
event_name (str): Name of the event to log.
**event_properties: Properties of the event to log.
Examples:
>>> from swarms.telemetry.posthog_utils import log_activity_posthog
>>> @log_activity_posthog("test_event", test_property="test_value")
... def test_function():
... print("Hello, world!")
>>> test_function()
Hello, world!
>>> # Check Posthog dashboard for event "test_event" with property
>>> # "test_property" set to "test_value".
"""
def decorator_log_activity(func):
@functools.wraps(func)
def wrapper_log_activity(*args, **kwargs):
result = func(*args, **kwargs)
class PosthogWrapper:
def __init__(
self, api_key, instance_address, debug=False, disabled=False
):
self.posthog = Posthog(api_key, host=instance_address)
self.posthog.debug = debug
self.posthog.disabled = disabled
# Assuming you have a way to get the user id
distinct_user_id = generate_unique_identifier()
def capture_event(self, distinct_id, event_name, properties=None):
self.posthog.capture(distinct_id, event_name, properties)
# Capture the event
posthog.capture(
distinct_user_id, event_name, event_properties
def capture_pageview(self, distinct_id, url):
self.posthog.capture(
distinct_id, "$pageview", {"$current_url": url}
)
return result
return wrapper_log_activity
def set_user_properties(
self, distinct_id, event_name, properties
):
self.posthog.capture(
distinct_id, event=event_name, properties=properties
)
return decorator_log_activity
def is_feature_enabled(
self, flag_key, distinct_id, send_feature_flag_events=True
):
return self.posthog.feature_enabled(
flag_key, distinct_id, send_feature_flag_events
)
def get_feature_flag_payload(self, flag_key, distinct_id):
return self.posthog.get_feature_flag_payload(
flag_key, distinct_id
)
# @log_activity_posthog(
# "function_executed", function_name="my_function"
# )
# def my_function():
# # Function logic here
# return "Function executed successfully!"
def get_feature_flag(self, flag_key, distinct_id):
return self.posthog.get_feature_flag(flag_key, distinct_id)
def capture_with_feature_flag(
self, distinct_id, event_name, flag_key, variant_key
):
self.posthog.capture(
distinct_id,
event_name,
{"$feature/{}".format(flag_key): variant_key},
)
# out = my_function()
# print(out)
def capture_with_feature_flags(
self, distinct_id, event_name, send_feature_flags=True
):
self.posthog.capture(
distinct_id,
event_name,
send_feature_flags=send_feature_flags,
)

Loading…
Cancel
Save