DictInternalMemory DictSharedMemory LangchainChromaVectorMemory synchronized_queue TaskQueueBase]pull/386/head
parent
76a140508f
commit
f2912babc5
@ -0,0 +1,323 @@
|
|||||||
|
# This file is automatically @generated by Cargo.
|
||||||
|
# It is not intended for manual editing.
|
||||||
|
version = 3
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "autocfg"
|
||||||
|
version = "1.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bitflags"
|
||||||
|
version = "1.3.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cfg-if"
|
||||||
|
version = "1.0.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-deque"
|
||||||
|
version = "0.8.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-epoch",
|
||||||
|
"crossbeam-utils",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-epoch"
|
||||||
|
version = "0.9.18"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-utils",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-utils"
|
||||||
|
version = "0.8.19"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "either"
|
||||||
|
version = "1.10.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "engine"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"log",
|
||||||
|
"pyo3",
|
||||||
|
"rayon",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "indoc"
|
||||||
|
version = "0.3.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "47741a8bc60fb26eb8d6e0238bbb26d8575ff623fdc97b1a2c00c050b9684ed8"
|
||||||
|
dependencies = [
|
||||||
|
"indoc-impl",
|
||||||
|
"proc-macro-hack",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "indoc-impl"
|
||||||
|
version = "0.3.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ce046d161f000fffde5f432a0d034d0341dc152643b2598ed5bfce44c4f3a8f0"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro-hack",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
"unindent",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "instant"
|
||||||
|
version = "0.1.12"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "libc"
|
||||||
|
version = "0.2.153"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "lock_api"
|
||||||
|
version = "0.4.11"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"scopeguard",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "log"
|
||||||
|
version = "0.4.20"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "once_cell"
|
||||||
|
version = "1.19.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "parking_lot"
|
||||||
|
version = "0.11.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99"
|
||||||
|
dependencies = [
|
||||||
|
"instant",
|
||||||
|
"lock_api",
|
||||||
|
"parking_lot_core",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "parking_lot_core"
|
||||||
|
version = "0.8.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "60a2cfe6f0ad2bfc16aefa463b497d5c7a5ecd44a23efa72aa342d90177356dc"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"instant",
|
||||||
|
"libc",
|
||||||
|
"redox_syscall",
|
||||||
|
"smallvec",
|
||||||
|
"winapi",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "paste"
|
||||||
|
version = "0.1.18"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "45ca20c77d80be666aef2b45486da86238fabe33e38306bd3118fe4af33fa880"
|
||||||
|
dependencies = [
|
||||||
|
"paste-impl",
|
||||||
|
"proc-macro-hack",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "paste-impl"
|
||||||
|
version = "0.1.18"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d95a7db200b97ef370c8e6de0088252f7e0dfff7d047a28528e47456c0fc98b6"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro-hack",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "proc-macro-hack"
|
||||||
|
version = "0.5.20+deprecated"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "dc375e1527247fe1a97d8b7156678dfe7c1af2fc075c9a4db3690ecd2a148068"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "proc-macro2"
|
||||||
|
version = "1.0.78"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae"
|
||||||
|
dependencies = [
|
||||||
|
"unicode-ident",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pyo3"
|
||||||
|
version = "0.15.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d41d50a7271e08c7c8a54cd24af5d62f73ee3a6f6a314215281ebdec421d5752"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"indoc",
|
||||||
|
"libc",
|
||||||
|
"parking_lot",
|
||||||
|
"paste",
|
||||||
|
"pyo3-build-config",
|
||||||
|
"pyo3-macros",
|
||||||
|
"unindent",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pyo3-build-config"
|
||||||
|
version = "0.15.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "779239fc40b8e18bc8416d3a37d280ca9b9fb04bda54b98037bb6748595c2410"
|
||||||
|
dependencies = [
|
||||||
|
"once_cell",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pyo3-macros"
|
||||||
|
version = "0.15.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "00b247e8c664be87998d8628e86f282c25066165f1f8dda66100c48202fdb93a"
|
||||||
|
dependencies = [
|
||||||
|
"pyo3-macros-backend",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pyo3-macros-backend"
|
||||||
|
version = "0.15.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5a8c2812c412e00e641d99eeb79dd478317d981d938aa60325dfa7157b607095"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"pyo3-build-config",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "quote"
|
||||||
|
version = "1.0.35"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rayon"
|
||||||
|
version = "1.8.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "fa7237101a77a10773db45d62004a272517633fbcc3df19d96455ede1122e051"
|
||||||
|
dependencies = [
|
||||||
|
"either",
|
||||||
|
"rayon-core",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rayon-core"
|
||||||
|
version = "1.12.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-deque",
|
||||||
|
"crossbeam-utils",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "redox_syscall"
|
||||||
|
version = "0.2.16"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "scopeguard"
|
||||||
|
version = "1.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "smallvec"
|
||||||
|
version = "1.13.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "syn"
|
||||||
|
version = "1.0.109"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"unicode-ident",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unicode-ident"
|
||||||
|
version = "1.0.12"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unindent"
|
||||||
|
version = "0.1.11"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "winapi"
|
||||||
|
version = "0.3.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
|
||||||
|
dependencies = [
|
||||||
|
"winapi-i686-pc-windows-gnu",
|
||||||
|
"winapi-x86_64-pc-windows-gnu",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "winapi-i686-pc-windows-gnu"
|
||||||
|
version = "0.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "winapi-x86_64-pc-windows-gnu"
|
||||||
|
version = "0.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
|
@ -0,0 +1,86 @@
|
|||||||
|
import uuid
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class InternalMemoryBase(ABC):
|
||||||
|
"""Abstract base class for internal memory of agents in the swarm."""
|
||||||
|
|
||||||
|
def __init__(self, n_entries):
|
||||||
|
"""Initialize the internal memory. In the current architecture the memory always consists of a set of soltuions or evaluations.
|
||||||
|
During the operation, the agent should retrivie best solutions from it's internal memory based on the score.
|
||||||
|
|
||||||
|
Moreover, the project is designed around LLMs for the proof of concepts, so we treat all entry content as a string.
|
||||||
|
"""
|
||||||
|
self.n_entries = n_entries
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add(self, score, entry):
|
||||||
|
"""Add an entry to the internal memory."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_top_n(self, n):
|
||||||
|
"""Get the top n entries from the internal memory."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class DictInternalMemory(InternalMemoryBase):
|
||||||
|
def __init__(self, n_entries: int):
|
||||||
|
"""
|
||||||
|
Initialize the internal memory. In the current architecture the memory always consists of a set of solutions or evaluations.
|
||||||
|
Simple key-value store for now.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_entries (int): The maximum number of entries to keep in the internal memory.
|
||||||
|
"""
|
||||||
|
super().__init__(n_entries)
|
||||||
|
self.data: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
def add(self, score: float, content: Any) -> None:
|
||||||
|
"""
|
||||||
|
Add an entry to the internal memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
score (float): The score or fitness value associated with the entry.
|
||||||
|
content (Any): The content of the entry.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
random_key: str = str(uuid.uuid4())
|
||||||
|
self.data[random_key] = {"score": score, "content": content}
|
||||||
|
|
||||||
|
# keep only the best n entries
|
||||||
|
sorted_data: List[Tuple[str, Dict[str, Any]]] = sorted(
|
||||||
|
self.data.items(),
|
||||||
|
key=lambda x: x[1]["score"],
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
self.data = dict(sorted_data[: self.n_entries])
|
||||||
|
|
||||||
|
def get_top_n(self, n: int) -> List[Tuple[str, Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Get the top n entries from the internal memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n (int): The number of top entries to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Tuple[str, Dict[str, Any]]]: A list of tuples containing the random keys and corresponding entry data.
|
||||||
|
"""
|
||||||
|
sorted_data: List[Tuple[str, Dict[str, Any]]] = sorted(
|
||||||
|
self.data.items(),
|
||||||
|
key=lambda x: x[1]["score"],
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
return sorted_data[:n]
|
||||||
|
|
||||||
|
def len(self) -> int:
|
||||||
|
"""
|
||||||
|
Get the number of entries in the internal memory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The number of entries in the internal memory.
|
||||||
|
"""
|
||||||
|
return len(self.data)
|
@ -0,0 +1,98 @@
|
|||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
class DictSharedMemory:
|
||||||
|
"""A class representing a shared memory that stores entries as a dictionary.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
file_loc (Path): The file location where the memory is stored.
|
||||||
|
lock (threading.Lock): A lock used for thread synchronization.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
__init__(self, file_loc: str = None) -> None: Initializes the shared memory.
|
||||||
|
add_entry(self, score: float, agent_id: str, agent_cycle: int, entry: Any) -> bool: Adds an entry to the internal memory.
|
||||||
|
get_top_n(self, n: int) -> None: Gets the top n entries from the internal memory.
|
||||||
|
write_to_file(self, data: Dict[str, Dict[str, Any]]) -> bool: Writes the internal memory to a file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, file_loc: str = None) -> None:
|
||||||
|
"""Initialize the shared memory. In the current architecture the memory always consists of a set of soltuions or evaluations.
|
||||||
|
Moreover, the project is designed around LLMs for the proof of concepts, so we treat all entry content as a string.
|
||||||
|
"""
|
||||||
|
if file_loc is not None:
|
||||||
|
self.file_loc = Path(file_loc)
|
||||||
|
if not self.file_loc.exists():
|
||||||
|
self.file_loc.touch()
|
||||||
|
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
|
def add(
|
||||||
|
self,
|
||||||
|
score: float,
|
||||||
|
agent_id: str,
|
||||||
|
agent_cycle: int,
|
||||||
|
entry: Any,
|
||||||
|
) -> bool:
|
||||||
|
"""Add an entry to the internal memory."""
|
||||||
|
with self.lock:
|
||||||
|
entry_id = str(uuid.uuid4())
|
||||||
|
data = {}
|
||||||
|
epoch = datetime.datetime.utcfromtimestamp(0)
|
||||||
|
epoch = (
|
||||||
|
datetime.datetime.utcnow() - epoch
|
||||||
|
).total_seconds()
|
||||||
|
data[entry_id] = {
|
||||||
|
"agent": agent_id,
|
||||||
|
"epoch": epoch,
|
||||||
|
"score": score,
|
||||||
|
"cycle": agent_cycle,
|
||||||
|
"content": entry,
|
||||||
|
}
|
||||||
|
status = self.write_to_file(data)
|
||||||
|
self.plot_performance()
|
||||||
|
return status
|
||||||
|
|
||||||
|
def get_top_n(self, n: int) -> None:
|
||||||
|
"""Get the top n entries from the internal memory."""
|
||||||
|
with self.lock:
|
||||||
|
with open(self.file_loc, "r") as f:
|
||||||
|
try:
|
||||||
|
file_data = json.load(f)
|
||||||
|
except Exception as e:
|
||||||
|
file_data = {}
|
||||||
|
raise e
|
||||||
|
|
||||||
|
sorted_data = dict(
|
||||||
|
sorted(
|
||||||
|
file_data.items(),
|
||||||
|
key=lambda item: item[1]["score"],
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
top_n = dict(list(sorted_data.items())[:n])
|
||||||
|
return top_n
|
||||||
|
|
||||||
|
def write_to_file(self, data: Dict[str, Dict[str, Any]]) -> bool:
|
||||||
|
"""Write the internal memory to a file."""
|
||||||
|
if self.file_loc is not None:
|
||||||
|
with open(self.file_loc, "r") as f:
|
||||||
|
try:
|
||||||
|
file_data = json.load(f)
|
||||||
|
except Exception as e:
|
||||||
|
file_data = {}
|
||||||
|
raise e
|
||||||
|
|
||||||
|
file_data = file_data | data
|
||||||
|
with open(self.file_loc, "w") as f:
|
||||||
|
json.dump(file_data, f, indent=4)
|
||||||
|
|
||||||
|
f.flush()
|
||||||
|
os.fsync(f.fileno())
|
||||||
|
|
||||||
|
return True
|
@ -0,0 +1,194 @@
|
|||||||
|
import threading
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from langchain.chains import RetrievalQA
|
||||||
|
from langchain.chains.question_answering import load_qa_chain
|
||||||
|
from swarms.models.openai_models import OpenAIChat
|
||||||
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||||
|
from langchain.text_splitter import CharacterTextSplitter
|
||||||
|
from langchain.vectorstores import Chroma
|
||||||
|
|
||||||
|
|
||||||
|
def synchronized_mem(method):
|
||||||
|
"""
|
||||||
|
Decorator that synchronizes access to a method using a lock.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: The method to be decorated.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The decorated method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrapper(self, *args, **kwargs):
|
||||||
|
with self.lock:
|
||||||
|
try:
|
||||||
|
return method(self, *args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to execute {method.__name__}: {e}")
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class LangchainChromaVectorMemory:
|
||||||
|
"""
|
||||||
|
A class representing a vector memory for storing and retrieving text entries.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
loc (str): The location of the vector memory.
|
||||||
|
chunk_size (int): The size of each text chunk.
|
||||||
|
chunk_overlap_frac (float): The fraction of overlap between text chunks.
|
||||||
|
embeddings (OpenAIEmbeddings): The embeddings used for text representation.
|
||||||
|
count (int): The current count of text entries in the vector memory.
|
||||||
|
lock (threading.Lock): A lock for thread safety.
|
||||||
|
db (Chroma): The Chroma database for storing text entries.
|
||||||
|
qa (RetrievalQA): The retrieval QA system for answering questions.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
__init__: Initializes the VectorMemory object.
|
||||||
|
_init_db: Initializes the Chroma database.
|
||||||
|
_init_retriever: Initializes the retrieval QA system.
|
||||||
|
add_entry: Adds an entry to the vector memory.
|
||||||
|
search_memory: Searches the vector memory for similar entries.
|
||||||
|
ask_question: Asks a question to the vector memory.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
loc=None,
|
||||||
|
chunk_size: int = 1000,
|
||||||
|
chunk_overlap_frac: float = 0.1,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initializes the VectorMemory object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loc (str): The location of the vector memory. If None, defaults to "./tmp/vector_memory".
|
||||||
|
chunk_size (int): The size of each text chunk.
|
||||||
|
chunk_overlap_frac (float): The fraction of overlap between text chunks.
|
||||||
|
"""
|
||||||
|
if loc is None:
|
||||||
|
loc = "./tmp/vector_memory"
|
||||||
|
self.loc = Path(loc)
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
self.chunk_overlap = chunk_size * chunk_overlap_frac
|
||||||
|
self.embeddings = OpenAIEmbeddings()
|
||||||
|
self.count = 0
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
|
self.db = self._init_db()
|
||||||
|
self.qa = self._init_retriever()
|
||||||
|
|
||||||
|
def _init_db(self):
|
||||||
|
"""
|
||||||
|
Initializes the Chroma database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Chroma: The initialized Chroma database.
|
||||||
|
"""
|
||||||
|
texts = [
|
||||||
|
"init"
|
||||||
|
] # TODO find how to initialize Chroma without any text
|
||||||
|
chroma_db = Chroma.from_texts(
|
||||||
|
texts=texts,
|
||||||
|
embedding=self.embeddings,
|
||||||
|
persist_directory=str(self.loc),
|
||||||
|
)
|
||||||
|
self.count = chroma_db._collection.count()
|
||||||
|
return chroma_db
|
||||||
|
|
||||||
|
def _init_retriever(self):
|
||||||
|
"""
|
||||||
|
Initializes the retrieval QA system.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RetrievalQA: The initialized retrieval QA system.
|
||||||
|
"""
|
||||||
|
model = OpenAIChat(
|
||||||
|
model_name="gpt-3.5-turbo",
|
||||||
|
)
|
||||||
|
qa_chain = load_qa_chain(model, chain_type="stuff")
|
||||||
|
retriever = self.db.as_retriever(
|
||||||
|
search_type="mmr", search_kwargs={"k": 10}
|
||||||
|
)
|
||||||
|
qa = RetrievalQA(
|
||||||
|
combine_documents_chain=qa_chain, retriever=retriever
|
||||||
|
)
|
||||||
|
return qa
|
||||||
|
|
||||||
|
@synchronized_mem
|
||||||
|
def add(self, entry: str):
|
||||||
|
"""
|
||||||
|
Add an entry to the internal memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entry (str): The entry to be added.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the entry was successfully added, False otherwise.
|
||||||
|
"""
|
||||||
|
text_splitter = CharacterTextSplitter(
|
||||||
|
chunk_size=self.chunk_size,
|
||||||
|
chunk_overlap=self.chunk_overlap,
|
||||||
|
separator=" ",
|
||||||
|
)
|
||||||
|
texts = text_splitter.split_text(entry)
|
||||||
|
|
||||||
|
self.db.add_texts(texts)
|
||||||
|
self.count += self.db._collection.count()
|
||||||
|
self.db.persist()
|
||||||
|
return True
|
||||||
|
|
||||||
|
@synchronized_mem
|
||||||
|
def search_memory(
|
||||||
|
self, query: str, k=10, type="mmr", distance_threshold=0.5
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Searching the vector memory for similar entries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The query to search for.
|
||||||
|
k (int): The number of results to return.
|
||||||
|
type (str): The type of search to perform: "cos" or "mmr".
|
||||||
|
distance_threshold (float): The similarity threshold to use for the search. Results with distance > similarity_threshold will be dropped.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[str]: A list of the top k results.
|
||||||
|
"""
|
||||||
|
self.count = self.db._collection.count()
|
||||||
|
if k > self.count:
|
||||||
|
k = self.count - 1
|
||||||
|
if k <= 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if type == "mmr":
|
||||||
|
texts = self.db.max_marginal_relevance_search(
|
||||||
|
query=query, k=k, fetch_k=min(20, self.count)
|
||||||
|
)
|
||||||
|
texts = [text.page_content for text in texts]
|
||||||
|
elif type == "cos":
|
||||||
|
texts = self.db.similarity_search_with_score(
|
||||||
|
query=query, k=k
|
||||||
|
)
|
||||||
|
texts = [
|
||||||
|
text[0].page_content
|
||||||
|
for text in texts
|
||||||
|
if text[-1] < distance_threshold
|
||||||
|
]
|
||||||
|
|
||||||
|
return texts
|
||||||
|
|
||||||
|
@synchronized_mem
|
||||||
|
def query(self, question: str):
|
||||||
|
"""
|
||||||
|
Ask a question to the vector memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
question (str): The question to ask.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The answer to the question.
|
||||||
|
"""
|
||||||
|
answer = self.qa.run(question)
|
||||||
|
return answer
|
@ -1,85 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from swarms.memory.base_vectordatabase import AbstractVectorDatabase
|
|
||||||
from swarms.structs.agent import Agent
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MultiAgentRag:
|
|
||||||
"""
|
|
||||||
Represents a multi-agent RAG (Relational Agent Graph) structure.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
agents (List[Agent]): List of agents in the multi-agent RAG.
|
|
||||||
db (AbstractVectorDatabase): Database used for querying.
|
|
||||||
verbose (bool): Flag indicating whether to print verbose output.
|
|
||||||
"""
|
|
||||||
|
|
||||||
agents: List[Agent]
|
|
||||||
db: AbstractVectorDatabase
|
|
||||||
verbose: bool = False
|
|
||||||
|
|
||||||
def query_database(self, query: str):
|
|
||||||
"""
|
|
||||||
Queries the database using the given query string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query (str): The query string.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List: The list of results from the database.
|
|
||||||
"""
|
|
||||||
results = []
|
|
||||||
for agent in self.agents:
|
|
||||||
agent_results = agent.long_term_memory_prompt(query)
|
|
||||||
results.extend(agent_results)
|
|
||||||
return results
|
|
||||||
|
|
||||||
def get_agent_by_id(self, agent_id) -> Optional[Agent]:
|
|
||||||
"""
|
|
||||||
Retrieves an agent from the multi-agent RAG by its ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_id: The ID of the agent to retrieve.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Agent or None: The agent with the specified ID, or None if not found.
|
|
||||||
"""
|
|
||||||
for agent in self.agents:
|
|
||||||
if agent.agent_id == agent_id:
|
|
||||||
return agent
|
|
||||||
return None
|
|
||||||
|
|
||||||
def add_message(
|
|
||||||
self, sender: Agent, message: str, *args, **kwargs
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Adds a message to the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sender (Agent): The agent sending the message.
|
|
||||||
message (str): The message to add.
|
|
||||||
*args: Additional positional arguments.
|
|
||||||
**kwargs: Additional keyword arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: The ID of the added message.
|
|
||||||
"""
|
|
||||||
doc = f"{sender.ai_name}: {message}"
|
|
||||||
|
|
||||||
return self.db.add(doc)
|
|
||||||
|
|
||||||
def query(self, message: str, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Queries the database using the given message.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message (str): The message to query.
|
|
||||||
*args: Additional positional arguments.
|
|
||||||
**kwargs: Additional keyword arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List: The list of results from the database.
|
|
||||||
"""
|
|
||||||
return self.db.query(message)
|
|
@ -0,0 +1,81 @@
|
|||||||
|
import threading
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from swarms.structs.agent import Agent
|
||||||
|
from swarms.structs.task import Task
|
||||||
|
|
||||||
|
|
||||||
|
def synchronized_queue(method):
|
||||||
|
"""
|
||||||
|
Decorator that synchronizes access to the decorated method using a lock.
|
||||||
|
The lock is acquired before executing the method and released afterwards.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: The method to be decorated.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The decorated method.
|
||||||
|
"""
|
||||||
|
timeout_sec = 5
|
||||||
|
|
||||||
|
def wrapper(self, *args, **kwargs):
|
||||||
|
with self.lock:
|
||||||
|
self.lock.acquire(timeout=timeout_sec)
|
||||||
|
try:
|
||||||
|
return method(self, *args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to execute {method.__name__}: {e}")
|
||||||
|
finally:
|
||||||
|
self.lock.release()
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class TaskQueueBase(ABC):
|
||||||
|
def __init__(self):
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
|
@synchronized_queue
|
||||||
|
@abstractmethod
|
||||||
|
def add_task(self, task: Task) -> bool:
|
||||||
|
"""Adds a task to the queue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (Task): The task to be added to the queue.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the task was successfully added, False otherwise.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@synchronized_queue
|
||||||
|
@abstractmethod
|
||||||
|
def get_task(self, agent: Agent) -> Task:
|
||||||
|
"""Gets the next task from the queue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent (Agent): The agent requesting the task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Task: The next task from the queue.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@synchronized_queue
|
||||||
|
@abstractmethod
|
||||||
|
def complete_task(self, task_id: str):
|
||||||
|
"""Sets the task as completed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id (str): The ID of the task to be marked as completed.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@synchronized_queue
|
||||||
|
@abstractmethod
|
||||||
|
def reset_task(self, task_id: str):
|
||||||
|
"""Resets the task if the agent failed to complete it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id (str): The ID of the task to be reset.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
@ -0,0 +1,151 @@
|
|||||||
|
from unittest.mock import MagicMock
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from swarms.structs.agent import Agent
|
||||||
|
from swarms.structs.majority_voting import MajorityVoting
|
||||||
|
|
||||||
|
|
||||||
|
def test_majority_voting_run_concurrent(mocker):
|
||||||
|
# Create mock agents
|
||||||
|
agent1 = MagicMock(spec=Agent)
|
||||||
|
agent2 = MagicMock(spec=Agent)
|
||||||
|
agent3 = MagicMock(spec=Agent)
|
||||||
|
|
||||||
|
# Create mock majority voting
|
||||||
|
mv = MajorityVoting(
|
||||||
|
agents=[agent1, agent2, agent3],
|
||||||
|
concurrent=True,
|
||||||
|
multithreaded=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create mock conversation
|
||||||
|
conversation = MagicMock()
|
||||||
|
mv.conversation = conversation
|
||||||
|
|
||||||
|
# Create mock results
|
||||||
|
results = ["Paris", "Paris", "Lyon"]
|
||||||
|
|
||||||
|
# Mock agent.run method
|
||||||
|
agent1.run.return_value = results[0]
|
||||||
|
agent2.run.return_value = results[1]
|
||||||
|
agent3.run.return_value = results[2]
|
||||||
|
|
||||||
|
# Run majority voting
|
||||||
|
majority_vote = mv.run("What is the capital of France?")
|
||||||
|
|
||||||
|
# Assert agent.run method was called with the correct task
|
||||||
|
agent1.run.assert_called_once_with(
|
||||||
|
"What is the capital of France?"
|
||||||
|
)
|
||||||
|
agent2.run.assert_called_once_with(
|
||||||
|
"What is the capital of France?"
|
||||||
|
)
|
||||||
|
agent3.run.assert_called_once_with(
|
||||||
|
"What is the capital of France?"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert conversation.add method was called with the correct responses
|
||||||
|
conversation.add.assert_any_call(agent1.agent_name, results[0])
|
||||||
|
conversation.add.assert_any_call(agent2.agent_name, results[1])
|
||||||
|
conversation.add.assert_any_call(agent3.agent_name, results[2])
|
||||||
|
|
||||||
|
# Assert majority vote is correct
|
||||||
|
assert majority_vote is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_majority_voting_run_multithreaded(mocker):
|
||||||
|
# Create mock agents
|
||||||
|
agent1 = MagicMock(spec=Agent)
|
||||||
|
agent2 = MagicMock(spec=Agent)
|
||||||
|
agent3 = MagicMock(spec=Agent)
|
||||||
|
|
||||||
|
# Create mock majority voting
|
||||||
|
mv = MajorityVoting(
|
||||||
|
agents=[agent1, agent2, agent3],
|
||||||
|
concurrent=False,
|
||||||
|
multithreaded=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create mock conversation
|
||||||
|
conversation = MagicMock()
|
||||||
|
mv.conversation = conversation
|
||||||
|
|
||||||
|
# Create mock results
|
||||||
|
results = ["Paris", "Paris", "Lyon"]
|
||||||
|
|
||||||
|
# Mock agent.run method
|
||||||
|
agent1.run.return_value = results[0]
|
||||||
|
agent2.run.return_value = results[1]
|
||||||
|
agent3.run.return_value = results[2]
|
||||||
|
|
||||||
|
# Run majority voting
|
||||||
|
majority_vote = mv.run("What is the capital of France?")
|
||||||
|
|
||||||
|
# Assert agent.run method was called with the correct task
|
||||||
|
agent1.run.assert_called_once_with(
|
||||||
|
"What is the capital of France?"
|
||||||
|
)
|
||||||
|
agent2.run.assert_called_once_with(
|
||||||
|
"What is the capital of France?"
|
||||||
|
)
|
||||||
|
agent3.run.assert_called_once_with(
|
||||||
|
"What is the capital of France?"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert conversation.add method was called with the correct responses
|
||||||
|
conversation.add.assert_any_call(agent1.agent_name, results[0])
|
||||||
|
conversation.add.assert_any_call(agent2.agent_name, results[1])
|
||||||
|
conversation.add.assert_any_call(agent3.agent_name, results[2])
|
||||||
|
|
||||||
|
# Assert majority vote is correct
|
||||||
|
assert majority_vote is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_majority_voting_run_asynchronous(mocker):
|
||||||
|
# Create mock agents
|
||||||
|
agent1 = MagicMock(spec=Agent)
|
||||||
|
agent2 = MagicMock(spec=Agent)
|
||||||
|
agent3 = MagicMock(spec=Agent)
|
||||||
|
|
||||||
|
# Create mock majority voting
|
||||||
|
mv = MajorityVoting(
|
||||||
|
agents=[agent1, agent2, agent3],
|
||||||
|
concurrent=False,
|
||||||
|
multithreaded=False,
|
||||||
|
asynchronous=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create mock conversation
|
||||||
|
conversation = MagicMock()
|
||||||
|
mv.conversation = conversation
|
||||||
|
|
||||||
|
# Create mock results
|
||||||
|
results = ["Paris", "Paris", "Lyon"]
|
||||||
|
|
||||||
|
# Mock agent.run method
|
||||||
|
agent1.run.return_value = results[0]
|
||||||
|
agent2.run.return_value = results[1]
|
||||||
|
agent3.run.return_value = results[2]
|
||||||
|
|
||||||
|
# Run majority voting
|
||||||
|
majority_vote = await mv.run("What is the capital of France?")
|
||||||
|
|
||||||
|
# Assert agent.run method was called with the correct task
|
||||||
|
agent1.run.assert_called_once_with(
|
||||||
|
"What is the capital of France?"
|
||||||
|
)
|
||||||
|
agent2.run.assert_called_once_with(
|
||||||
|
"What is the capital of France?"
|
||||||
|
)
|
||||||
|
agent3.run.assert_called_once_with(
|
||||||
|
"What is the capital of France?"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert conversation.add method was called with the correct responses
|
||||||
|
conversation.add.assert_any_call(agent1.agent_name, results[0])
|
||||||
|
conversation.add.assert_any_call(agent2.agent_name, results[1])
|
||||||
|
conversation.add.assert_any_call(agent3.agent_name, results[2])
|
||||||
|
|
||||||
|
# Assert majority vote is correct
|
||||||
|
assert majority_vote is not None
|
@ -0,0 +1,151 @@
|
|||||||
|
from unittest.mock import MagicMock
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from swarms.structs.agent import Agent
|
||||||
|
from swarms.structs.majority_voting import MajorityVoting
|
||||||
|
|
||||||
|
|
||||||
|
def test_majority_voting_run_concurrent(mocker):
|
||||||
|
# Create mock agents
|
||||||
|
agent1 = MagicMock(spec=Agent)
|
||||||
|
agent2 = MagicMock(spec=Agent)
|
||||||
|
agent3 = MagicMock(spec=Agent)
|
||||||
|
|
||||||
|
# Create mock majority voting
|
||||||
|
mv = MajorityVoting(
|
||||||
|
agents=[agent1, agent2, agent3],
|
||||||
|
concurrent=True,
|
||||||
|
multithreaded=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create mock conversation
|
||||||
|
conversation = MagicMock()
|
||||||
|
mv.conversation = conversation
|
||||||
|
|
||||||
|
# Create mock results
|
||||||
|
results = ["Paris", "Paris", "Lyon"]
|
||||||
|
|
||||||
|
# Mock agent.run method
|
||||||
|
agent1.run.return_value = results[0]
|
||||||
|
agent2.run.return_value = results[1]
|
||||||
|
agent3.run.return_value = results[2]
|
||||||
|
|
||||||
|
# Run majority voting
|
||||||
|
majority_vote = mv.run("What is the capital of France?")
|
||||||
|
|
||||||
|
# Assert agent.run method was called with the correct task
|
||||||
|
agent1.run.assert_called_once_with(
|
||||||
|
"What is the capital of France?"
|
||||||
|
)
|
||||||
|
agent2.run.assert_called_once_with(
|
||||||
|
"What is the capital of France?"
|
||||||
|
)
|
||||||
|
agent3.run.assert_called_once_with(
|
||||||
|
"What is the capital of France?"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert conversation.add method was called with the correct responses
|
||||||
|
conversation.add.assert_any_call(agent1.agent_name, results[0])
|
||||||
|
conversation.add.assert_any_call(agent2.agent_name, results[1])
|
||||||
|
conversation.add.assert_any_call(agent3.agent_name, results[2])
|
||||||
|
|
||||||
|
# Assert majority vote is correct
|
||||||
|
assert majority_vote is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_majority_voting_run_multithreaded(mocker):
|
||||||
|
# Create mock agents
|
||||||
|
agent1 = MagicMock(spec=Agent)
|
||||||
|
agent2 = MagicMock(spec=Agent)
|
||||||
|
agent3 = MagicMock(spec=Agent)
|
||||||
|
|
||||||
|
# Create mock majority voting
|
||||||
|
mv = MajorityVoting(
|
||||||
|
agents=[agent1, agent2, agent3],
|
||||||
|
concurrent=False,
|
||||||
|
multithreaded=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create mock conversation
|
||||||
|
conversation = MagicMock()
|
||||||
|
mv.conversation = conversation
|
||||||
|
|
||||||
|
# Create mock results
|
||||||
|
results = ["Paris", "Paris", "Lyon"]
|
||||||
|
|
||||||
|
# Mock agent.run method
|
||||||
|
agent1.run.return_value = results[0]
|
||||||
|
agent2.run.return_value = results[1]
|
||||||
|
agent3.run.return_value = results[2]
|
||||||
|
|
||||||
|
# Run majority voting
|
||||||
|
majority_vote = mv.run("What is the capital of France?")
|
||||||
|
|
||||||
|
# Assert agent.run method was called with the correct task
|
||||||
|
agent1.run.assert_called_once_with(
|
||||||
|
"What is the capital of France?"
|
||||||
|
)
|
||||||
|
agent2.run.assert_called_once_with(
|
||||||
|
"What is the capital of France?"
|
||||||
|
)
|
||||||
|
agent3.run.assert_called_once_with(
|
||||||
|
"What is the capital of France?"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert conversation.add method was called with the correct responses
|
||||||
|
conversation.add.assert_any_call(agent1.agent_name, results[0])
|
||||||
|
conversation.add.assert_any_call(agent2.agent_name, results[1])
|
||||||
|
conversation.add.assert_any_call(agent3.agent_name, results[2])
|
||||||
|
|
||||||
|
# Assert majority vote is correct
|
||||||
|
assert majority_vote is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_majority_voting_run_asynchronous(mocker):
|
||||||
|
# Create mock agents
|
||||||
|
agent1 = MagicMock(spec=Agent)
|
||||||
|
agent2 = MagicMock(spec=Agent)
|
||||||
|
agent3 = MagicMock(spec=Agent)
|
||||||
|
|
||||||
|
# Create mock majority voting
|
||||||
|
mv = MajorityVoting(
|
||||||
|
agents=[agent1, agent2, agent3],
|
||||||
|
concurrent=False,
|
||||||
|
multithreaded=False,
|
||||||
|
asynchronous=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create mock conversation
|
||||||
|
conversation = MagicMock()
|
||||||
|
mv.conversation = conversation
|
||||||
|
|
||||||
|
# Create mock results
|
||||||
|
results = ["Paris", "Paris", "Lyon"]
|
||||||
|
|
||||||
|
# Mock agent.run method
|
||||||
|
agent1.run.return_value = results[0]
|
||||||
|
agent2.run.return_value = results[1]
|
||||||
|
agent3.run.return_value = results[2]
|
||||||
|
|
||||||
|
# Run majority voting
|
||||||
|
majority_vote = await mv.run("What is the capital of France?")
|
||||||
|
|
||||||
|
# Assert agent.run method was called with the correct task
|
||||||
|
agent1.run.assert_called_once_with(
|
||||||
|
"What is the capital of France?"
|
||||||
|
)
|
||||||
|
agent2.run.assert_called_once_with(
|
||||||
|
"What is the capital of France?"
|
||||||
|
)
|
||||||
|
agent3.run.assert_called_once_with(
|
||||||
|
"What is the capital of France?"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert conversation.add method was called with the correct responses
|
||||||
|
conversation.add.assert_any_call(agent1.agent_name, results[0])
|
||||||
|
conversation.add.assert_any_call(agent2.agent_name, results[1])
|
||||||
|
conversation.add.assert_any_call(agent3.agent_name, results[2])
|
||||||
|
|
||||||
|
# Assert majority vote is correct
|
||||||
|
assert majority_vote is not None
|
Loading…
Reference in new issue