parent
02cecfc281
commit
498ff905b0
@ -0,0 +1,28 @@
|
|||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from swarms.models.revgptV4 import RevChatGPTModel
|
||||||
|
from swarms.workers.worker import Worker
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"model": os.getenv("REVGPT_MODEL"),
|
||||||
|
"plugin_ids": [os.getenv("REVGPT_PLUGIN_IDS")],
|
||||||
|
"disable_history": os.getenv("REVGPT_DISABLE_HISTORY") == "True",
|
||||||
|
"PUID": os.getenv("REVGPT_PUID"),
|
||||||
|
"unverified_plugin_domains": [
|
||||||
|
os.getenv("REVGPT_UNVERIFIED_PLUGIN_DOMAINS")
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
llm = RevChatGPTModel(access_token=os.getenv("ACCESS_TOKEN"), **config)
|
||||||
|
|
||||||
|
worker = Worker(ai_name="Optimus Prime", llm=llm)
|
||||||
|
|
||||||
|
task = (
|
||||||
|
"What were the winning boston marathon times for the past 5 years (ending"
|
||||||
|
" in 2022)? Generate a table of the year, name, country of origin, and"
|
||||||
|
" times."
|
||||||
|
)
|
||||||
|
response = worker.run(task)
|
||||||
|
print(response)
|
@ -0,0 +1,120 @@
|
|||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from swarms.models import Anthropic, OpenAIChat
|
||||||
|
from swarms.prompts.accountant_swarm_prompts import (
|
||||||
|
DECISION_MAKING_PROMPT,
|
||||||
|
DOC_ANALYZER_AGENT_PROMPT,
|
||||||
|
FRAUD_DETECTION_AGENT_PROMPT,
|
||||||
|
SUMMARY_GENERATOR_AGENT_PROMPT,
|
||||||
|
)
|
||||||
|
from swarms.structs import Flow
|
||||||
|
from swarms.utils.pdf_to_text import pdf_to_text
|
||||||
|
|
||||||
|
# Environment variables
|
||||||
|
load_dotenv()
|
||||||
|
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||||
|
openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
|
||||||
|
# Base llms
|
||||||
|
llm1 = OpenAIChat(
|
||||||
|
openai_api_key=openai_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm2 = Anthropic(
|
||||||
|
anthropic_api_key=anthropic_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Agents
|
||||||
|
doc_analyzer_agent = Flow(
|
||||||
|
llm=llm1,
|
||||||
|
sop=DOC_ANALYZER_AGENT_PROMPT,
|
||||||
|
)
|
||||||
|
summary_generator_agent = Flow(
|
||||||
|
llm=llm2,
|
||||||
|
sop=SUMMARY_GENERATOR_AGENT_PROMPT,
|
||||||
|
)
|
||||||
|
decision_making_support_agent = Flow(
|
||||||
|
llm=llm2,
|
||||||
|
sop=DECISION_MAKING_PROMPT,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AccountantSwarms:
|
||||||
|
"""
|
||||||
|
Accountant Swarms is a collection of agents that work together to help
|
||||||
|
accountants with their work.
|
||||||
|
|
||||||
|
Flow: analyze doc -> detect fraud -> generate summary -> decision making support
|
||||||
|
|
||||||
|
The agents are:
|
||||||
|
- User Consultant: Asks the user many questions
|
||||||
|
- Document Analyzer: Extracts text from the image of the financial document
|
||||||
|
- Fraud Detection: Detects fraud in the document
|
||||||
|
- Summary Agent: Generates an actionable summary of the document
|
||||||
|
- Decision Making Support: Provides decision making support to the accountant
|
||||||
|
|
||||||
|
The agents are connected together in a workflow that is defined in the
|
||||||
|
run method.
|
||||||
|
|
||||||
|
The workflow is as follows:
|
||||||
|
1. The Document Analyzer agent extracts text from the image of the
|
||||||
|
financial document.
|
||||||
|
2. The Fraud Detection agent detects fraud in the document.
|
||||||
|
3. The Summary Agent generates an actionable summary of the document.
|
||||||
|
4. The Decision Making Support agent provides decision making support
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pdf_path: str,
|
||||||
|
list_pdfs: List[str] = None,
|
||||||
|
fraud_detection_instructions: str = None,
|
||||||
|
summary_agent_instructions: str = None,
|
||||||
|
decision_making_support_agent_instructions: str = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.pdf_path = pdf_path
|
||||||
|
self.list_pdfs = list_pdfs
|
||||||
|
self.fraud_detection_instructions = fraud_detection_instructions
|
||||||
|
self.summary_agent_instructions = summary_agent_instructions
|
||||||
|
self.decision_making_support_agent_instructions = (
|
||||||
|
decision_making_support_agent_instructions
|
||||||
|
)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
# Transform the pdf to text
|
||||||
|
pdf_text = pdf_to_text(self.pdf_path)
|
||||||
|
|
||||||
|
# Detect fraud in the document
|
||||||
|
fraud_detection_agent_output = doc_analyzer_agent.run(
|
||||||
|
f"{self.fraud_detection_instructions}: {pdf_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate an actionable summary of the document
|
||||||
|
summary_agent_output = summary_generator_agent.run(
|
||||||
|
f"{self.summary_agent_instructions}: {fraud_detection_agent_output}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Provide decision making support to the accountant
|
||||||
|
decision_making_support_agent_output = decision_making_support_agent.run(
|
||||||
|
f"{self.decision_making_support_agent_instructions}:"
|
||||||
|
f" {summary_agent_output}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return decision_making_support_agent_output
|
||||||
|
|
||||||
|
|
||||||
|
swarm = AccountantSwarms(
|
||||||
|
pdf_path="tesla.pdf",
|
||||||
|
fraud_detection_instructions="Detect fraud in the document",
|
||||||
|
summary_agent_instructions="Generate an actionable summary of the document",
|
||||||
|
decision_making_support_agent_instructions=(
|
||||||
|
"Provide decision making support to the business owner:"
|
||||||
|
),
|
||||||
|
)
|
@ -0,0 +1,65 @@
|
|||||||
|
from swarms.models import Anthropic
|
||||||
|
from swarms.structs import Flow
|
||||||
|
from swarms.tools.tool import tool
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
llm = Anthropic(
|
||||||
|
anthropic_api_key="",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_load_playwright(url: str) -> str:
|
||||||
|
"""Load the specified URLs using Playwright and parse using BeautifulSoup."""
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
from playwright.async_api import async_playwright
|
||||||
|
|
||||||
|
results = ""
|
||||||
|
async with async_playwright() as p:
|
||||||
|
browser = await p.chromium.launch(headless=True)
|
||||||
|
try:
|
||||||
|
page = await browser.new_page()
|
||||||
|
await page.goto(url)
|
||||||
|
|
||||||
|
page_source = await page.content()
|
||||||
|
soup = BeautifulSoup(page_source, "html.parser")
|
||||||
|
|
||||||
|
for script in soup(["script", "style"]):
|
||||||
|
script.extract()
|
||||||
|
|
||||||
|
text = soup.get_text()
|
||||||
|
lines = (line.strip() for line in text.splitlines())
|
||||||
|
chunks = (
|
||||||
|
phrase.strip() for line in lines for phrase in line.split(" ")
|
||||||
|
)
|
||||||
|
results = "\n".join(chunk for chunk in chunks if chunk)
|
||||||
|
except Exception as e:
|
||||||
|
results = f"Error: {e}"
|
||||||
|
await browser.close()
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def run_async(coro):
|
||||||
|
event_loop = asyncio.get_event_loop()
|
||||||
|
return event_loop.run_until_complete(coro)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def browse_web_page(url: str) -> str:
|
||||||
|
"""Verbose way to scrape a whole webpage. Likely to cause issues parsing."""
|
||||||
|
return run_async(async_load_playwright(url))
|
||||||
|
|
||||||
|
|
||||||
|
## Initialize the workflow
|
||||||
|
flow = Flow(
|
||||||
|
llm=llm,
|
||||||
|
max_loops=5,
|
||||||
|
tools=[browse_web_page],
|
||||||
|
dashboard=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
out = flow.run(
|
||||||
|
"Generate a 10,000 word blog on mental clarity and the benefits of"
|
||||||
|
" meditation."
|
||||||
|
)
|
@ -0,0 +1,123 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from concurrent import futures
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Any
|
||||||
|
from attr import define, field, Factory
|
||||||
|
from swarms.utils.futures import execute_futures_dict
|
||||||
|
from griptape.artifacts import TextArtifact
|
||||||
|
|
||||||
|
|
||||||
|
@define
|
||||||
|
class BaseVectorStore(ABC):
|
||||||
|
""" """
|
||||||
|
|
||||||
|
DEFAULT_QUERY_COUNT = 5
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QueryResult:
|
||||||
|
id: str
|
||||||
|
vector: list[float]
|
||||||
|
score: float
|
||||||
|
meta: Optional[dict] = None
|
||||||
|
namespace: Optional[str] = None
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Entry:
|
||||||
|
id: str
|
||||||
|
vector: list[float]
|
||||||
|
meta: Optional[dict] = None
|
||||||
|
namespace: Optional[str] = None
|
||||||
|
|
||||||
|
embedding_driver: Any
|
||||||
|
futures_executor: futures.Executor = field(
|
||||||
|
default=Factory(lambda: futures.ThreadPoolExecutor()), kw_only=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def upsert_text_artifacts(
|
||||||
|
self,
|
||||||
|
artifacts: dict[str, list[TextArtifact]],
|
||||||
|
meta: Optional[dict] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
execute_futures_dict(
|
||||||
|
{
|
||||||
|
namespace: self.futures_executor.submit(
|
||||||
|
self.upsert_text_artifact, a, namespace, meta, **kwargs
|
||||||
|
)
|
||||||
|
for namespace, artifact_list in artifacts.items()
|
||||||
|
for a in artifact_list
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def upsert_text_artifact(
|
||||||
|
self,
|
||||||
|
artifact: TextArtifact,
|
||||||
|
namespace: Optional[str] = None,
|
||||||
|
meta: Optional[dict] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
if not meta:
|
||||||
|
meta = {}
|
||||||
|
|
||||||
|
meta["artifact"] = artifact.to_json()
|
||||||
|
|
||||||
|
if artifact.embedding:
|
||||||
|
vector = artifact.embedding
|
||||||
|
else:
|
||||||
|
vector = artifact.generate_embedding(self.embedding_driver)
|
||||||
|
|
||||||
|
return self.upsert_vector(
|
||||||
|
vector,
|
||||||
|
vector_id=artifact.id,
|
||||||
|
namespace=namespace,
|
||||||
|
meta=meta,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def upsert_text(
|
||||||
|
self,
|
||||||
|
string: str,
|
||||||
|
vector_id: Optional[str] = None,
|
||||||
|
namespace: Optional[str] = None,
|
||||||
|
meta: Optional[dict] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
return self.upsert_vector(
|
||||||
|
self.embedding_driver.embed_string(string),
|
||||||
|
vector_id=vector_id,
|
||||||
|
namespace=namespace,
|
||||||
|
meta=meta if meta else {},
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def upsert_vector(
|
||||||
|
self,
|
||||||
|
vector: list[float],
|
||||||
|
vector_id: Optional[str] = None,
|
||||||
|
namespace: Optional[str] = None,
|
||||||
|
meta: Optional[dict] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_entry(
|
||||||
|
self, vector_id: str, namespace: Optional[str] = None
|
||||||
|
) -> Entry:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_entries(self, namespace: Optional[str] = None) -> list[Entry]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def query(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
count: Optional[int] = None,
|
||||||
|
namespace: Optional[str] = None,
|
||||||
|
include_vectors: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> list[QueryResult]:
|
||||||
|
...
|
@ -0,0 +1,723 @@
|
|||||||
|
<<<<<<< HEAD
|
||||||
|
=======
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from swarms.structs.document import Document
|
||||||
|
from swarms.models.embeddings_base import Embeddings
|
||||||
|
from langchain.schema.vectorstore import VectorStore
|
||||||
|
from langchain.utils import xor_args
|
||||||
|
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import chromadb
|
||||||
|
import chromadb.config
|
||||||
|
from chromadb.api.types import ID, OneOrMany, Where, WhereDocument
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
DEFAULT_K = 4 # Number of Documents to return.
|
||||||
|
|
||||||
|
|
||||||
|
def _results_to_docs(results: Any) -> List[Document]:
|
||||||
|
return [doc for doc, _ in _results_to_docs_and_scores(results)]
|
||||||
|
|
||||||
|
|
||||||
|
def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
|
||||||
|
return [
|
||||||
|
# TODO: Chroma can do batch querying,
|
||||||
|
# we shouldn't hard code to the 1st result
|
||||||
|
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
|
||||||
|
for result in zip(
|
||||||
|
results["documents"][0],
|
||||||
|
results["metadatas"][0],
|
||||||
|
results["distances"][0],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Chroma(VectorStore):
|
||||||
|
"""`ChromaDB` vector store.
|
||||||
|
|
||||||
|
To use, you should have the ``chromadb`` python package installed.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.vectorstores import Chroma
|
||||||
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||||
|
|
||||||
|
embeddings = OpenAIEmbeddings()
|
||||||
|
vectorstore = Chroma("langchain_store", embeddings)
|
||||||
|
"""
|
||||||
|
|
||||||
|
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||||
|
embedding_function: Optional[Embeddings] = None,
|
||||||
|
persist_directory: Optional[str] = None,
|
||||||
|
client_settings: Optional[chromadb.config.Settings] = None,
|
||||||
|
collection_metadata: Optional[Dict] = None,
|
||||||
|
client: Optional[chromadb.Client] = None,
|
||||||
|
relevance_score_fn: Optional[Callable[[float], float]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize with a Chroma client."""
|
||||||
|
try:
|
||||||
|
import chromadb
|
||||||
|
import chromadb.config
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import chromadb python package. "
|
||||||
|
"Please install it with `pip install chromadb`."
|
||||||
|
)
|
||||||
|
|
||||||
|
if client is not None:
|
||||||
|
self._client_settings = client_settings
|
||||||
|
self._client = client
|
||||||
|
self._persist_directory = persist_directory
|
||||||
|
else:
|
||||||
|
if client_settings:
|
||||||
|
# If client_settings is provided with persist_directory specified,
|
||||||
|
# then it is "in-memory and persisting to disk" mode.
|
||||||
|
client_settings.persist_directory = (
|
||||||
|
persist_directory or client_settings.persist_directory
|
||||||
|
)
|
||||||
|
if client_settings.persist_directory is not None:
|
||||||
|
# Maintain backwards compatibility with chromadb < 0.4.0
|
||||||
|
major, minor, _ = chromadb.__version__.split(".")
|
||||||
|
if int(major) == 0 and int(minor) < 4:
|
||||||
|
client_settings.chroma_db_impl = "duckdb+parquet"
|
||||||
|
|
||||||
|
_client_settings = client_settings
|
||||||
|
elif persist_directory:
|
||||||
|
# Maintain backwards compatibility with chromadb < 0.4.0
|
||||||
|
major, minor, _ = chromadb.__version__.split(".")
|
||||||
|
if int(major) == 0 and int(minor) < 4:
|
||||||
|
_client_settings = chromadb.config.Settings(
|
||||||
|
chroma_db_impl="duckdb+parquet",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_client_settings = chromadb.config.Settings(
|
||||||
|
is_persistent=True
|
||||||
|
)
|
||||||
|
_client_settings.persist_directory = persist_directory
|
||||||
|
else:
|
||||||
|
_client_settings = chromadb.config.Settings()
|
||||||
|
self._client_settings = _client_settings
|
||||||
|
self._client = chromadb.Client(_client_settings)
|
||||||
|
self._persist_directory = (
|
||||||
|
_client_settings.persist_directory or persist_directory
|
||||||
|
)
|
||||||
|
|
||||||
|
self._embedding_function = embedding_function
|
||||||
|
self._collection = self._client.get_or_create_collection(
|
||||||
|
name=collection_name,
|
||||||
|
embedding_function=(
|
||||||
|
self._embedding_function.embed_documents
|
||||||
|
if self._embedding_function is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
metadata=collection_metadata,
|
||||||
|
)
|
||||||
|
self.override_relevance_score_fn = relevance_score_fn
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embeddings(self) -> Optional[Embeddings]:
|
||||||
|
return self._embedding_function
|
||||||
|
|
||||||
|
@xor_args(("query_texts", "query_embeddings"))
|
||||||
|
def __query_collection(
|
||||||
|
self,
|
||||||
|
query_texts: Optional[List[str]] = None,
|
||||||
|
query_embeddings: Optional[List[List[float]]] = None,
|
||||||
|
n_results: int = 4,
|
||||||
|
where: Optional[Dict[str, str]] = None,
|
||||||
|
where_document: Optional[Dict[str, str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Query the chroma collection."""
|
||||||
|
try:
|
||||||
|
import chromadb # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import chromadb python package. "
|
||||||
|
"Please install it with `pip install chromadb`."
|
||||||
|
)
|
||||||
|
return self._collection.query(
|
||||||
|
query_texts=query_texts,
|
||||||
|
query_embeddings=query_embeddings,
|
||||||
|
n_results=n_results,
|
||||||
|
where=where,
|
||||||
|
where_document=where_document,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_texts(
|
||||||
|
self,
|
||||||
|
texts: Iterable[str],
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Run more texts through the embeddings and add to the vectorstore.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts (Iterable[str]): Texts to add to the vectorstore.
|
||||||
|
metadatas (Optional[List[dict]], optional): Optional list of metadatas.
|
||||||
|
ids (Optional[List[str]], optional): Optional list of IDs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: List of IDs of the added texts.
|
||||||
|
"""
|
||||||
|
# TODO: Handle the case where the user doesn't provide ids on the Collection
|
||||||
|
if ids is None:
|
||||||
|
ids = [str(uuid.uuid1()) for _ in texts]
|
||||||
|
embeddings = None
|
||||||
|
texts = list(texts)
|
||||||
|
if self._embedding_function is not None:
|
||||||
|
embeddings = self._embedding_function.embed_documents(texts)
|
||||||
|
if metadatas:
|
||||||
|
# fill metadatas with empty dicts if somebody
|
||||||
|
# did not specify metadata for all texts
|
||||||
|
length_diff = len(texts) - len(metadatas)
|
||||||
|
if length_diff:
|
||||||
|
metadatas = metadatas + [{}] * length_diff
|
||||||
|
empty_ids = []
|
||||||
|
non_empty_ids = []
|
||||||
|
for idx, m in enumerate(metadatas):
|
||||||
|
if m:
|
||||||
|
non_empty_ids.append(idx)
|
||||||
|
else:
|
||||||
|
empty_ids.append(idx)
|
||||||
|
if non_empty_ids:
|
||||||
|
metadatas = [metadatas[idx] for idx in non_empty_ids]
|
||||||
|
texts_with_metadatas = [texts[idx] for idx in non_empty_ids]
|
||||||
|
embeddings_with_metadatas = (
|
||||||
|
[embeddings[idx] for idx in non_empty_ids]
|
||||||
|
if embeddings
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
ids_with_metadata = [ids[idx] for idx in non_empty_ids]
|
||||||
|
try:
|
||||||
|
self._collection.upsert(
|
||||||
|
metadatas=metadatas,
|
||||||
|
embeddings=embeddings_with_metadatas,
|
||||||
|
documents=texts_with_metadatas,
|
||||||
|
ids=ids_with_metadata,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
if "Expected metadata value to be" in str(e):
|
||||||
|
msg = (
|
||||||
|
"Try filtering complex metadata from the document"
|
||||||
|
" using "
|
||||||
|
"langchain.vectorstores.utils.filter_complex_metadata."
|
||||||
|
)
|
||||||
|
raise ValueError(e.args[0] + "\n\n" + msg)
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
if empty_ids:
|
||||||
|
texts_without_metadatas = [texts[j] for j in empty_ids]
|
||||||
|
embeddings_without_metadatas = (
|
||||||
|
[embeddings[j] for j in empty_ids] if embeddings else None
|
||||||
|
)
|
||||||
|
ids_without_metadatas = [ids[j] for j in empty_ids]
|
||||||
|
self._collection.upsert(
|
||||||
|
embeddings=embeddings_without_metadatas,
|
||||||
|
documents=texts_without_metadatas,
|
||||||
|
ids=ids_without_metadatas,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._collection.upsert(
|
||||||
|
embeddings=embeddings,
|
||||||
|
documents=texts,
|
||||||
|
ids=ids,
|
||||||
|
)
|
||||||
|
return ids
|
||||||
|
|
||||||
|
def similarity_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = DEFAULT_K,
|
||||||
|
filter: Optional[Dict[str, str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Run similarity search with Chroma.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): Query text to search for.
|
||||||
|
k (int): Number of results to return. Defaults to 4.
|
||||||
|
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Document]: List of documents most similar to the query text.
|
||||||
|
"""
|
||||||
|
docs_and_scores = self.similarity_search_with_score(
|
||||||
|
query, k, filter=filter
|
||||||
|
)
|
||||||
|
return [doc for doc, _ in docs_and_scores]
|
||||||
|
|
||||||
|
def similarity_search_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
k: int = DEFAULT_K,
|
||||||
|
filter: Optional[Dict[str, str]] = None,
|
||||||
|
where_document: Optional[Dict[str, str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return docs most similar to embedding vector.
|
||||||
|
Args:
|
||||||
|
embedding (List[float]): Embedding to look up documents similar to.
|
||||||
|
k (int): Number of Documents to return. Defaults to 4.
|
||||||
|
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||||
|
Returns:
|
||||||
|
List of Documents most similar to the query vector.
|
||||||
|
"""
|
||||||
|
results = self.__query_collection(
|
||||||
|
query_embeddings=embedding,
|
||||||
|
n_results=k,
|
||||||
|
where=filter,
|
||||||
|
where_document=where_document,
|
||||||
|
)
|
||||||
|
return _results_to_docs(results)
|
||||||
|
|
||||||
|
def similarity_search_by_vector_with_relevance_scores(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
k: int = DEFAULT_K,
|
||||||
|
filter: Optional[Dict[str, str]] = None,
|
||||||
|
where_document: Optional[Dict[str, str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
"""
|
||||||
|
Return docs most similar to embedding vector and similarity score.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding (List[float]): Embedding to look up documents similar to.
|
||||||
|
k (int): Number of Documents to return. Defaults to 4.
|
||||||
|
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Tuple[Document, float]]: List of documents most similar to
|
||||||
|
the query text and cosine distance in float for each.
|
||||||
|
Lower score represents more similarity.
|
||||||
|
"""
|
||||||
|
results = self.__query_collection(
|
||||||
|
query_embeddings=embedding,
|
||||||
|
n_results=k,
|
||||||
|
where=filter,
|
||||||
|
where_document=where_document,
|
||||||
|
)
|
||||||
|
return _results_to_docs_and_scores(results)
|
||||||
|
|
||||||
|
def similarity_search_with_score(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = DEFAULT_K,
|
||||||
|
filter: Optional[Dict[str, str]] = None,
|
||||||
|
where_document: Optional[Dict[str, str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
"""Run similarity search with Chroma with distance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): Query text to search for.
|
||||||
|
k (int): Number of results to return. Defaults to 4.
|
||||||
|
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Tuple[Document, float]]: List of documents most similar to
|
||||||
|
the query text and cosine distance in float for each.
|
||||||
|
Lower score represents more similarity.
|
||||||
|
"""
|
||||||
|
if self._embedding_function is None:
|
||||||
|
results = self.__query_collection(
|
||||||
|
query_texts=[query],
|
||||||
|
n_results=k,
|
||||||
|
where=filter,
|
||||||
|
where_document=where_document,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query_embedding = self._embedding_function.embed_query(query)
|
||||||
|
results = self.__query_collection(
|
||||||
|
query_embeddings=[query_embedding],
|
||||||
|
n_results=k,
|
||||||
|
where=filter,
|
||||||
|
where_document=where_document,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _results_to_docs_and_scores(results)
|
||||||
|
|
||||||
|
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||||
|
"""
|
||||||
|
The 'correct' relevance function
|
||||||
|
may differ depending on a few things, including:
|
||||||
|
- the distance / similarity metric used by the VectorStore
|
||||||
|
- the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
|
||||||
|
- embedding dimensionality
|
||||||
|
- etc.
|
||||||
|
"""
|
||||||
|
if self.override_relevance_score_fn:
|
||||||
|
return self.override_relevance_score_fn
|
||||||
|
|
||||||
|
distance = "l2"
|
||||||
|
distance_key = "hnsw:space"
|
||||||
|
metadata = self._collection.metadata
|
||||||
|
|
||||||
|
if metadata and distance_key in metadata:
|
||||||
|
distance = metadata[distance_key]
|
||||||
|
|
||||||
|
if distance == "cosine":
|
||||||
|
return self._cosine_relevance_score_fn
|
||||||
|
elif distance == "l2":
|
||||||
|
return self._euclidean_relevance_score_fn
|
||||||
|
elif distance == "ip":
|
||||||
|
return self._max_inner_product_relevance_score_fn
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"No supported normalization function"
|
||||||
|
f" for distance metric of type: {distance}."
|
||||||
|
"Consider providing relevance_score_fn to Chroma constructor."
|
||||||
|
)
|
||||||
|
|
||||||
|
def max_marginal_relevance_search_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
k: int = DEFAULT_K,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
filter: Optional[Dict[str, str]] = None,
|
||||||
|
where_document: Optional[Dict[str, str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return docs selected using the maximal marginal relevance.
|
||||||
|
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||||
|
among selected documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding: Embedding to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||||
|
lambda_mult: Number between 0 and 1 that determines the degree
|
||||||
|
of diversity among the results with 0 corresponding
|
||||||
|
to maximum diversity and 1 to minimum diversity.
|
||||||
|
Defaults to 0.5.
|
||||||
|
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Documents selected by maximal marginal relevance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
results = self.__query_collection(
|
||||||
|
query_embeddings=embedding,
|
||||||
|
n_results=fetch_k,
|
||||||
|
where=filter,
|
||||||
|
where_document=where_document,
|
||||||
|
include=["metadatas", "documents", "distances", "embeddings"],
|
||||||
|
)
|
||||||
|
mmr_selected = maximal_marginal_relevance(
|
||||||
|
np.array(embedding, dtype=np.float32),
|
||||||
|
results["embeddings"][0],
|
||||||
|
k=k,
|
||||||
|
lambda_mult=lambda_mult,
|
||||||
|
)
|
||||||
|
|
||||||
|
candidates = _results_to_docs(results)
|
||||||
|
|
||||||
|
selected_results = [
|
||||||
|
r for i, r in enumerate(candidates) if i in mmr_selected
|
||||||
|
]
|
||||||
|
return selected_results
|
||||||
|
|
||||||
|
def max_marginal_relevance_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = DEFAULT_K,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
filter: Optional[Dict[str, str]] = None,
|
||||||
|
where_document: Optional[Dict[str, str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return docs selected using the maximal marginal relevance.
|
||||||
|
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||||
|
among selected documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Text to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||||
|
lambda_mult: Number between 0 and 1 that determines the degree
|
||||||
|
of diversity among the results with 0 corresponding
|
||||||
|
to maximum diversity and 1 to minimum diversity.
|
||||||
|
Defaults to 0.5.
|
||||||
|
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Documents selected by maximal marginal relevance.
|
||||||
|
"""
|
||||||
|
if self._embedding_function is None:
|
||||||
|
raise ValueError(
|
||||||
|
"For MMR search, you must specify an embedding function"
|
||||||
|
" oncreation."
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding = self._embedding_function.embed_query(query)
|
||||||
|
docs = self.max_marginal_relevance_search_by_vector(
|
||||||
|
embedding,
|
||||||
|
k,
|
||||||
|
fetch_k,
|
||||||
|
lambda_mult=lambda_mult,
|
||||||
|
filter=filter,
|
||||||
|
where_document=where_document,
|
||||||
|
)
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def delete_collection(self) -> None:
|
||||||
|
"""Delete the collection."""
|
||||||
|
self._client.delete_collection(self._collection.name)
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self,
|
||||||
|
ids: Optional[OneOrMany[ID]] = None,
|
||||||
|
where: Optional[Where] = None,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
offset: Optional[int] = None,
|
||||||
|
where_document: Optional[WhereDocument] = None,
|
||||||
|
include: Optional[List[str]] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Gets the collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids: The ids of the embeddings to get. Optional.
|
||||||
|
where: A Where type dict used to filter results by.
|
||||||
|
E.g. `{"color" : "red", "price": 4.20}`. Optional.
|
||||||
|
limit: The number of documents to return. Optional.
|
||||||
|
offset: The offset to start returning results from.
|
||||||
|
Useful for paging results with limit. Optional.
|
||||||
|
where_document: A WhereDocument type dict used to filter by the documents.
|
||||||
|
E.g. `{$contains: "hello"}`. Optional.
|
||||||
|
include: A list of what to include in the results.
|
||||||
|
Can contain `"embeddings"`, `"metadatas"`, `"documents"`.
|
||||||
|
Ids are always included.
|
||||||
|
Defaults to `["metadatas", "documents"]`. Optional.
|
||||||
|
"""
|
||||||
|
kwargs = {
|
||||||
|
"ids": ids,
|
||||||
|
"where": where,
|
||||||
|
"limit": limit,
|
||||||
|
"offset": offset,
|
||||||
|
"where_document": where_document,
|
||||||
|
}
|
||||||
|
|
||||||
|
if include is not None:
|
||||||
|
kwargs["include"] = include
|
||||||
|
|
||||||
|
return self._collection.get(**kwargs)
|
||||||
|
|
||||||
|
def persist(self) -> None:
|
||||||
|
"""Persist the collection.
|
||||||
|
|
||||||
|
This can be used to explicitly persist the data to disk.
|
||||||
|
It will also be called automatically when the object is destroyed.
|
||||||
|
"""
|
||||||
|
if self._persist_directory is None:
|
||||||
|
raise ValueError(
|
||||||
|
"You must specify a persist_directory on"
|
||||||
|
"creation to persist the collection."
|
||||||
|
)
|
||||||
|
import chromadb
|
||||||
|
|
||||||
|
# Maintain backwards compatibility with chromadb < 0.4.0
|
||||||
|
major, minor, _ = chromadb.__version__.split(".")
|
||||||
|
if int(major) == 0 and int(minor) < 4:
|
||||||
|
self._client.persist()
|
||||||
|
|
||||||
|
def update_document(self, document_id: str, document: Document) -> None:
|
||||||
|
"""Update a document in the collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id (str): ID of the document to update.
|
||||||
|
document (Document): Document to update.
|
||||||
|
"""
|
||||||
|
return self.update_documents([document_id], [document])
|
||||||
|
|
||||||
|
def update_documents(
|
||||||
|
self, ids: List[str], documents: List[Document]
|
||||||
|
) -> None:
|
||||||
|
"""Update a document in the collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids (List[str]): List of ids of the document to update.
|
||||||
|
documents (List[Document]): List of documents to update.
|
||||||
|
"""
|
||||||
|
text = [document.page_content for document in documents]
|
||||||
|
metadata = [document.metadata for document in documents]
|
||||||
|
if self._embedding_function is None:
|
||||||
|
raise ValueError(
|
||||||
|
"For update, you must specify an embedding function on"
|
||||||
|
" creation."
|
||||||
|
)
|
||||||
|
embeddings = self._embedding_function.embed_documents(text)
|
||||||
|
|
||||||
|
if hasattr(
|
||||||
|
self._collection._client, "max_batch_size"
|
||||||
|
): # for Chroma 0.4.10 and above
|
||||||
|
from chromadb.utils.batch_utils import create_batches
|
||||||
|
|
||||||
|
for batch in create_batches(
|
||||||
|
api=self._collection._client,
|
||||||
|
ids=ids,
|
||||||
|
metadatas=metadata,
|
||||||
|
documents=text,
|
||||||
|
embeddings=embeddings,
|
||||||
|
):
|
||||||
|
self._collection.update(
|
||||||
|
ids=batch[0],
|
||||||
|
embeddings=batch[1],
|
||||||
|
documents=batch[3],
|
||||||
|
metadatas=batch[2],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._collection.update(
|
||||||
|
ids=ids,
|
||||||
|
embeddings=embeddings,
|
||||||
|
documents=text,
|
||||||
|
metadatas=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_texts(
|
||||||
|
cls: Type[Chroma],
|
||||||
|
texts: List[str],
|
||||||
|
embedding: Optional[Embeddings] = None,
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||||
|
persist_directory: Optional[str] = None,
|
||||||
|
client_settings: Optional[chromadb.config.Settings] = None,
|
||||||
|
client: Optional[chromadb.Client] = None,
|
||||||
|
collection_metadata: Optional[Dict] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Chroma:
|
||||||
|
"""Create a Chroma vectorstore from a raw documents.
|
||||||
|
|
||||||
|
If a persist_directory is specified, the collection will be persisted there.
|
||||||
|
Otherwise, the data will be ephemeral in-memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts (List[str]): List of texts to add to the collection.
|
||||||
|
collection_name (str): Name of the collection to create.
|
||||||
|
persist_directory (Optional[str]): Directory to persist the collection.
|
||||||
|
embedding (Optional[Embeddings]): Embedding function. Defaults to None.
|
||||||
|
metadatas (Optional[List[dict]]): List of metadatas. Defaults to None.
|
||||||
|
ids (Optional[List[str]]): List of document IDs. Defaults to None.
|
||||||
|
client_settings (Optional[chromadb.config.Settings]): Chroma client settings
|
||||||
|
collection_metadata (Optional[Dict]): Collection configurations.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Chroma: Chroma vectorstore.
|
||||||
|
"""
|
||||||
|
chroma_collection = cls(
|
||||||
|
collection_name=collection_name,
|
||||||
|
embedding_function=embedding,
|
||||||
|
persist_directory=persist_directory,
|
||||||
|
client_settings=client_settings,
|
||||||
|
client=client,
|
||||||
|
collection_metadata=collection_metadata,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
if ids is None:
|
||||||
|
ids = [str(uuid.uuid1()) for _ in texts]
|
||||||
|
if hasattr(
|
||||||
|
chroma_collection._client, "max_batch_size"
|
||||||
|
): # for Chroma 0.4.10 and above
|
||||||
|
from chromadb.utils.batch_utils import create_batches
|
||||||
|
|
||||||
|
for batch in create_batches(
|
||||||
|
api=chroma_collection._client,
|
||||||
|
ids=ids,
|
||||||
|
metadatas=metadatas,
|
||||||
|
documents=texts,
|
||||||
|
):
|
||||||
|
chroma_collection.add_texts(
|
||||||
|
texts=batch[3] if batch[3] else [],
|
||||||
|
metadatas=batch[2] if batch[2] else None,
|
||||||
|
ids=batch[0],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chroma_collection.add_texts(
|
||||||
|
texts=texts, metadatas=metadatas, ids=ids
|
||||||
|
)
|
||||||
|
return chroma_collection
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_documents(
|
||||||
|
cls: Type[Chroma],
|
||||||
|
documents: List[Document],
|
||||||
|
embedding: Optional[Embeddings] = None,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||||
|
persist_directory: Optional[str] = None,
|
||||||
|
client_settings: Optional[chromadb.config.Settings] = None,
|
||||||
|
client: Optional[chromadb.Client] = None, # Add this line
|
||||||
|
collection_metadata: Optional[Dict] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Chroma:
|
||||||
|
"""Create a Chroma vectorstore from a list of documents.
|
||||||
|
|
||||||
|
If a persist_directory is specified, the collection will be persisted there.
|
||||||
|
Otherwise, the data will be ephemeral in-memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name (str): Name of the collection to create.
|
||||||
|
persist_directory (Optional[str]): Directory to persist the collection.
|
||||||
|
ids (Optional[List[str]]): List of document IDs. Defaults to None.
|
||||||
|
documents (List[Document]): List of documents to add to the vectorstore.
|
||||||
|
embedding (Optional[Embeddings]): Embedding function. Defaults to None.
|
||||||
|
client_settings (Optional[chromadb.config.Settings]): Chroma client settings
|
||||||
|
collection_metadata (Optional[Dict]): Collection configurations.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Chroma: Chroma vectorstore.
|
||||||
|
"""
|
||||||
|
texts = [doc.page_content for doc in documents]
|
||||||
|
metadatas = [doc.metadata for doc in documents]
|
||||||
|
return cls.from_texts(
|
||||||
|
texts=texts,
|
||||||
|
embedding=embedding,
|
||||||
|
metadatas=metadatas,
|
||||||
|
ids=ids,
|
||||||
|
collection_name=collection_name,
|
||||||
|
persist_directory=persist_directory,
|
||||||
|
client_settings=client_settings,
|
||||||
|
client=client,
|
||||||
|
collection_metadata=collection_metadata,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None:
|
||||||
|
"""Delete by vector IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids: List of ids to delete.
|
||||||
|
"""
|
||||||
|
self._collection.delete(ids=ids)
|
||||||
|
>>>>>>> 49c7b97c (code quality fixes: line length = 80)
|
@ -0,0 +1,177 @@
|
|||||||
|
import uuid
|
||||||
|
from abc import ABC
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from swarms.memory.schemas import Artifact, Status
|
||||||
|
from swarms.memory.schemas import Step as APIStep
|
||||||
|
from swarms.memory.schemas import Task as APITask
|
||||||
|
|
||||||
|
|
||||||
|
class Step(APIStep):
|
||||||
|
additional_properties: Optional[Dict[str, str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Task(APITask):
|
||||||
|
steps: List[Step] = []
|
||||||
|
|
||||||
|
|
||||||
|
class NotFoundException(Exception):
|
||||||
|
"""
|
||||||
|
Exception raised when a resource is not found.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, item_name: str, item_id: str):
|
||||||
|
self.item_name = item_name
|
||||||
|
self.item_id = item_id
|
||||||
|
super().__init__(f"{item_name} with {item_id} not found.")
|
||||||
|
|
||||||
|
|
||||||
|
class TaskDB(ABC):
|
||||||
|
async def create_task(
|
||||||
|
self,
|
||||||
|
input: Optional[str],
|
||||||
|
additional_input: Any = None,
|
||||||
|
artifacts: Optional[List[Artifact]] = None,
|
||||||
|
steps: Optional[List[Step]] = None,
|
||||||
|
) -> Task:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def create_step(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
input: Optional[str] = None,
|
||||||
|
is_last: bool = False,
|
||||||
|
additional_properties: Optional[Dict[str, str]] = None,
|
||||||
|
) -> Step:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def create_artifact(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
file_name: str,
|
||||||
|
relative_path: Optional[str] = None,
|
||||||
|
step_id: Optional[str] = None,
|
||||||
|
) -> Artifact:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_task(self, task_id: str) -> Task:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_step(self, task_id: str, step_id: str) -> Step:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def list_tasks(self) -> List[Task]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def list_steps(
|
||||||
|
self, task_id: str, status: Optional[Status] = None
|
||||||
|
) -> List[Step]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryTaskDB(TaskDB):
|
||||||
|
_tasks: Dict[str, Task] = {}
|
||||||
|
|
||||||
|
async def create_task(
|
||||||
|
self,
|
||||||
|
input: Optional[str],
|
||||||
|
additional_input: Any = None,
|
||||||
|
artifacts: Optional[List[Artifact]] = None,
|
||||||
|
steps: Optional[List[Step]] = None,
|
||||||
|
) -> Task:
|
||||||
|
if not steps:
|
||||||
|
steps = []
|
||||||
|
if not artifacts:
|
||||||
|
artifacts = []
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
task = Task(
|
||||||
|
task_id=task_id,
|
||||||
|
input=input,
|
||||||
|
steps=steps,
|
||||||
|
artifacts=artifacts,
|
||||||
|
additional_input=additional_input,
|
||||||
|
)
|
||||||
|
self._tasks[task_id] = task
|
||||||
|
return task
|
||||||
|
|
||||||
|
async def create_step(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
input: Optional[str] = None,
|
||||||
|
is_last=False,
|
||||||
|
additional_properties: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Step:
|
||||||
|
step_id = str(uuid.uuid4())
|
||||||
|
step = Step(
|
||||||
|
task_id=task_id,
|
||||||
|
step_id=step_id,
|
||||||
|
name=name,
|
||||||
|
input=input,
|
||||||
|
status=Status.created,
|
||||||
|
is_last=is_last,
|
||||||
|
additional_properties=additional_properties,
|
||||||
|
)
|
||||||
|
task = await self.get_task(task_id)
|
||||||
|
task.steps.append(step)
|
||||||
|
return step
|
||||||
|
|
||||||
|
async def get_task(self, task_id: str) -> Task:
|
||||||
|
task = self._tasks.get(task_id, None)
|
||||||
|
if not task:
|
||||||
|
raise NotFoundException("Task", task_id)
|
||||||
|
return task
|
||||||
|
|
||||||
|
async def get_step(self, task_id: str, step_id: str) -> Step:
|
||||||
|
task = await self.get_task(task_id)
|
||||||
|
step = next(filter(lambda s: s.task_id == task_id, task.steps), None)
|
||||||
|
if not step:
|
||||||
|
raise NotFoundException("Step", step_id)
|
||||||
|
return step
|
||||||
|
|
||||||
|
async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
|
||||||
|
task = await self.get_task(task_id)
|
||||||
|
artifact = next(
|
||||||
|
filter(lambda a: a.artifact_id == artifact_id, task.artifacts), None
|
||||||
|
)
|
||||||
|
if not artifact:
|
||||||
|
raise NotFoundException("Artifact", artifact_id)
|
||||||
|
return artifact
|
||||||
|
|
||||||
|
async def create_artifact(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
file_name: str,
|
||||||
|
relative_path: Optional[str] = None,
|
||||||
|
step_id: Optional[str] = None,
|
||||||
|
) -> Artifact:
|
||||||
|
artifact_id = str(uuid.uuid4())
|
||||||
|
artifact = Artifact(
|
||||||
|
artifact_id=artifact_id,
|
||||||
|
file_name=file_name,
|
||||||
|
relative_path=relative_path,
|
||||||
|
)
|
||||||
|
task = await self.get_task(task_id)
|
||||||
|
task.artifacts.append(artifact)
|
||||||
|
|
||||||
|
if step_id:
|
||||||
|
step = await self.get_step(task_id, step_id)
|
||||||
|
step.artifacts.append(artifact)
|
||||||
|
|
||||||
|
return artifact
|
||||||
|
|
||||||
|
async def list_tasks(self) -> List[Task]:
|
||||||
|
return [task for task in self._tasks.values()]
|
||||||
|
|
||||||
|
async def list_steps(
|
||||||
|
self, task_id: str, status: Optional[Status] = None
|
||||||
|
) -> List[Step]:
|
||||||
|
task = await self.get_task(task_id)
|
||||||
|
steps = task.steps
|
||||||
|
if status:
|
||||||
|
steps = list(filter(lambda s: s.status == status, steps))
|
||||||
|
return steps
|
@ -0,0 +1,157 @@
|
|||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import oceandb
|
||||||
|
from oceandb.utils.embedding_function import MultiModalEmbeddingFunction
|
||||||
|
|
||||||
|
|
||||||
|
class OceanDB:
|
||||||
|
"""
|
||||||
|
A class to interact with OceanDB.
|
||||||
|
|
||||||
|
...
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
client : oceandb.Client
|
||||||
|
a client to interact with OceanDB
|
||||||
|
|
||||||
|
Methods
|
||||||
|
-------
|
||||||
|
create_collection(collection_name: str, modality: str):
|
||||||
|
Creates a new collection in OceanDB.
|
||||||
|
append_document(collection, document: str, id: str):
|
||||||
|
Appends a document to a collection in OceanDB.
|
||||||
|
add_documents(collection, documents: List[str], ids: List[str]):
|
||||||
|
Adds multiple documents to a collection in OceanDB.
|
||||||
|
query(collection, query_texts: list[str], n_results: int):
|
||||||
|
Queries a collection in OceanDB.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, client: oceandb.Client = None):
|
||||||
|
"""
|
||||||
|
Constructs all the necessary attributes for the OceanDB object.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
client : oceandb.Client, optional
|
||||||
|
a client to interact with OceanDB (default is None, which creates a new client)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.client = client if client else oceandb.Client()
|
||||||
|
print(self.client.heartbeat())
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to initialize OceanDB client. Error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def create_collection(self, collection_name: str, modality: str):
|
||||||
|
"""
|
||||||
|
Creates a new collection in OceanDB.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
collection_name : str
|
||||||
|
the name of the new collection
|
||||||
|
modality : str
|
||||||
|
the modality of the new collection
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
collection
|
||||||
|
the created collection
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
embedding_function = MultiModalEmbeddingFunction(modality=modality)
|
||||||
|
collection = self.client.create_collection(
|
||||||
|
collection_name, embedding_function=embedding_function
|
||||||
|
)
|
||||||
|
return collection
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to create collection. Error {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def append_document(self, collection, document: str, id: str):
|
||||||
|
"""
|
||||||
|
Appends a document to a collection in OceanDB.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
collection
|
||||||
|
the collection to append the document to
|
||||||
|
document : str
|
||||||
|
the document to append
|
||||||
|
id : str
|
||||||
|
the id of the document
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
result
|
||||||
|
the result of the append operation
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return collection.add(documents=[document], ids=[id])
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(
|
||||||
|
f"Failed to append document to the collection. Error {e}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def add_documents(self, collection, documents: List[str], ids: List[str]):
|
||||||
|
"""
|
||||||
|
Adds multiple documents to a collection in OceanDB.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
collection
|
||||||
|
the collection to add the documents to
|
||||||
|
documents : List[str]
|
||||||
|
the documents to add
|
||||||
|
ids : List[str]
|
||||||
|
the ids of the documents
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
result
|
||||||
|
the result of the add operation
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return collection.add(documents=documents, ids=ids)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to add documents to collection. Error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def query(self, collection, query_texts: list[str], n_results: int):
|
||||||
|
"""
|
||||||
|
Queries a collection in OceanDB.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
collection
|
||||||
|
the collection to query
|
||||||
|
query_texts : list[str]
|
||||||
|
the texts to query
|
||||||
|
n_results : int
|
||||||
|
the number of results to return
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
results
|
||||||
|
the results of the query
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
results = collection.query(
|
||||||
|
query_texts=query_texts, n_results=n_results
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to query the collection. Error {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# Example
|
||||||
|
# ocean = OceanDB()
|
||||||
|
# collection = ocean.create_collection("test", "text")
|
||||||
|
# ocean.append_document(collection, "hello world", "1")
|
||||||
|
# ocean.add_documents(collection, ["hello world", "hello world"], ["2", "3"])
|
||||||
|
# results = ocean.query(collection, ["hello world"], 3)
|
||||||
|
# print(results)
|
@ -0,0 +1,261 @@
|
|||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import concurrent.futures
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import requests
|
||||||
|
from cachetools import TTLCache
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from openai import OpenAI
|
||||||
|
from ratelimit import limits, sleep_and_retry
|
||||||
|
from termcolor import colored
|
||||||
|
|
||||||
|
# ENV
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GPT4VisionResponse:
|
||||||
|
"""A response structure for GPT-4"""
|
||||||
|
|
||||||
|
answer: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GPT4Vision:
|
||||||
|
"""
|
||||||
|
GPT4Vision model class
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
-----------
|
||||||
|
max_retries: int
|
||||||
|
The maximum number of retries to make to the API
|
||||||
|
backoff_factor: float
|
||||||
|
The backoff factor to use for exponential backoff
|
||||||
|
timeout_seconds: int
|
||||||
|
The timeout in seconds for the API request
|
||||||
|
api_key: str
|
||||||
|
The API key to use for the API request
|
||||||
|
quality: str
|
||||||
|
The quality of the image to generate
|
||||||
|
max_tokens: int
|
||||||
|
The maximum number of tokens to use for the API request
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
--------
|
||||||
|
process_img(self, img_path: str) -> str:
|
||||||
|
Processes the image to be used for the API request
|
||||||
|
run(self, img: Union[str, List[str]], tasks: List[str]) -> GPT4VisionResponse:
|
||||||
|
Makes a call to the GPT-4 Vision API and returns the image url
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> gpt4vision = GPT4Vision()
|
||||||
|
>>> img = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
|
||||||
|
>>> tasks = ["A painting of a dog"]
|
||||||
|
>>> answer = gpt4vision(img, tasks)
|
||||||
|
>>> print(answer)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_retries: int = 3
|
||||||
|
model: str = "gpt-4-vision-preview"
|
||||||
|
backoff_factor: float = 2.0
|
||||||
|
timeout_seconds: int = 10
|
||||||
|
openai_api_key: Optional[str] = None or os.getenv("OPENAI_API_KEY")
|
||||||
|
# 'Low' or 'High' for respesctively fast or high quality, but high more token usage
|
||||||
|
quality: str = "low"
|
||||||
|
# Max tokens to use for the API request, the maximum might be 3,000 but we don't know
|
||||||
|
max_tokens: int = 200
|
||||||
|
client = OpenAI(
|
||||||
|
api_key=openai_api_key,
|
||||||
|
)
|
||||||
|
dashboard: bool = True
|
||||||
|
call_limit: int = 1
|
||||||
|
period_seconds: int = 60
|
||||||
|
|
||||||
|
# Cache for storing API Responses
|
||||||
|
cache = TTLCache(maxsize=100, ttl=600) # Cache for 10 minutes
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Config class for the GPT4Vision model"""
|
||||||
|
|
||||||
|
arbitary_types_allowed = True
|
||||||
|
|
||||||
|
def process_img(self, img: str) -> str:
|
||||||
|
"""Processes the image to be used for the API request"""
|
||||||
|
with open(img, "rb") as image_file:
|
||||||
|
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||||
|
|
||||||
|
@sleep_and_retry
|
||||||
|
@limits(
|
||||||
|
calls=call_limit, period=period_seconds
|
||||||
|
) # Rate limit of 10 calls per minute
|
||||||
|
def run(self, task: str, img: str):
|
||||||
|
"""
|
||||||
|
Run the GPT-4 Vision model
|
||||||
|
|
||||||
|
Task: str
|
||||||
|
The task to run
|
||||||
|
Img: str
|
||||||
|
The image to run the task on
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.dashboard:
|
||||||
|
self.print_dashboard()
|
||||||
|
try:
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model="gpt-4-vision-preview",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": task},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": str(img),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
out = print(response.choices[0])
|
||||||
|
# out = self.clean_output(out)
|
||||||
|
return out
|
||||||
|
except openai.OpenAIError as e:
|
||||||
|
# logger.error(f"OpenAI API error: {e}")
|
||||||
|
return f"OpenAI API error: Could not process the image. {e}"
|
||||||
|
except Exception as e:
|
||||||
|
return f"Unexpected error occurred while processing the image. {e}"
|
||||||
|
|
||||||
|
def clean_output(self, output: str):
|
||||||
|
# Regex pattern to find the Choice object representation in the output
|
||||||
|
pattern = r"Choice\(.*?\(content=\"(.*?)\".*?\)\)"
|
||||||
|
match = re.search(pattern, output, re.DOTALL)
|
||||||
|
|
||||||
|
if match:
|
||||||
|
# Extract the content from the matched pattern
|
||||||
|
content = match.group(1)
|
||||||
|
# Replace escaped quotes to get the clean content
|
||||||
|
content = content.replace(r"\"", '"')
|
||||||
|
print(content)
|
||||||
|
else:
|
||||||
|
print("No content found in the output.")
|
||||||
|
|
||||||
|
async def arun(self, task: str, img: str):
|
||||||
|
"""
|
||||||
|
Arun is an async version of run
|
||||||
|
|
||||||
|
Task: str
|
||||||
|
The task to run
|
||||||
|
Img: str
|
||||||
|
The image to run the task on
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = await self.client.chat.completions.create(
|
||||||
|
model="gpt-4-vision-preview",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": task},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": img,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
return print(response.choices[0])
|
||||||
|
except openai.OpenAIError as e:
|
||||||
|
# logger.error(f"OpenAI API error: {e}")
|
||||||
|
return f"OpenAI API error: Could not process the image. {e}"
|
||||||
|
except Exception as e:
|
||||||
|
return f"Unexpected error occurred while processing the image. {e}"
|
||||||
|
|
||||||
|
def run_batch(self, tasks_images: List[Tuple[str, str]]) -> List[str]:
|
||||||
|
"""Process a batch of tasks and images"""
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
futures = [
|
||||||
|
executor.submit(self.run, task, img)
|
||||||
|
for task, img in tasks_images
|
||||||
|
]
|
||||||
|
results = [future.result() for future in futures]
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def run_batch_async(
|
||||||
|
self, tasks_images: List[Tuple[str, str]]
|
||||||
|
) -> List[str]:
|
||||||
|
"""Process a batch of tasks and images asynchronously"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
futures = [
|
||||||
|
loop.run_in_executor(None, self.run, task, img)
|
||||||
|
for task, img in tasks_images
|
||||||
|
]
|
||||||
|
return await asyncio.gather(*futures)
|
||||||
|
|
||||||
|
async def run_batch_async_with_retries(
|
||||||
|
self, tasks_images: List[Tuple[str, str]]
|
||||||
|
) -> List[str]:
|
||||||
|
"""Process a batch of tasks and images asynchronously with retries"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
futures = [
|
||||||
|
loop.run_in_executor(None, self.run_with_retries, task, img)
|
||||||
|
for task, img in tasks_images
|
||||||
|
]
|
||||||
|
return await asyncio.gather(*futures)
|
||||||
|
|
||||||
|
def print_dashboard(self):
|
||||||
|
dashboard = print(
|
||||||
|
colored(
|
||||||
|
f"""
|
||||||
|
GPT4Vision Dashboard
|
||||||
|
-------------------
|
||||||
|
Max Retries: {self.max_retries}
|
||||||
|
Model: {self.model}
|
||||||
|
Backoff Factor: {self.backoff_factor}
|
||||||
|
Timeout Seconds: {self.timeout_seconds}
|
||||||
|
Image Quality: {self.quality}
|
||||||
|
Max Tokens: {self.max_tokens}
|
||||||
|
|
||||||
|
""",
|
||||||
|
"green",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return dashboard
|
||||||
|
|
||||||
|
def health_check(self):
|
||||||
|
"""Health check for the GPT4Vision model"""
|
||||||
|
try:
|
||||||
|
response = requests.get("https://api.openai.com/v1/engines")
|
||||||
|
return response.status_code == 200
|
||||||
|
except requests.RequestException as error:
|
||||||
|
print(f"Health check failed: {error}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def sanitize_input(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
Sanitize input to prevent injection attacks.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
text: str - The input text to be sanitized.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The sanitized text.
|
||||||
|
"""
|
||||||
|
# Example of simple sanitization, this should be expanded based on the context and usage
|
||||||
|
sanitized_text = re.sub(r"[^\w\s]", "", text)
|
||||||
|
return sanitized_text
|
@ -0,0 +1,93 @@
|
|||||||
|
import os
|
||||||
|
from typing import Callable, List
|
||||||
|
|
||||||
|
|
||||||
|
class DialogueSimulator:
|
||||||
|
"""
|
||||||
|
Dialogue Simulator
|
||||||
|
------------------
|
||||||
|
|
||||||
|
Args:
|
||||||
|
------
|
||||||
|
agents: List[Callable]
|
||||||
|
max_iters: int
|
||||||
|
name: str
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
------
|
||||||
|
>>> from swarms import DialogueSimulator
|
||||||
|
>>> from swarms.structs.flow import Flow
|
||||||
|
>>> agents = Flow()
|
||||||
|
>>> agents1 = Flow()
|
||||||
|
>>> model = DialogueSimulator([agents, agents1], max_iters=10, name="test")
|
||||||
|
>>> model.run("test")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, agents: List[Callable], max_iters: int = 10, name: str = None
|
||||||
|
):
|
||||||
|
self.agents = agents
|
||||||
|
self.max_iters = max_iters
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def run(self, message: str = None):
|
||||||
|
"""Run the dialogue simulator"""
|
||||||
|
try:
|
||||||
|
step = 0
|
||||||
|
if self.name and message:
|
||||||
|
prompt = f"Name {self.name} and message: {message}"
|
||||||
|
for agent in self.agents:
|
||||||
|
agent.run(prompt)
|
||||||
|
step += 1
|
||||||
|
|
||||||
|
while step < self.max_iters:
|
||||||
|
speaker_idx = step % len(self.agents)
|
||||||
|
speaker = self.agents[speaker_idx]
|
||||||
|
speaker_message = speaker.run(prompt)
|
||||||
|
|
||||||
|
for receiver in self.agents:
|
||||||
|
message_history = (
|
||||||
|
f"Speaker Name: {speaker.name} and message:"
|
||||||
|
f" {speaker_message}"
|
||||||
|
)
|
||||||
|
receiver.run(message_history)
|
||||||
|
|
||||||
|
print(f"({speaker.name}): {speaker_message}")
|
||||||
|
print("\n")
|
||||||
|
step += 1
|
||||||
|
except Exception as error:
|
||||||
|
print(f"Error running dialogue simulator: {error}")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
f"DialogueSimulator({self.agents}, {self.max_iters}, {self.name})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_state(self):
|
||||||
|
"""Save the state of the dialogue simulator"""
|
||||||
|
try:
|
||||||
|
if self.name:
|
||||||
|
filename = f"{self.name}.txt"
|
||||||
|
with open(filename, "w") as file:
|
||||||
|
file.write(str(self))
|
||||||
|
except Exception as error:
|
||||||
|
print(f"Error saving state: {error}")
|
||||||
|
|
||||||
|
def load_state(self):
|
||||||
|
"""Load the state of the dialogue simulator"""
|
||||||
|
try:
|
||||||
|
if self.name:
|
||||||
|
filename = f"{self.name}.txt"
|
||||||
|
with open(filename, "r") as file:
|
||||||
|
return file.read()
|
||||||
|
except Exception as error:
|
||||||
|
print(f"Error loading state: {error}")
|
||||||
|
|
||||||
|
def delete_state(self):
|
||||||
|
"""Delete the state of the dialogue simulator"""
|
||||||
|
try:
|
||||||
|
if self.name:
|
||||||
|
filename = f"{self.name}.txt"
|
||||||
|
os.remove(filename)
|
||||||
|
except Exception as error:
|
||||||
|
print(f"Error deleting state: {error}")
|
@ -0,0 +1,287 @@
|
|||||||
|
import logging
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import chromadb
|
||||||
|
from chromadb.utils import embedding_functions
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatus(Enum):
|
||||||
|
QUEUED = 1
|
||||||
|
RUNNING = 2
|
||||||
|
COMPLETED = 3
|
||||||
|
FAILED = 4
|
||||||
|
|
||||||
|
|
||||||
|
class Orchestrator:
|
||||||
|
"""
|
||||||
|
The Orchestrator takes in an agent, worker, or boss as input
|
||||||
|
then handles all the logic for
|
||||||
|
- task creation,
|
||||||
|
- task assignment,
|
||||||
|
- and task compeletion.
|
||||||
|
|
||||||
|
And, the communication for millions of agents to chat with eachother through
|
||||||
|
a vector database that each agent has access to chat with.
|
||||||
|
|
||||||
|
Each LLM agent chats with the orchestrator through a dedicated
|
||||||
|
communication layer. The orchestrator assigns tasks to each LLM agent,
|
||||||
|
which the agents then complete and return.
|
||||||
|
|
||||||
|
This setup allows for a high degree of flexibility, scalability, and robustness.
|
||||||
|
|
||||||
|
In the context of swarm LLMs, one could consider an **Omni-Vector Embedding Database
|
||||||
|
for communication. This database could store and manage
|
||||||
|
the high-dimensional vectors produced by each LLM agent.
|
||||||
|
|
||||||
|
Strengths: This approach would allow for similarity-based lookup and matching of
|
||||||
|
LLM-generated vectors, which can be particularly useful for tasks that involve finding similar outputs or recognizing patterns.
|
||||||
|
|
||||||
|
Weaknesses: An Omni-Vector Embedding Database might add complexity to the system in terms of setup and maintenance.
|
||||||
|
It might also require significant computational resources,
|
||||||
|
depending on the volume of data being handled and the complexity of the vectors.
|
||||||
|
The handling and transmission of high-dimensional vectors could also pose challenges
|
||||||
|
in terms of network load.
|
||||||
|
|
||||||
|
# Orchestrator
|
||||||
|
* Takes in an agent class with vector store,
|
||||||
|
then handles all the communication and scales
|
||||||
|
up a swarm with number of agents and handles task assignment and task completion
|
||||||
|
|
||||||
|
from swarms import OpenAI, Orchestrator, Swarm
|
||||||
|
|
||||||
|
orchestrated = Orchestrate(OpenAI, nodes=40) #handles all the task assignment and allocation and agent communication using a vectorstore as a universal communication layer and also handlles the task completion logic
|
||||||
|
|
||||||
|
Objective = "Make a business website for a marketing consultancy"
|
||||||
|
|
||||||
|
Swarms = Swarms(orchestrated, auto=True, Objective))
|
||||||
|
```
|
||||||
|
|
||||||
|
In terms of architecture, the swarm might look something like this:
|
||||||
|
|
||||||
|
```
|
||||||
|
(Orchestrator)
|
||||||
|
/ \
|
||||||
|
Tools + Vector DB -- (LLM Agent)---(Communication Layer) (Communication Layer)---(LLM Agent)-- Tools + Vector DB
|
||||||
|
/ | | \
|
||||||
|
(Task Assignment) (Task Completion) (Task Assignment) (Task Completion)
|
||||||
|
|
||||||
|
|
||||||
|
###Usage
|
||||||
|
```
|
||||||
|
from swarms import Orchestrator
|
||||||
|
|
||||||
|
# Instantiate the Orchestrator with 10 agents
|
||||||
|
orchestrator = Orchestrator(llm, agent_list=[llm]*10, task_queue=[])
|
||||||
|
|
||||||
|
# Add tasks to the Orchestrator
|
||||||
|
tasks = [{"content": f"Write a short story about a {animal}."} for animal in ["cat", "dog", "bird", "fish", "lion", "tiger", "elephant", "giraffe", "monkey", "zebra"]]
|
||||||
|
orchestrator.assign_tasks(tasks)
|
||||||
|
|
||||||
|
# Run the Orchestrator
|
||||||
|
orchestrator.run()
|
||||||
|
|
||||||
|
# Retrieve the results
|
||||||
|
for task in tasks:
|
||||||
|
print(orchestrator.retrieve_result(id(task)))
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
agent,
|
||||||
|
agent_list: List[Any],
|
||||||
|
task_queue: List[Any],
|
||||||
|
collection_name: str = "swarm",
|
||||||
|
api_key: str = None,
|
||||||
|
model_name: str = None,
|
||||||
|
embed_func=None,
|
||||||
|
worker=None,
|
||||||
|
):
|
||||||
|
self.agent = agent
|
||||||
|
self.agents = queue.Queue()
|
||||||
|
|
||||||
|
for _ in range(agent_list):
|
||||||
|
self.agents.put(agent())
|
||||||
|
|
||||||
|
self.task_queue = queue.Queue()
|
||||||
|
|
||||||
|
self.chroma_client = chromadb.Client()
|
||||||
|
|
||||||
|
self.collection = self.chroma_client.create_collection(
|
||||||
|
name=collection_name
|
||||||
|
)
|
||||||
|
|
||||||
|
self.current_tasks = {}
|
||||||
|
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
self.condition = threading.Condition(self.lock)
|
||||||
|
self.executor = ThreadPoolExecutor(max_workers=len(agent_list))
|
||||||
|
|
||||||
|
self.embed_func = embed_func if embed_func else self.embed
|
||||||
|
|
||||||
|
# @abstractmethod
|
||||||
|
|
||||||
|
def assign_task(self, agent_id: int, task: Dict[str, Any]) -> None:
|
||||||
|
"""Assign a task to a specific agent"""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
with self.condition:
|
||||||
|
while not self.task_queue:
|
||||||
|
self.condition.wait()
|
||||||
|
agent = self.agents.get()
|
||||||
|
task = self.task_queue.get()
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self.worker.run(task["content"])
|
||||||
|
|
||||||
|
# using the embed method to get the vector representation of the result
|
||||||
|
vector_representation = self.embed(
|
||||||
|
result, self.api_key, self.model_name
|
||||||
|
)
|
||||||
|
|
||||||
|
self.collection.add(
|
||||||
|
embeddings=[vector_representation],
|
||||||
|
documents=[str(id(task))],
|
||||||
|
ids=[str(id(task))],
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"Task {id(str)} has been processed by agent"
|
||||||
|
f" {id(agent)} with"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as error:
|
||||||
|
logging.error(
|
||||||
|
f"Failed to process task {id(task)} by agent {id(agent)}."
|
||||||
|
f" Error: {error}"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
with self.condition:
|
||||||
|
self.agents.put(agent)
|
||||||
|
self.condition.notify()
|
||||||
|
|
||||||
|
def embed(self, input, api_key, model_name):
|
||||||
|
openai = embedding_functions.OpenAIEmbeddingFunction(
|
||||||
|
api_key=api_key, model_name=model_name
|
||||||
|
)
|
||||||
|
embedding = openai(input)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
# @abstractmethod
|
||||||
|
|
||||||
|
def retrieve_results(self, agent_id: int) -> Any:
|
||||||
|
"""Retrieve results from a specific agent"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Query the vector database for documents created by the agents
|
||||||
|
results = self.collection.query(
|
||||||
|
query_texts=[str(agent_id)], n_results=10
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(
|
||||||
|
f"Failed to retrieve results from agent {agent_id}. Error {e}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# @abstractmethod
|
||||||
|
def update_vector_db(self, data) -> None:
|
||||||
|
"""Update the vector database"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.collection.add(
|
||||||
|
embeddings=[data["vector"]],
|
||||||
|
documents=[str(data["task_id"])],
|
||||||
|
ids=[str(data["task_id"])],
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to update the vector database. Error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# @abstractmethod
|
||||||
|
|
||||||
|
def get_vector_db(self):
|
||||||
|
"""Retrieve the vector database"""
|
||||||
|
return self.collection
|
||||||
|
|
||||||
|
def append_to_db(self, result: str):
|
||||||
|
"""append the result of the swarm to a specifici collection in the database"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.collection.add(documents=[result], ids=[str(id(result))])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(
|
||||||
|
f"Failed to append the agent output to database. Error: {e}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def run(self, objective: str):
|
||||||
|
"""Runs"""
|
||||||
|
if not objective or not isinstance(objective, str):
|
||||||
|
logging.error("Invalid objective")
|
||||||
|
raise ValueError("A valid objective is required")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.task_queue.append(objective)
|
||||||
|
|
||||||
|
results = [
|
||||||
|
self.assign_task(agent_id, task)
|
||||||
|
for agent_id, task in zip(
|
||||||
|
range(len(self.agents)), self.task_queue
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
self.append_to_db(result)
|
||||||
|
|
||||||
|
logging.info(f"Successfully ran swarms with results: {results}")
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"An error occured in swarm: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def chat(self, sender_id: int, receiver_id: int, message: str):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Allows the agents to chat with eachother thrught the vectordatabase
|
||||||
|
|
||||||
|
# Instantiate the Orchestrator with 10 agents
|
||||||
|
orchestrator = Orchestrator(
|
||||||
|
llm,
|
||||||
|
agent_list=[llm]*10,
|
||||||
|
task_queue=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Agent 1 sends a message to Agent 2
|
||||||
|
orchestrator.chat(sender_id=1, receiver_id=2, message="Hello, Agent 2!")
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
message_vector = self.embed(message, self.api_key, self.model_name)
|
||||||
|
|
||||||
|
# store the mesage in the vector database
|
||||||
|
self.collection.add(
|
||||||
|
embeddings=[message_vector],
|
||||||
|
documents=[message],
|
||||||
|
ids=[f"{sender_id}_to_{receiver_id}"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.run(objective=f"chat with agent {receiver_id} about {message}")
|
||||||
|
|
||||||
|
def add_agents(self, num_agents: int):
|
||||||
|
for _ in range(num_agents):
|
||||||
|
self.agents.put(self.agent())
|
||||||
|
self.executor = ThreadPoolExecutor(max_workers=self.agents.qsize())
|
||||||
|
|
||||||
|
def remove_agents(self, num_agents):
|
||||||
|
for _ in range(num_agents):
|
||||||
|
if not self.agents.empty():
|
||||||
|
self.agents.get()
|
||||||
|
self.executor = ThreadPoolExecutor(max_workers=self.agents.qsize())
|
@ -0,0 +1,200 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
from langchain.agents import tool
|
||||||
|
from langchain.agents.agent_toolkits.pandas.base import (
|
||||||
|
create_pandas_dataframe_agent,
|
||||||
|
)
|
||||||
|
from langchain.chains.qa_with_sources.loading import (
|
||||||
|
BaseCombineDocumentsChain,
|
||||||
|
)
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
from PIL import Image
|
||||||
|
from pydantic import Field
|
||||||
|
from transformers import (
|
||||||
|
BlipForQuestionAnswering,
|
||||||
|
BlipProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
from swarms.utils.logger import logger
|
||||||
|
|
||||||
|
ROOT_DIR = "./data/"
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def pushd(new_dir):
|
||||||
|
"""Context manager for changing the current working directory."""
|
||||||
|
prev_dir = os.getcwd()
|
||||||
|
os.chdir(new_dir)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
os.chdir(prev_dir)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def process_csv(
|
||||||
|
llm,
|
||||||
|
csv_file_path: str,
|
||||||
|
instructions: str,
|
||||||
|
output_path: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Process a CSV by with pandas in a limited REPL.\
|
||||||
|
Only use this after writing data to disk as a csv file.\
|
||||||
|
Any figures must be saved to disk to be viewed by the human.\
|
||||||
|
Instructions should be written in natural language, not code. Assume the dataframe is already loaded."""
|
||||||
|
with pushd(ROOT_DIR):
|
||||||
|
try:
|
||||||
|
df = pd.read_csv(csv_file_path)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error: {e}"
|
||||||
|
agent = create_pandas_dataframe_agent(
|
||||||
|
llm, df, max_iterations=30, verbose=False
|
||||||
|
)
|
||||||
|
if output_path is not None:
|
||||||
|
instructions += f" Save output to disk at {output_path}"
|
||||||
|
try:
|
||||||
|
result = agent.run(instructions)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
async def async_load_playwright(url: str) -> str:
|
||||||
|
"""Load the specified URLs using Playwright and parse using BeautifulSoup."""
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
from playwright.async_api import async_playwright
|
||||||
|
|
||||||
|
results = ""
|
||||||
|
async with async_playwright() as p:
|
||||||
|
browser = await p.chromium.launch(headless=True)
|
||||||
|
try:
|
||||||
|
page = await browser.new_page()
|
||||||
|
await page.goto(url)
|
||||||
|
|
||||||
|
page_source = await page.content()
|
||||||
|
soup = BeautifulSoup(page_source, "html.parser")
|
||||||
|
|
||||||
|
for script in soup(["script", "style"]):
|
||||||
|
script.extract()
|
||||||
|
|
||||||
|
text = soup.get_text()
|
||||||
|
lines = (line.strip() for line in text.splitlines())
|
||||||
|
chunks = (
|
||||||
|
phrase.strip() for line in lines for phrase in line.split(" ")
|
||||||
|
)
|
||||||
|
results = "\n".join(chunk for chunk in chunks if chunk)
|
||||||
|
except Exception as e:
|
||||||
|
results = f"Error: {e}"
|
||||||
|
await browser.close()
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def run_async(coro):
|
||||||
|
event_loop = asyncio.get_event_loop()
|
||||||
|
return event_loop.run_until_complete(coro)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def browse_web_page(url: str) -> str:
|
||||||
|
"""Verbose way to scrape a whole webpage. Likely to cause issues parsing."""
|
||||||
|
return run_async(async_load_playwright(url))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_text_splitter():
|
||||||
|
return RecursiveCharacterTextSplitter(
|
||||||
|
# Set a really small chunk size, just to show.
|
||||||
|
chunk_size=500,
|
||||||
|
chunk_overlap=20,
|
||||||
|
length_function=len,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WebpageQATool(BaseTool):
|
||||||
|
name = "query_webpage"
|
||||||
|
description = (
|
||||||
|
"Browse a webpage and retrieve the information relevant to the"
|
||||||
|
" question."
|
||||||
|
)
|
||||||
|
text_splitter: RecursiveCharacterTextSplitter = Field(
|
||||||
|
default_factory=_get_text_splitter
|
||||||
|
)
|
||||||
|
qa_chain: BaseCombineDocumentsChain
|
||||||
|
|
||||||
|
def _run(self, url: str, question: str) -> str:
|
||||||
|
"""Useful for browsing websites and scraping the text information."""
|
||||||
|
result = browse_web_page.run(url)
|
||||||
|
docs = [Document(page_content=result, metadata={"source": url})]
|
||||||
|
web_docs = self.text_splitter.split_documents(docs)
|
||||||
|
results = []
|
||||||
|
# TODO: Handle this with a MapReduceChain
|
||||||
|
for i in range(0, len(web_docs), 4):
|
||||||
|
input_docs = web_docs[i : i + 4]
|
||||||
|
window_result = self.qa_chain(
|
||||||
|
{"input_documents": input_docs, "question": question},
|
||||||
|
return_only_outputs=True,
|
||||||
|
)
|
||||||
|
results.append(f"Response from window {i} - {window_result}")
|
||||||
|
results_docs = [
|
||||||
|
Document(page_content="\n".join(results), metadata={"source": url})
|
||||||
|
]
|
||||||
|
return self.qa_chain(
|
||||||
|
{"input_documents": results_docs, "question": question},
|
||||||
|
return_only_outputs=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _arun(self, url: str, question: str) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class EdgeGPTTool:
|
||||||
|
# Initialize the custom tool
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
name="EdgeGPTTool",
|
||||||
|
description="Tool that uses EdgeGPTModel to generate responses",
|
||||||
|
):
|
||||||
|
super().__init__(name=name, description=description)
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def _run(self, prompt):
|
||||||
|
return self.model.__call__(prompt)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def VQAinference(self, inputs):
|
||||||
|
"""
|
||||||
|
Answer Question About The Image, VQA Multi-Modal Worker agent
|
||||||
|
description="useful when you need an answer for a question based on an image. "
|
||||||
|
"like: what is the background color of the last image, how many cats in this figure, what is in this figure. "
|
||||||
|
"The input to this tool should be a comma separated string of two, representing the image_path and the question",
|
||||||
|
|
||||||
|
"""
|
||||||
|
device = "cuda:0"
|
||||||
|
torch_dtype = torch.float16 if "cuda" in device else torch.float32
|
||||||
|
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
|
||||||
|
model = BlipForQuestionAnswering.from_pretrained(
|
||||||
|
"Salesforce/blip-vqa-base", torch_dtype=torch_dtype
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
image_path, question = inputs.split(",")
|
||||||
|
raw_image = Image.open(image_path).convert("RGB")
|
||||||
|
inputs = processor(raw_image, question, return_tensors="pt").to(
|
||||||
|
device, torch_dtype
|
||||||
|
)
|
||||||
|
out = model.generate(**inputs)
|
||||||
|
answer = processor.decode(out[0], skip_special_tokens=True)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input"
|
||||||
|
f" Question: {question}, Output Answer: {answer}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return answer
|
@ -0,0 +1,284 @@
|
|||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from diffusers import (
|
||||||
|
EulerAncestralDiscreteScheduler,
|
||||||
|
StableDiffusionInpaintPipeline,
|
||||||
|
StableDiffusionInstructPix2PixPipeline,
|
||||||
|
StableDiffusionPipeline,
|
||||||
|
)
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import (
|
||||||
|
BlipForConditionalGeneration,
|
||||||
|
BlipForQuestionAnswering,
|
||||||
|
BlipProcessor,
|
||||||
|
CLIPSegForImageSegmentation,
|
||||||
|
CLIPSegProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
from swarms.prompts.prebuild.multi_modal_prompts import IMAGE_PROMPT
|
||||||
|
from swarms.tools.tool import tool
|
||||||
|
from swarms.utils.logger import logger
|
||||||
|
from swarms.utils.main import BaseHandler, get_new_image_name
|
||||||
|
|
||||||
|
|
||||||
|
class MaskFormer:
|
||||||
|
def __init__(self, device):
|
||||||
|
print("Initializing MaskFormer to %s" % device)
|
||||||
|
self.device = device
|
||||||
|
self.processor = CLIPSegProcessor.from_pretrained(
|
||||||
|
"CIDAS/clipseg-rd64-refined"
|
||||||
|
)
|
||||||
|
self.model = CLIPSegForImageSegmentation.from_pretrained(
|
||||||
|
"CIDAS/clipseg-rd64-refined"
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
def inference(self, image_path, text):
|
||||||
|
threshold = 0.5
|
||||||
|
min_area = 0.02
|
||||||
|
padding = 20
|
||||||
|
original_image = Image.open(image_path)
|
||||||
|
image = original_image.resize((512, 512))
|
||||||
|
inputs = self.processor(
|
||||||
|
text=text, images=image, padding="max_length", return_tensors="pt"
|
||||||
|
).to(self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
mask = torch.sigmoid(outputs[0]).squeeze().cpu().numpy() > threshold
|
||||||
|
area_ratio = len(np.argwhere(mask)) / (mask.shape[0] * mask.shape[1])
|
||||||
|
if area_ratio < min_area:
|
||||||
|
return None
|
||||||
|
true_indices = np.argwhere(mask)
|
||||||
|
mask_array = np.zeros_like(mask, dtype=bool)
|
||||||
|
for idx in true_indices:
|
||||||
|
padded_slice = tuple(
|
||||||
|
slice(max(0, i - padding), i + padding + 1) for i in idx
|
||||||
|
)
|
||||||
|
mask_array[padded_slice] = True
|
||||||
|
visual_mask = (mask_array * 255).astype(np.uint8)
|
||||||
|
image_mask = Image.fromarray(visual_mask)
|
||||||
|
return image_mask.resize(original_image.size)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageEditing:
|
||||||
|
def __init__(self, device):
|
||||||
|
print("Initializing ImageEditing to %s" % device)
|
||||||
|
self.device = device
|
||||||
|
self.mask_former = MaskFormer(device=self.device)
|
||||||
|
self.revision = "fp16" if "cuda" in device else None
|
||||||
|
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
|
||||||
|
self.inpaint = StableDiffusionInpaintPipeline.from_pretrained(
|
||||||
|
"runwayml/stable-diffusion-inpainting",
|
||||||
|
revision=self.revision,
|
||||||
|
torch_dtype=self.torch_dtype,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="Remove Something From The Photo",
|
||||||
|
description=(
|
||||||
|
"useful when you want to remove and object or something from the"
|
||||||
|
" photo from its description or location. The input to this tool"
|
||||||
|
" should be a comma separated string of two, representing the"
|
||||||
|
" image_path and the object need to be removed. "
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def inference_remove(self, inputs):
|
||||||
|
image_path, to_be_removed_txt = inputs.split(",")
|
||||||
|
return self.inference_replace(
|
||||||
|
f"{image_path},{to_be_removed_txt},background"
|
||||||
|
)
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="Replace Something From The Photo",
|
||||||
|
description=(
|
||||||
|
"useful when you want to replace an object from the object"
|
||||||
|
" description or location with another object from its description."
|
||||||
|
" The input to this tool should be a comma separated string of"
|
||||||
|
" three, representing the image_path, the object to be replaced,"
|
||||||
|
" the object to be replaced with "
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def inference_replace(self, inputs):
|
||||||
|
image_path, to_be_replaced_txt, replace_with_txt = inputs.split(",")
|
||||||
|
original_image = Image.open(image_path)
|
||||||
|
original_size = original_image.size
|
||||||
|
mask_image = self.mask_former.inference(image_path, to_be_replaced_txt)
|
||||||
|
updated_image = self.inpaint(
|
||||||
|
prompt=replace_with_txt,
|
||||||
|
image=original_image.resize((512, 512)),
|
||||||
|
mask_image=mask_image.resize((512, 512)),
|
||||||
|
).images[0]
|
||||||
|
updated_image_path = get_new_image_name(
|
||||||
|
image_path, func_name="replace-something"
|
||||||
|
)
|
||||||
|
updated_image = updated_image.resize(original_size)
|
||||||
|
updated_image.save(updated_image_path)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"\nProcessed ImageEditing, Input Image: {image_path}, Replace"
|
||||||
|
f" {to_be_replaced_txt} to {replace_with_txt}, Output Image:"
|
||||||
|
f" {updated_image_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return updated_image_path
|
||||||
|
|
||||||
|
|
||||||
|
class InstructPix2Pix:
|
||||||
|
def __init__(self, device):
|
||||||
|
print("Initializing InstructPix2Pix to %s" % device)
|
||||||
|
self.device = device
|
||||||
|
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
|
||||||
|
self.pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
||||||
|
"timbrooks/instruct-pix2pix",
|
||||||
|
safety_checker=None,
|
||||||
|
torch_dtype=self.torch_dtype,
|
||||||
|
).to(device)
|
||||||
|
self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
||||||
|
self.pipe.scheduler.config
|
||||||
|
)
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="Instruct Image Using Text",
|
||||||
|
description=(
|
||||||
|
"useful when you want to the style of the image to be like the"
|
||||||
|
" text. like: make it look like a painting. or make it like a"
|
||||||
|
" robot. The input to this tool should be a comma separated string"
|
||||||
|
" of two, representing the image_path and the text. "
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def inference(self, inputs):
|
||||||
|
"""Change style of image."""
|
||||||
|
logger.debug("===> Starting InstructPix2Pix Inference")
|
||||||
|
image_path, text = inputs.split(",")[0], ",".join(inputs.split(",")[1:])
|
||||||
|
original_image = Image.open(image_path)
|
||||||
|
image = self.pipe(
|
||||||
|
text,
|
||||||
|
image=original_image,
|
||||||
|
num_inference_steps=40,
|
||||||
|
image_guidance_scale=1.2,
|
||||||
|
).images[0]
|
||||||
|
updated_image_path = get_new_image_name(image_path, func_name="pix2pix")
|
||||||
|
image.save(updated_image_path)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"\nProcessed InstructPix2Pix, Input Image: {image_path}, Instruct"
|
||||||
|
f" Text: {text}, Output Image: {updated_image_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return updated_image_path
|
||||||
|
|
||||||
|
|
||||||
|
class Text2Image:
|
||||||
|
def __init__(self, device):
|
||||||
|
print("Initializing Text2Image to %s" % device)
|
||||||
|
self.device = device
|
||||||
|
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
|
||||||
|
self.pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
|
"runwayml/stable-diffusion-v1-5", torch_dtype=self.torch_dtype
|
||||||
|
)
|
||||||
|
self.pipe.to(device)
|
||||||
|
self.a_prompt = "best quality, extremely detailed"
|
||||||
|
self.n_prompt = (
|
||||||
|
"longbody, lowres, bad anatomy, bad hands, missing fingers, extra"
|
||||||
|
" digit, fewer digits, cropped, worst quality, low quality"
|
||||||
|
)
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="Generate Image From User Input Text",
|
||||||
|
description=(
|
||||||
|
"useful when you want to generate an image from a user input text"
|
||||||
|
" and save it to a file. like: generate an image of an object or"
|
||||||
|
" something, or generate an image that includes some objects. The"
|
||||||
|
" input to this tool should be a string, representing the text used"
|
||||||
|
" to generate image. "
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def inference(self, text):
|
||||||
|
image_filename = os.path.join("image", str(uuid.uuid4())[0:8] + ".png")
|
||||||
|
prompt = text + ", " + self.a_prompt
|
||||||
|
image = self.pipe(prompt, negative_prompt=self.n_prompt).images[0]
|
||||||
|
image.save(image_filename)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"\nProcessed Text2Image, Input Text: {text}, Output Image:"
|
||||||
|
f" {image_filename}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return image_filename
|
||||||
|
|
||||||
|
|
||||||
|
class VisualQuestionAnswering:
|
||||||
|
def __init__(self, device):
|
||||||
|
print("Initializing VisualQuestionAnswering to %s" % device)
|
||||||
|
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
|
||||||
|
self.device = device
|
||||||
|
self.processor = BlipProcessor.from_pretrained(
|
||||||
|
"Salesforce/blip-vqa-base"
|
||||||
|
)
|
||||||
|
self.model = BlipForQuestionAnswering.from_pretrained(
|
||||||
|
"Salesforce/blip-vqa-base", torch_dtype=self.torch_dtype
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="Answer Question About The Image",
|
||||||
|
description=(
|
||||||
|
"useful when you need an answer for a question based on an image."
|
||||||
|
" like: what is the background color of the last image, how many"
|
||||||
|
" cats in this figure, what is in this figure. The input to this"
|
||||||
|
" tool should be a comma separated string of two, representing the"
|
||||||
|
" image_path and the question"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def inference(self, inputs):
|
||||||
|
image_path, question = inputs.split(",")
|
||||||
|
raw_image = Image.open(image_path).convert("RGB")
|
||||||
|
inputs = self.processor(raw_image, question, return_tensors="pt").to(
|
||||||
|
self.device, self.torch_dtype
|
||||||
|
)
|
||||||
|
out = self.model.generate(**inputs)
|
||||||
|
answer = self.processor.decode(out[0], skip_special_tokens=True)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"\nProcessed VisualQuestionAnswering, Input Image: {image_path},"
|
||||||
|
f" Input Question: {question}, Output Answer: {answer}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return answer
|
||||||
|
|
||||||
|
|
||||||
|
class ImageCaptioning(BaseHandler):
|
||||||
|
def __init__(self, device):
|
||||||
|
print("Initializing ImageCaptioning to %s" % device)
|
||||||
|
self.device = device
|
||||||
|
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
|
||||||
|
self.processor = BlipProcessor.from_pretrained(
|
||||||
|
"Salesforce/blip-image-captioning-base"
|
||||||
|
)
|
||||||
|
self.model = BlipForConditionalGeneration.from_pretrained(
|
||||||
|
"Salesforce/blip-image-captioning-base",
|
||||||
|
torch_dtype=self.torch_dtype,
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
def handle(self, filename: str):
|
||||||
|
img = Image.open(filename)
|
||||||
|
width, height = img.size
|
||||||
|
ratio = min(512 / width, 512 / height)
|
||||||
|
width_new, height_new = (round(width * ratio), round(height * ratio))
|
||||||
|
img = img.resize((width_new, height_new))
|
||||||
|
img = img.convert("RGB")
|
||||||
|
img.save(filename, "PNG")
|
||||||
|
print(f"Resize image form {width}x{height} to {width_new}x{height_new}")
|
||||||
|
|
||||||
|
inputs = self.processor(Image.open(filename), return_tensors="pt").to(
|
||||||
|
self.device, self.torch_dtype
|
||||||
|
)
|
||||||
|
out = self.model.generate(**inputs)
|
||||||
|
description = self.processor.decode(out[0], skip_special_tokens=True)
|
||||||
|
print(
|
||||||
|
f"\nProcessed ImageCaptioning, Input Image: {filename}, Output"
|
||||||
|
f" Text: {description}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return IMAGE_PROMPT.format(filename=filename, description=description)
|
@ -0,0 +1,12 @@
|
|||||||
|
from concurrent import futures
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def execute_futures_dict(fs_dict: dict[str, futures.Future[T]]) -> dict[str, T]:
|
||||||
|
futures.wait(
|
||||||
|
fs_dict.values(), timeout=None, return_when=futures.ALL_COMPLETED
|
||||||
|
)
|
||||||
|
|
||||||
|
return {key: future.result() for key, future in fs_dict.items()}
|
@ -0,0 +1,40 @@
|
|||||||
|
import pytest
|
||||||
|
from langchain.base_language import BaseLanguageModel
|
||||||
|
|
||||||
|
from swarms.agents.omni_modal_agent import (
|
||||||
|
OmniModalAgent, # Replace `your_module_name` with the appropriate module name
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Mock objects or set up fixtures for dependent classes or external methods
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_llm():
|
||||||
|
# For this mock, we are assuming the BaseLanguageModel has a method named "process"
|
||||||
|
class MockLLM(BaseLanguageModel):
|
||||||
|
def process(self, input):
|
||||||
|
return "mock response"
|
||||||
|
|
||||||
|
return MockLLM()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def omni_agent(mock_llm):
|
||||||
|
return OmniModalAgent(mock_llm)
|
||||||
|
|
||||||
|
|
||||||
|
def test_omnimodalagent_initialization(omni_agent):
|
||||||
|
assert omni_agent.llm is not None, "LLM initialization failed"
|
||||||
|
assert len(omni_agent.tools) > 0, "Tools initialization failed"
|
||||||
|
|
||||||
|
|
||||||
|
def test_omnimodalagent_run(omni_agent):
|
||||||
|
input_string = "Hello, how are you?"
|
||||||
|
response = omni_agent.run(input_string)
|
||||||
|
assert response is not None, "Response generation failed"
|
||||||
|
assert isinstance(response, str), "Response should be a string"
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_executor_initialization(omni_agent):
|
||||||
|
assert (
|
||||||
|
omni_agent.task_executor is not None
|
||||||
|
), "TaskExecutor initialization failed"
|
@ -0,0 +1,97 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
from swarms.memory.ocean import OceanDB
|
||||||
|
|
||||||
|
|
||||||
|
def test_init():
|
||||||
|
with patch("oceandb.Client") as MockClient:
|
||||||
|
MockClient.return_value.heartbeat.return_value = "OK"
|
||||||
|
db = OceanDB(MockClient)
|
||||||
|
MockClient.assert_called_once()
|
||||||
|
assert db.client == MockClient
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_exception():
|
||||||
|
with patch("oceandb.Client") as MockClient:
|
||||||
|
MockClient.side_effect = Exception("Client error")
|
||||||
|
with pytest.raises(Exception) as e:
|
||||||
|
OceanDB(MockClient)
|
||||||
|
assert str(e.value) == "Client error"
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_collection():
|
||||||
|
with patch("oceandb.Client") as MockClient:
|
||||||
|
db = OceanDB(MockClient)
|
||||||
|
db.create_collection("test", "modality")
|
||||||
|
MockClient.create_collection.assert_called_once_with(
|
||||||
|
"test", embedding_function=Mock.ANY
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_collection_exception():
|
||||||
|
with patch("oceandb.Client") as MockClient:
|
||||||
|
MockClient.create_collection.side_effect = Exception(
|
||||||
|
"Create collection error"
|
||||||
|
)
|
||||||
|
db = OceanDB(MockClient)
|
||||||
|
with pytest.raises(Exception) as e:
|
||||||
|
db.create_collection("test", "modality")
|
||||||
|
assert str(e.value) == "Create collection error"
|
||||||
|
|
||||||
|
|
||||||
|
def test_append_document():
|
||||||
|
with patch("oceandb.Client") as MockClient:
|
||||||
|
db = OceanDB(MockClient)
|
||||||
|
collection = Mock()
|
||||||
|
db.append_document(collection, "doc", "id")
|
||||||
|
collection.add.assert_called_once_with(documents=["doc"], ids=["id"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_append_document_exception():
|
||||||
|
with patch("oceandb.Client") as MockClient:
|
||||||
|
db = OceanDB(MockClient)
|
||||||
|
collection = Mock()
|
||||||
|
collection.add.side_effect = Exception("Append document error")
|
||||||
|
with pytest.raises(Exception) as e:
|
||||||
|
db.append_document(collection, "doc", "id")
|
||||||
|
assert str(e.value) == "Append document error"
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_documents():
|
||||||
|
with patch("oceandb.Client") as MockClient:
|
||||||
|
db = OceanDB(MockClient)
|
||||||
|
collection = Mock()
|
||||||
|
db.add_documents(collection, ["doc1", "doc2"], ["id1", "id2"])
|
||||||
|
collection.add.assert_called_once_with(
|
||||||
|
documents=["doc1", "doc2"], ids=["id1", "id2"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_documents_exception():
|
||||||
|
with patch("oceandb.Client") as MockClient:
|
||||||
|
db = OceanDB(MockClient)
|
||||||
|
collection = Mock()
|
||||||
|
collection.add.side_effect = Exception("Add documents error")
|
||||||
|
with pytest.raises(Exception) as e:
|
||||||
|
db.add_documents(collection, ["doc1", "doc2"], ["id1", "id2"])
|
||||||
|
assert str(e.value) == "Add documents error"
|
||||||
|
|
||||||
|
|
||||||
|
def test_query():
|
||||||
|
with patch("oceandb.Client") as MockClient:
|
||||||
|
db = OceanDB(MockClient)
|
||||||
|
collection = Mock()
|
||||||
|
db.query(collection, ["query1", "query2"], 2)
|
||||||
|
collection.query.assert_called_once_with(
|
||||||
|
query_texts=["query1", "query2"], n_results=2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_exception():
|
||||||
|
with patch("oceandb.Client") as MockClient:
|
||||||
|
db = OceanDB(MockClient)
|
||||||
|
collection = Mock()
|
||||||
|
collection.query.side_effect = Exception("Query error")
|
||||||
|
with pytest.raises(Exception) as e:
|
||||||
|
db.query(collection, ["query1", "query2"], 2)
|
||||||
|
assert str(e.value) == "Query error"
|
@ -0,0 +1,61 @@
|
|||||||
|
import unittest
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Assuming the BingChat class is in a file named "bing_chat.py"
|
||||||
|
from bing_chat import BingChat
|
||||||
|
|
||||||
|
|
||||||
|
class TestBingChat(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
# Path to a mock cookies file for testing
|
||||||
|
self.mock_cookies_path = "./mock_cookies.json"
|
||||||
|
with open(self.mock_cookies_path, "w") as file:
|
||||||
|
json.dump({"mock_cookie": "mock_value"}, file)
|
||||||
|
|
||||||
|
self.chat = BingChat(cookies_path=self.mock_cookies_path)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
os.remove(self.mock_cookies_path)
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
self.assertIsInstance(self.chat, BingChat)
|
||||||
|
self.assertIsNotNone(self.chat.bot)
|
||||||
|
|
||||||
|
def test_call(self):
|
||||||
|
# Mocking the asynchronous behavior for the purpose of the test
|
||||||
|
self.chat.bot.ask = lambda *args, **kwargs: {"text": "Hello, Test!"}
|
||||||
|
response = self.chat("Test prompt")
|
||||||
|
self.assertEqual(response, "Hello, Test!")
|
||||||
|
|
||||||
|
def test_create_img(self):
|
||||||
|
# Mocking the ImageGen behavior for the purpose of the test
|
||||||
|
class MockImageGen:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_images(self, *args, **kwargs):
|
||||||
|
return [{"path": "mock_image.png"}]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_images(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
original_image_gen = BingChat.ImageGen
|
||||||
|
BingChat.ImageGen = MockImageGen
|
||||||
|
|
||||||
|
img_path = self.chat.create_img(
|
||||||
|
"Test prompt", auth_cookie="mock_auth_cookie"
|
||||||
|
)
|
||||||
|
self.assertEqual(img_path, "./output/mock_image.png")
|
||||||
|
|
||||||
|
BingChat.ImageGen = original_image_gen
|
||||||
|
|
||||||
|
def test_set_cookie_dir_path(self):
|
||||||
|
test_path = "./test_path"
|
||||||
|
BingChat.set_cookie_dir_path(test_path)
|
||||||
|
self.assertEqual(BingChat.Cookie.dir_path, test_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
@ -0,0 +1,414 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from requests.exceptions import (
|
||||||
|
ConnectionError,
|
||||||
|
HTTPError,
|
||||||
|
RequestException,
|
||||||
|
Timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
from swarms.models.gpt4v import GPT4Vision, GPT4VisionResponse
|
||||||
|
|
||||||
|
load_dotenv
|
||||||
|
|
||||||
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
|
||||||
|
# Mock the OpenAI client
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_openai_client():
|
||||||
|
return Mock()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def gpt4vision(mock_openai_client):
|
||||||
|
return GPT4Vision(client=mock_openai_client)
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt4vision_default_values():
|
||||||
|
# Arrange and Act
|
||||||
|
gpt4vision = GPT4Vision()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert gpt4vision.max_retries == 3
|
||||||
|
assert gpt4vision.model == "gpt-4-vision-preview"
|
||||||
|
assert gpt4vision.backoff_factor == 2.0
|
||||||
|
assert gpt4vision.timeout_seconds == 10
|
||||||
|
assert gpt4vision.api_key is None
|
||||||
|
assert gpt4vision.quality == "low"
|
||||||
|
assert gpt4vision.max_tokens == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt4vision_api_key_from_env_variable():
|
||||||
|
# Arrange
|
||||||
|
api_key = os.environ["OPENAI_API_KEY"]
|
||||||
|
|
||||||
|
# Act
|
||||||
|
gpt4vision = GPT4Vision()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert gpt4vision.api_key == api_key
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt4vision_set_api_key():
|
||||||
|
# Arrange
|
||||||
|
gpt4vision = GPT4Vision(api_key=api_key)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert gpt4vision.api_key == api_key
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt4vision_invalid_max_retries():
|
||||||
|
# Arrange and Act
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
GPT4Vision(max_retries=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt4vision_invalid_backoff_factor():
|
||||||
|
# Arrange and Act
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
GPT4Vision(backoff_factor=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt4vision_invalid_timeout_seconds():
|
||||||
|
# Arrange and Act
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
GPT4Vision(timeout_seconds=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt4vision_invalid_max_tokens():
|
||||||
|
# Arrange and Act
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
GPT4Vision(max_tokens=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt4vision_logger_initialized():
|
||||||
|
# Arrange
|
||||||
|
gpt4vision = GPT4Vision()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(gpt4vision.logger, logging.Logger)
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt4vision_process_img_nonexistent_file():
|
||||||
|
# Arrange
|
||||||
|
gpt4vision = GPT4Vision()
|
||||||
|
img_path = "nonexistent_image.jpg"
|
||||||
|
|
||||||
|
# Act and Assert
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
gpt4vision.process_img(img_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt4vision_call_single_task_single_image_no_openai_client(gpt4vision):
|
||||||
|
# Arrange
|
||||||
|
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||||
|
task = "Describe this image."
|
||||||
|
|
||||||
|
# Act and Assert
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
gpt4vision(img_url, [task])
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt4vision_call_single_task_single_image_empty_response(
|
||||||
|
gpt4vision, mock_openai_client
|
||||||
|
):
|
||||||
|
# Arrange
|
||||||
|
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||||
|
task = "Describe this image."
|
||||||
|
|
||||||
|
mock_openai_client.chat.completions.create.return_value.choices = []
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = gpt4vision(img_url, [task])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response.answer == ""
|
||||||
|
mock_openai_client.chat.completions.create.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt4vision_call_multiple_tasks_single_image_empty_responses(
|
||||||
|
gpt4vision, mock_openai_client
|
||||||
|
):
|
||||||
|
# Arrange
|
||||||
|
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||||
|
tasks = ["Describe this image.", "What's in this picture?"]
|
||||||
|
|
||||||
|
mock_openai_client.chat.completions.create.return_value.choices = []
|
||||||
|
|
||||||
|
# Act
|
||||||
|
responses = gpt4vision(img_url, tasks)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert all(response.answer == "" for response in responses)
|
||||||
|
assert (
|
||||||
|
mock_openai_client.chat.completions.create.call_count == 1
|
||||||
|
) # Should be called only once
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt4vision_call_single_task_single_image_timeout(
|
||||||
|
gpt4vision, mock_openai_client
|
||||||
|
):
|
||||||
|
# Arrange
|
||||||
|
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||||
|
task = "Describe this image."
|
||||||
|
|
||||||
|
mock_openai_client.chat.completions.create.side_effect = Timeout(
|
||||||
|
"Request timed out"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act and Assert
|
||||||
|
with pytest.raises(Timeout):
|
||||||
|
gpt4vision(img_url, [task])
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt4vision_call_retry_with_success_after_timeout(
|
||||||
|
gpt4vision, mock_openai_client
|
||||||
|
):
|
||||||
|
# Arrange
|
||||||
|
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||||
|
task = "Describe this image."
|
||||||
|
|
||||||
|
# Simulate success after a timeout and retry
|
||||||
|
mock_openai_client.chat.completions.create.side_effect = [
|
||||||
|
Timeout("Request timed out"),
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"message": {
|
||||||
|
"content": {"text": "A description of the image."}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = gpt4vision(img_url, [task])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response.answer == "A description of the image."
|
||||||
|
assert (
|
||||||
|
mock_openai_client.chat.completions.create.call_count == 2
|
||||||
|
) # Should be called twice
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt4vision_process_img():
|
||||||
|
# Arrange
|
||||||
|
img_path = "test_image.jpg"
|
||||||
|
gpt4vision = GPT4Vision()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
img_data = gpt4vision.process_img(img_path)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert img_data.startswith("/9j/") # Base64-encoded image data
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt4vision_call_single_task_single_image(
|
||||||
|
gpt4vision, mock_openai_client
|
||||||
|
):
|
||||||
|
# Arrange
|
||||||
|
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||||
|
task = "Describe this image."
|
||||||
|
|
||||||
|
expected_response = GPT4VisionResponse(answer="A description of the image.")
|
||||||
|
|
||||||
|
mock_openai_client.chat.completions.create.return_value.choices[0].text = (
|
||||||
|
expected_response.answer
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = gpt4vision(img_url, [task])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response == expected_response
|
||||||
|
mock_openai_client.chat.completions.create.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt4vision_call_single_task_multiple_images(
|
||||||
|
gpt4vision, mock_openai_client
|
||||||
|
):
|
||||||
|
# Arrange
|
||||||
|
img_urls = [
|
||||||
|
"https://example.com/image1.jpg",
|
||||||
|
"https://example.com/image2.jpg",
|
||||||
|
]
|
||||||
|
task = "Describe these images."
|
||||||
|
|
||||||
|
expected_response = GPT4VisionResponse(answer="Descriptions of the images.")
|
||||||
|
|
||||||
|
mock_openai_client.chat.completions.create.return_value.choices[0].text = (
|
||||||
|
expected_response.answer
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = gpt4vision(img_urls, [task])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response == expected_response
|
||||||
|
mock_openai_client.chat.completions.create.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt4vision_call_multiple_tasks_single_image(
|
||||||
|
gpt4vision, mock_openai_client
|
||||||
|
):
|
||||||
|
# Arrange
|
||||||
|
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||||
|
tasks = ["Describe this image.", "What's in this picture?"]
|
||||||
|
|
||||||
|
expected_responses = [
|
||||||
|
GPT4VisionResponse(answer="A description of the image."),
|
||||||
|
GPT4VisionResponse(answer="It contains various objects."),
|
||||||
|
]
|
||||||
|
|
||||||
|
def create_mock_response(response):
|
||||||
|
return {
|
||||||
|
"choices": [{"message": {"content": {"text": response.answer}}}]
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_openai_client.chat.completions.create.side_effect = [
|
||||||
|
create_mock_response(response) for response in expected_responses
|
||||||
|
]
|
||||||
|
|
||||||
|
# Act
|
||||||
|
responses = gpt4vision(img_url, tasks)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert responses == expected_responses
|
||||||
|
assert (
|
||||||
|
mock_openai_client.chat.completions.create.call_count == 1
|
||||||
|
) # Should be called only once
|
||||||
|
|
||||||
|
def test_gpt4vision_call_multiple_tasks_single_image(
|
||||||
|
gpt4vision, mock_openai_client
|
||||||
|
):
|
||||||
|
# Arrange
|
||||||
|
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||||
|
tasks = ["Describe this image.", "What's in this picture?"]
|
||||||
|
|
||||||
|
expected_responses = [
|
||||||
|
GPT4VisionResponse(answer="A description of the image."),
|
||||||
|
GPT4VisionResponse(answer="It contains various objects."),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_openai_client.chat.completions.create.side_effect = [
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"message": {
|
||||||
|
"content": {"text": expected_responses[i].answer}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
for i in range(len(expected_responses))
|
||||||
|
]
|
||||||
|
|
||||||
|
# Act
|
||||||
|
responses = gpt4vision(img_url, tasks)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert responses == expected_responses
|
||||||
|
assert (
|
||||||
|
mock_openai_client.chat.completions.create.call_count == 1
|
||||||
|
) # Should be called only once
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt4vision_call_multiple_tasks_multiple_images(
|
||||||
|
gpt4vision, mock_openai_client
|
||||||
|
):
|
||||||
|
# Arrange
|
||||||
|
img_urls = [
|
||||||
|
"https://images.unsplash.com/photo-1694734479857-626882b6db37?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
|
||||||
|
"https://images.unsplash.com/photo-1694734479898-6ac4633158ac?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
|
||||||
|
]
|
||||||
|
tasks = ["Describe these images.", "What's in these pictures?"]
|
||||||
|
|
||||||
|
expected_responses = [
|
||||||
|
GPT4VisionResponse(answer="Descriptions of the images."),
|
||||||
|
GPT4VisionResponse(answer="They contain various objects."),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_openai_client.chat.completions.create.side_effect = [
|
||||||
|
{"choices": [{"message": {"content": {"text": response.answer}}}]}
|
||||||
|
for response in expected_responses
|
||||||
|
]
|
||||||
|
|
||||||
|
# Act
|
||||||
|
responses = gpt4vision(img_urls, tasks)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert responses == expected_responses
|
||||||
|
assert (
|
||||||
|
mock_openai_client.chat.completions.create.call_count == 1
|
||||||
|
) # Should be called only once
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt4vision_call_http_error(gpt4vision, mock_openai_client):
|
||||||
|
# Arrange
|
||||||
|
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||||
|
task = "Describe this image."
|
||||||
|
|
||||||
|
mock_openai_client.chat.completions.create.side_effect = HTTPError(
|
||||||
|
"HTTP Error"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act and Assert
|
||||||
|
with pytest.raises(HTTPError):
|
||||||
|
gpt4vision(img_url, [task])
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt4vision_call_request_error(gpt4vision, mock_openai_client):
|
||||||
|
# Arrange
|
||||||
|
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||||
|
task = "Describe this image."
|
||||||
|
|
||||||
|
mock_openai_client.chat.completions.create.side_effect = RequestException(
|
||||||
|
"Request Error"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act and Assert
|
||||||
|
with pytest.raises(RequestException):
|
||||||
|
gpt4vision(img_url, [task])
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt4vision_call_connection_error(gpt4vision, mock_openai_client):
|
||||||
|
# Arrange
|
||||||
|
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||||
|
task = "Describe this image."
|
||||||
|
|
||||||
|
mock_openai_client.chat.completions.create.side_effect = ConnectionError(
|
||||||
|
"Connection Error"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act and Assert
|
||||||
|
with pytest.raises(ConnectionError):
|
||||||
|
gpt4vision(img_url, [task])
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt4vision_call_retry_with_success(gpt4vision, mock_openai_client):
|
||||||
|
# Arrange
|
||||||
|
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||||
|
task = "Describe this image."
|
||||||
|
|
||||||
|
# Simulate success after a retry
|
||||||
|
mock_openai_client.chat.completions.create.side_effect = [
|
||||||
|
RequestException("Temporary error"),
|
||||||
|
{
|
||||||
|
"choices": [{"text": "A description of the image."}]
|
||||||
|
}, # fixed dictionary syntax
|
||||||
|
]
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = gpt4vision(img_url, [task])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response.answer == "A description of the image."
|
||||||
|
assert (
|
||||||
|
mock_openai_client.chat.completions.create.call_count == 2
|
||||||
|
) # Should be called twice
|
@ -0,0 +1,90 @@
|
|||||||
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
from Sswarms.models.revgptv1 import RevChatGPTModelv1
|
||||||
|
|
||||||
|
|
||||||
|
class TestRevChatGPT(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.access_token = "<your_access_token>"
|
||||||
|
self.model = RevChatGPTModelv1(access_token=self.access_token)
|
||||||
|
|
||||||
|
def test_run(self):
|
||||||
|
prompt = "What is the capital of France?"
|
||||||
|
response = self.model.run(prompt)
|
||||||
|
self.assertEqual(response, "The capital of France is Paris.")
|
||||||
|
|
||||||
|
def test_run_time(self):
|
||||||
|
prompt = "Generate a 300 word essay about technology."
|
||||||
|
self.model.run(prompt)
|
||||||
|
self.assertLess(self.model.end_time - self.model.start_time, 60)
|
||||||
|
|
||||||
|
def test_generate_summary(self):
|
||||||
|
text = (
|
||||||
|
"This is a sample text to summarize. It has multiple sentences and"
|
||||||
|
" details. The summary should be concise."
|
||||||
|
)
|
||||||
|
summary = self.model.generate_summary(text)
|
||||||
|
self.assertLess(len(summary), len(text) / 2)
|
||||||
|
|
||||||
|
def test_enable_plugin(self):
|
||||||
|
plugin_id = "some_plugin_id"
|
||||||
|
self.model.enable_plugin(plugin_id)
|
||||||
|
self.assertIn(plugin_id, self.model.config["plugin_ids"])
|
||||||
|
|
||||||
|
def test_list_plugins(self):
|
||||||
|
plugins = self.model.list_plugins()
|
||||||
|
self.assertGreater(len(plugins), 0)
|
||||||
|
self.assertIsInstance(plugins[0], dict)
|
||||||
|
self.assertIn("id", plugins[0])
|
||||||
|
self.assertIn("name", plugins[0])
|
||||||
|
|
||||||
|
def test_get_conversations(self):
|
||||||
|
conversations = self.model.chatbot.get_conversations()
|
||||||
|
self.assertIsInstance(conversations, list)
|
||||||
|
|
||||||
|
@patch("RevChatGPTModelv1.Chatbot.get_msg_history")
|
||||||
|
def test_get_msg_history(self, mock_get_msg_history):
|
||||||
|
conversation_id = "convo_id"
|
||||||
|
self.model.chatbot.get_msg_history(conversation_id)
|
||||||
|
mock_get_msg_history.assert_called_with(conversation_id)
|
||||||
|
|
||||||
|
@patch("RevChatGPTModelv1.Chatbot.share_conversation")
|
||||||
|
def test_share_conversation(self, mock_share_conversation):
|
||||||
|
self.model.chatbot.share_conversation()
|
||||||
|
mock_share_conversation.assert_called()
|
||||||
|
|
||||||
|
def test_gen_title(self):
|
||||||
|
convo_id = "123"
|
||||||
|
message_id = "456"
|
||||||
|
title = self.model.chatbot.gen_title(convo_id, message_id)
|
||||||
|
self.assertIsInstance(title, str)
|
||||||
|
|
||||||
|
def test_change_title(self):
|
||||||
|
convo_id = "123"
|
||||||
|
title = "New Title"
|
||||||
|
self.model.chatbot.change_title(convo_id, title)
|
||||||
|
self.assertEqual(
|
||||||
|
self.model.chatbot.get_msg_history(convo_id)["title"], title
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_delete_conversation(self):
|
||||||
|
convo_id = "123"
|
||||||
|
self.model.chatbot.delete_conversation(convo_id)
|
||||||
|
with self.assertRaises(Exception):
|
||||||
|
self.model.chatbot.get_msg_history(convo_id)
|
||||||
|
|
||||||
|
def test_clear_conversations(self):
|
||||||
|
self.model.chatbot.clear_conversations()
|
||||||
|
conversations = self.model.chatbot.get_conversations()
|
||||||
|
self.assertEqual(len(conversations), 0)
|
||||||
|
|
||||||
|
def test_rollback_conversation(self):
|
||||||
|
original_convo_id = self.model.chatbot.conversation_id
|
||||||
|
self.model.chatbot.rollback_conversation(1)
|
||||||
|
self.assertNotEqual(
|
||||||
|
original_convo_id, self.model.chatbot.conversation_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
@ -0,0 +1,66 @@
|
|||||||
|
from unittest.mock import patch
|
||||||
|
from swarms.swarms.multi_agent_debate import (
|
||||||
|
MultiAgentDebate,
|
||||||
|
Worker,
|
||||||
|
select_speaker,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiagentdebate_initialization():
|
||||||
|
multiagentdebate = MultiAgentDebate(
|
||||||
|
agents=[Worker] * 5, selection_func=select_speaker
|
||||||
|
)
|
||||||
|
assert isinstance(multiagentdebate, MultiAgentDebate)
|
||||||
|
assert len(multiagentdebate.agents) == 5
|
||||||
|
assert multiagentdebate.selection_func == select_speaker
|
||||||
|
|
||||||
|
|
||||||
|
@patch("swarms.workers.Worker.reset")
|
||||||
|
def test_multiagentdebate_reset_agents(mock_reset):
|
||||||
|
multiagentdebate = MultiAgentDebate(
|
||||||
|
agents=[Worker] * 5, selection_func=select_speaker
|
||||||
|
)
|
||||||
|
multiagentdebate.reset_agents()
|
||||||
|
assert mock_reset.call_count == 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiagentdebate_inject_agent():
|
||||||
|
multiagentdebate = MultiAgentDebate(
|
||||||
|
agents=[Worker] * 5, selection_func=select_speaker
|
||||||
|
)
|
||||||
|
multiagentdebate.inject_agent(Worker)
|
||||||
|
assert len(multiagentdebate.agents) == 6
|
||||||
|
|
||||||
|
|
||||||
|
@patch("swarms.workers.Worker.run")
|
||||||
|
def test_multiagentdebate_run(mock_run):
|
||||||
|
multiagentdebate = MultiAgentDebate(
|
||||||
|
agents=[Worker] * 5, selection_func=select_speaker
|
||||||
|
)
|
||||||
|
results = multiagentdebate.run("Write a short story.")
|
||||||
|
assert len(results) == 5
|
||||||
|
assert mock_run.call_count == 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiagentdebate_update_task():
|
||||||
|
multiagentdebate = MultiAgentDebate(
|
||||||
|
agents=[Worker] * 5, selection_func=select_speaker
|
||||||
|
)
|
||||||
|
multiagentdebate.update_task("Write a short story.")
|
||||||
|
assert multiagentdebate.task == "Write a short story."
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiagentdebate_format_results():
|
||||||
|
multiagentdebate = MultiAgentDebate(
|
||||||
|
agents=[Worker] * 5, selection_func=select_speaker
|
||||||
|
)
|
||||||
|
results = [
|
||||||
|
{"agent": "Agent 1", "response": "Hello, world!"},
|
||||||
|
{"agent": "Agent 2", "response": "Goodbye, world!"},
|
||||||
|
]
|
||||||
|
formatted_results = multiagentdebate.format_results(results)
|
||||||
|
assert (
|
||||||
|
formatted_results
|
||||||
|
== "Agent Agent 1 responded: Hello, world!\nAgent Agent 2 responded:"
|
||||||
|
" Goodbye, world!"
|
||||||
|
)
|
Loading…
Reference in new issue