DictInternalMemory
    DictSharedMemory
    LangchainChromaVectorMemory
    synchronized_queue
    TaskQueueBase]
pull/386/head
Kye 11 months ago
parent 76a140508f
commit f2912babc5

323
Cargo.lock generated

@ -0,0 +1,323 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "autocfg"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "bitflags"
version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "cfg-if"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "crossbeam-deque"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d"
dependencies = [
"crossbeam-epoch",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-epoch"
version = "0.9.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345"
[[package]]
name = "either"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a"
[[package]]
name = "engine"
version = "0.1.0"
dependencies = [
"log",
"pyo3",
"rayon",
]
[[package]]
name = "indoc"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47741a8bc60fb26eb8d6e0238bbb26d8575ff623fdc97b1a2c00c050b9684ed8"
dependencies = [
"indoc-impl",
"proc-macro-hack",
]
[[package]]
name = "indoc-impl"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce046d161f000fffde5f432a0d034d0341dc152643b2598ed5bfce44c4f3a8f0"
dependencies = [
"proc-macro-hack",
"proc-macro2",
"quote",
"syn",
"unindent",
]
[[package]]
name = "instant"
version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c"
dependencies = [
"cfg-if",
]
[[package]]
name = "libc"
version = "0.2.153"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd"
[[package]]
name = "lock_api"
version = "0.4.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45"
dependencies = [
"autocfg",
"scopeguard",
]
[[package]]
name = "log"
version = "0.4.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f"
[[package]]
name = "once_cell"
version = "1.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
[[package]]
name = "parking_lot"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99"
dependencies = [
"instant",
"lock_api",
"parking_lot_core",
]
[[package]]
name = "parking_lot_core"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60a2cfe6f0ad2bfc16aefa463b497d5c7a5ecd44a23efa72aa342d90177356dc"
dependencies = [
"cfg-if",
"instant",
"libc",
"redox_syscall",
"smallvec",
"winapi",
]
[[package]]
name = "paste"
version = "0.1.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "45ca20c77d80be666aef2b45486da86238fabe33e38306bd3118fe4af33fa880"
dependencies = [
"paste-impl",
"proc-macro-hack",
]
[[package]]
name = "paste-impl"
version = "0.1.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d95a7db200b97ef370c8e6de0088252f7e0dfff7d047a28528e47456c0fc98b6"
dependencies = [
"proc-macro-hack",
]
[[package]]
name = "proc-macro-hack"
version = "0.5.20+deprecated"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc375e1527247fe1a97d8b7156678dfe7c1af2fc075c9a4db3690ecd2a148068"
[[package]]
name = "proc-macro2"
version = "1.0.78"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae"
dependencies = [
"unicode-ident",
]
[[package]]
name = "pyo3"
version = "0.15.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d41d50a7271e08c7c8a54cd24af5d62f73ee3a6f6a314215281ebdec421d5752"
dependencies = [
"cfg-if",
"indoc",
"libc",
"parking_lot",
"paste",
"pyo3-build-config",
"pyo3-macros",
"unindent",
]
[[package]]
name = "pyo3-build-config"
version = "0.15.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "779239fc40b8e18bc8416d3a37d280ca9b9fb04bda54b98037bb6748595c2410"
dependencies = [
"once_cell",
]
[[package]]
name = "pyo3-macros"
version = "0.15.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00b247e8c664be87998d8628e86f282c25066165f1f8dda66100c48202fdb93a"
dependencies = [
"pyo3-macros-backend",
"quote",
"syn",
]
[[package]]
name = "pyo3-macros-backend"
version = "0.15.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a8c2812c412e00e641d99eeb79dd478317d981d938aa60325dfa7157b607095"
dependencies = [
"proc-macro2",
"pyo3-build-config",
"quote",
"syn",
]
[[package]]
name = "quote"
version = "1.0.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef"
dependencies = [
"proc-macro2",
]
[[package]]
name = "rayon"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa7237101a77a10773db45d62004a272517633fbcc3df19d96455ede1122e051"
dependencies = [
"either",
"rayon-core",
]
[[package]]
name = "rayon-core"
version = "1.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2"
dependencies = [
"crossbeam-deque",
"crossbeam-utils",
]
[[package]]
name = "redox_syscall"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a"
dependencies = [
"bitflags",
]
[[package]]
name = "scopeguard"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "smallvec"
version = "1.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7"
[[package]]
name = "syn"
version = "1.0.109"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "unicode-ident"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
[[package]]
name = "unindent"
version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c"
[[package]]
name = "winapi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
dependencies = [
"winapi-i686-pc-windows-gnu",
"winapi-x86_64-pc-windows-gnu",
]
[[package]]
name = "winapi-i686-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"

@ -5,13 +5,10 @@ edition = "2018"
[lib]
name = "engine"
path = "src/my_lib.rs"
path = "runtime/concurrent_exec.rs"
crate-type = ["cdylib"]
[dependencies]
pyo3 = { version = "0.15", features = ["extension-module"] }
rayon = "1.5.1"
log = "0.4.14"
rustcuda = "0.1.0"
rustcuda_derive = "*"
rustcuda_core = "0.1"

@ -1,10 +1,12 @@
[build-system]
requires = ["poetry-core>=1.0.0"]
requires = ["poetry-core>=1.0.0", "maturin"]
build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "swarms"
version = "4.1.4"
version = "4.1.5"
description = "Swarms - Pytorch"
license = "MIT"
authors = ["Kye Gomez <kye@apac.ai>"]

@ -1,7 +1,7 @@
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
use pyo3::types::IntoPyDict;
use rayon::{ThreadPool, ThreadPoolBuilder, prelude::*};
use rayon::{ThreadPool, ThreadPoolBuilder};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use std::thread;
@ -13,16 +13,6 @@ fn rust_module(py: Python, m: &PyModule) -> PyResult<()> {
Ok(())
}
#[pyfunction]
pub fn concurrent_exec<F, G, H>(
py_codes: Vec<&str>,
timeout: Option<Duration>,
num_threads: usize,
error_handler: F,
log_function: G,
result_handler: H,
) -> PyResult<Vec<PyResult<()>>>
/// This function wraps Python code in Rust concurrency for ultra high performance.
///
/// # Arguments
@ -45,6 +35,16 @@ pub fn concurrent_exec<F, G, H>(
/// let result_handler = |r| println!("Result: {:?}", r);
/// execute_python_codes(py_codes, timeout, num_threads, error_handler, log_function, result_handler);
/// ```
#[pyfunction]
pub fn concurrent_exec<F, G, H>(
py_codes: Vec<&str>,
timeout: Option<Duration>,
num_threads: usize,
error_handler: F,
log_function: G,
result_handler: H,
) -> PyResult<Vec<PyResult<()>>>
where
F: Fn(&str),
G: Fn(&str),
@ -83,7 +83,7 @@ where
None => {}
}
results.lock().unwrap().push(result.clone());
results.lock().unwrap().push(result.clone(result));
result_handler(&result);
});
});

@ -6,6 +6,9 @@ from swarms.memory.weaviate_db import WeaviateDB
from swarms.memory.visual_memory import VisualShortTermMemory
from swarms.memory.action_subtask import ActionSubtaskEntry
from swarms.memory.chroma_db import ChromaDB
from swarms.memory.dict_internal_memory import DictInternalMemory
from swarms.memory.dict_shared_memory import DictSharedMemory
from swarms.memory.lanchain_chroma import LangchainChromaVectorMemory
__all__ = [
"AbstractVectorDatabase",
@ -16,4 +19,7 @@ __all__ = [
"VisualShortTermMemory",
"ActionSubtaskEntry",
"ChromaDB",
"DictInternalMemory",
"DictSharedMemory",
"LangchainChromaVectorMemory",
]

@ -0,0 +1,86 @@
import uuid
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Tuple
class InternalMemoryBase(ABC):
"""Abstract base class for internal memory of agents in the swarm."""
def __init__(self, n_entries):
"""Initialize the internal memory. In the current architecture the memory always consists of a set of soltuions or evaluations.
During the operation, the agent should retrivie best solutions from it's internal memory based on the score.
Moreover, the project is designed around LLMs for the proof of concepts, so we treat all entry content as a string.
"""
self.n_entries = n_entries
@abstractmethod
def add(self, score, entry):
"""Add an entry to the internal memory."""
raise NotImplementedError
@abstractmethod
def get_top_n(self, n):
"""Get the top n entries from the internal memory."""
raise NotImplementedError
class DictInternalMemory(InternalMemoryBase):
def __init__(self, n_entries: int):
"""
Initialize the internal memory. In the current architecture the memory always consists of a set of solutions or evaluations.
Simple key-value store for now.
Args:
n_entries (int): The maximum number of entries to keep in the internal memory.
"""
super().__init__(n_entries)
self.data: Dict[str, Dict[str, Any]] = {}
def add(self, score: float, content: Any) -> None:
"""
Add an entry to the internal memory.
Args:
score (float): The score or fitness value associated with the entry.
content (Any): The content of the entry.
Returns:
None
"""
random_key: str = str(uuid.uuid4())
self.data[random_key] = {"score": score, "content": content}
# keep only the best n entries
sorted_data: List[Tuple[str, Dict[str, Any]]] = sorted(
self.data.items(),
key=lambda x: x[1]["score"],
reverse=True,
)
self.data = dict(sorted_data[: self.n_entries])
def get_top_n(self, n: int) -> List[Tuple[str, Dict[str, Any]]]:
"""
Get the top n entries from the internal memory.
Args:
n (int): The number of top entries to retrieve.
Returns:
List[Tuple[str, Dict[str, Any]]]: A list of tuples containing the random keys and corresponding entry data.
"""
sorted_data: List[Tuple[str, Dict[str, Any]]] = sorted(
self.data.items(),
key=lambda x: x[1]["score"],
reverse=True,
)
return sorted_data[:n]
def len(self) -> int:
"""
Get the number of entries in the internal memory.
Returns:
int: The number of entries in the internal memory.
"""
return len(self.data)

@ -0,0 +1,98 @@
import datetime
import json
import os
import threading
import uuid
from pathlib import Path
from typing import Dict, Any
class DictSharedMemory:
"""A class representing a shared memory that stores entries as a dictionary.
Attributes:
file_loc (Path): The file location where the memory is stored.
lock (threading.Lock): A lock used for thread synchronization.
Methods:
__init__(self, file_loc: str = None) -> None: Initializes the shared memory.
add_entry(self, score: float, agent_id: str, agent_cycle: int, entry: Any) -> bool: Adds an entry to the internal memory.
get_top_n(self, n: int) -> None: Gets the top n entries from the internal memory.
write_to_file(self, data: Dict[str, Dict[str, Any]]) -> bool: Writes the internal memory to a file.
"""
def __init__(self, file_loc: str = None) -> None:
"""Initialize the shared memory. In the current architecture the memory always consists of a set of soltuions or evaluations.
Moreover, the project is designed around LLMs for the proof of concepts, so we treat all entry content as a string.
"""
if file_loc is not None:
self.file_loc = Path(file_loc)
if not self.file_loc.exists():
self.file_loc.touch()
self.lock = threading.Lock()
def add(
self,
score: float,
agent_id: str,
agent_cycle: int,
entry: Any,
) -> bool:
"""Add an entry to the internal memory."""
with self.lock:
entry_id = str(uuid.uuid4())
data = {}
epoch = datetime.datetime.utcfromtimestamp(0)
epoch = (
datetime.datetime.utcnow() - epoch
).total_seconds()
data[entry_id] = {
"agent": agent_id,
"epoch": epoch,
"score": score,
"cycle": agent_cycle,
"content": entry,
}
status = self.write_to_file(data)
self.plot_performance()
return status
def get_top_n(self, n: int) -> None:
"""Get the top n entries from the internal memory."""
with self.lock:
with open(self.file_loc, "r") as f:
try:
file_data = json.load(f)
except Exception as e:
file_data = {}
raise e
sorted_data = dict(
sorted(
file_data.items(),
key=lambda item: item[1]["score"],
reverse=True,
)
)
top_n = dict(list(sorted_data.items())[:n])
return top_n
def write_to_file(self, data: Dict[str, Dict[str, Any]]) -> bool:
"""Write the internal memory to a file."""
if self.file_loc is not None:
with open(self.file_loc, "r") as f:
try:
file_data = json.load(f)
except Exception as e:
file_data = {}
raise e
file_data = file_data | data
with open(self.file_loc, "w") as f:
json.dump(file_data, f, indent=4)
f.flush()
os.fsync(f.fileno())
return True

@ -0,0 +1,194 @@
import threading
from pathlib import Path
from langchain.chains import RetrievalQA
from langchain.chains.question_answering import load_qa_chain
from swarms.models.openai_models import OpenAIChat
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma
def synchronized_mem(method):
"""
Decorator that synchronizes access to a method using a lock.
Args:
method: The method to be decorated.
Returns:
The decorated method.
"""
def wrapper(self, *args, **kwargs):
with self.lock:
try:
return method(self, *args, **kwargs)
except Exception as e:
print(f"Failed to execute {method.__name__}: {e}")
return wrapper
class LangchainChromaVectorMemory:
"""
A class representing a vector memory for storing and retrieving text entries.
Attributes:
loc (str): The location of the vector memory.
chunk_size (int): The size of each text chunk.
chunk_overlap_frac (float): The fraction of overlap between text chunks.
embeddings (OpenAIEmbeddings): The embeddings used for text representation.
count (int): The current count of text entries in the vector memory.
lock (threading.Lock): A lock for thread safety.
db (Chroma): The Chroma database for storing text entries.
qa (RetrievalQA): The retrieval QA system for answering questions.
Methods:
__init__: Initializes the VectorMemory object.
_init_db: Initializes the Chroma database.
_init_retriever: Initializes the retrieval QA system.
add_entry: Adds an entry to the vector memory.
search_memory: Searches the vector memory for similar entries.
ask_question: Asks a question to the vector memory.
"""
def __init__(
self,
loc=None,
chunk_size: int = 1000,
chunk_overlap_frac: float = 0.1,
*args,
**kwargs,
):
"""
Initializes the VectorMemory object.
Args:
loc (str): The location of the vector memory. If None, defaults to "./tmp/vector_memory".
chunk_size (int): The size of each text chunk.
chunk_overlap_frac (float): The fraction of overlap between text chunks.
"""
if loc is None:
loc = "./tmp/vector_memory"
self.loc = Path(loc)
self.chunk_size = chunk_size
self.chunk_overlap = chunk_size * chunk_overlap_frac
self.embeddings = OpenAIEmbeddings()
self.count = 0
self.lock = threading.Lock()
self.db = self._init_db()
self.qa = self._init_retriever()
def _init_db(self):
"""
Initializes the Chroma database.
Returns:
Chroma: The initialized Chroma database.
"""
texts = [
"init"
] # TODO find how to initialize Chroma without any text
chroma_db = Chroma.from_texts(
texts=texts,
embedding=self.embeddings,
persist_directory=str(self.loc),
)
self.count = chroma_db._collection.count()
return chroma_db
def _init_retriever(self):
"""
Initializes the retrieval QA system.
Returns:
RetrievalQA: The initialized retrieval QA system.
"""
model = OpenAIChat(
model_name="gpt-3.5-turbo",
)
qa_chain = load_qa_chain(model, chain_type="stuff")
retriever = self.db.as_retriever(
search_type="mmr", search_kwargs={"k": 10}
)
qa = RetrievalQA(
combine_documents_chain=qa_chain, retriever=retriever
)
return qa
@synchronized_mem
def add(self, entry: str):
"""
Add an entry to the internal memory.
Args:
entry (str): The entry to be added.
Returns:
bool: True if the entry was successfully added, False otherwise.
"""
text_splitter = CharacterTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
separator=" ",
)
texts = text_splitter.split_text(entry)
self.db.add_texts(texts)
self.count += self.db._collection.count()
self.db.persist()
return True
@synchronized_mem
def search_memory(
self, query: str, k=10, type="mmr", distance_threshold=0.5
):
"""
Searching the vector memory for similar entries.
Args:
query (str): The query to search for.
k (int): The number of results to return.
type (str): The type of search to perform: "cos" or "mmr".
distance_threshold (float): The similarity threshold to use for the search. Results with distance > similarity_threshold will be dropped.
Returns:
list[str]: A list of the top k results.
"""
self.count = self.db._collection.count()
if k > self.count:
k = self.count - 1
if k <= 0:
return None
if type == "mmr":
texts = self.db.max_marginal_relevance_search(
query=query, k=k, fetch_k=min(20, self.count)
)
texts = [text.page_content for text in texts]
elif type == "cos":
texts = self.db.similarity_search_with_score(
query=query, k=k
)
texts = [
text[0].page_content
for text in texts
if text[-1] < distance_threshold
]
return texts
@synchronized_mem
def query(self, question: str):
"""
Ask a question to the vector memory.
Args:
question (str): The question to ask.
Returns:
str: The answer to the question.
"""
answer = self.qa.run(question)
return answer

@ -17,7 +17,6 @@ from swarms.structs.recursive_workflow import RecursiveWorkflow
from swarms.structs.schemas import (
Artifact,
ArtifactUpload,
Step,
StepInput,
StepOutput,
StepRequestBody,

@ -1,85 +0,0 @@
from dataclasses import dataclass
from typing import List, Optional
from swarms.memory.base_vectordatabase import AbstractVectorDatabase
from swarms.structs.agent import Agent
@dataclass
class MultiAgentRag:
"""
Represents a multi-agent RAG (Relational Agent Graph) structure.
Attributes:
agents (List[Agent]): List of agents in the multi-agent RAG.
db (AbstractVectorDatabase): Database used for querying.
verbose (bool): Flag indicating whether to print verbose output.
"""
agents: List[Agent]
db: AbstractVectorDatabase
verbose: bool = False
def query_database(self, query: str):
"""
Queries the database using the given query string.
Args:
query (str): The query string.
Returns:
List: The list of results from the database.
"""
results = []
for agent in self.agents:
agent_results = agent.long_term_memory_prompt(query)
results.extend(agent_results)
return results
def get_agent_by_id(self, agent_id) -> Optional[Agent]:
"""
Retrieves an agent from the multi-agent RAG by its ID.
Args:
agent_id: The ID of the agent to retrieve.
Returns:
Agent or None: The agent with the specified ID, or None if not found.
"""
for agent in self.agents:
if agent.agent_id == agent_id:
return agent
return None
def add_message(
self, sender: Agent, message: str, *args, **kwargs
):
"""
Adds a message to the database.
Args:
sender (Agent): The agent sending the message.
message (str): The message to add.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
int: The ID of the added message.
"""
doc = f"{sender.ai_name}: {message}"
return self.db.add(doc)
def query(self, message: str, *args, **kwargs):
"""
Queries the database using the given message.
Args:
message (str): The message to query.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
List: The list of results from the database.
"""
return self.db.query(message)

@ -59,8 +59,6 @@ class StackOverflowSwarm(BaseMultiAgentStructure):
# Forum for the agents to interact
self.forum = []
def run(self, task: str, *args, **kwargs):
"""
Run the swarm to solve a problem or answer a question like stack overflow
@ -85,9 +83,7 @@ class StackOverflowSwarm(BaseMultiAgentStructure):
**kwargs,
)
# Add to the conversation
self.conversation.add(
agent.ai_name, f"{response}"
)
self.conversation.add(agent.ai_name, f"{response}")
logger.info(f"[{agent.ai_name}]: [{response}]")
return self.conversation.return_history_as_string()

@ -0,0 +1,81 @@
import threading
from abc import ABC, abstractmethod
from swarms.structs.agent import Agent
from swarms.structs.task import Task
def synchronized_queue(method):
"""
Decorator that synchronizes access to the decorated method using a lock.
The lock is acquired before executing the method and released afterwards.
Args:
method: The method to be decorated.
Returns:
The decorated method.
"""
timeout_sec = 5
def wrapper(self, *args, **kwargs):
with self.lock:
self.lock.acquire(timeout=timeout_sec)
try:
return method(self, *args, **kwargs)
except Exception as e:
print(f"Failed to execute {method.__name__}: {e}")
finally:
self.lock.release()
return wrapper
class TaskQueueBase(ABC):
def __init__(self):
self.lock = threading.Lock()
@synchronized_queue
@abstractmethod
def add_task(self, task: Task) -> bool:
"""Adds a task to the queue.
Args:
task (Task): The task to be added to the queue.
Returns:
bool: True if the task was successfully added, False otherwise.
"""
raise NotImplementedError
@synchronized_queue
@abstractmethod
def get_task(self, agent: Agent) -> Task:
"""Gets the next task from the queue.
Args:
agent (Agent): The agent requesting the task.
Returns:
Task: The next task from the queue.
"""
raise NotImplementedError
@synchronized_queue
@abstractmethod
def complete_task(self, task_id: str):
"""Sets the task as completed.
Args:
task_id (str): The ID of the task to be marked as completed.
"""
raise NotImplementedError
@synchronized_queue
@abstractmethod
def reset_task(self, task_id: str):
"""Resets the task if the agent failed to complete it.
Args:
task_id (str): The ID of the task to be reset.
"""
raise NotImplementedError

@ -0,0 +1,151 @@
from unittest.mock import MagicMock
import pytest
from swarms.structs.agent import Agent
from swarms.structs.majority_voting import MajorityVoting
def test_majority_voting_run_concurrent(mocker):
# Create mock agents
agent1 = MagicMock(spec=Agent)
agent2 = MagicMock(spec=Agent)
agent3 = MagicMock(spec=Agent)
# Create mock majority voting
mv = MajorityVoting(
agents=[agent1, agent2, agent3],
concurrent=True,
multithreaded=False,
)
# Create mock conversation
conversation = MagicMock()
mv.conversation = conversation
# Create mock results
results = ["Paris", "Paris", "Lyon"]
# Mock agent.run method
agent1.run.return_value = results[0]
agent2.run.return_value = results[1]
agent3.run.return_value = results[2]
# Run majority voting
majority_vote = mv.run("What is the capital of France?")
# Assert agent.run method was called with the correct task
agent1.run.assert_called_once_with(
"What is the capital of France?"
)
agent2.run.assert_called_once_with(
"What is the capital of France?"
)
agent3.run.assert_called_once_with(
"What is the capital of France?"
)
# Assert conversation.add method was called with the correct responses
conversation.add.assert_any_call(agent1.agent_name, results[0])
conversation.add.assert_any_call(agent2.agent_name, results[1])
conversation.add.assert_any_call(agent3.agent_name, results[2])
# Assert majority vote is correct
assert majority_vote is not None
def test_majority_voting_run_multithreaded(mocker):
# Create mock agents
agent1 = MagicMock(spec=Agent)
agent2 = MagicMock(spec=Agent)
agent3 = MagicMock(spec=Agent)
# Create mock majority voting
mv = MajorityVoting(
agents=[agent1, agent2, agent3],
concurrent=False,
multithreaded=True,
)
# Create mock conversation
conversation = MagicMock()
mv.conversation = conversation
# Create mock results
results = ["Paris", "Paris", "Lyon"]
# Mock agent.run method
agent1.run.return_value = results[0]
agent2.run.return_value = results[1]
agent3.run.return_value = results[2]
# Run majority voting
majority_vote = mv.run("What is the capital of France?")
# Assert agent.run method was called with the correct task
agent1.run.assert_called_once_with(
"What is the capital of France?"
)
agent2.run.assert_called_once_with(
"What is the capital of France?"
)
agent3.run.assert_called_once_with(
"What is the capital of France?"
)
# Assert conversation.add method was called with the correct responses
conversation.add.assert_any_call(agent1.agent_name, results[0])
conversation.add.assert_any_call(agent2.agent_name, results[1])
conversation.add.assert_any_call(agent3.agent_name, results[2])
# Assert majority vote is correct
assert majority_vote is not None
@pytest.mark.asyncio
async def test_majority_voting_run_asynchronous(mocker):
# Create mock agents
agent1 = MagicMock(spec=Agent)
agent2 = MagicMock(spec=Agent)
agent3 = MagicMock(spec=Agent)
# Create mock majority voting
mv = MajorityVoting(
agents=[agent1, agent2, agent3],
concurrent=False,
multithreaded=False,
asynchronous=True,
)
# Create mock conversation
conversation = MagicMock()
mv.conversation = conversation
# Create mock results
results = ["Paris", "Paris", "Lyon"]
# Mock agent.run method
agent1.run.return_value = results[0]
agent2.run.return_value = results[1]
agent3.run.return_value = results[2]
# Run majority voting
majority_vote = await mv.run("What is the capital of France?")
# Assert agent.run method was called with the correct task
agent1.run.assert_called_once_with(
"What is the capital of France?"
)
agent2.run.assert_called_once_with(
"What is the capital of France?"
)
agent3.run.assert_called_once_with(
"What is the capital of France?"
)
# Assert conversation.add method was called with the correct responses
conversation.add.assert_any_call(agent1.agent_name, results[0])
conversation.add.assert_any_call(agent2.agent_name, results[1])
conversation.add.assert_any_call(agent3.agent_name, results[2])
# Assert majority vote is correct
assert majority_vote is not None

@ -0,0 +1,151 @@
from unittest.mock import MagicMock
import pytest
from swarms.structs.agent import Agent
from swarms.structs.majority_voting import MajorityVoting
def test_majority_voting_run_concurrent(mocker):
# Create mock agents
agent1 = MagicMock(spec=Agent)
agent2 = MagicMock(spec=Agent)
agent3 = MagicMock(spec=Agent)
# Create mock majority voting
mv = MajorityVoting(
agents=[agent1, agent2, agent3],
concurrent=True,
multithreaded=False,
)
# Create mock conversation
conversation = MagicMock()
mv.conversation = conversation
# Create mock results
results = ["Paris", "Paris", "Lyon"]
# Mock agent.run method
agent1.run.return_value = results[0]
agent2.run.return_value = results[1]
agent3.run.return_value = results[2]
# Run majority voting
majority_vote = mv.run("What is the capital of France?")
# Assert agent.run method was called with the correct task
agent1.run.assert_called_once_with(
"What is the capital of France?"
)
agent2.run.assert_called_once_with(
"What is the capital of France?"
)
agent3.run.assert_called_once_with(
"What is the capital of France?"
)
# Assert conversation.add method was called with the correct responses
conversation.add.assert_any_call(agent1.agent_name, results[0])
conversation.add.assert_any_call(agent2.agent_name, results[1])
conversation.add.assert_any_call(agent3.agent_name, results[2])
# Assert majority vote is correct
assert majority_vote is not None
def test_majority_voting_run_multithreaded(mocker):
# Create mock agents
agent1 = MagicMock(spec=Agent)
agent2 = MagicMock(spec=Agent)
agent3 = MagicMock(spec=Agent)
# Create mock majority voting
mv = MajorityVoting(
agents=[agent1, agent2, agent3],
concurrent=False,
multithreaded=True,
)
# Create mock conversation
conversation = MagicMock()
mv.conversation = conversation
# Create mock results
results = ["Paris", "Paris", "Lyon"]
# Mock agent.run method
agent1.run.return_value = results[0]
agent2.run.return_value = results[1]
agent3.run.return_value = results[2]
# Run majority voting
majority_vote = mv.run("What is the capital of France?")
# Assert agent.run method was called with the correct task
agent1.run.assert_called_once_with(
"What is the capital of France?"
)
agent2.run.assert_called_once_with(
"What is the capital of France?"
)
agent3.run.assert_called_once_with(
"What is the capital of France?"
)
# Assert conversation.add method was called with the correct responses
conversation.add.assert_any_call(agent1.agent_name, results[0])
conversation.add.assert_any_call(agent2.agent_name, results[1])
conversation.add.assert_any_call(agent3.agent_name, results[2])
# Assert majority vote is correct
assert majority_vote is not None
@pytest.mark.asyncio
async def test_majority_voting_run_asynchronous(mocker):
# Create mock agents
agent1 = MagicMock(spec=Agent)
agent2 = MagicMock(spec=Agent)
agent3 = MagicMock(spec=Agent)
# Create mock majority voting
mv = MajorityVoting(
agents=[agent1, agent2, agent3],
concurrent=False,
multithreaded=False,
asynchronous=True,
)
# Create mock conversation
conversation = MagicMock()
mv.conversation = conversation
# Create mock results
results = ["Paris", "Paris", "Lyon"]
# Mock agent.run method
agent1.run.return_value = results[0]
agent2.run.return_value = results[1]
agent3.run.return_value = results[2]
# Run majority voting
majority_vote = await mv.run("What is the capital of France?")
# Assert agent.run method was called with the correct task
agent1.run.assert_called_once_with(
"What is the capital of France?"
)
agent2.run.assert_called_once_with(
"What is the capital of France?"
)
agent3.run.assert_called_once_with(
"What is the capital of France?"
)
# Assert conversation.add method was called with the correct responses
conversation.add.assert_any_call(agent1.agent_name, results[0])
conversation.add.assert_any_call(agent2.agent_name, results[1])
conversation.add.assert_any_call(agent3.agent_name, results[2])
# Assert majority vote is correct
assert majority_vote is not None
Loading…
Cancel
Save