parent
02cecfc281
commit
498ff905b0
@ -0,0 +1,28 @@
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from swarms.models.revgptV4 import RevChatGPTModel
|
||||
from swarms.workers.worker import Worker
|
||||
|
||||
load_dotenv()
|
||||
|
||||
config = {
|
||||
"model": os.getenv("REVGPT_MODEL"),
|
||||
"plugin_ids": [os.getenv("REVGPT_PLUGIN_IDS")],
|
||||
"disable_history": os.getenv("REVGPT_DISABLE_HISTORY") == "True",
|
||||
"PUID": os.getenv("REVGPT_PUID"),
|
||||
"unverified_plugin_domains": [
|
||||
os.getenv("REVGPT_UNVERIFIED_PLUGIN_DOMAINS")
|
||||
],
|
||||
}
|
||||
|
||||
llm = RevChatGPTModel(access_token=os.getenv("ACCESS_TOKEN"), **config)
|
||||
|
||||
worker = Worker(ai_name="Optimus Prime", llm=llm)
|
||||
|
||||
task = (
|
||||
"What were the winning boston marathon times for the past 5 years (ending"
|
||||
" in 2022)? Generate a table of the year, name, country of origin, and"
|
||||
" times."
|
||||
)
|
||||
response = worker.run(task)
|
||||
print(response)
|
@ -0,0 +1,120 @@
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from swarms.models import Anthropic, OpenAIChat
|
||||
from swarms.prompts.accountant_swarm_prompts import (
|
||||
DECISION_MAKING_PROMPT,
|
||||
DOC_ANALYZER_AGENT_PROMPT,
|
||||
FRAUD_DETECTION_AGENT_PROMPT,
|
||||
SUMMARY_GENERATOR_AGENT_PROMPT,
|
||||
)
|
||||
from swarms.structs import Flow
|
||||
from swarms.utils.pdf_to_text import pdf_to_text
|
||||
|
||||
# Environment variables
|
||||
load_dotenv()
|
||||
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
|
||||
# Base llms
|
||||
llm1 = OpenAIChat(
|
||||
openai_api_key=openai_api_key,
|
||||
)
|
||||
|
||||
llm2 = Anthropic(
|
||||
anthropic_api_key=anthropic_api_key,
|
||||
)
|
||||
|
||||
|
||||
# Agents
|
||||
doc_analyzer_agent = Flow(
|
||||
llm=llm1,
|
||||
sop=DOC_ANALYZER_AGENT_PROMPT,
|
||||
)
|
||||
summary_generator_agent = Flow(
|
||||
llm=llm2,
|
||||
sop=SUMMARY_GENERATOR_AGENT_PROMPT,
|
||||
)
|
||||
decision_making_support_agent = Flow(
|
||||
llm=llm2,
|
||||
sop=DECISION_MAKING_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
class AccountantSwarms:
|
||||
"""
|
||||
Accountant Swarms is a collection of agents that work together to help
|
||||
accountants with their work.
|
||||
|
||||
Flow: analyze doc -> detect fraud -> generate summary -> decision making support
|
||||
|
||||
The agents are:
|
||||
- User Consultant: Asks the user many questions
|
||||
- Document Analyzer: Extracts text from the image of the financial document
|
||||
- Fraud Detection: Detects fraud in the document
|
||||
- Summary Agent: Generates an actionable summary of the document
|
||||
- Decision Making Support: Provides decision making support to the accountant
|
||||
|
||||
The agents are connected together in a workflow that is defined in the
|
||||
run method.
|
||||
|
||||
The workflow is as follows:
|
||||
1. The Document Analyzer agent extracts text from the image of the
|
||||
financial document.
|
||||
2. The Fraud Detection agent detects fraud in the document.
|
||||
3. The Summary Agent generates an actionable summary of the document.
|
||||
4. The Decision Making Support agent provides decision making support
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pdf_path: str,
|
||||
list_pdfs: List[str] = None,
|
||||
fraud_detection_instructions: str = None,
|
||||
summary_agent_instructions: str = None,
|
||||
decision_making_support_agent_instructions: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.pdf_path = pdf_path
|
||||
self.list_pdfs = list_pdfs
|
||||
self.fraud_detection_instructions = fraud_detection_instructions
|
||||
self.summary_agent_instructions = summary_agent_instructions
|
||||
self.decision_making_support_agent_instructions = (
|
||||
decision_making_support_agent_instructions
|
||||
)
|
||||
|
||||
def run(self):
|
||||
# Transform the pdf to text
|
||||
pdf_text = pdf_to_text(self.pdf_path)
|
||||
|
||||
# Detect fraud in the document
|
||||
fraud_detection_agent_output = doc_analyzer_agent.run(
|
||||
f"{self.fraud_detection_instructions}: {pdf_text}"
|
||||
)
|
||||
|
||||
# Generate an actionable summary of the document
|
||||
summary_agent_output = summary_generator_agent.run(
|
||||
f"{self.summary_agent_instructions}: {fraud_detection_agent_output}"
|
||||
)
|
||||
|
||||
# Provide decision making support to the accountant
|
||||
decision_making_support_agent_output = decision_making_support_agent.run(
|
||||
f"{self.decision_making_support_agent_instructions}:"
|
||||
f" {summary_agent_output}"
|
||||
)
|
||||
|
||||
return decision_making_support_agent_output
|
||||
|
||||
|
||||
swarm = AccountantSwarms(
|
||||
pdf_path="tesla.pdf",
|
||||
fraud_detection_instructions="Detect fraud in the document",
|
||||
summary_agent_instructions="Generate an actionable summary of the document",
|
||||
decision_making_support_agent_instructions=(
|
||||
"Provide decision making support to the business owner:"
|
||||
),
|
||||
)
|
@ -0,0 +1,65 @@
|
||||
from swarms.models import Anthropic
|
||||
from swarms.structs import Flow
|
||||
from swarms.tools.tool import tool
|
||||
|
||||
import asyncio
|
||||
|
||||
|
||||
llm = Anthropic(
|
||||
anthropic_api_key="",
|
||||
)
|
||||
|
||||
|
||||
async def async_load_playwright(url: str) -> str:
|
||||
"""Load the specified URLs using Playwright and parse using BeautifulSoup."""
|
||||
from bs4 import BeautifulSoup
|
||||
from playwright.async_api import async_playwright
|
||||
|
||||
results = ""
|
||||
async with async_playwright() as p:
|
||||
browser = await p.chromium.launch(headless=True)
|
||||
try:
|
||||
page = await browser.new_page()
|
||||
await page.goto(url)
|
||||
|
||||
page_source = await page.content()
|
||||
soup = BeautifulSoup(page_source, "html.parser")
|
||||
|
||||
for script in soup(["script", "style"]):
|
||||
script.extract()
|
||||
|
||||
text = soup.get_text()
|
||||
lines = (line.strip() for line in text.splitlines())
|
||||
chunks = (
|
||||
phrase.strip() for line in lines for phrase in line.split(" ")
|
||||
)
|
||||
results = "\n".join(chunk for chunk in chunks if chunk)
|
||||
except Exception as e:
|
||||
results = f"Error: {e}"
|
||||
await browser.close()
|
||||
return results
|
||||
|
||||
|
||||
def run_async(coro):
|
||||
event_loop = asyncio.get_event_loop()
|
||||
return event_loop.run_until_complete(coro)
|
||||
|
||||
|
||||
@tool
|
||||
def browse_web_page(url: str) -> str:
|
||||
"""Verbose way to scrape a whole webpage. Likely to cause issues parsing."""
|
||||
return run_async(async_load_playwright(url))
|
||||
|
||||
|
||||
## Initialize the workflow
|
||||
flow = Flow(
|
||||
llm=llm,
|
||||
max_loops=5,
|
||||
tools=[browse_web_page],
|
||||
dashboard=True,
|
||||
)
|
||||
|
||||
out = flow.run(
|
||||
"Generate a 10,000 word blog on mental clarity and the benefits of"
|
||||
" meditation."
|
||||
)
|
@ -0,0 +1,123 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent import futures
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Any
|
||||
from attr import define, field, Factory
|
||||
from swarms.utils.futures import execute_futures_dict
|
||||
from griptape.artifacts import TextArtifact
|
||||
|
||||
|
||||
@define
|
||||
class BaseVectorStore(ABC):
|
||||
""" """
|
||||
|
||||
DEFAULT_QUERY_COUNT = 5
|
||||
|
||||
@dataclass
|
||||
class QueryResult:
|
||||
id: str
|
||||
vector: list[float]
|
||||
score: float
|
||||
meta: Optional[dict] = None
|
||||
namespace: Optional[str] = None
|
||||
|
||||
@dataclass
|
||||
class Entry:
|
||||
id: str
|
||||
vector: list[float]
|
||||
meta: Optional[dict] = None
|
||||
namespace: Optional[str] = None
|
||||
|
||||
embedding_driver: Any
|
||||
futures_executor: futures.Executor = field(
|
||||
default=Factory(lambda: futures.ThreadPoolExecutor()), kw_only=True
|
||||
)
|
||||
|
||||
def upsert_text_artifacts(
|
||||
self,
|
||||
artifacts: dict[str, list[TextArtifact]],
|
||||
meta: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
execute_futures_dict(
|
||||
{
|
||||
namespace: self.futures_executor.submit(
|
||||
self.upsert_text_artifact, a, namespace, meta, **kwargs
|
||||
)
|
||||
for namespace, artifact_list in artifacts.items()
|
||||
for a in artifact_list
|
||||
}
|
||||
)
|
||||
|
||||
def upsert_text_artifact(
|
||||
self,
|
||||
artifact: TextArtifact,
|
||||
namespace: Optional[str] = None,
|
||||
meta: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
if not meta:
|
||||
meta = {}
|
||||
|
||||
meta["artifact"] = artifact.to_json()
|
||||
|
||||
if artifact.embedding:
|
||||
vector = artifact.embedding
|
||||
else:
|
||||
vector = artifact.generate_embedding(self.embedding_driver)
|
||||
|
||||
return self.upsert_vector(
|
||||
vector,
|
||||
vector_id=artifact.id,
|
||||
namespace=namespace,
|
||||
meta=meta,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def upsert_text(
|
||||
self,
|
||||
string: str,
|
||||
vector_id: Optional[str] = None,
|
||||
namespace: Optional[str] = None,
|
||||
meta: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
return self.upsert_vector(
|
||||
self.embedding_driver.embed_string(string),
|
||||
vector_id=vector_id,
|
||||
namespace=namespace,
|
||||
meta=meta if meta else {},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def upsert_vector(
|
||||
self,
|
||||
vector: list[float],
|
||||
vector_id: Optional[str] = None,
|
||||
namespace: Optional[str] = None,
|
||||
meta: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def load_entry(
|
||||
self, vector_id: str, namespace: Optional[str] = None
|
||||
) -> Entry:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def load_entries(self, namespace: Optional[str] = None) -> list[Entry]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def query(
|
||||
self,
|
||||
query: str,
|
||||
count: Optional[int] = None,
|
||||
namespace: Optional[str] = None,
|
||||
include_vectors: bool = False,
|
||||
**kwargs,
|
||||
) -> list[QueryResult]:
|
||||
...
|
@ -0,0 +1,723 @@
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
||||
from swarms.structs.document import Document
|
||||
from swarms.models.embeddings_base import Embeddings
|
||||
from langchain.schema.vectorstore import VectorStore
|
||||
from langchain.utils import xor_args
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import chromadb
|
||||
import chromadb.config
|
||||
from chromadb.api.types import ID, OneOrMany, Where, WhereDocument
|
||||
|
||||
logger = logging.getLogger()
|
||||
DEFAULT_K = 4 # Number of Documents to return.
|
||||
|
||||
|
||||
def _results_to_docs(results: Any) -> List[Document]:
|
||||
return [doc for doc, _ in _results_to_docs_and_scores(results)]
|
||||
|
||||
|
||||
def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
|
||||
return [
|
||||
# TODO: Chroma can do batch querying,
|
||||
# we shouldn't hard code to the 1st result
|
||||
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
|
||||
for result in zip(
|
||||
results["documents"][0],
|
||||
results["metadatas"][0],
|
||||
results["distances"][0],
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class Chroma(VectorStore):
|
||||
"""`ChromaDB` vector store.
|
||||
|
||||
To use, you should have the ``chromadb`` python package installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.vectorstores import Chroma
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
|
||||
embeddings = OpenAIEmbeddings()
|
||||
vectorstore = Chroma("langchain_store", embeddings)
|
||||
"""
|
||||
|
||||
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||
embedding_function: Optional[Embeddings] = None,
|
||||
persist_directory: Optional[str] = None,
|
||||
client_settings: Optional[chromadb.config.Settings] = None,
|
||||
collection_metadata: Optional[Dict] = None,
|
||||
client: Optional[chromadb.Client] = None,
|
||||
relevance_score_fn: Optional[Callable[[float], float]] = None,
|
||||
) -> None:
|
||||
"""Initialize with a Chroma client."""
|
||||
try:
|
||||
import chromadb
|
||||
import chromadb.config
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import chromadb python package. "
|
||||
"Please install it with `pip install chromadb`."
|
||||
)
|
||||
|
||||
if client is not None:
|
||||
self._client_settings = client_settings
|
||||
self._client = client
|
||||
self._persist_directory = persist_directory
|
||||
else:
|
||||
if client_settings:
|
||||
# If client_settings is provided with persist_directory specified,
|
||||
# then it is "in-memory and persisting to disk" mode.
|
||||
client_settings.persist_directory = (
|
||||
persist_directory or client_settings.persist_directory
|
||||
)
|
||||
if client_settings.persist_directory is not None:
|
||||
# Maintain backwards compatibility with chromadb < 0.4.0
|
||||
major, minor, _ = chromadb.__version__.split(".")
|
||||
if int(major) == 0 and int(minor) < 4:
|
||||
client_settings.chroma_db_impl = "duckdb+parquet"
|
||||
|
||||
_client_settings = client_settings
|
||||
elif persist_directory:
|
||||
# Maintain backwards compatibility with chromadb < 0.4.0
|
||||
major, minor, _ = chromadb.__version__.split(".")
|
||||
if int(major) == 0 and int(minor) < 4:
|
||||
_client_settings = chromadb.config.Settings(
|
||||
chroma_db_impl="duckdb+parquet",
|
||||
)
|
||||
else:
|
||||
_client_settings = chromadb.config.Settings(
|
||||
is_persistent=True
|
||||
)
|
||||
_client_settings.persist_directory = persist_directory
|
||||
else:
|
||||
_client_settings = chromadb.config.Settings()
|
||||
self._client_settings = _client_settings
|
||||
self._client = chromadb.Client(_client_settings)
|
||||
self._persist_directory = (
|
||||
_client_settings.persist_directory or persist_directory
|
||||
)
|
||||
|
||||
self._embedding_function = embedding_function
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=collection_name,
|
||||
embedding_function=(
|
||||
self._embedding_function.embed_documents
|
||||
if self._embedding_function is not None
|
||||
else None
|
||||
),
|
||||
metadata=collection_metadata,
|
||||
)
|
||||
self.override_relevance_score_fn = relevance_score_fn
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
return self._embedding_function
|
||||
|
||||
@xor_args(("query_texts", "query_embeddings"))
|
||||
def __query_collection(
|
||||
self,
|
||||
query_texts: Optional[List[str]] = None,
|
||||
query_embeddings: Optional[List[List[float]]] = None,
|
||||
n_results: int = 4,
|
||||
where: Optional[Dict[str, str]] = None,
|
||||
where_document: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Query the chroma collection."""
|
||||
try:
|
||||
import chromadb # noqa: F401
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import chromadb python package. "
|
||||
"Please install it with `pip install chromadb`."
|
||||
)
|
||||
return self._collection.query(
|
||||
query_texts=query_texts,
|
||||
query_embeddings=query_embeddings,
|
||||
n_results=n_results,
|
||||
where=where,
|
||||
where_document=where_document,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
Args:
|
||||
texts (Iterable[str]): Texts to add to the vectorstore.
|
||||
metadatas (Optional[List[dict]], optional): Optional list of metadatas.
|
||||
ids (Optional[List[str]], optional): Optional list of IDs.
|
||||
|
||||
Returns:
|
||||
List[str]: List of IDs of the added texts.
|
||||
"""
|
||||
# TODO: Handle the case where the user doesn't provide ids on the Collection
|
||||
if ids is None:
|
||||
ids = [str(uuid.uuid1()) for _ in texts]
|
||||
embeddings = None
|
||||
texts = list(texts)
|
||||
if self._embedding_function is not None:
|
||||
embeddings = self._embedding_function.embed_documents(texts)
|
||||
if metadatas:
|
||||
# fill metadatas with empty dicts if somebody
|
||||
# did not specify metadata for all texts
|
||||
length_diff = len(texts) - len(metadatas)
|
||||
if length_diff:
|
||||
metadatas = metadatas + [{}] * length_diff
|
||||
empty_ids = []
|
||||
non_empty_ids = []
|
||||
for idx, m in enumerate(metadatas):
|
||||
if m:
|
||||
non_empty_ids.append(idx)
|
||||
else:
|
||||
empty_ids.append(idx)
|
||||
if non_empty_ids:
|
||||
metadatas = [metadatas[idx] for idx in non_empty_ids]
|
||||
texts_with_metadatas = [texts[idx] for idx in non_empty_ids]
|
||||
embeddings_with_metadatas = (
|
||||
[embeddings[idx] for idx in non_empty_ids]
|
||||
if embeddings
|
||||
else None
|
||||
)
|
||||
ids_with_metadata = [ids[idx] for idx in non_empty_ids]
|
||||
try:
|
||||
self._collection.upsert(
|
||||
metadatas=metadatas,
|
||||
embeddings=embeddings_with_metadatas,
|
||||
documents=texts_with_metadatas,
|
||||
ids=ids_with_metadata,
|
||||
)
|
||||
except ValueError as e:
|
||||
if "Expected metadata value to be" in str(e):
|
||||
msg = (
|
||||
"Try filtering complex metadata from the document"
|
||||
" using "
|
||||
"langchain.vectorstores.utils.filter_complex_metadata."
|
||||
)
|
||||
raise ValueError(e.args[0] + "\n\n" + msg)
|
||||
else:
|
||||
raise e
|
||||
if empty_ids:
|
||||
texts_without_metadatas = [texts[j] for j in empty_ids]
|
||||
embeddings_without_metadatas = (
|
||||
[embeddings[j] for j in empty_ids] if embeddings else None
|
||||
)
|
||||
ids_without_metadatas = [ids[j] for j in empty_ids]
|
||||
self._collection.upsert(
|
||||
embeddings=embeddings_without_metadatas,
|
||||
documents=texts_without_metadatas,
|
||||
ids=ids_without_metadatas,
|
||||
)
|
||||
else:
|
||||
self._collection.upsert(
|
||||
embeddings=embeddings,
|
||||
documents=texts,
|
||||
ids=ids,
|
||||
)
|
||||
return ids
|
||||
|
||||
def similarity_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = DEFAULT_K,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Run similarity search with Chroma.
|
||||
|
||||
Args:
|
||||
query (str): Query text to search for.
|
||||
k (int): Number of results to return. Defaults to 4.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[Document]: List of documents most similar to the query text.
|
||||
"""
|
||||
docs_and_scores = self.similarity_search_with_score(
|
||||
query, k, filter=filter
|
||||
)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = DEFAULT_K,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
where_document: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to embedding vector.
|
||||
Args:
|
||||
embedding (List[float]): Embedding to look up documents similar to.
|
||||
k (int): Number of Documents to return. Defaults to 4.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
Returns:
|
||||
List of Documents most similar to the query vector.
|
||||
"""
|
||||
results = self.__query_collection(
|
||||
query_embeddings=embedding,
|
||||
n_results=k,
|
||||
where=filter,
|
||||
where_document=where_document,
|
||||
)
|
||||
return _results_to_docs(results)
|
||||
|
||||
def similarity_search_by_vector_with_relevance_scores(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = DEFAULT_K,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
where_document: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""
|
||||
Return docs most similar to embedding vector and similarity score.
|
||||
|
||||
Args:
|
||||
embedding (List[float]): Embedding to look up documents similar to.
|
||||
k (int): Number of Documents to return. Defaults to 4.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[Tuple[Document, float]]: List of documents most similar to
|
||||
the query text and cosine distance in float for each.
|
||||
Lower score represents more similarity.
|
||||
"""
|
||||
results = self.__query_collection(
|
||||
query_embeddings=embedding,
|
||||
n_results=k,
|
||||
where=filter,
|
||||
where_document=where_document,
|
||||
)
|
||||
return _results_to_docs_and_scores(results)
|
||||
|
||||
def similarity_search_with_score(
|
||||
self,
|
||||
query: str,
|
||||
k: int = DEFAULT_K,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
where_document: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Run similarity search with Chroma with distance.
|
||||
|
||||
Args:
|
||||
query (str): Query text to search for.
|
||||
k (int): Number of results to return. Defaults to 4.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[Tuple[Document, float]]: List of documents most similar to
|
||||
the query text and cosine distance in float for each.
|
||||
Lower score represents more similarity.
|
||||
"""
|
||||
if self._embedding_function is None:
|
||||
results = self.__query_collection(
|
||||
query_texts=[query],
|
||||
n_results=k,
|
||||
where=filter,
|
||||
where_document=where_document,
|
||||
)
|
||||
else:
|
||||
query_embedding = self._embedding_function.embed_query(query)
|
||||
results = self.__query_collection(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=k,
|
||||
where=filter,
|
||||
where_document=where_document,
|
||||
)
|
||||
|
||||
return _results_to_docs_and_scores(results)
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
"""
|
||||
The 'correct' relevance function
|
||||
may differ depending on a few things, including:
|
||||
- the distance / similarity metric used by the VectorStore
|
||||
- the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
|
||||
- embedding dimensionality
|
||||
- etc.
|
||||
"""
|
||||
if self.override_relevance_score_fn:
|
||||
return self.override_relevance_score_fn
|
||||
|
||||
distance = "l2"
|
||||
distance_key = "hnsw:space"
|
||||
metadata = self._collection.metadata
|
||||
|
||||
if metadata and distance_key in metadata:
|
||||
distance = metadata[distance_key]
|
||||
|
||||
if distance == "cosine":
|
||||
return self._cosine_relevance_score_fn
|
||||
elif distance == "l2":
|
||||
return self._euclidean_relevance_score_fn
|
||||
elif distance == "ip":
|
||||
return self._max_inner_product_relevance_score_fn
|
||||
else:
|
||||
raise ValueError(
|
||||
"No supported normalization function"
|
||||
f" for distance metric of type: {distance}."
|
||||
"Consider providing relevance_score_fn to Chroma constructor."
|
||||
)
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = DEFAULT_K,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
where_document: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
|
||||
results = self.__query_collection(
|
||||
query_embeddings=embedding,
|
||||
n_results=fetch_k,
|
||||
where=filter,
|
||||
where_document=where_document,
|
||||
include=["metadatas", "documents", "distances", "embeddings"],
|
||||
)
|
||||
mmr_selected = maximal_marginal_relevance(
|
||||
np.array(embedding, dtype=np.float32),
|
||||
results["embeddings"][0],
|
||||
k=k,
|
||||
lambda_mult=lambda_mult,
|
||||
)
|
||||
|
||||
candidates = _results_to_docs(results)
|
||||
|
||||
selected_results = [
|
||||
r for i, r in enumerate(candidates) if i in mmr_selected
|
||||
]
|
||||
return selected_results
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = DEFAULT_K,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
where_document: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
if self._embedding_function is None:
|
||||
raise ValueError(
|
||||
"For MMR search, you must specify an embedding function"
|
||||
" oncreation."
|
||||
)
|
||||
|
||||
embedding = self._embedding_function.embed_query(query)
|
||||
docs = self.max_marginal_relevance_search_by_vector(
|
||||
embedding,
|
||||
k,
|
||||
fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
filter=filter,
|
||||
where_document=where_document,
|
||||
)
|
||||
return docs
|
||||
|
||||
def delete_collection(self) -> None:
|
||||
"""Delete the collection."""
|
||||
self._client.delete_collection(self._collection.name)
|
||||
|
||||
def get(
|
||||
self,
|
||||
ids: Optional[OneOrMany[ID]] = None,
|
||||
where: Optional[Where] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
where_document: Optional[WhereDocument] = None,
|
||||
include: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Gets the collection.
|
||||
|
||||
Args:
|
||||
ids: The ids of the embeddings to get. Optional.
|
||||
where: A Where type dict used to filter results by.
|
||||
E.g. `{"color" : "red", "price": 4.20}`. Optional.
|
||||
limit: The number of documents to return. Optional.
|
||||
offset: The offset to start returning results from.
|
||||
Useful for paging results with limit. Optional.
|
||||
where_document: A WhereDocument type dict used to filter by the documents.
|
||||
E.g. `{$contains: "hello"}`. Optional.
|
||||
include: A list of what to include in the results.
|
||||
Can contain `"embeddings"`, `"metadatas"`, `"documents"`.
|
||||
Ids are always included.
|
||||
Defaults to `["metadatas", "documents"]`. Optional.
|
||||
"""
|
||||
kwargs = {
|
||||
"ids": ids,
|
||||
"where": where,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"where_document": where_document,
|
||||
}
|
||||
|
||||
if include is not None:
|
||||
kwargs["include"] = include
|
||||
|
||||
return self._collection.get(**kwargs)
|
||||
|
||||
def persist(self) -> None:
|
||||
"""Persist the collection.
|
||||
|
||||
This can be used to explicitly persist the data to disk.
|
||||
It will also be called automatically when the object is destroyed.
|
||||
"""
|
||||
if self._persist_directory is None:
|
||||
raise ValueError(
|
||||
"You must specify a persist_directory on"
|
||||
"creation to persist the collection."
|
||||
)
|
||||
import chromadb
|
||||
|
||||
# Maintain backwards compatibility with chromadb < 0.4.0
|
||||
major, minor, _ = chromadb.__version__.split(".")
|
||||
if int(major) == 0 and int(minor) < 4:
|
||||
self._client.persist()
|
||||
|
||||
def update_document(self, document_id: str, document: Document) -> None:
|
||||
"""Update a document in the collection.
|
||||
|
||||
Args:
|
||||
document_id (str): ID of the document to update.
|
||||
document (Document): Document to update.
|
||||
"""
|
||||
return self.update_documents([document_id], [document])
|
||||
|
||||
def update_documents(
|
||||
self, ids: List[str], documents: List[Document]
|
||||
) -> None:
|
||||
"""Update a document in the collection.
|
||||
|
||||
Args:
|
||||
ids (List[str]): List of ids of the document to update.
|
||||
documents (List[Document]): List of documents to update.
|
||||
"""
|
||||
text = [document.page_content for document in documents]
|
||||
metadata = [document.metadata for document in documents]
|
||||
if self._embedding_function is None:
|
||||
raise ValueError(
|
||||
"For update, you must specify an embedding function on"
|
||||
" creation."
|
||||
)
|
||||
embeddings = self._embedding_function.embed_documents(text)
|
||||
|
||||
if hasattr(
|
||||
self._collection._client, "max_batch_size"
|
||||
): # for Chroma 0.4.10 and above
|
||||
from chromadb.utils.batch_utils import create_batches
|
||||
|
||||
for batch in create_batches(
|
||||
api=self._collection._client,
|
||||
ids=ids,
|
||||
metadatas=metadata,
|
||||
documents=text,
|
||||
embeddings=embeddings,
|
||||
):
|
||||
self._collection.update(
|
||||
ids=batch[0],
|
||||
embeddings=batch[1],
|
||||
documents=batch[3],
|
||||
metadatas=batch[2],
|
||||
)
|
||||
else:
|
||||
self._collection.update(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
documents=text,
|
||||
metadatas=metadata,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls: Type[Chroma],
|
||||
texts: List[str],
|
||||
embedding: Optional[Embeddings] = None,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||
persist_directory: Optional[str] = None,
|
||||
client_settings: Optional[chromadb.config.Settings] = None,
|
||||
client: Optional[chromadb.Client] = None,
|
||||
collection_metadata: Optional[Dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> Chroma:
|
||||
"""Create a Chroma vectorstore from a raw documents.
|
||||
|
||||
If a persist_directory is specified, the collection will be persisted there.
|
||||
Otherwise, the data will be ephemeral in-memory.
|
||||
|
||||
Args:
|
||||
texts (List[str]): List of texts to add to the collection.
|
||||
collection_name (str): Name of the collection to create.
|
||||
persist_directory (Optional[str]): Directory to persist the collection.
|
||||
embedding (Optional[Embeddings]): Embedding function. Defaults to None.
|
||||
metadatas (Optional[List[dict]]): List of metadatas. Defaults to None.
|
||||
ids (Optional[List[str]]): List of document IDs. Defaults to None.
|
||||
client_settings (Optional[chromadb.config.Settings]): Chroma client settings
|
||||
collection_metadata (Optional[Dict]): Collection configurations.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
Chroma: Chroma vectorstore.
|
||||
"""
|
||||
chroma_collection = cls(
|
||||
collection_name=collection_name,
|
||||
embedding_function=embedding,
|
||||
persist_directory=persist_directory,
|
||||
client_settings=client_settings,
|
||||
client=client,
|
||||
collection_metadata=collection_metadata,
|
||||
**kwargs,
|
||||
)
|
||||
if ids is None:
|
||||
ids = [str(uuid.uuid1()) for _ in texts]
|
||||
if hasattr(
|
||||
chroma_collection._client, "max_batch_size"
|
||||
): # for Chroma 0.4.10 and above
|
||||
from chromadb.utils.batch_utils import create_batches
|
||||
|
||||
for batch in create_batches(
|
||||
api=chroma_collection._client,
|
||||
ids=ids,
|
||||
metadatas=metadatas,
|
||||
documents=texts,
|
||||
):
|
||||
chroma_collection.add_texts(
|
||||
texts=batch[3] if batch[3] else [],
|
||||
metadatas=batch[2] if batch[2] else None,
|
||||
ids=batch[0],
|
||||
)
|
||||
else:
|
||||
chroma_collection.add_texts(
|
||||
texts=texts, metadatas=metadatas, ids=ids
|
||||
)
|
||||
return chroma_collection
|
||||
|
||||
@classmethod
|
||||
def from_documents(
|
||||
cls: Type[Chroma],
|
||||
documents: List[Document],
|
||||
embedding: Optional[Embeddings] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||
persist_directory: Optional[str] = None,
|
||||
client_settings: Optional[chromadb.config.Settings] = None,
|
||||
client: Optional[chromadb.Client] = None, # Add this line
|
||||
collection_metadata: Optional[Dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> Chroma:
|
||||
"""Create a Chroma vectorstore from a list of documents.
|
||||
|
||||
If a persist_directory is specified, the collection will be persisted there.
|
||||
Otherwise, the data will be ephemeral in-memory.
|
||||
|
||||
Args:
|
||||
collection_name (str): Name of the collection to create.
|
||||
persist_directory (Optional[str]): Directory to persist the collection.
|
||||
ids (Optional[List[str]]): List of document IDs. Defaults to None.
|
||||
documents (List[Document]): List of documents to add to the vectorstore.
|
||||
embedding (Optional[Embeddings]): Embedding function. Defaults to None.
|
||||
client_settings (Optional[chromadb.config.Settings]): Chroma client settings
|
||||
collection_metadata (Optional[Dict]): Collection configurations.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
Chroma: Chroma vectorstore.
|
||||
"""
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
return cls.from_texts(
|
||||
texts=texts,
|
||||
embedding=embedding,
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
collection_name=collection_name,
|
||||
persist_directory=persist_directory,
|
||||
client_settings=client_settings,
|
||||
client=client,
|
||||
collection_metadata=collection_metadata,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None:
|
||||
"""Delete by vector IDs.
|
||||
|
||||
Args:
|
||||
ids: List of ids to delete.
|
||||
"""
|
||||
self._collection.delete(ids=ids)
|
||||
>>>>>>> 49c7b97c (code quality fixes: line length = 80)
|
@ -0,0 +1,177 @@
|
||||
import uuid
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from swarms.memory.schemas import Artifact, Status
|
||||
from swarms.memory.schemas import Step as APIStep
|
||||
from swarms.memory.schemas import Task as APITask
|
||||
|
||||
|
||||
class Step(APIStep):
|
||||
additional_properties: Optional[Dict[str, str]] = None
|
||||
|
||||
|
||||
class Task(APITask):
|
||||
steps: List[Step] = []
|
||||
|
||||
|
||||
class NotFoundException(Exception):
|
||||
"""
|
||||
Exception raised when a resource is not found.
|
||||
"""
|
||||
|
||||
def __init__(self, item_name: str, item_id: str):
|
||||
self.item_name = item_name
|
||||
self.item_id = item_id
|
||||
super().__init__(f"{item_name} with {item_id} not found.")
|
||||
|
||||
|
||||
class TaskDB(ABC):
|
||||
async def create_task(
|
||||
self,
|
||||
input: Optional[str],
|
||||
additional_input: Any = None,
|
||||
artifacts: Optional[List[Artifact]] = None,
|
||||
steps: Optional[List[Step]] = None,
|
||||
) -> Task:
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_step(
|
||||
self,
|
||||
task_id: str,
|
||||
name: Optional[str] = None,
|
||||
input: Optional[str] = None,
|
||||
is_last: bool = False,
|
||||
additional_properties: Optional[Dict[str, str]] = None,
|
||||
) -> Step:
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_artifact(
|
||||
self,
|
||||
task_id: str,
|
||||
file_name: str,
|
||||
relative_path: Optional[str] = None,
|
||||
step_id: Optional[str] = None,
|
||||
) -> Artifact:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_task(self, task_id: str) -> Task:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_step(self, task_id: str, step_id: str) -> Step:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
|
||||
raise NotImplementedError
|
||||
|
||||
async def list_tasks(self) -> List[Task]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def list_steps(
|
||||
self, task_id: str, status: Optional[Status] = None
|
||||
) -> List[Step]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class InMemoryTaskDB(TaskDB):
|
||||
_tasks: Dict[str, Task] = {}
|
||||
|
||||
async def create_task(
|
||||
self,
|
||||
input: Optional[str],
|
||||
additional_input: Any = None,
|
||||
artifacts: Optional[List[Artifact]] = None,
|
||||
steps: Optional[List[Step]] = None,
|
||||
) -> Task:
|
||||
if not steps:
|
||||
steps = []
|
||||
if not artifacts:
|
||||
artifacts = []
|
||||
task_id = str(uuid.uuid4())
|
||||
task = Task(
|
||||
task_id=task_id,
|
||||
input=input,
|
||||
steps=steps,
|
||||
artifacts=artifacts,
|
||||
additional_input=additional_input,
|
||||
)
|
||||
self._tasks[task_id] = task
|
||||
return task
|
||||
|
||||
async def create_step(
|
||||
self,
|
||||
task_id: str,
|
||||
name: Optional[str] = None,
|
||||
input: Optional[str] = None,
|
||||
is_last=False,
|
||||
additional_properties: Optional[Dict[str, Any]] = None,
|
||||
) -> Step:
|
||||
step_id = str(uuid.uuid4())
|
||||
step = Step(
|
||||
task_id=task_id,
|
||||
step_id=step_id,
|
||||
name=name,
|
||||
input=input,
|
||||
status=Status.created,
|
||||
is_last=is_last,
|
||||
additional_properties=additional_properties,
|
||||
)
|
||||
task = await self.get_task(task_id)
|
||||
task.steps.append(step)
|
||||
return step
|
||||
|
||||
async def get_task(self, task_id: str) -> Task:
|
||||
task = self._tasks.get(task_id, None)
|
||||
if not task:
|
||||
raise NotFoundException("Task", task_id)
|
||||
return task
|
||||
|
||||
async def get_step(self, task_id: str, step_id: str) -> Step:
|
||||
task = await self.get_task(task_id)
|
||||
step = next(filter(lambda s: s.task_id == task_id, task.steps), None)
|
||||
if not step:
|
||||
raise NotFoundException("Step", step_id)
|
||||
return step
|
||||
|
||||
async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
|
||||
task = await self.get_task(task_id)
|
||||
artifact = next(
|
||||
filter(lambda a: a.artifact_id == artifact_id, task.artifacts), None
|
||||
)
|
||||
if not artifact:
|
||||
raise NotFoundException("Artifact", artifact_id)
|
||||
return artifact
|
||||
|
||||
async def create_artifact(
|
||||
self,
|
||||
task_id: str,
|
||||
file_name: str,
|
||||
relative_path: Optional[str] = None,
|
||||
step_id: Optional[str] = None,
|
||||
) -> Artifact:
|
||||
artifact_id = str(uuid.uuid4())
|
||||
artifact = Artifact(
|
||||
artifact_id=artifact_id,
|
||||
file_name=file_name,
|
||||
relative_path=relative_path,
|
||||
)
|
||||
task = await self.get_task(task_id)
|
||||
task.artifacts.append(artifact)
|
||||
|
||||
if step_id:
|
||||
step = await self.get_step(task_id, step_id)
|
||||
step.artifacts.append(artifact)
|
||||
|
||||
return artifact
|
||||
|
||||
async def list_tasks(self) -> List[Task]:
|
||||
return [task for task in self._tasks.values()]
|
||||
|
||||
async def list_steps(
|
||||
self, task_id: str, status: Optional[Status] = None
|
||||
) -> List[Step]:
|
||||
task = await self.get_task(task_id)
|
||||
steps = task.steps
|
||||
if status:
|
||||
steps = list(filter(lambda s: s.status == status, steps))
|
||||
return steps
|
@ -0,0 +1,157 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
import oceandb
|
||||
from oceandb.utils.embedding_function import MultiModalEmbeddingFunction
|
||||
|
||||
|
||||
class OceanDB:
|
||||
"""
|
||||
A class to interact with OceanDB.
|
||||
|
||||
...
|
||||
|
||||
Attributes
|
||||
----------
|
||||
client : oceandb.Client
|
||||
a client to interact with OceanDB
|
||||
|
||||
Methods
|
||||
-------
|
||||
create_collection(collection_name: str, modality: str):
|
||||
Creates a new collection in OceanDB.
|
||||
append_document(collection, document: str, id: str):
|
||||
Appends a document to a collection in OceanDB.
|
||||
add_documents(collection, documents: List[str], ids: List[str]):
|
||||
Adds multiple documents to a collection in OceanDB.
|
||||
query(collection, query_texts: list[str], n_results: int):
|
||||
Queries a collection in OceanDB.
|
||||
"""
|
||||
|
||||
def __init__(self, client: oceandb.Client = None):
|
||||
"""
|
||||
Constructs all the necessary attributes for the OceanDB object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
client : oceandb.Client, optional
|
||||
a client to interact with OceanDB (default is None, which creates a new client)
|
||||
"""
|
||||
try:
|
||||
self.client = client if client else oceandb.Client()
|
||||
print(self.client.heartbeat())
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to initialize OceanDB client. Error: {e}")
|
||||
raise
|
||||
|
||||
def create_collection(self, collection_name: str, modality: str):
|
||||
"""
|
||||
Creates a new collection in OceanDB.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
collection_name : str
|
||||
the name of the new collection
|
||||
modality : str
|
||||
the modality of the new collection
|
||||
|
||||
Returns
|
||||
-------
|
||||
collection
|
||||
the created collection
|
||||
"""
|
||||
try:
|
||||
embedding_function = MultiModalEmbeddingFunction(modality=modality)
|
||||
collection = self.client.create_collection(
|
||||
collection_name, embedding_function=embedding_function
|
||||
)
|
||||
return collection
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to create collection. Error {e}")
|
||||
raise
|
||||
|
||||
def append_document(self, collection, document: str, id: str):
|
||||
"""
|
||||
Appends a document to a collection in OceanDB.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
collection
|
||||
the collection to append the document to
|
||||
document : str
|
||||
the document to append
|
||||
id : str
|
||||
the id of the document
|
||||
|
||||
Returns
|
||||
-------
|
||||
result
|
||||
the result of the append operation
|
||||
"""
|
||||
try:
|
||||
return collection.add(documents=[document], ids=[id])
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Failed to append document to the collection. Error {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
def add_documents(self, collection, documents: List[str], ids: List[str]):
|
||||
"""
|
||||
Adds multiple documents to a collection in OceanDB.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
collection
|
||||
the collection to add the documents to
|
||||
documents : List[str]
|
||||
the documents to add
|
||||
ids : List[str]
|
||||
the ids of the documents
|
||||
|
||||
Returns
|
||||
-------
|
||||
result
|
||||
the result of the add operation
|
||||
"""
|
||||
try:
|
||||
return collection.add(documents=documents, ids=ids)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to add documents to collection. Error: {e}")
|
||||
raise
|
||||
|
||||
def query(self, collection, query_texts: list[str], n_results: int):
|
||||
"""
|
||||
Queries a collection in OceanDB.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
collection
|
||||
the collection to query
|
||||
query_texts : list[str]
|
||||
the texts to query
|
||||
n_results : int
|
||||
the number of results to return
|
||||
|
||||
Returns
|
||||
-------
|
||||
results
|
||||
the results of the query
|
||||
"""
|
||||
try:
|
||||
results = collection.query(
|
||||
query_texts=query_texts, n_results=n_results
|
||||
)
|
||||
return results
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to query the collection. Error {e}")
|
||||
raise
|
||||
|
||||
|
||||
# Example
|
||||
# ocean = OceanDB()
|
||||
# collection = ocean.create_collection("test", "text")
|
||||
# ocean.append_document(collection, "hello world", "1")
|
||||
# ocean.add_documents(collection, ["hello world", "hello world"], ["2", "3"])
|
||||
# results = ocean.query(collection, ["hello world"], 3)
|
||||
# print(results)
|
@ -0,0 +1,261 @@
|
||||
import os
|
||||
import asyncio
|
||||
import base64
|
||||
import concurrent.futures
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import openai
|
||||
import requests
|
||||
from cachetools import TTLCache
|
||||
from dotenv import load_dotenv
|
||||
from openai import OpenAI
|
||||
from ratelimit import limits, sleep_and_retry
|
||||
from termcolor import colored
|
||||
|
||||
# ENV
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPT4VisionResponse:
|
||||
"""A response structure for GPT-4"""
|
||||
|
||||
answer: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPT4Vision:
|
||||
"""
|
||||
GPT4Vision model class
|
||||
|
||||
Attributes:
|
||||
-----------
|
||||
max_retries: int
|
||||
The maximum number of retries to make to the API
|
||||
backoff_factor: float
|
||||
The backoff factor to use for exponential backoff
|
||||
timeout_seconds: int
|
||||
The timeout in seconds for the API request
|
||||
api_key: str
|
||||
The API key to use for the API request
|
||||
quality: str
|
||||
The quality of the image to generate
|
||||
max_tokens: int
|
||||
The maximum number of tokens to use for the API request
|
||||
|
||||
Methods:
|
||||
--------
|
||||
process_img(self, img_path: str) -> str:
|
||||
Processes the image to be used for the API request
|
||||
run(self, img: Union[str, List[str]], tasks: List[str]) -> GPT4VisionResponse:
|
||||
Makes a call to the GPT-4 Vision API and returns the image url
|
||||
|
||||
Example:
|
||||
>>> gpt4vision = GPT4Vision()
|
||||
>>> img = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
|
||||
>>> tasks = ["A painting of a dog"]
|
||||
>>> answer = gpt4vision(img, tasks)
|
||||
>>> print(answer)
|
||||
|
||||
"""
|
||||
|
||||
max_retries: int = 3
|
||||
model: str = "gpt-4-vision-preview"
|
||||
backoff_factor: float = 2.0
|
||||
timeout_seconds: int = 10
|
||||
openai_api_key: Optional[str] = None or os.getenv("OPENAI_API_KEY")
|
||||
# 'Low' or 'High' for respesctively fast or high quality, but high more token usage
|
||||
quality: str = "low"
|
||||
# Max tokens to use for the API request, the maximum might be 3,000 but we don't know
|
||||
max_tokens: int = 200
|
||||
client = OpenAI(
|
||||
api_key=openai_api_key,
|
||||
)
|
||||
dashboard: bool = True
|
||||
call_limit: int = 1
|
||||
period_seconds: int = 60
|
||||
|
||||
# Cache for storing API Responses
|
||||
cache = TTLCache(maxsize=100, ttl=600) # Cache for 10 minutes
|
||||
|
||||
class Config:
|
||||
"""Config class for the GPT4Vision model"""
|
||||
|
||||
arbitary_types_allowed = True
|
||||
|
||||
def process_img(self, img: str) -> str:
|
||||
"""Processes the image to be used for the API request"""
|
||||
with open(img, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
@sleep_and_retry
|
||||
@limits(
|
||||
calls=call_limit, period=period_seconds
|
||||
) # Rate limit of 10 calls per minute
|
||||
def run(self, task: str, img: str):
|
||||
"""
|
||||
Run the GPT-4 Vision model
|
||||
|
||||
Task: str
|
||||
The task to run
|
||||
Img: str
|
||||
The image to run the task on
|
||||
|
||||
"""
|
||||
if self.dashboard:
|
||||
self.print_dashboard()
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model="gpt-4-vision-preview",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": task},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": str(img),
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
out = print(response.choices[0])
|
||||
# out = self.clean_output(out)
|
||||
return out
|
||||
except openai.OpenAIError as e:
|
||||
# logger.error(f"OpenAI API error: {e}")
|
||||
return f"OpenAI API error: Could not process the image. {e}"
|
||||
except Exception as e:
|
||||
return f"Unexpected error occurred while processing the image. {e}"
|
||||
|
||||
def clean_output(self, output: str):
|
||||
# Regex pattern to find the Choice object representation in the output
|
||||
pattern = r"Choice\(.*?\(content=\"(.*?)\".*?\)\)"
|
||||
match = re.search(pattern, output, re.DOTALL)
|
||||
|
||||
if match:
|
||||
# Extract the content from the matched pattern
|
||||
content = match.group(1)
|
||||
# Replace escaped quotes to get the clean content
|
||||
content = content.replace(r"\"", '"')
|
||||
print(content)
|
||||
else:
|
||||
print("No content found in the output.")
|
||||
|
||||
async def arun(self, task: str, img: str):
|
||||
"""
|
||||
Arun is an async version of run
|
||||
|
||||
Task: str
|
||||
The task to run
|
||||
Img: str
|
||||
The image to run the task on
|
||||
|
||||
"""
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model="gpt-4-vision-preview",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": task},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": img,
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
return print(response.choices[0])
|
||||
except openai.OpenAIError as e:
|
||||
# logger.error(f"OpenAI API error: {e}")
|
||||
return f"OpenAI API error: Could not process the image. {e}"
|
||||
except Exception as e:
|
||||
return f"Unexpected error occurred while processing the image. {e}"
|
||||
|
||||
def run_batch(self, tasks_images: List[Tuple[str, str]]) -> List[str]:
|
||||
"""Process a batch of tasks and images"""
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = [
|
||||
executor.submit(self.run, task, img)
|
||||
for task, img in tasks_images
|
||||
]
|
||||
results = [future.result() for future in futures]
|
||||
return results
|
||||
|
||||
async def run_batch_async(
|
||||
self, tasks_images: List[Tuple[str, str]]
|
||||
) -> List[str]:
|
||||
"""Process a batch of tasks and images asynchronously"""
|
||||
loop = asyncio.get_event_loop()
|
||||
futures = [
|
||||
loop.run_in_executor(None, self.run, task, img)
|
||||
for task, img in tasks_images
|
||||
]
|
||||
return await asyncio.gather(*futures)
|
||||
|
||||
async def run_batch_async_with_retries(
|
||||
self, tasks_images: List[Tuple[str, str]]
|
||||
) -> List[str]:
|
||||
"""Process a batch of tasks and images asynchronously with retries"""
|
||||
loop = asyncio.get_event_loop()
|
||||
futures = [
|
||||
loop.run_in_executor(None, self.run_with_retries, task, img)
|
||||
for task, img in tasks_images
|
||||
]
|
||||
return await asyncio.gather(*futures)
|
||||
|
||||
def print_dashboard(self):
|
||||
dashboard = print(
|
||||
colored(
|
||||
f"""
|
||||
GPT4Vision Dashboard
|
||||
-------------------
|
||||
Max Retries: {self.max_retries}
|
||||
Model: {self.model}
|
||||
Backoff Factor: {self.backoff_factor}
|
||||
Timeout Seconds: {self.timeout_seconds}
|
||||
Image Quality: {self.quality}
|
||||
Max Tokens: {self.max_tokens}
|
||||
|
||||
""",
|
||||
"green",
|
||||
)
|
||||
)
|
||||
return dashboard
|
||||
|
||||
def health_check(self):
|
||||
"""Health check for the GPT4Vision model"""
|
||||
try:
|
||||
response = requests.get("https://api.openai.com/v1/engines")
|
||||
return response.status_code == 200
|
||||
except requests.RequestException as error:
|
||||
print(f"Health check failed: {error}")
|
||||
return False
|
||||
|
||||
def sanitize_input(self, text: str) -> str:
|
||||
"""
|
||||
Sanitize input to prevent injection attacks.
|
||||
|
||||
Parameters:
|
||||
text: str - The input text to be sanitized.
|
||||
|
||||
Returns:
|
||||
The sanitized text.
|
||||
"""
|
||||
# Example of simple sanitization, this should be expanded based on the context and usage
|
||||
sanitized_text = re.sub(r"[^\w\s]", "", text)
|
||||
return sanitized_text
|
@ -0,0 +1,93 @@
|
||||
import os
|
||||
from typing import Callable, List
|
||||
|
||||
|
||||
class DialogueSimulator:
|
||||
"""
|
||||
Dialogue Simulator
|
||||
------------------
|
||||
|
||||
Args:
|
||||
------
|
||||
agents: List[Callable]
|
||||
max_iters: int
|
||||
name: str
|
||||
|
||||
Usage:
|
||||
------
|
||||
>>> from swarms import DialogueSimulator
|
||||
>>> from swarms.structs.flow import Flow
|
||||
>>> agents = Flow()
|
||||
>>> agents1 = Flow()
|
||||
>>> model = DialogueSimulator([agents, agents1], max_iters=10, name="test")
|
||||
>>> model.run("test")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, agents: List[Callable], max_iters: int = 10, name: str = None
|
||||
):
|
||||
self.agents = agents
|
||||
self.max_iters = max_iters
|
||||
self.name = name
|
||||
|
||||
def run(self, message: str = None):
|
||||
"""Run the dialogue simulator"""
|
||||
try:
|
||||
step = 0
|
||||
if self.name and message:
|
||||
prompt = f"Name {self.name} and message: {message}"
|
||||
for agent in self.agents:
|
||||
agent.run(prompt)
|
||||
step += 1
|
||||
|
||||
while step < self.max_iters:
|
||||
speaker_idx = step % len(self.agents)
|
||||
speaker = self.agents[speaker_idx]
|
||||
speaker_message = speaker.run(prompt)
|
||||
|
||||
for receiver in self.agents:
|
||||
message_history = (
|
||||
f"Speaker Name: {speaker.name} and message:"
|
||||
f" {speaker_message}"
|
||||
)
|
||||
receiver.run(message_history)
|
||||
|
||||
print(f"({speaker.name}): {speaker_message}")
|
||||
print("\n")
|
||||
step += 1
|
||||
except Exception as error:
|
||||
print(f"Error running dialogue simulator: {error}")
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"DialogueSimulator({self.agents}, {self.max_iters}, {self.name})"
|
||||
)
|
||||
|
||||
def save_state(self):
|
||||
"""Save the state of the dialogue simulator"""
|
||||
try:
|
||||
if self.name:
|
||||
filename = f"{self.name}.txt"
|
||||
with open(filename, "w") as file:
|
||||
file.write(str(self))
|
||||
except Exception as error:
|
||||
print(f"Error saving state: {error}")
|
||||
|
||||
def load_state(self):
|
||||
"""Load the state of the dialogue simulator"""
|
||||
try:
|
||||
if self.name:
|
||||
filename = f"{self.name}.txt"
|
||||
with open(filename, "r") as file:
|
||||
return file.read()
|
||||
except Exception as error:
|
||||
print(f"Error loading state: {error}")
|
||||
|
||||
def delete_state(self):
|
||||
"""Delete the state of the dialogue simulator"""
|
||||
try:
|
||||
if self.name:
|
||||
filename = f"{self.name}.txt"
|
||||
os.remove(filename)
|
||||
except Exception as error:
|
||||
print(f"Error deleting state: {error}")
|
@ -0,0 +1,287 @@
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import chromadb
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
|
||||
class TaskStatus(Enum):
|
||||
QUEUED = 1
|
||||
RUNNING = 2
|
||||
COMPLETED = 3
|
||||
FAILED = 4
|
||||
|
||||
|
||||
class Orchestrator:
|
||||
"""
|
||||
The Orchestrator takes in an agent, worker, or boss as input
|
||||
then handles all the logic for
|
||||
- task creation,
|
||||
- task assignment,
|
||||
- and task compeletion.
|
||||
|
||||
And, the communication for millions of agents to chat with eachother through
|
||||
a vector database that each agent has access to chat with.
|
||||
|
||||
Each LLM agent chats with the orchestrator through a dedicated
|
||||
communication layer. The orchestrator assigns tasks to each LLM agent,
|
||||
which the agents then complete and return.
|
||||
|
||||
This setup allows for a high degree of flexibility, scalability, and robustness.
|
||||
|
||||
In the context of swarm LLMs, one could consider an **Omni-Vector Embedding Database
|
||||
for communication. This database could store and manage
|
||||
the high-dimensional vectors produced by each LLM agent.
|
||||
|
||||
Strengths: This approach would allow for similarity-based lookup and matching of
|
||||
LLM-generated vectors, which can be particularly useful for tasks that involve finding similar outputs or recognizing patterns.
|
||||
|
||||
Weaknesses: An Omni-Vector Embedding Database might add complexity to the system in terms of setup and maintenance.
|
||||
It might also require significant computational resources,
|
||||
depending on the volume of data being handled and the complexity of the vectors.
|
||||
The handling and transmission of high-dimensional vectors could also pose challenges
|
||||
in terms of network load.
|
||||
|
||||
# Orchestrator
|
||||
* Takes in an agent class with vector store,
|
||||
then handles all the communication and scales
|
||||
up a swarm with number of agents and handles task assignment and task completion
|
||||
|
||||
from swarms import OpenAI, Orchestrator, Swarm
|
||||
|
||||
orchestrated = Orchestrate(OpenAI, nodes=40) #handles all the task assignment and allocation and agent communication using a vectorstore as a universal communication layer and also handlles the task completion logic
|
||||
|
||||
Objective = "Make a business website for a marketing consultancy"
|
||||
|
||||
Swarms = Swarms(orchestrated, auto=True, Objective))
|
||||
```
|
||||
|
||||
In terms of architecture, the swarm might look something like this:
|
||||
|
||||
```
|
||||
(Orchestrator)
|
||||
/ \
|
||||
Tools + Vector DB -- (LLM Agent)---(Communication Layer) (Communication Layer)---(LLM Agent)-- Tools + Vector DB
|
||||
/ | | \
|
||||
(Task Assignment) (Task Completion) (Task Assignment) (Task Completion)
|
||||
|
||||
|
||||
###Usage
|
||||
```
|
||||
from swarms import Orchestrator
|
||||
|
||||
# Instantiate the Orchestrator with 10 agents
|
||||
orchestrator = Orchestrator(llm, agent_list=[llm]*10, task_queue=[])
|
||||
|
||||
# Add tasks to the Orchestrator
|
||||
tasks = [{"content": f"Write a short story about a {animal}."} for animal in ["cat", "dog", "bird", "fish", "lion", "tiger", "elephant", "giraffe", "monkey", "zebra"]]
|
||||
orchestrator.assign_tasks(tasks)
|
||||
|
||||
# Run the Orchestrator
|
||||
orchestrator.run()
|
||||
|
||||
# Retrieve the results
|
||||
for task in tasks:
|
||||
print(orchestrator.retrieve_result(id(task)))
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent,
|
||||
agent_list: List[Any],
|
||||
task_queue: List[Any],
|
||||
collection_name: str = "swarm",
|
||||
api_key: str = None,
|
||||
model_name: str = None,
|
||||
embed_func=None,
|
||||
worker=None,
|
||||
):
|
||||
self.agent = agent
|
||||
self.agents = queue.Queue()
|
||||
|
||||
for _ in range(agent_list):
|
||||
self.agents.put(agent())
|
||||
|
||||
self.task_queue = queue.Queue()
|
||||
|
||||
self.chroma_client = chromadb.Client()
|
||||
|
||||
self.collection = self.chroma_client.create_collection(
|
||||
name=collection_name
|
||||
)
|
||||
|
||||
self.current_tasks = {}
|
||||
|
||||
self.lock = threading.Lock()
|
||||
self.condition = threading.Condition(self.lock)
|
||||
self.executor = ThreadPoolExecutor(max_workers=len(agent_list))
|
||||
|
||||
self.embed_func = embed_func if embed_func else self.embed
|
||||
|
||||
# @abstractmethod
|
||||
|
||||
def assign_task(self, agent_id: int, task: Dict[str, Any]) -> None:
|
||||
"""Assign a task to a specific agent"""
|
||||
|
||||
while True:
|
||||
with self.condition:
|
||||
while not self.task_queue:
|
||||
self.condition.wait()
|
||||
agent = self.agents.get()
|
||||
task = self.task_queue.get()
|
||||
|
||||
try:
|
||||
result = self.worker.run(task["content"])
|
||||
|
||||
# using the embed method to get the vector representation of the result
|
||||
vector_representation = self.embed(
|
||||
result, self.api_key, self.model_name
|
||||
)
|
||||
|
||||
self.collection.add(
|
||||
embeddings=[vector_representation],
|
||||
documents=[str(id(task))],
|
||||
ids=[str(id(task))],
|
||||
)
|
||||
|
||||
logging.info(
|
||||
f"Task {id(str)} has been processed by agent"
|
||||
f" {id(agent)} with"
|
||||
)
|
||||
|
||||
except Exception as error:
|
||||
logging.error(
|
||||
f"Failed to process task {id(task)} by agent {id(agent)}."
|
||||
f" Error: {error}"
|
||||
)
|
||||
finally:
|
||||
with self.condition:
|
||||
self.agents.put(agent)
|
||||
self.condition.notify()
|
||||
|
||||
def embed(self, input, api_key, model_name):
|
||||
openai = embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key=api_key, model_name=model_name
|
||||
)
|
||||
embedding = openai(input)
|
||||
return embedding
|
||||
|
||||
# @abstractmethod
|
||||
|
||||
def retrieve_results(self, agent_id: int) -> Any:
|
||||
"""Retrieve results from a specific agent"""
|
||||
|
||||
try:
|
||||
# Query the vector database for documents created by the agents
|
||||
results = self.collection.query(
|
||||
query_texts=[str(agent_id)], n_results=10
|
||||
)
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Failed to retrieve results from agent {agent_id}. Error {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
# @abstractmethod
|
||||
def update_vector_db(self, data) -> None:
|
||||
"""Update the vector database"""
|
||||
|
||||
try:
|
||||
self.collection.add(
|
||||
embeddings=[data["vector"]],
|
||||
documents=[str(data["task_id"])],
|
||||
ids=[str(data["task_id"])],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to update the vector database. Error: {e}")
|
||||
raise
|
||||
|
||||
# @abstractmethod
|
||||
|
||||
def get_vector_db(self):
|
||||
"""Retrieve the vector database"""
|
||||
return self.collection
|
||||
|
||||
def append_to_db(self, result: str):
|
||||
"""append the result of the swarm to a specifici collection in the database"""
|
||||
|
||||
try:
|
||||
self.collection.add(documents=[result], ids=[str(id(result))])
|
||||
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Failed to append the agent output to database. Error: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
def run(self, objective: str):
|
||||
"""Runs"""
|
||||
if not objective or not isinstance(objective, str):
|
||||
logging.error("Invalid objective")
|
||||
raise ValueError("A valid objective is required")
|
||||
|
||||
try:
|
||||
self.task_queue.append(objective)
|
||||
|
||||
results = [
|
||||
self.assign_task(agent_id, task)
|
||||
for agent_id, task in zip(
|
||||
range(len(self.agents)), self.task_queue
|
||||
)
|
||||
]
|
||||
|
||||
for result in results:
|
||||
self.append_to_db(result)
|
||||
|
||||
logging.info(f"Successfully ran swarms with results: {results}")
|
||||
return results
|
||||
except Exception as e:
|
||||
logging.error(f"An error occured in swarm: {e}")
|
||||
return None
|
||||
|
||||
def chat(self, sender_id: int, receiver_id: int, message: str):
|
||||
"""
|
||||
|
||||
Allows the agents to chat with eachother thrught the vectordatabase
|
||||
|
||||
# Instantiate the Orchestrator with 10 agents
|
||||
orchestrator = Orchestrator(
|
||||
llm,
|
||||
agent_list=[llm]*10,
|
||||
task_queue=[]
|
||||
)
|
||||
|
||||
# Agent 1 sends a message to Agent 2
|
||||
orchestrator.chat(sender_id=1, receiver_id=2, message="Hello, Agent 2!")
|
||||
|
||||
"""
|
||||
|
||||
message_vector = self.embed(message, self.api_key, self.model_name)
|
||||
|
||||
# store the mesage in the vector database
|
||||
self.collection.add(
|
||||
embeddings=[message_vector],
|
||||
documents=[message],
|
||||
ids=[f"{sender_id}_to_{receiver_id}"],
|
||||
)
|
||||
|
||||
self.run(objective=f"chat with agent {receiver_id} about {message}")
|
||||
|
||||
def add_agents(self, num_agents: int):
|
||||
for _ in range(num_agents):
|
||||
self.agents.put(self.agent())
|
||||
self.executor = ThreadPoolExecutor(max_workers=self.agents.qsize())
|
||||
|
||||
def remove_agents(self, num_agents):
|
||||
for _ in range(num_agents):
|
||||
if not self.agents.empty():
|
||||
self.agents.get()
|
||||
self.executor = ThreadPoolExecutor(max_workers=self.agents.qsize())
|
@ -0,0 +1,200 @@
|
||||
import asyncio
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from langchain.agents import tool
|
||||
from langchain.agents.agent_toolkits.pandas.base import (
|
||||
create_pandas_dataframe_agent,
|
||||
)
|
||||
from langchain.chains.qa_with_sources.loading import (
|
||||
BaseCombineDocumentsChain,
|
||||
)
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.tools import BaseTool
|
||||
from PIL import Image
|
||||
from pydantic import Field
|
||||
from transformers import (
|
||||
BlipForQuestionAnswering,
|
||||
BlipProcessor,
|
||||
)
|
||||
|
||||
from swarms.utils.logger import logger
|
||||
|
||||
ROOT_DIR = "./data/"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def pushd(new_dir):
|
||||
"""Context manager for changing the current working directory."""
|
||||
prev_dir = os.getcwd()
|
||||
os.chdir(new_dir)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.chdir(prev_dir)
|
||||
|
||||
|
||||
@tool
|
||||
def process_csv(
|
||||
llm,
|
||||
csv_file_path: str,
|
||||
instructions: str,
|
||||
output_path: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Process a CSV by with pandas in a limited REPL.\
|
||||
Only use this after writing data to disk as a csv file.\
|
||||
Any figures must be saved to disk to be viewed by the human.\
|
||||
Instructions should be written in natural language, not code. Assume the dataframe is already loaded."""
|
||||
with pushd(ROOT_DIR):
|
||||
try:
|
||||
df = pd.read_csv(csv_file_path)
|
||||
except Exception as e:
|
||||
return f"Error: {e}"
|
||||
agent = create_pandas_dataframe_agent(
|
||||
llm, df, max_iterations=30, verbose=False
|
||||
)
|
||||
if output_path is not None:
|
||||
instructions += f" Save output to disk at {output_path}"
|
||||
try:
|
||||
result = agent.run(instructions)
|
||||
return result
|
||||
except Exception as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
|
||||
async def async_load_playwright(url: str) -> str:
|
||||
"""Load the specified URLs using Playwright and parse using BeautifulSoup."""
|
||||
from bs4 import BeautifulSoup
|
||||
from playwright.async_api import async_playwright
|
||||
|
||||
results = ""
|
||||
async with async_playwright() as p:
|
||||
browser = await p.chromium.launch(headless=True)
|
||||
try:
|
||||
page = await browser.new_page()
|
||||
await page.goto(url)
|
||||
|
||||
page_source = await page.content()
|
||||
soup = BeautifulSoup(page_source, "html.parser")
|
||||
|
||||
for script in soup(["script", "style"]):
|
||||
script.extract()
|
||||
|
||||
text = soup.get_text()
|
||||
lines = (line.strip() for line in text.splitlines())
|
||||
chunks = (
|
||||
phrase.strip() for line in lines for phrase in line.split(" ")
|
||||
)
|
||||
results = "\n".join(chunk for chunk in chunks if chunk)
|
||||
except Exception as e:
|
||||
results = f"Error: {e}"
|
||||
await browser.close()
|
||||
return results
|
||||
|
||||
|
||||
def run_async(coro):
|
||||
event_loop = asyncio.get_event_loop()
|
||||
return event_loop.run_until_complete(coro)
|
||||
|
||||
|
||||
@tool
|
||||
def browse_web_page(url: str) -> str:
|
||||
"""Verbose way to scrape a whole webpage. Likely to cause issues parsing."""
|
||||
return run_async(async_load_playwright(url))
|
||||
|
||||
|
||||
def _get_text_splitter():
|
||||
return RecursiveCharacterTextSplitter(
|
||||
# Set a really small chunk size, just to show.
|
||||
chunk_size=500,
|
||||
chunk_overlap=20,
|
||||
length_function=len,
|
||||
)
|
||||
|
||||
|
||||
class WebpageQATool(BaseTool):
|
||||
name = "query_webpage"
|
||||
description = (
|
||||
"Browse a webpage and retrieve the information relevant to the"
|
||||
" question."
|
||||
)
|
||||
text_splitter: RecursiveCharacterTextSplitter = Field(
|
||||
default_factory=_get_text_splitter
|
||||
)
|
||||
qa_chain: BaseCombineDocumentsChain
|
||||
|
||||
def _run(self, url: str, question: str) -> str:
|
||||
"""Useful for browsing websites and scraping the text information."""
|
||||
result = browse_web_page.run(url)
|
||||
docs = [Document(page_content=result, metadata={"source": url})]
|
||||
web_docs = self.text_splitter.split_documents(docs)
|
||||
results = []
|
||||
# TODO: Handle this with a MapReduceChain
|
||||
for i in range(0, len(web_docs), 4):
|
||||
input_docs = web_docs[i : i + 4]
|
||||
window_result = self.qa_chain(
|
||||
{"input_documents": input_docs, "question": question},
|
||||
return_only_outputs=True,
|
||||
)
|
||||
results.append(f"Response from window {i} - {window_result}")
|
||||
results_docs = [
|
||||
Document(page_content="\n".join(results), metadata={"source": url})
|
||||
]
|
||||
return self.qa_chain(
|
||||
{"input_documents": results_docs, "question": question},
|
||||
return_only_outputs=True,
|
||||
)
|
||||
|
||||
async def _arun(self, url: str, question: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class EdgeGPTTool:
|
||||
# Initialize the custom tool
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
name="EdgeGPTTool",
|
||||
description="Tool that uses EdgeGPTModel to generate responses",
|
||||
):
|
||||
super().__init__(name=name, description=description)
|
||||
self.model = model
|
||||
|
||||
def _run(self, prompt):
|
||||
return self.model.__call__(prompt)
|
||||
|
||||
|
||||
@tool
|
||||
def VQAinference(self, inputs):
|
||||
"""
|
||||
Answer Question About The Image, VQA Multi-Modal Worker agent
|
||||
description="useful when you need an answer for a question based on an image. "
|
||||
"like: what is the background color of the last image, how many cats in this figure, what is in this figure. "
|
||||
"The input to this tool should be a comma separated string of two, representing the image_path and the question",
|
||||
|
||||
"""
|
||||
device = "cuda:0"
|
||||
torch_dtype = torch.float16 if "cuda" in device else torch.float32
|
||||
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
|
||||
model = BlipForQuestionAnswering.from_pretrained(
|
||||
"Salesforce/blip-vqa-base", torch_dtype=torch_dtype
|
||||
).to(device)
|
||||
|
||||
image_path, question = inputs.split(",")
|
||||
raw_image = Image.open(image_path).convert("RGB")
|
||||
inputs = processor(raw_image, question, return_tensors="pt").to(
|
||||
device, torch_dtype
|
||||
)
|
||||
out = model.generate(**inputs)
|
||||
answer = processor.decode(out[0], skip_special_tokens=True)
|
||||
|
||||
logger.debug(
|
||||
f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input"
|
||||
f" Question: {question}, Output Answer: {answer}"
|
||||
)
|
||||
|
||||
return answer
|
@ -0,0 +1,284 @@
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import (
|
||||
EulerAncestralDiscreteScheduler,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionInstructPix2PixPipeline,
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from PIL import Image
|
||||
from transformers import (
|
||||
BlipForConditionalGeneration,
|
||||
BlipForQuestionAnswering,
|
||||
BlipProcessor,
|
||||
CLIPSegForImageSegmentation,
|
||||
CLIPSegProcessor,
|
||||
)
|
||||
|
||||
from swarms.prompts.prebuild.multi_modal_prompts import IMAGE_PROMPT
|
||||
from swarms.tools.tool import tool
|
||||
from swarms.utils.logger import logger
|
||||
from swarms.utils.main import BaseHandler, get_new_image_name
|
||||
|
||||
|
||||
class MaskFormer:
|
||||
def __init__(self, device):
|
||||
print("Initializing MaskFormer to %s" % device)
|
||||
self.device = device
|
||||
self.processor = CLIPSegProcessor.from_pretrained(
|
||||
"CIDAS/clipseg-rd64-refined"
|
||||
)
|
||||
self.model = CLIPSegForImageSegmentation.from_pretrained(
|
||||
"CIDAS/clipseg-rd64-refined"
|
||||
).to(device)
|
||||
|
||||
def inference(self, image_path, text):
|
||||
threshold = 0.5
|
||||
min_area = 0.02
|
||||
padding = 20
|
||||
original_image = Image.open(image_path)
|
||||
image = original_image.resize((512, 512))
|
||||
inputs = self.processor(
|
||||
text=text, images=image, padding="max_length", return_tensors="pt"
|
||||
).to(self.device)
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
mask = torch.sigmoid(outputs[0]).squeeze().cpu().numpy() > threshold
|
||||
area_ratio = len(np.argwhere(mask)) / (mask.shape[0] * mask.shape[1])
|
||||
if area_ratio < min_area:
|
||||
return None
|
||||
true_indices = np.argwhere(mask)
|
||||
mask_array = np.zeros_like(mask, dtype=bool)
|
||||
for idx in true_indices:
|
||||
padded_slice = tuple(
|
||||
slice(max(0, i - padding), i + padding + 1) for i in idx
|
||||
)
|
||||
mask_array[padded_slice] = True
|
||||
visual_mask = (mask_array * 255).astype(np.uint8)
|
||||
image_mask = Image.fromarray(visual_mask)
|
||||
return image_mask.resize(original_image.size)
|
||||
|
||||
|
||||
class ImageEditing:
|
||||
def __init__(self, device):
|
||||
print("Initializing ImageEditing to %s" % device)
|
||||
self.device = device
|
||||
self.mask_former = MaskFormer(device=self.device)
|
||||
self.revision = "fp16" if "cuda" in device else None
|
||||
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
|
||||
self.inpaint = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
revision=self.revision,
|
||||
torch_dtype=self.torch_dtype,
|
||||
).to(device)
|
||||
|
||||
@tool(
|
||||
name="Remove Something From The Photo",
|
||||
description=(
|
||||
"useful when you want to remove and object or something from the"
|
||||
" photo from its description or location. The input to this tool"
|
||||
" should be a comma separated string of two, representing the"
|
||||
" image_path and the object need to be removed. "
|
||||
),
|
||||
)
|
||||
def inference_remove(self, inputs):
|
||||
image_path, to_be_removed_txt = inputs.split(",")
|
||||
return self.inference_replace(
|
||||
f"{image_path},{to_be_removed_txt},background"
|
||||
)
|
||||
|
||||
@tool(
|
||||
name="Replace Something From The Photo",
|
||||
description=(
|
||||
"useful when you want to replace an object from the object"
|
||||
" description or location with another object from its description."
|
||||
" The input to this tool should be a comma separated string of"
|
||||
" three, representing the image_path, the object to be replaced,"
|
||||
" the object to be replaced with "
|
||||
),
|
||||
)
|
||||
def inference_replace(self, inputs):
|
||||
image_path, to_be_replaced_txt, replace_with_txt = inputs.split(",")
|
||||
original_image = Image.open(image_path)
|
||||
original_size = original_image.size
|
||||
mask_image = self.mask_former.inference(image_path, to_be_replaced_txt)
|
||||
updated_image = self.inpaint(
|
||||
prompt=replace_with_txt,
|
||||
image=original_image.resize((512, 512)),
|
||||
mask_image=mask_image.resize((512, 512)),
|
||||
).images[0]
|
||||
updated_image_path = get_new_image_name(
|
||||
image_path, func_name="replace-something"
|
||||
)
|
||||
updated_image = updated_image.resize(original_size)
|
||||
updated_image.save(updated_image_path)
|
||||
|
||||
logger.debug(
|
||||
f"\nProcessed ImageEditing, Input Image: {image_path}, Replace"
|
||||
f" {to_be_replaced_txt} to {replace_with_txt}, Output Image:"
|
||||
f" {updated_image_path}"
|
||||
)
|
||||
|
||||
return updated_image_path
|
||||
|
||||
|
||||
class InstructPix2Pix:
|
||||
def __init__(self, device):
|
||||
print("Initializing InstructPix2Pix to %s" % device)
|
||||
self.device = device
|
||||
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
|
||||
self.pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
||||
"timbrooks/instruct-pix2pix",
|
||||
safety_checker=None,
|
||||
torch_dtype=self.torch_dtype,
|
||||
).to(device)
|
||||
self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
||||
self.pipe.scheduler.config
|
||||
)
|
||||
|
||||
@tool(
|
||||
name="Instruct Image Using Text",
|
||||
description=(
|
||||
"useful when you want to the style of the image to be like the"
|
||||
" text. like: make it look like a painting. or make it like a"
|
||||
" robot. The input to this tool should be a comma separated string"
|
||||
" of two, representing the image_path and the text. "
|
||||
),
|
||||
)
|
||||
def inference(self, inputs):
|
||||
"""Change style of image."""
|
||||
logger.debug("===> Starting InstructPix2Pix Inference")
|
||||
image_path, text = inputs.split(",")[0], ",".join(inputs.split(",")[1:])
|
||||
original_image = Image.open(image_path)
|
||||
image = self.pipe(
|
||||
text,
|
||||
image=original_image,
|
||||
num_inference_steps=40,
|
||||
image_guidance_scale=1.2,
|
||||
).images[0]
|
||||
updated_image_path = get_new_image_name(image_path, func_name="pix2pix")
|
||||
image.save(updated_image_path)
|
||||
|
||||
logger.debug(
|
||||
f"\nProcessed InstructPix2Pix, Input Image: {image_path}, Instruct"
|
||||
f" Text: {text}, Output Image: {updated_image_path}"
|
||||
)
|
||||
|
||||
return updated_image_path
|
||||
|
||||
|
||||
class Text2Image:
|
||||
def __init__(self, device):
|
||||
print("Initializing Text2Image to %s" % device)
|
||||
self.device = device
|
||||
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
|
||||
self.pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=self.torch_dtype
|
||||
)
|
||||
self.pipe.to(device)
|
||||
self.a_prompt = "best quality, extremely detailed"
|
||||
self.n_prompt = (
|
||||
"longbody, lowres, bad anatomy, bad hands, missing fingers, extra"
|
||||
" digit, fewer digits, cropped, worst quality, low quality"
|
||||
)
|
||||
|
||||
@tool(
|
||||
name="Generate Image From User Input Text",
|
||||
description=(
|
||||
"useful when you want to generate an image from a user input text"
|
||||
" and save it to a file. like: generate an image of an object or"
|
||||
" something, or generate an image that includes some objects. The"
|
||||
" input to this tool should be a string, representing the text used"
|
||||
" to generate image. "
|
||||
),
|
||||
)
|
||||
def inference(self, text):
|
||||
image_filename = os.path.join("image", str(uuid.uuid4())[0:8] + ".png")
|
||||
prompt = text + ", " + self.a_prompt
|
||||
image = self.pipe(prompt, negative_prompt=self.n_prompt).images[0]
|
||||
image.save(image_filename)
|
||||
|
||||
logger.debug(
|
||||
f"\nProcessed Text2Image, Input Text: {text}, Output Image:"
|
||||
f" {image_filename}"
|
||||
)
|
||||
|
||||
return image_filename
|
||||
|
||||
|
||||
class VisualQuestionAnswering:
|
||||
def __init__(self, device):
|
||||
print("Initializing VisualQuestionAnswering to %s" % device)
|
||||
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
|
||||
self.device = device
|
||||
self.processor = BlipProcessor.from_pretrained(
|
||||
"Salesforce/blip-vqa-base"
|
||||
)
|
||||
self.model = BlipForQuestionAnswering.from_pretrained(
|
||||
"Salesforce/blip-vqa-base", torch_dtype=self.torch_dtype
|
||||
).to(self.device)
|
||||
|
||||
@tool(
|
||||
name="Answer Question About The Image",
|
||||
description=(
|
||||
"useful when you need an answer for a question based on an image."
|
||||
" like: what is the background color of the last image, how many"
|
||||
" cats in this figure, what is in this figure. The input to this"
|
||||
" tool should be a comma separated string of two, representing the"
|
||||
" image_path and the question"
|
||||
),
|
||||
)
|
||||
def inference(self, inputs):
|
||||
image_path, question = inputs.split(",")
|
||||
raw_image = Image.open(image_path).convert("RGB")
|
||||
inputs = self.processor(raw_image, question, return_tensors="pt").to(
|
||||
self.device, self.torch_dtype
|
||||
)
|
||||
out = self.model.generate(**inputs)
|
||||
answer = self.processor.decode(out[0], skip_special_tokens=True)
|
||||
|
||||
logger.debug(
|
||||
f"\nProcessed VisualQuestionAnswering, Input Image: {image_path},"
|
||||
f" Input Question: {question}, Output Answer: {answer}"
|
||||
)
|
||||
|
||||
return answer
|
||||
|
||||
|
||||
class ImageCaptioning(BaseHandler):
|
||||
def __init__(self, device):
|
||||
print("Initializing ImageCaptioning to %s" % device)
|
||||
self.device = device
|
||||
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
|
||||
self.processor = BlipProcessor.from_pretrained(
|
||||
"Salesforce/blip-image-captioning-base"
|
||||
)
|
||||
self.model = BlipForConditionalGeneration.from_pretrained(
|
||||
"Salesforce/blip-image-captioning-base",
|
||||
torch_dtype=self.torch_dtype,
|
||||
).to(self.device)
|
||||
|
||||
def handle(self, filename: str):
|
||||
img = Image.open(filename)
|
||||
width, height = img.size
|
||||
ratio = min(512 / width, 512 / height)
|
||||
width_new, height_new = (round(width * ratio), round(height * ratio))
|
||||
img = img.resize((width_new, height_new))
|
||||
img = img.convert("RGB")
|
||||
img.save(filename, "PNG")
|
||||
print(f"Resize image form {width}x{height} to {width_new}x{height_new}")
|
||||
|
||||
inputs = self.processor(Image.open(filename), return_tensors="pt").to(
|
||||
self.device, self.torch_dtype
|
||||
)
|
||||
out = self.model.generate(**inputs)
|
||||
description = self.processor.decode(out[0], skip_special_tokens=True)
|
||||
print(
|
||||
f"\nProcessed ImageCaptioning, Input Image: {filename}, Output"
|
||||
f" Text: {description}"
|
||||
)
|
||||
|
||||
return IMAGE_PROMPT.format(filename=filename, description=description)
|
@ -0,0 +1,12 @@
|
||||
from concurrent import futures
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def execute_futures_dict(fs_dict: dict[str, futures.Future[T]]) -> dict[str, T]:
|
||||
futures.wait(
|
||||
fs_dict.values(), timeout=None, return_when=futures.ALL_COMPLETED
|
||||
)
|
||||
|
||||
return {key: future.result() for key, future in fs_dict.items()}
|
@ -0,0 +1,40 @@
|
||||
import pytest
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
|
||||
from swarms.agents.omni_modal_agent import (
|
||||
OmniModalAgent, # Replace `your_module_name` with the appropriate module name
|
||||
)
|
||||
|
||||
|
||||
# Mock objects or set up fixtures for dependent classes or external methods
|
||||
@pytest.fixture
|
||||
def mock_llm():
|
||||
# For this mock, we are assuming the BaseLanguageModel has a method named "process"
|
||||
class MockLLM(BaseLanguageModel):
|
||||
def process(self, input):
|
||||
return "mock response"
|
||||
|
||||
return MockLLM()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def omni_agent(mock_llm):
|
||||
return OmniModalAgent(mock_llm)
|
||||
|
||||
|
||||
def test_omnimodalagent_initialization(omni_agent):
|
||||
assert omni_agent.llm is not None, "LLM initialization failed"
|
||||
assert len(omni_agent.tools) > 0, "Tools initialization failed"
|
||||
|
||||
|
||||
def test_omnimodalagent_run(omni_agent):
|
||||
input_string = "Hello, how are you?"
|
||||
response = omni_agent.run(input_string)
|
||||
assert response is not None, "Response generation failed"
|
||||
assert isinstance(response, str), "Response should be a string"
|
||||
|
||||
|
||||
def test_task_executor_initialization(omni_agent):
|
||||
assert (
|
||||
omni_agent.task_executor is not None
|
||||
), "TaskExecutor initialization failed"
|
@ -0,0 +1,97 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from swarms.memory.ocean import OceanDB
|
||||
|
||||
|
||||
def test_init():
|
||||
with patch("oceandb.Client") as MockClient:
|
||||
MockClient.return_value.heartbeat.return_value = "OK"
|
||||
db = OceanDB(MockClient)
|
||||
MockClient.assert_called_once()
|
||||
assert db.client == MockClient
|
||||
|
||||
|
||||
def test_init_exception():
|
||||
with patch("oceandb.Client") as MockClient:
|
||||
MockClient.side_effect = Exception("Client error")
|
||||
with pytest.raises(Exception) as e:
|
||||
OceanDB(MockClient)
|
||||
assert str(e.value) == "Client error"
|
||||
|
||||
|
||||
def test_create_collection():
|
||||
with patch("oceandb.Client") as MockClient:
|
||||
db = OceanDB(MockClient)
|
||||
db.create_collection("test", "modality")
|
||||
MockClient.create_collection.assert_called_once_with(
|
||||
"test", embedding_function=Mock.ANY
|
||||
)
|
||||
|
||||
|
||||
def test_create_collection_exception():
|
||||
with patch("oceandb.Client") as MockClient:
|
||||
MockClient.create_collection.side_effect = Exception(
|
||||
"Create collection error"
|
||||
)
|
||||
db = OceanDB(MockClient)
|
||||
with pytest.raises(Exception) as e:
|
||||
db.create_collection("test", "modality")
|
||||
assert str(e.value) == "Create collection error"
|
||||
|
||||
|
||||
def test_append_document():
|
||||
with patch("oceandb.Client") as MockClient:
|
||||
db = OceanDB(MockClient)
|
||||
collection = Mock()
|
||||
db.append_document(collection, "doc", "id")
|
||||
collection.add.assert_called_once_with(documents=["doc"], ids=["id"])
|
||||
|
||||
|
||||
def test_append_document_exception():
|
||||
with patch("oceandb.Client") as MockClient:
|
||||
db = OceanDB(MockClient)
|
||||
collection = Mock()
|
||||
collection.add.side_effect = Exception("Append document error")
|
||||
with pytest.raises(Exception) as e:
|
||||
db.append_document(collection, "doc", "id")
|
||||
assert str(e.value) == "Append document error"
|
||||
|
||||
|
||||
def test_add_documents():
|
||||
with patch("oceandb.Client") as MockClient:
|
||||
db = OceanDB(MockClient)
|
||||
collection = Mock()
|
||||
db.add_documents(collection, ["doc1", "doc2"], ["id1", "id2"])
|
||||
collection.add.assert_called_once_with(
|
||||
documents=["doc1", "doc2"], ids=["id1", "id2"]
|
||||
)
|
||||
|
||||
|
||||
def test_add_documents_exception():
|
||||
with patch("oceandb.Client") as MockClient:
|
||||
db = OceanDB(MockClient)
|
||||
collection = Mock()
|
||||
collection.add.side_effect = Exception("Add documents error")
|
||||
with pytest.raises(Exception) as e:
|
||||
db.add_documents(collection, ["doc1", "doc2"], ["id1", "id2"])
|
||||
assert str(e.value) == "Add documents error"
|
||||
|
||||
|
||||
def test_query():
|
||||
with patch("oceandb.Client") as MockClient:
|
||||
db = OceanDB(MockClient)
|
||||
collection = Mock()
|
||||
db.query(collection, ["query1", "query2"], 2)
|
||||
collection.query.assert_called_once_with(
|
||||
query_texts=["query1", "query2"], n_results=2
|
||||
)
|
||||
|
||||
|
||||
def test_query_exception():
|
||||
with patch("oceandb.Client") as MockClient:
|
||||
db = OceanDB(MockClient)
|
||||
collection = Mock()
|
||||
collection.query.side_effect = Exception("Query error")
|
||||
with pytest.raises(Exception) as e:
|
||||
db.query(collection, ["query1", "query2"], 2)
|
||||
assert str(e.value) == "Query error"
|
@ -0,0 +1,61 @@
|
||||
import unittest
|
||||
import json
|
||||
import os
|
||||
|
||||
# Assuming the BingChat class is in a file named "bing_chat.py"
|
||||
from bing_chat import BingChat
|
||||
|
||||
|
||||
class TestBingChat(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Path to a mock cookies file for testing
|
||||
self.mock_cookies_path = "./mock_cookies.json"
|
||||
with open(self.mock_cookies_path, "w") as file:
|
||||
json.dump({"mock_cookie": "mock_value"}, file)
|
||||
|
||||
self.chat = BingChat(cookies_path=self.mock_cookies_path)
|
||||
|
||||
def tearDown(self):
|
||||
os.remove(self.mock_cookies_path)
|
||||
|
||||
def test_init(self):
|
||||
self.assertIsInstance(self.chat, BingChat)
|
||||
self.assertIsNotNone(self.chat.bot)
|
||||
|
||||
def test_call(self):
|
||||
# Mocking the asynchronous behavior for the purpose of the test
|
||||
self.chat.bot.ask = lambda *args, **kwargs: {"text": "Hello, Test!"}
|
||||
response = self.chat("Test prompt")
|
||||
self.assertEqual(response, "Hello, Test!")
|
||||
|
||||
def test_create_img(self):
|
||||
# Mocking the ImageGen behavior for the purpose of the test
|
||||
class MockImageGen:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def get_images(self, *args, **kwargs):
|
||||
return [{"path": "mock_image.png"}]
|
||||
|
||||
@staticmethod
|
||||
def save_images(*args, **kwargs):
|
||||
pass
|
||||
|
||||
original_image_gen = BingChat.ImageGen
|
||||
BingChat.ImageGen = MockImageGen
|
||||
|
||||
img_path = self.chat.create_img(
|
||||
"Test prompt", auth_cookie="mock_auth_cookie"
|
||||
)
|
||||
self.assertEqual(img_path, "./output/mock_image.png")
|
||||
|
||||
BingChat.ImageGen = original_image_gen
|
||||
|
||||
def test_set_cookie_dir_path(self):
|
||||
test_path = "./test_path"
|
||||
BingChat.set_cookie_dir_path(test_path)
|
||||
self.assertEqual(BingChat.Cookie.dir_path, test_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -0,0 +1,414 @@
|
||||
import logging
|
||||
import os
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from requests.exceptions import (
|
||||
ConnectionError,
|
||||
HTTPError,
|
||||
RequestException,
|
||||
Timeout,
|
||||
)
|
||||
|
||||
from swarms.models.gpt4v import GPT4Vision, GPT4VisionResponse
|
||||
|
||||
load_dotenv
|
||||
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
|
||||
# Mock the OpenAI client
|
||||
@pytest.fixture
|
||||
def mock_openai_client():
|
||||
return Mock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gpt4vision(mock_openai_client):
|
||||
return GPT4Vision(client=mock_openai_client)
|
||||
|
||||
|
||||
def test_gpt4vision_default_values():
|
||||
# Arrange and Act
|
||||
gpt4vision = GPT4Vision()
|
||||
|
||||
# Assert
|
||||
assert gpt4vision.max_retries == 3
|
||||
assert gpt4vision.model == "gpt-4-vision-preview"
|
||||
assert gpt4vision.backoff_factor == 2.0
|
||||
assert gpt4vision.timeout_seconds == 10
|
||||
assert gpt4vision.api_key is None
|
||||
assert gpt4vision.quality == "low"
|
||||
assert gpt4vision.max_tokens == 200
|
||||
|
||||
|
||||
def test_gpt4vision_api_key_from_env_variable():
|
||||
# Arrange
|
||||
api_key = os.environ["OPENAI_API_KEY"]
|
||||
|
||||
# Act
|
||||
gpt4vision = GPT4Vision()
|
||||
|
||||
# Assert
|
||||
assert gpt4vision.api_key == api_key
|
||||
|
||||
|
||||
def test_gpt4vision_set_api_key():
|
||||
# Arrange
|
||||
gpt4vision = GPT4Vision(api_key=api_key)
|
||||
|
||||
# Assert
|
||||
assert gpt4vision.api_key == api_key
|
||||
|
||||
|
||||
def test_gpt4vision_invalid_max_retries():
|
||||
# Arrange and Act
|
||||
with pytest.raises(ValueError):
|
||||
GPT4Vision(max_retries=-1)
|
||||
|
||||
|
||||
def test_gpt4vision_invalid_backoff_factor():
|
||||
# Arrange and Act
|
||||
with pytest.raises(ValueError):
|
||||
GPT4Vision(backoff_factor=-1)
|
||||
|
||||
|
||||
def test_gpt4vision_invalid_timeout_seconds():
|
||||
# Arrange and Act
|
||||
with pytest.raises(ValueError):
|
||||
GPT4Vision(timeout_seconds=-1)
|
||||
|
||||
|
||||
def test_gpt4vision_invalid_max_tokens():
|
||||
# Arrange and Act
|
||||
with pytest.raises(ValueError):
|
||||
GPT4Vision(max_tokens=-1)
|
||||
|
||||
|
||||
def test_gpt4vision_logger_initialized():
|
||||
# Arrange
|
||||
gpt4vision = GPT4Vision()
|
||||
|
||||
# Assert
|
||||
assert isinstance(gpt4vision.logger, logging.Logger)
|
||||
|
||||
|
||||
def test_gpt4vision_process_img_nonexistent_file():
|
||||
# Arrange
|
||||
gpt4vision = GPT4Vision()
|
||||
img_path = "nonexistent_image.jpg"
|
||||
|
||||
# Act and Assert
|
||||
with pytest.raises(FileNotFoundError):
|
||||
gpt4vision.process_img(img_path)
|
||||
|
||||
|
||||
def test_gpt4vision_call_single_task_single_image_no_openai_client(gpt4vision):
|
||||
# Arrange
|
||||
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||
task = "Describe this image."
|
||||
|
||||
# Act and Assert
|
||||
with pytest.raises(AttributeError):
|
||||
gpt4vision(img_url, [task])
|
||||
|
||||
|
||||
def test_gpt4vision_call_single_task_single_image_empty_response(
|
||||
gpt4vision, mock_openai_client
|
||||
):
|
||||
# Arrange
|
||||
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||
task = "Describe this image."
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value.choices = []
|
||||
|
||||
# Act
|
||||
response = gpt4vision(img_url, [task])
|
||||
|
||||
# Assert
|
||||
assert response.answer == ""
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
|
||||
|
||||
def test_gpt4vision_call_multiple_tasks_single_image_empty_responses(
|
||||
gpt4vision, mock_openai_client
|
||||
):
|
||||
# Arrange
|
||||
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||
tasks = ["Describe this image.", "What's in this picture?"]
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value.choices = []
|
||||
|
||||
# Act
|
||||
responses = gpt4vision(img_url, tasks)
|
||||
|
||||
# Assert
|
||||
assert all(response.answer == "" for response in responses)
|
||||
assert (
|
||||
mock_openai_client.chat.completions.create.call_count == 1
|
||||
) # Should be called only once
|
||||
|
||||
|
||||
def test_gpt4vision_call_single_task_single_image_timeout(
|
||||
gpt4vision, mock_openai_client
|
||||
):
|
||||
# Arrange
|
||||
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||
task = "Describe this image."
|
||||
|
||||
mock_openai_client.chat.completions.create.side_effect = Timeout(
|
||||
"Request timed out"
|
||||
)
|
||||
|
||||
# Act and Assert
|
||||
with pytest.raises(Timeout):
|
||||
gpt4vision(img_url, [task])
|
||||
|
||||
|
||||
def test_gpt4vision_call_retry_with_success_after_timeout(
|
||||
gpt4vision, mock_openai_client
|
||||
):
|
||||
# Arrange
|
||||
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||
task = "Describe this image."
|
||||
|
||||
# Simulate success after a timeout and retry
|
||||
mock_openai_client.chat.completions.create.side_effect = [
|
||||
Timeout("Request timed out"),
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": {"text": "A description of the image."}
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
# Act
|
||||
response = gpt4vision(img_url, [task])
|
||||
|
||||
# Assert
|
||||
assert response.answer == "A description of the image."
|
||||
assert (
|
||||
mock_openai_client.chat.completions.create.call_count == 2
|
||||
) # Should be called twice
|
||||
|
||||
|
||||
def test_gpt4vision_process_img():
|
||||
# Arrange
|
||||
img_path = "test_image.jpg"
|
||||
gpt4vision = GPT4Vision()
|
||||
|
||||
# Act
|
||||
img_data = gpt4vision.process_img(img_path)
|
||||
|
||||
# Assert
|
||||
assert img_data.startswith("/9j/") # Base64-encoded image data
|
||||
|
||||
|
||||
def test_gpt4vision_call_single_task_single_image(
|
||||
gpt4vision, mock_openai_client
|
||||
):
|
||||
# Arrange
|
||||
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||
task = "Describe this image."
|
||||
|
||||
expected_response = GPT4VisionResponse(answer="A description of the image.")
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value.choices[0].text = (
|
||||
expected_response.answer
|
||||
)
|
||||
|
||||
# Act
|
||||
response = gpt4vision(img_url, [task])
|
||||
|
||||
# Assert
|
||||
assert response == expected_response
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
|
||||
|
||||
def test_gpt4vision_call_single_task_multiple_images(
|
||||
gpt4vision, mock_openai_client
|
||||
):
|
||||
# Arrange
|
||||
img_urls = [
|
||||
"https://example.com/image1.jpg",
|
||||
"https://example.com/image2.jpg",
|
||||
]
|
||||
task = "Describe these images."
|
||||
|
||||
expected_response = GPT4VisionResponse(answer="Descriptions of the images.")
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value.choices[0].text = (
|
||||
expected_response.answer
|
||||
)
|
||||
|
||||
# Act
|
||||
response = gpt4vision(img_urls, [task])
|
||||
|
||||
# Assert
|
||||
assert response == expected_response
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
|
||||
|
||||
def test_gpt4vision_call_multiple_tasks_single_image(
|
||||
gpt4vision, mock_openai_client
|
||||
):
|
||||
# Arrange
|
||||
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||
tasks = ["Describe this image.", "What's in this picture?"]
|
||||
|
||||
expected_responses = [
|
||||
GPT4VisionResponse(answer="A description of the image."),
|
||||
GPT4VisionResponse(answer="It contains various objects."),
|
||||
]
|
||||
|
||||
def create_mock_response(response):
|
||||
return {
|
||||
"choices": [{"message": {"content": {"text": response.answer}}}]
|
||||
}
|
||||
|
||||
mock_openai_client.chat.completions.create.side_effect = [
|
||||
create_mock_response(response) for response in expected_responses
|
||||
]
|
||||
|
||||
# Act
|
||||
responses = gpt4vision(img_url, tasks)
|
||||
|
||||
# Assert
|
||||
assert responses == expected_responses
|
||||
assert (
|
||||
mock_openai_client.chat.completions.create.call_count == 1
|
||||
) # Should be called only once
|
||||
|
||||
def test_gpt4vision_call_multiple_tasks_single_image(
|
||||
gpt4vision, mock_openai_client
|
||||
):
|
||||
# Arrange
|
||||
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||
tasks = ["Describe this image.", "What's in this picture?"]
|
||||
|
||||
expected_responses = [
|
||||
GPT4VisionResponse(answer="A description of the image."),
|
||||
GPT4VisionResponse(answer="It contains various objects."),
|
||||
]
|
||||
|
||||
mock_openai_client.chat.completions.create.side_effect = [
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": {"text": expected_responses[i].answer}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
for i in range(len(expected_responses))
|
||||
]
|
||||
|
||||
# Act
|
||||
responses = gpt4vision(img_url, tasks)
|
||||
|
||||
# Assert
|
||||
assert responses == expected_responses
|
||||
assert (
|
||||
mock_openai_client.chat.completions.create.call_count == 1
|
||||
) # Should be called only once
|
||||
|
||||
|
||||
def test_gpt4vision_call_multiple_tasks_multiple_images(
|
||||
gpt4vision, mock_openai_client
|
||||
):
|
||||
# Arrange
|
||||
img_urls = [
|
||||
"https://images.unsplash.com/photo-1694734479857-626882b6db37?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
|
||||
"https://images.unsplash.com/photo-1694734479898-6ac4633158ac?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
|
||||
]
|
||||
tasks = ["Describe these images.", "What's in these pictures?"]
|
||||
|
||||
expected_responses = [
|
||||
GPT4VisionResponse(answer="Descriptions of the images."),
|
||||
GPT4VisionResponse(answer="They contain various objects."),
|
||||
]
|
||||
|
||||
mock_openai_client.chat.completions.create.side_effect = [
|
||||
{"choices": [{"message": {"content": {"text": response.answer}}}]}
|
||||
for response in expected_responses
|
||||
]
|
||||
|
||||
# Act
|
||||
responses = gpt4vision(img_urls, tasks)
|
||||
|
||||
# Assert
|
||||
assert responses == expected_responses
|
||||
assert (
|
||||
mock_openai_client.chat.completions.create.call_count == 1
|
||||
) # Should be called only once
|
||||
|
||||
|
||||
def test_gpt4vision_call_http_error(gpt4vision, mock_openai_client):
|
||||
# Arrange
|
||||
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||
task = "Describe this image."
|
||||
|
||||
mock_openai_client.chat.completions.create.side_effect = HTTPError(
|
||||
"HTTP Error"
|
||||
)
|
||||
|
||||
# Act and Assert
|
||||
with pytest.raises(HTTPError):
|
||||
gpt4vision(img_url, [task])
|
||||
|
||||
|
||||
def test_gpt4vision_call_request_error(gpt4vision, mock_openai_client):
|
||||
# Arrange
|
||||
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||
task = "Describe this image."
|
||||
|
||||
mock_openai_client.chat.completions.create.side_effect = RequestException(
|
||||
"Request Error"
|
||||
)
|
||||
|
||||
# Act and Assert
|
||||
with pytest.raises(RequestException):
|
||||
gpt4vision(img_url, [task])
|
||||
|
||||
|
||||
def test_gpt4vision_call_connection_error(gpt4vision, mock_openai_client):
|
||||
# Arrange
|
||||
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||
task = "Describe this image."
|
||||
|
||||
mock_openai_client.chat.completions.create.side_effect = ConnectionError(
|
||||
"Connection Error"
|
||||
)
|
||||
|
||||
# Act and Assert
|
||||
with pytest.raises(ConnectionError):
|
||||
gpt4vision(img_url, [task])
|
||||
|
||||
|
||||
def test_gpt4vision_call_retry_with_success(gpt4vision, mock_openai_client):
|
||||
# Arrange
|
||||
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||
task = "Describe this image."
|
||||
|
||||
# Simulate success after a retry
|
||||
mock_openai_client.chat.completions.create.side_effect = [
|
||||
RequestException("Temporary error"),
|
||||
{
|
||||
"choices": [{"text": "A description of the image."}]
|
||||
}, # fixed dictionary syntax
|
||||
]
|
||||
|
||||
# Act
|
||||
response = gpt4vision(img_url, [task])
|
||||
|
||||
# Assert
|
||||
assert response.answer == "A description of the image."
|
||||
assert (
|
||||
mock_openai_client.chat.completions.create.call_count == 2
|
||||
) # Should be called twice
|
@ -0,0 +1,90 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
from Sswarms.models.revgptv1 import RevChatGPTModelv1
|
||||
|
||||
|
||||
class TestRevChatGPT(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.access_token = "<your_access_token>"
|
||||
self.model = RevChatGPTModelv1(access_token=self.access_token)
|
||||
|
||||
def test_run(self):
|
||||
prompt = "What is the capital of France?"
|
||||
response = self.model.run(prompt)
|
||||
self.assertEqual(response, "The capital of France is Paris.")
|
||||
|
||||
def test_run_time(self):
|
||||
prompt = "Generate a 300 word essay about technology."
|
||||
self.model.run(prompt)
|
||||
self.assertLess(self.model.end_time - self.model.start_time, 60)
|
||||
|
||||
def test_generate_summary(self):
|
||||
text = (
|
||||
"This is a sample text to summarize. It has multiple sentences and"
|
||||
" details. The summary should be concise."
|
||||
)
|
||||
summary = self.model.generate_summary(text)
|
||||
self.assertLess(len(summary), len(text) / 2)
|
||||
|
||||
def test_enable_plugin(self):
|
||||
plugin_id = "some_plugin_id"
|
||||
self.model.enable_plugin(plugin_id)
|
||||
self.assertIn(plugin_id, self.model.config["plugin_ids"])
|
||||
|
||||
def test_list_plugins(self):
|
||||
plugins = self.model.list_plugins()
|
||||
self.assertGreater(len(plugins), 0)
|
||||
self.assertIsInstance(plugins[0], dict)
|
||||
self.assertIn("id", plugins[0])
|
||||
self.assertIn("name", plugins[0])
|
||||
|
||||
def test_get_conversations(self):
|
||||
conversations = self.model.chatbot.get_conversations()
|
||||
self.assertIsInstance(conversations, list)
|
||||
|
||||
@patch("RevChatGPTModelv1.Chatbot.get_msg_history")
|
||||
def test_get_msg_history(self, mock_get_msg_history):
|
||||
conversation_id = "convo_id"
|
||||
self.model.chatbot.get_msg_history(conversation_id)
|
||||
mock_get_msg_history.assert_called_with(conversation_id)
|
||||
|
||||
@patch("RevChatGPTModelv1.Chatbot.share_conversation")
|
||||
def test_share_conversation(self, mock_share_conversation):
|
||||
self.model.chatbot.share_conversation()
|
||||
mock_share_conversation.assert_called()
|
||||
|
||||
def test_gen_title(self):
|
||||
convo_id = "123"
|
||||
message_id = "456"
|
||||
title = self.model.chatbot.gen_title(convo_id, message_id)
|
||||
self.assertIsInstance(title, str)
|
||||
|
||||
def test_change_title(self):
|
||||
convo_id = "123"
|
||||
title = "New Title"
|
||||
self.model.chatbot.change_title(convo_id, title)
|
||||
self.assertEqual(
|
||||
self.model.chatbot.get_msg_history(convo_id)["title"], title
|
||||
)
|
||||
|
||||
def test_delete_conversation(self):
|
||||
convo_id = "123"
|
||||
self.model.chatbot.delete_conversation(convo_id)
|
||||
with self.assertRaises(Exception):
|
||||
self.model.chatbot.get_msg_history(convo_id)
|
||||
|
||||
def test_clear_conversations(self):
|
||||
self.model.chatbot.clear_conversations()
|
||||
conversations = self.model.chatbot.get_conversations()
|
||||
self.assertEqual(len(conversations), 0)
|
||||
|
||||
def test_rollback_conversation(self):
|
||||
original_convo_id = self.model.chatbot.conversation_id
|
||||
self.model.chatbot.rollback_conversation(1)
|
||||
self.assertNotEqual(
|
||||
original_convo_id, self.model.chatbot.conversation_id
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -0,0 +1,66 @@
|
||||
from unittest.mock import patch
|
||||
from swarms.swarms.multi_agent_debate import (
|
||||
MultiAgentDebate,
|
||||
Worker,
|
||||
select_speaker,
|
||||
)
|
||||
|
||||
|
||||
def test_multiagentdebate_initialization():
|
||||
multiagentdebate = MultiAgentDebate(
|
||||
agents=[Worker] * 5, selection_func=select_speaker
|
||||
)
|
||||
assert isinstance(multiagentdebate, MultiAgentDebate)
|
||||
assert len(multiagentdebate.agents) == 5
|
||||
assert multiagentdebate.selection_func == select_speaker
|
||||
|
||||
|
||||
@patch("swarms.workers.Worker.reset")
|
||||
def test_multiagentdebate_reset_agents(mock_reset):
|
||||
multiagentdebate = MultiAgentDebate(
|
||||
agents=[Worker] * 5, selection_func=select_speaker
|
||||
)
|
||||
multiagentdebate.reset_agents()
|
||||
assert mock_reset.call_count == 5
|
||||
|
||||
|
||||
def test_multiagentdebate_inject_agent():
|
||||
multiagentdebate = MultiAgentDebate(
|
||||
agents=[Worker] * 5, selection_func=select_speaker
|
||||
)
|
||||
multiagentdebate.inject_agent(Worker)
|
||||
assert len(multiagentdebate.agents) == 6
|
||||
|
||||
|
||||
@patch("swarms.workers.Worker.run")
|
||||
def test_multiagentdebate_run(mock_run):
|
||||
multiagentdebate = MultiAgentDebate(
|
||||
agents=[Worker] * 5, selection_func=select_speaker
|
||||
)
|
||||
results = multiagentdebate.run("Write a short story.")
|
||||
assert len(results) == 5
|
||||
assert mock_run.call_count == 5
|
||||
|
||||
|
||||
def test_multiagentdebate_update_task():
|
||||
multiagentdebate = MultiAgentDebate(
|
||||
agents=[Worker] * 5, selection_func=select_speaker
|
||||
)
|
||||
multiagentdebate.update_task("Write a short story.")
|
||||
assert multiagentdebate.task == "Write a short story."
|
||||
|
||||
|
||||
def test_multiagentdebate_format_results():
|
||||
multiagentdebate = MultiAgentDebate(
|
||||
agents=[Worker] * 5, selection_func=select_speaker
|
||||
)
|
||||
results = [
|
||||
{"agent": "Agent 1", "response": "Hello, world!"},
|
||||
{"agent": "Agent 2", "response": "Goodbye, world!"},
|
||||
]
|
||||
formatted_results = multiagentdebate.format_results(results)
|
||||
assert (
|
||||
formatted_results
|
||||
== "Agent Agent 1 responded: Hello, world!\nAgent Agent 2 responded:"
|
||||
" Goodbye, world!"
|
||||
)
|
Loading…
Reference in new issue