From 593b9b104efc21451625f2055ffeb0ef2e01493a Mon Sep 17 00:00:00 2001 From: Kye Date: Sun, 4 Feb 2024 08:34:37 -0800 Subject: [PATCH] [FEAT][Chroma] --- playground/structs/agent_with_longterm.py | 39 +++++ pyproject.toml | 2 +- swarms/memory/__init__.py | 2 + swarms/memory/chroma_db.py | 203 ++++++++++------------ swarms/telemetry/posthog_utils.py | 122 +++++++------ 5 files changed, 195 insertions(+), 173 deletions(-) create mode 100644 playground/structs/agent_with_longterm.py diff --git a/playground/structs/agent_with_longterm.py b/playground/structs/agent_with_longterm.py new file mode 100644 index 00000000..e803d095 --- /dev/null +++ b/playground/structs/agent_with_longterm.py @@ -0,0 +1,39 @@ +import os + +from dotenv import load_dotenv + +# Import the OpenAIChat model and the Agent struct +from swarms import Agent, OpenAIChat, ChromaDB + +# Load the environment variables +load_dotenv() + +# Get the API key from the environment +api_key = os.environ.get("OPENAI_API_KEY") + + +# Initilaize the chromadb client +chromadb = ChromaDB( + metric="cosine", + output="results", +) + +# Initialize the language model +llm = OpenAIChat( + temperature=0.5, + model_name="gpt-4", + openai_api_key=api_key, + max_tokens=1000, +) + +## Initialize the workflow +agent = Agent( + llm=llm, + max_loops=4, + autosave=True, + dashboard=True, + long_term_memory=ChromaDB(), +) + +# Run the workflow on a task +agent.run("Generate a 10,000 word blog on health and wellness.") diff --git a/pyproject.toml b/pyproject.toml index 2a89e00f..9e81bf20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "swarms" -version = "4.0.3" +version = "4.0.4" description = "Swarms - Pytorch" license = "MIT" authors = ["Kye Gomez "] diff --git a/swarms/memory/__init__.py b/swarms/memory/__init__.py index 4d58ea5a..2dca8172 100644 --- a/swarms/memory/__init__.py +++ b/swarms/memory/__init__.py @@ -5,6 +5,7 @@ from swarms.memory.sqlite import SQLiteDB from swarms.memory.weaviate_db import WeaviateDB from swarms.memory.visual_memory import VisualShortTermMemory from swarms.memory.action_subtask import ActionSubtaskEntry +from swarms.memory.chroma_db import ChromaDB __all__ = [ "AbstractVectorDatabase", @@ -14,4 +15,5 @@ __all__ = [ "WeaviateDB", "VisualShortTermMemory", "ActionSubtaskEntry", + "ChromaDB", ] diff --git a/swarms/memory/chroma_db.py b/swarms/memory/chroma_db.py index 8e200974..3d355b4f 100644 --- a/swarms/memory/chroma_db.py +++ b/swarms/memory/chroma_db.py @@ -1,30 +1,19 @@ +import numpy as np import logging -import os -from typing import Dict, List, Optional +import uuid +from typing import Optional, Callable, List import chromadb -import tiktoken as tiktoken -from chromadb.config import Settings -from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction from dotenv import load_dotenv -from termcolor import colored +from chromadb.utils.data_loaders import ImageLoader +from chromadb.utils.embedding_functions import ( + OpenCLIPEmbeddingFunction, +) -from swarms.utils.token_count_tiktoken import limit_tokens_from_string +# Load environment variables load_dotenv() -# ChromaDB settings -client = chromadb.Client(Settings(anonymized_telemetry=False)) - - -# ChromaDB client -def get_chromadb_client(): - return client - - -# OpenAI API key -OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") - # Results storage using local ChromaDB class ChromaDB: @@ -33,10 +22,10 @@ class ChromaDB: ChromaDB database Args: - metric (str): _description_ - RESULTS_STORE_NAME (str): _description_ - LLM_MODEL (str): _description_ - openai_api_key (str): _description_ + metric (str): The similarity metric to use. + output (str): The name of the collection to store the results in. + limit_tokens (int, optional): The maximum number of tokens to use for the query. Defaults to 1000. + n_results (int, optional): The number of results to retrieve. Defaults to 2. Methods: add: _description_ @@ -45,135 +34,129 @@ class ChromaDB: Examples: >>> chromadb = ChromaDB( >>> metric="cosine", - >>> RESULTS_STORE_NAME="results", - >>> LLM_MODEL="gpt3", + >>> output="results", + >>> llm="gpt3", >>> openai_api_key=OPENAI_API_KEY, >>> ) >>> chromadb.add(task, result, result_id) - >>> chromadb.query(query, top_results_num) """ def __init__( self, metric: str, - RESULTS_STORE_NAME: str, - LLM_MODEL: str, - openai_api_key: str = OPENAI_API_KEY, - top_results_num: int = 3, + output_dir: str, limit_tokens: Optional[int] = 1000, + n_results: int = 2, + embedding_function: Callable = None, + data_loader: Callable = None, + multimodal: bool = False, *args, **kwargs, ): self.metric = metric - self.RESULTS_STORE_NAME = RESULTS_STORE_NAME - self.LLM_MODEL = LLM_MODEL - self.openai_api_key = openai_api_key - self.top_results_num = top_results_num + self.output_dir = output_dir self.limit_tokens = limit_tokens + self.n_results = n_results # Disable ChromaDB logging - logging.getLogger("chromadb").setLevel(logging.ERROR) + logging.getLogger("chromadb").setLevel(logging.INFO) + # Create Chroma collection chroma_persist_dir = "chroma" chroma_client = chromadb.PersistentClient( settings=chromadb.config.Settings( persist_directory=chroma_persist_dir, - ) + ), + *args, + **kwargs, ) - # Create embedding function - embedding_function = OpenAIEmbeddingFunction( - api_key=openai_api_key - ) + # Data loader + if data_loader: + self.data_loader = data_loader + else: + self.data_loader = ImageLoader() + + # Embedding model + if embedding_function: + self.embedding_function = embedding_function + else: + self.embedding_function = None + + # If multimodal set the embedding model to OpenCLIP + if multimodal: + self.embedding_function = OpenCLIPEmbeddingFunction() + + # Create ChromaDB client + self.client = chromadb.Client() # Create Chroma collection self.collection = chroma_client.get_or_create_collection( - name=RESULTS_STORE_NAME, + name=output_dir, metadata={"hnsw:space": metric}, - embedding_function=embedding_function, + embedding_function=self.embedding_function, + data_loader=self.data_loader, + *args, + **kwargs, ) def add( - self, task: Dict, result: str, result_id: str, *args, **kwargs + self, + document: str, + images: List[np.ndarray] = None, + img_urls: List[str] = None, + *args, + **kwargs, ): - """Adds a result to the ChromaDB collection + """ + Add a document to the ChromaDB collection. Args: - task (Dict): _description_ - result (str): _description_ - result_id (str): _description_ - """ + document (str): The document to be added. + condition (bool, optional): The condition to check before adding the document. Defaults to True. + Returns: + str: The ID of the added document. + """ try: - # Embed the result - embeddings = ( - self.collection.embedding_function.embed([result])[0] - .tolist() - .copy() - ) - - # If the result is a list, flatten it - if ( - len( - self.collection.get(ids=[result_id], include=[])[ - "ids" - ] - ) - > 0 - ): # Check if the result already exists - self.collection.update( - ids=result_id, - embeddings=embeddings, - documents=result, - metadatas={ - "task": task["task_name"], - "result": result, - }, - ) - - # If the result is not a list, add it - else: - self.collection.add( - ids=result_id, - embeddings=embeddings, - documents=result, - metadatas={ - "task": task["task_name"], - "result": result, - }, - *args, - **kwargs, - ) - except Exception as error: - print( - colored(f"Error adding to ChromaDB: {error}", "red") + doc_id = str(uuid.uuid4()) + self.collection.add( + ids=[doc_id], + documents=[document], + images=images, + uris=img_urls, + *args, + **kwargs, ) + return doc_id + except Exception as e: + raise Exception(f"Failed to add document: {str(e)}") - def query(self, query: str, *args, **kwargs) -> List[dict]: - """Queries the ChromaDB collection with a query for the top results + def query( + self, + query_text: str, + query_images: List[np.ndarray], + *args, + **kwargs, + ): + """ + Query documents from the ChromaDB collection. Args: - query (str): _description_ - top_results_num (int): _description_ + query (str): The query string. + n_docs (int, optional): The number of documents to retrieve. Defaults to 1. Returns: - List[dict]: _description_ + dict: The retrieved documents. """ try: - count: int = self.collection.count() - if count == 0: - return [] - results = self.collection.query( - query_texts=query, - n_results=min(self.top_results_num, count), - include=["metadatas"], + docs = self.collection.query( + query_texts=[query_text], + query_images=query_images, + n_results=self.n_docs, *args, **kwargs, - ) - out = [item["task"] for item in results["metadatas"][0]] - out = limit_tokens_from_string( - out, "gpt-4", self.limit_tokens - ) - return out - except Exception as error: - print(colored(f"Error querying ChromaDB: {error}", "red")) + )["documents"] + return docs[0] + except Exception as e: + raise Exception(f"Failed to query documents: {str(e)}") diff --git a/swarms/telemetry/posthog_utils.py b/swarms/telemetry/posthog_utils.py index a6a520b5..66bd0e6c 100644 --- a/swarms/telemetry/posthog_utils.py +++ b/swarms/telemetry/posthog_utils.py @@ -1,69 +1,67 @@ -import functools -import os +import logging from dotenv import load_dotenv from posthog import Posthog -from swarms.telemetry.user_utils import generate_unique_identifier + # Load environment variables load_dotenv() - -# # Initialize Posthog client -api_key = os.getenv("POSTHOG_API_KEY") or None -host = os.getenv("POSTHOG_HOST") or None -posthog = Posthog(api_key, host=host) -posthog.debug = True - -# return posthog - - -def log_activity_posthog(event_name: str, **event_properties): - """Log activity to Posthog. - - - Args: - event_name (str): Name of the event to log. - **event_properties: Properties of the event to log. - - Examples: - >>> from swarms.telemetry.posthog_utils import log_activity_posthog - >>> @log_activity_posthog("test_event", test_property="test_value") - ... def test_function(): - ... print("Hello, world!") - >>> test_function() - Hello, world! - >>> # Check Posthog dashboard for event "test_event" with property - >>> # "test_property" set to "test_value". - """ - - def decorator_log_activity(func): - @functools.wraps(func) - def wrapper_log_activity(*args, **kwargs): - result = func(*args, **kwargs) - - # Assuming you have a way to get the user id - distinct_user_id = generate_unique_identifier() - - # Capture the event - posthog.capture( - distinct_user_id, event_name, event_properties - ) - - return result - - return wrapper_log_activity - - return decorator_log_activity - - -# @log_activity_posthog( -# "function_executed", function_name="my_function" -# ) -# def my_function(): -# # Function logic here -# return "Function executed successfully!" - - -# out = my_function() -# print(out) +logger = logging.getLogger(__name__) + + +class PosthogWrapper: + def __init__( + self, api_key, instance_address, debug=False, disabled=False + ): + self.posthog = Posthog(api_key, host=instance_address) + self.posthog.debug = debug + self.posthog.disabled = disabled + + def capture_event(self, distinct_id, event_name, properties=None): + self.posthog.capture(distinct_id, event_name, properties) + + def capture_pageview(self, distinct_id, url): + self.posthog.capture( + distinct_id, "$pageview", {"$current_url": url} + ) + + def set_user_properties( + self, distinct_id, event_name, properties + ): + self.posthog.capture( + distinct_id, event=event_name, properties=properties + ) + + def is_feature_enabled( + self, flag_key, distinct_id, send_feature_flag_events=True + ): + return self.posthog.feature_enabled( + flag_key, distinct_id, send_feature_flag_events + ) + + def get_feature_flag_payload(self, flag_key, distinct_id): + return self.posthog.get_feature_flag_payload( + flag_key, distinct_id + ) + + def get_feature_flag(self, flag_key, distinct_id): + return self.posthog.get_feature_flag(flag_key, distinct_id) + + def capture_with_feature_flag( + self, distinct_id, event_name, flag_key, variant_key + ): + self.posthog.capture( + distinct_id, + event_name, + {"$feature/{}".format(flag_key): variant_key}, + ) + + def capture_with_feature_flags( + self, distinct_id, event_name, send_feature_flags=True + ): + self.posthog.capture( + distinct_id, + event_name, + send_feature_flags=send_feature_flags, + )